aboutsummaryrefslogtreecommitdiffhomepage
path: root/tensorflow/python
diff options
context:
space:
mode:
Diffstat (limited to 'tensorflow/python')
-rw-r--r--tensorflow/python/BUILD62
-rw-r--r--tensorflow/python/autograph/CONTRIBUTING.md2
-rw-r--r--tensorflow/python/autograph/converters/BUILD6
-rw-r--r--tensorflow/python/autograph/converters/function_scopes.py (renamed from tensorflow/python/autograph/converters/name_scopes.py)32
-rw-r--r--tensorflow/python/autograph/converters/function_scopes_test.py (renamed from tensorflow/python/autograph/converters/name_scopes_test.py)40
-rw-r--r--tensorflow/python/autograph/converters/return_statements.py14
-rw-r--r--tensorflow/python/autograph/converters/return_statements_test.py12
-rw-r--r--tensorflow/python/autograph/core/BUILD51
-rw-r--r--tensorflow/python/autograph/core/converter.py53
-rw-r--r--tensorflow/python/autograph/core/converter_test.py124
-rw-r--r--tensorflow/python/autograph/core/converter_testing.py2
-rw-r--r--tensorflow/python/autograph/core/function_wrapping.py30
-rw-r--r--tensorflow/python/autograph/core/function_wrapping_test.py34
-rw-r--r--tensorflow/python/autograph/impl/conversion.py6
-rw-r--r--tensorflow/python/autograph/lang/special_functions.py24
-rw-r--r--tensorflow/python/autograph/lang/special_functions_test.py37
-rw-r--r--tensorflow/python/autograph/operators/data_structures.py17
-rw-r--r--tensorflow/python/autograph/operators/data_structures_test.py31
-rw-r--r--tensorflow/python/autograph/operators/py_builtins.py1
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils.py38
-rw-r--r--tensorflow/python/autograph/pyct/inspect_utils_test.py19
-rw-r--r--tensorflow/python/autograph/pyct/static_analysis/live_values.py4
-rw-r--r--tensorflow/python/autograph/pyct/templates.py2
-rw-r--r--tensorflow/python/autograph/pyct/templates_test.py12
-rw-r--r--tensorflow/python/client/session_test.py18
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/BUILD1
-rw-r--r--tensorflow/python/data/__init__.py1
-rw-r--r--tensorflow/python/data/experimental/BUILD16
-rw-r--r--tensorflow/python/data/experimental/__init__.py109
-rw-r--r--tensorflow/python/data/experimental/benchmarks/BUILD25
-rw-r--r--tensorflow/python/data/experimental/benchmarks/map_benchmark.py245
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/BUILD687
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py322
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py533
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/counter_test.py51
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py632
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py124
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py148
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py56
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py76
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/function_buffering_resource_test.py247
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py72
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py199
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py367
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py115
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py79
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py239
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py660
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py243
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py368
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py293
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/BUILD207
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py65
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py106
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py59
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py232
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py87
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py231
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py193
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py60
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py104
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py91
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py811
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py850
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py234
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py353
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py182
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py71
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/scan_test.py172
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/BUILD719
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py83
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py253
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py125
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py49
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py73
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py95
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py692
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py71
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py45
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py122
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py61
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py57
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py46
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py83
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py88
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py140
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py39
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py66
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py101
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py139
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py50
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py39
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py118
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py46
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py40
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py129
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py85
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py39
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py148
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py53
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py106
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py53
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py99
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py51
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py40
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py54
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py115
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py590
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/sql_dataset_test_base.py94
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py322
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py71
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py118
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/unbatch_test.py300
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/unique_test.py83
-rw-r--r--tensorflow/python/data/experimental/ops/BUILD377
-rw-r--r--tensorflow/python/data/experimental/ops/batching.py669
-rw-r--r--tensorflow/python/data/experimental/ops/counter.py55
-rw-r--r--tensorflow/python/data/experimental/ops/enumerate_ops.py60
-rw-r--r--tensorflow/python/data/experimental/ops/error_ops.py78
-rw-r--r--tensorflow/python/data/experimental/ops/get_single_element.py72
-rw-r--r--tensorflow/python/data/experimental/ops/grouping.py551
-rw-r--r--tensorflow/python/data/experimental/ops/indexed_dataset_ops.py177
-rw-r--r--tensorflow/python/data/experimental/ops/interleave_ops.py262
-rw-r--r--tensorflow/python/data/experimental/ops/iterator_ops.py268
-rw-r--r--tensorflow/python/data/experimental/ops/map_defun.py58
-rw-r--r--tensorflow/python/data/experimental/ops/optimization.py114
-rw-r--r--tensorflow/python/data/experimental/ops/parsing_ops.py152
-rw-r--r--tensorflow/python/data/experimental/ops/prefetching_ops.py531
-rw-r--r--tensorflow/python/data/experimental/ops/random_ops.py54
-rw-r--r--tensorflow/python/data/experimental/ops/readers.py904
-rw-r--r--tensorflow/python/data/experimental/ops/resampling.py296
-rw-r--r--tensorflow/python/data/experimental/ops/scan_ops.py177
-rw-r--r--tensorflow/python/data/experimental/ops/shuffle_ops.py102
-rw-r--r--tensorflow/python/data/experimental/ops/stats_ops.py214
-rw-r--r--tensorflow/python/data/experimental/ops/threadpool.py104
-rw-r--r--tensorflow/python/data/experimental/ops/unique.py79
-rw-r--r--tensorflow/python/data/experimental/ops/writers.py60
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD327
-rw-r--r--tensorflow/python/data/kernel_tests/batch_dataset_op_test.py10
-rw-r--r--tensorflow/python/data/kernel_tests/cache_dataset_op_test.py5
-rw-r--r--tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_ops_test.py159
-rw-r--r--tensorflow/python/data/kernel_tests/filter_dataset_op_test.py10
-rw-r--r--tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/inputs_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/map_dataset_op_test.py134
-rw-r--r--tensorflow/python/data/kernel_tests/multi_device_iterator_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/optional_ops_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/range_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py7
-rw-r--r--tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py124
-rw-r--r--tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/shard_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/kernel_tests/test_base.py138
-rw-r--r--tensorflow/python/data/kernel_tests/window_dataset_op_test.py8
-rw-r--r--tensorflow/python/data/kernel_tests/zip_dataset_op_test.py3
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py393
-rw-r--r--tensorflow/python/data/ops/optional_ops.py4
-rw-r--r--tensorflow/python/data/ops/readers.py4
-rw-r--r--tensorflow/python/debug/BUILD1
-rw-r--r--tensorflow/python/debug/cli/analyzer_cli_test.py20
-rw-r--r--tensorflow/python/debug/cli/stepper_cli_test.py4
-rw-r--r--tensorflow/python/debug/examples/debug_tflearn_iris.py14
-rwxr-xr-xtensorflow/python/debug/examples/examples_test.sh2
-rw-r--r--tensorflow/python/debug/lib/debug_utils_test.py4
-rw-r--r--tensorflow/python/debug/lib/dist_session_debug_grpc_test.py4
-rw-r--r--tensorflow/python/debug/lib/grpc_large_data_test.py12
-rw-r--r--tensorflow/python/debug/lib/session_debug_file_test.py4
-rw-r--r--tensorflow/python/debug/lib/session_debug_grpc_test.py46
-rw-r--r--tensorflow/python/debug/lib/session_debug_testlib.py90
-rw-r--r--tensorflow/python/debug/lib/stepper_test.py14
-rw-r--r--tensorflow/python/debug/wrappers/dumping_wrapper_test.py2
-rw-r--r--tensorflow/python/debug/wrappers/local_cli_wrapper_test.py14
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py4
-rw-r--r--tensorflow/python/distribute/estimator_training.py2
-rw-r--r--tensorflow/python/eager/BUILD2
-rw-r--r--tensorflow/python/eager/backprop.py2
-rw-r--r--tensorflow/python/eager/benchmarks_test.py4
-rw-r--r--tensorflow/python/eager/core_test.py28
-rw-r--r--tensorflow/python/eager/function.py294
-rw-r--r--tensorflow/python/eager/function_test.py153
-rw-r--r--tensorflow/python/eager/imperative_grad.py5
-rwxr-xr-xtensorflow/python/eager/pywrap_tfe.h4
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc246
-rw-r--r--tensorflow/python/estimator/BUILD3
-rw-r--r--tensorflow/python/estimator/canned/dnn.py183
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined.py7
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined_test.py268
-rw-r--r--tensorflow/python/estimator/canned/dnn_test.py161
-rw-r--r--tensorflow/python/estimator/canned/dnn_testing_utils.py116
-rw-r--r--tensorflow/python/estimator/canned/linear.py83
-rw-r--r--tensorflow/python/estimator/canned/linear_test.py138
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py184
-rw-r--r--tensorflow/python/estimator/estimator.py71
-rw-r--r--tensorflow/python/estimator/estimator_test.py150
-rw-r--r--tensorflow/python/estimator/keras.py87
-rw-r--r--tensorflow/python/estimator/keras_test.py34
-rw-r--r--tensorflow/python/estimator/util.py8
-rw-r--r--tensorflow/python/feature_column/BUILD2
-rw-r--r--tensorflow/python/feature_column/feature_column.py90
-rw-r--r--tensorflow/python/feature_column/feature_column_test.py12
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py632
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py1874
-rw-r--r--tensorflow/python/framework/device.py12
-rw-r--r--tensorflow/python/framework/dtypes.py4
-rw-r--r--tensorflow/python/framework/errors_impl.py6
-rw-r--r--tensorflow/python/framework/function.py2
-rw-r--r--tensorflow/python/framework/function_test.py2
-rw-r--r--tensorflow/python/framework/graph_io.py2
-rw-r--r--tensorflow/python/framework/graph_util_test.py8
-rw-r--r--tensorflow/python/framework/importer.py5
-rw-r--r--tensorflow/python/framework/ops.py28
-rw-r--r--tensorflow/python/framework/random_seed.py6
-rw-r--r--tensorflow/python/framework/sparse_tensor.py4
-rw-r--r--tensorflow/python/framework/subscribe_test.py4
-rw-r--r--tensorflow/python/framework/test_util.py105
-rw-r--r--tensorflow/python/grappler/item_test.py2
-rw-r--r--tensorflow/python/grappler/memory_optimizer_test.py10
-rw-r--r--tensorflow/python/grappler/tf_optimizer_test.py2
-rwxr-xr-xtensorflow/python/keras/BUILD157
-rw-r--r--tensorflow/python/keras/activations.py5
-rw-r--r--tensorflow/python/keras/activations_test.py10
-rw-r--r--tensorflow/python/keras/backend.py117
-rw-r--r--tensorflow/python/keras/backend_test.py44
-rw-r--r--tensorflow/python/keras/callbacks.py4
-rw-r--r--tensorflow/python/keras/callbacks_test.py40
-rw-r--r--tensorflow/python/keras/engine/base_layer.py161
-rw-r--r--tensorflow/python/keras/engine/distributed_training_utils.py134
-rw-r--r--tensorflow/python/keras/engine/input_layer.py1
-rw-r--r--tensorflow/python/keras/engine/network.py20
-rw-r--r--tensorflow/python/keras/engine/saving_test.py13
-rw-r--r--tensorflow/python/keras/engine/topology_test.py31
-rw-r--r--tensorflow/python/keras/engine/training.py73
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py384
-rw-r--r--tensorflow/python/keras/engine/training_eager.py3
-rw-r--r--tensorflow/python/keras/engine/training_eager_test.py44
-rw-r--r--tensorflow/python/keras/engine/training_generator.py11
-rw-r--r--tensorflow/python/keras/engine/training_test.py16
-rw-r--r--tensorflow/python/keras/engine/training_utils.py15
-rw-r--r--tensorflow/python/keras/layers/convolutional.py177
-rw-r--r--tensorflow/python/keras/layers/convolutional_test.py33
-rw-r--r--tensorflow/python/keras/layers/core.py51
-rw-r--r--tensorflow/python/keras/layers/core_test.py45
-rw-r--r--tensorflow/python/keras/layers/cudnn_recurrent.py24
-rw-r--r--tensorflow/python/keras/layers/cudnn_recurrent_test.py27
-rw-r--r--tensorflow/python/keras/layers/embeddings.py10
-rw-r--r--tensorflow/python/keras/layers/pooling.py185
-rw-r--r--tensorflow/python/keras/layers/pooling_test.py30
-rw-r--r--tensorflow/python/keras/layers/recurrent.py65
-rw-r--r--tensorflow/python/keras/layers/recurrent_test.py90
-rw-r--r--tensorflow/python/keras/layers/wrappers.py3
-rw-r--r--tensorflow/python/keras/metrics.py30
-rw-r--r--tensorflow/python/keras/metrics_test.py26
-rw-r--r--tensorflow/python/keras/models.py14
-rw-r--r--tensorflow/python/keras/optimizer_v2/adadelta.py116
-rw-r--r--tensorflow/python/keras/optimizer_v2/adadelta_test.py166
-rw-r--r--tensorflow/python/keras/optimizer_v2/adagrad.py119
-rw-r--r--tensorflow/python/keras/optimizer_v2/adagrad_test.py276
-rw-r--r--tensorflow/python/keras/optimizer_v2/adam.py203
-rw-r--r--tensorflow/python/keras/optimizer_v2/adam_test.py333
-rw-r--r--tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py761
-rw-r--r--tensorflow/python/keras/optimizer_v2/optimizer_v2.py1349
-rw-r--r--tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py277
-rw-r--r--tensorflow/python/keras/optimizer_v2/rmsprop.py239
-rw-r--r--tensorflow/python/keras/optimizer_v2/rmsprop_test.py444
-rw-r--r--tensorflow/python/keras/optimizer_v2/sgd.py170
-rw-r--r--tensorflow/python/keras/optimizer_v2/sgd_test.py759
-rw-r--r--tensorflow/python/keras/preprocessing/image_test.py37
-rw-r--r--tensorflow/python/keras/testing_utils.py5
-rw-r--r--tensorflow/python/keras/utils/conv_utils.py45
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils.py17
-rw-r--r--tensorflow/python/keras/utils/multi_gpu_utils_test.py26
-rw-r--r--tensorflow/python/keras/utils/np_utils.py5
-rw-r--r--tensorflow/python/kernel_tests/BUILD55
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py12
-rw-r--r--tensorflow/python/kernel_tests/batch_gather_op_test.py13
-rw-r--r--tensorflow/python/kernel_tests/bincount_op_test.py18
-rw-r--r--tensorflow/python/kernel_tests/cond_v2_test.py49
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py267
-rw-r--r--tensorflow/python/kernel_tests/dense_update_ops_test.py6
-rw-r--r--tensorflow/python/kernel_tests/depthwise_conv_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/distributions/bernoulli_test.py12
-rw-r--r--tensorflow/python/kernel_tests/distributions/beta_test.py5
-rw-r--r--tensorflow/python/kernel_tests/distributions/dirichlet_test.py17
-rw-r--r--tensorflow/python/kernel_tests/distributions/exponential_test.py23
-rw-r--r--tensorflow/python/kernel_tests/distributions/gamma_test.py8
-rw-r--r--tensorflow/python/kernel_tests/distributions/laplace_test.py4
-rw-r--r--tensorflow/python/kernel_tests/distributions/normal_test.py8
-rw-r--r--tensorflow/python/kernel_tests/identity_op_py_test.py2
-rw-r--r--tensorflow/python/kernel_tests/list_ops_test.py26
-rw-r--r--tensorflow/python/kernel_tests/logging_ops_test.py13
-rw-r--r--tensorflow/python/kernel_tests/relu_op_test.py11
-rw-r--r--tensorflow/python/kernel_tests/resource_variable_ops_test.py11
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py39
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/scatter_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/softplus_op_test.py2
-rw-r--r--tensorflow/python/kernel_tests/substr_op_test.py503
-rw-r--r--tensorflow/python/kernel_tests/transpose_op_test.py29
-rw-r--r--tensorflow/python/kernel_tests/unicode_script_op_test.py57
-rw-r--r--tensorflow/python/kernel_tests/variable_scope_test.py4
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py78
-rw-r--r--tensorflow/python/layers/base.py16
-rw-r--r--tensorflow/python/layers/convolutional_test.py36
-rw-r--r--tensorflow/python/layers/core.py16
-rw-r--r--tensorflow/python/layers/core_test.py40
-rw-r--r--tensorflow/python/lib/io/tf_record.py13
-rw-r--r--tensorflow/python/ops/array_ops.py67
-rw-r--r--tensorflow/python/ops/candidate_sampling_ops.py8
-rw-r--r--tensorflow/python/ops/check_ops.py63
-rw-r--r--tensorflow/python/ops/clip_ops.py8
-rw-r--r--tensorflow/python/ops/cond_v2_impl.py64
-rw-r--r--tensorflow/python/ops/confusion_matrix.py4
-rw-r--r--tensorflow/python/ops/control_flow_ops.py22
-rw-r--r--tensorflow/python/ops/control_flow_ops_benchmark.py122
-rw-r--r--tensorflow/python/ops/conv2d_benchmark.py4
-rw-r--r--tensorflow/python/ops/data_flow_ops.py17
-rw-r--r--tensorflow/python/ops/distributions/BUILD7
-rw-r--r--tensorflow/python/ops/distributions/bernoulli.py9
-rw-r--r--tensorflow/python/ops/distributions/beta.py18
-rw-r--r--tensorflow/python/ops/distributions/categorical.py9
-rw-r--r--tensorflow/python/ops/distributions/dirichlet.py11
-rw-r--r--tensorflow/python/ops/distributions/dirichlet_multinomial.py9
-rw-r--r--tensorflow/python/ops/distributions/distribution.py51
-rw-r--r--tensorflow/python/ops/distributions/exponential.py16
-rw-r--r--tensorflow/python/ops/distributions/gamma.py16
-rw-r--r--tensorflow/python/ops/distributions/identity_bijector.py9
-rw-r--r--tensorflow/python/ops/distributions/kullback_leibler.py25
-rw-r--r--tensorflow/python/ops/distributions/laplace.py14
-rw-r--r--tensorflow/python/ops/distributions/multinomial.py9
-rw-r--r--tensorflow/python/ops/distributions/normal.py14
-rw-r--r--tensorflow/python/ops/distributions/special_math.py61
-rw-r--r--tensorflow/python/ops/distributions/student_t.py14
-rw-r--r--tensorflow/python/ops/distributions/transformed_distribution.py9
-rw-r--r--tensorflow/python/ops/distributions/uniform.py9
-rw-r--r--tensorflow/python/ops/embedding_ops.py8
-rw-r--r--tensorflow/python/ops/gradients.py1
-rw-r--r--tensorflow/python/ops/gradients_impl.py67
-rw-r--r--tensorflow/python/ops/gradients_test.py36
-rw-r--r--tensorflow/python/ops/init_ops.py5
-rw-r--r--tensorflow/python/ops/linalg_ops.py15
-rw-r--r--tensorflow/python/ops/lookup_ops.py2
-rw-r--r--tensorflow/python/ops/manip_ops.py4
-rw-r--r--tensorflow/python/ops/math_grad.py34
-rw-r--r--tensorflow/python/ops/math_grad_test.py88
-rw-r--r--tensorflow/python/ops/math_ops.py145
-rw-r--r--tensorflow/python/ops/math_ops_test.py71
-rw-r--r--tensorflow/python/ops/matmul_benchmark.py8
-rw-r--r--tensorflow/python/ops/nn_impl.py6
-rw-r--r--tensorflow/python/ops/nn_ops.py18
-rw-r--r--tensorflow/python/ops/numerics.py4
-rw-r--r--tensorflow/python/ops/parallel_for/pfor.py6
-rw-r--r--tensorflow/python/ops/parsing_ops.py31
-rw-r--r--tensorflow/python/ops/random_ops.py19
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py6
-rw-r--r--tensorflow/python/ops/sparse_ops.py107
-rw-r--r--tensorflow/python/ops/special_math_ops.py4
-rw-r--r--tensorflow/python/ops/string_ops.py48
-rw-r--r--tensorflow/python/ops/variable_scope.py126
-rw-r--r--tensorflow/python/ops/variables.py371
-rw-r--r--tensorflow/python/ops/while_v2.py63
-rw-r--r--tensorflow/python/platform/tf_logging.py58
-rwxr-xr-xtensorflow/python/pywrap_tfe.i1
-rw-r--r--tensorflow/python/saved_model/builder_impl.py7
-rw-r--r--tensorflow/python/saved_model/loader_impl.py8
-rw-r--r--tensorflow/python/saved_model/loader_test.py14
-rw-r--r--tensorflow/python/saved_model/main_op_impl.py5
-rw-r--r--tensorflow/python/saved_model/saved_model_test.py56
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_impl.py27
-rw-r--r--tensorflow/python/saved_model/utils_impl.py10
-rw-r--r--tensorflow/python/tools/BUILD8
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files.bzl2
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files_v1.bzl2
-rw-r--r--tensorflow/python/tools/freeze_graph_test.py6
-rw-r--r--tensorflow/python/tools/saved_model_cli.py19
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py5
-rw-r--r--tensorflow/python/training/checkpointable/util.py58
-rw-r--r--tensorflow/python/training/checkpointable/util_test.py5
-rw-r--r--tensorflow/python/training/distribute.py64
-rw-r--r--tensorflow/python/training/distribution_strategy_context.py2
-rw-r--r--tensorflow/python/training/evaluation.py68
-rw-r--r--tensorflow/python/training/input.py3
-rw-r--r--tensorflow/python/training/learning_rate_decay_test.py4
-rw-r--r--tensorflow/python/training/monitored_session.py8
-rw-r--r--tensorflow/python/training/monitored_session_test.py28
-rw-r--r--tensorflow/python/training/moving_averages.py58
-rw-r--r--tensorflow/python/training/moving_averages_test.py27
-rw-r--r--tensorflow/python/training/optimizer.py17
-rw-r--r--tensorflow/python/training/quantize_training_test.py3
-rw-r--r--tensorflow/python/training/queue_runner_test.py22
-rw-r--r--tensorflow/python/training/saver_test.py217
-rw-r--r--tensorflow/python/training/server_lib_same_variables_no_clear_test.py4
-rw-r--r--tensorflow/python/training/server_lib_test.py18
-rw-r--r--tensorflow/python/training/session_manager.py7
-rw-r--r--tensorflow/python/training/session_manager_test.py98
-rw-r--r--tensorflow/python/training/supervisor.py7
-rw-r--r--tensorflow/python/training/supervisor_test.py52
-rw-r--r--tensorflow/python/training/sync_replicas_optimizer.py6
-rw-r--r--tensorflow/python/training/sync_replicas_optimizer_test.py17
-rw-r--r--tensorflow/python/training/training_ops_test.py32
-rw-r--r--tensorflow/python/training/training_util_test.py4
-rw-r--r--tensorflow/python/util/function_utils.py23
-rw-r--r--tensorflow/python/util/function_utils_test.py95
-rw-r--r--tensorflow/python/util/nest.py20
-rw-r--r--tensorflow/python/util/nest_test.py34
-rw-r--r--tensorflow/python/util/protobuf/compare.py18
-rw-r--r--tensorflow/python/util/tf_inspect.py93
-rw-r--r--tensorflow/python/util/tf_inspect_test.py199
-rw-r--r--tensorflow/python/util/util.cc567
-rw-r--r--tensorflow/python/util/util.h43
-rw-r--r--tensorflow/python/util/util.i22
418 files changed, 36739 insertions, 5005 deletions
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 79f14466e6..822d596995 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -333,6 +333,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//third_party/python_runtime:headers",
+ "@com_google_absl//absl/memory",
],
)
@@ -1638,6 +1639,15 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "experimental_dataset_ops_gen",
+ visibility = [
+ "//learning/brain/python/ops:__pkg__",
+ "//tensorflow:__subpackages__",
+ "//tensorflow/python/kernel_tests:__pkg__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "image_ops_gen",
visibility = ["//learning/brain/python/ops:__pkg__"],
)
@@ -1730,6 +1740,14 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "stateless_random_ops_gen",
+ visibility = [
+ "//tensorflow/contrib/stateless:__pkg__",
+ "//tensorflow/python/data/experimental/ops:__pkg__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "list_ops_gen",
)
@@ -2007,6 +2025,7 @@ py_library(
":array_ops",
":cond_v2_impl",
":constant_op",
+ ":control_flow_ops",
":control_flow_util",
":framework_ops",
":function_def_to_graph",
@@ -2133,6 +2152,7 @@ py_library(
":array_grad",
":array_ops",
":bitwise_ops",
+ ":check_ops",
":cond_v2_impl",
":control_flow_grad",
":control_flow_ops",
@@ -2153,8 +2173,11 @@ py_library(
":random_grad",
":resource_variable_ops",
":spectral_grad",
+ ":tensor_array_ops",
+ ":tensor_util",
":util",
":variable_scope",
+ "//tensorflow/core:protos_all_py",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:tape",
@@ -3291,9 +3314,11 @@ py_library(
"training/checkpointable/**/*.py",
# The following targets have their own build rules (same name as the
# file):
+ "training/basic_session_run_hooks.py",
"training/checkpoint_management.py",
"training/saveable_object.py",
"training/saver.py",
+ "training/session_run_hook.py",
"training/training_util.py",
],
),
@@ -3301,6 +3326,7 @@ py_library(
deps = [
":array_ops",
":array_ops_gen",
+ ":basic_session_run_hooks",
":checkpoint_management",
":checkpoint_ops_gen",
":client",
@@ -3325,6 +3351,7 @@ py_library(
":saver",
":sdca_ops",
":session",
+ ":session_run_hook",
":sparse_ops",
":sparse_tensor",
":state_ops",
@@ -3369,6 +3396,28 @@ py_library(
)
py_library(
+ name = "session_run_hook",
+ srcs = ["training/session_run_hook.py"],
+ srcs_version = "PY2AND3",
+ deps = [":util"],
+)
+
+py_library(
+ name = "basic_session_run_hooks",
+ srcs = ["training/basic_session_run_hooks.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client",
+ ":framework",
+ ":platform",
+ ":protos_all_py",
+ ":session_run_hook",
+ ":training_util",
+ ":util",
+ ],
+)
+
+py_library(
name = "saver",
srcs = ["training/saver.py"],
srcs_version = "PY2AND3",
@@ -5148,6 +5197,19 @@ cuda_py_test(
)
cuda_py_test(
+ name = "control_flow_ops_benchmark",
+ srcs = ["ops/control_flow_ops_benchmark.py"],
+ additional_deps = [
+ ":client_testlib",
+ ":constant_op",
+ ":control_flow_ops",
+ ":framework_ops",
+ "//tensorflow/python/eager:function",
+ ],
+ main = "ops/control_flow_ops_benchmark.py",
+)
+
+cuda_py_test(
name = "conv2d_benchmark",
size = "large",
srcs = ["ops/conv2d_benchmark.py"],
diff --git a/tensorflow/python/autograph/CONTRIBUTING.md b/tensorflow/python/autograph/CONTRIBUTING.md
index 1ded5ba5f6..f3587a4384 100644
--- a/tensorflow/python/autograph/CONTRIBUTING.md
+++ b/tensorflow/python/autograph/CONTRIBUTING.md
@@ -9,8 +9,6 @@ In preparation for TF 2.0, we moved the code base of AutoGraph from
does not impact functionality, and AutoGraph will remain accessible under
`tensorflow.contrib.autograph` until `tensorflow.contrib` is retired.
-When
-
## TensorFlow Code of Conduct
Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md).
diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD
index 7b029de8ed..f06dc78f0e 100644
--- a/tensorflow/python/autograph/converters/BUILD
+++ b/tensorflow/python/autograph/converters/BUILD
@@ -27,10 +27,10 @@ py_library(
"decorators.py",
"directives.py",
"error_handlers.py",
+ "function_scopes.py",
"list_comprehensions.py",
"lists.py",
"logical_expressions.py",
- "name_scopes.py",
"return_statements.py",
"side_effect_guards.py",
"slices.py",
@@ -157,8 +157,8 @@ py_test(
)
py_test(
- name = "name_scopes_test",
- srcs = ["name_scopes_test.py"],
+ name = "function_scopes_test",
+ srcs = ["function_scopes_test.py"],
deps = [
":converters",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/autograph/converters/name_scopes.py b/tensorflow/python/autograph/converters/function_scopes.py
index a9c55ccff0..284b5b3519 100644
--- a/tensorflow/python/autograph/converters/name_scopes.py
+++ b/tensorflow/python/autograph/converters/function_scopes.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Wraps a function body with a `name_scope` of the function name."""
+"""Wraps the body of a converted function with auxiliary constructs."""
from __future__ import absolute_import
from __future__ import division
@@ -24,8 +24,8 @@ from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import templates
-class FunctionNameScopeTransformer(converter.Base):
- """Wrap a function body with a `name_scope` of the function name."""
+class FunctionBodyTransformer(converter.Base):
+ """Wraps function bodies around autograph-specific boilerplate."""
def _name_for_current_scope(self):
innermost = self.enclosing_entities[-1]
@@ -49,26 +49,28 @@ class FunctionNameScopeTransformer(converter.Base):
def visit_FunctionDef(self, node):
node = self.generic_visit(node)
- unscoped_body = []
- scoped_body = node.body
- if scoped_body:
- first = scoped_body[0]
- if isinstance(first, gast.Expr) and isinstance(first.value, gast.Str):
- # Skip any docstring.
- unscoped_body = scoped_body[:1]
- scoped_body = scoped_body[1:]
+ final_body = []
+ indented_body = node.body
+ if node.body:
+ first_statement = node.body[0]
+ # Skip the docstring, if any.
+ if (isinstance(first_statement, gast.Expr) and
+ isinstance(first_statement.value, gast.Str)):
+ indented_body = indented_body[1:]
+ final_body.append(first_statement)
template = """
- with tf.name_scope(scope_name):
+ with ag__.function_scope(scope_name):
body
"""
scoped_body = templates.replace(
template,
scope_name=gast.Str(self._name_for_current_scope()),
- body=scoped_body)
- node.body = unscoped_body + scoped_body
+ body=indented_body)
+ final_body.extend(scoped_body)
+ node.body = final_body
return node
def transform(node, ctx):
- return FunctionNameScopeTransformer(ctx).visit(node)
+ return FunctionBodyTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/name_scopes_test.py b/tensorflow/python/autograph/converters/function_scopes_test.py
index 73933c1c4f..e5ce03a109 100644
--- a/tensorflow/python/autograph/converters/name_scopes_test.py
+++ b/tensorflow/python/autograph/converters/function_scopes_test.py
@@ -12,51 +12,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for for_canonicalization module."""
+"""Tests for function_scopes module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.autograph.converters import name_scopes
+from tensorflow.python.autograph.converters import function_scopes
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
-class FunctionNameScopeTransformer(converter_testing.TestCase):
+class FunctionBodyTransformerTest(converter_testing.TestCase):
def test_basic(self):
def test_fn(l):
- """This should stay here."""
+ """Docstring."""
a = 1
l += a
return l
- with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
+ with self.converted(test_fn, function_scopes, {}) as result:
result_op = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', result_op.op.name)
- self.assertEqual('This should stay here.', result.test_fn.__doc__)
+ self.assertEqual('Docstring.', result.test_fn.__doc__)
- def test_long_docstring(self):
+ def test_multiline_docstring(self):
- def test_fn(l):
- """Multi-line docstring.
+ tf = None
+
+ def test_fn():
+ """First sentence.
- Args:
- l: A thing.
- Returns:
- l
+ Second sentence.
"""
- return l + 1
+ return tf.constant(1)
- with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
- result_op = result.test_fn(constant_op.constant(1))
+ with self.converted(test_fn, function_scopes, {},
+ constant_op.constant) as result:
+ result_op = result.test_fn()
self.assertIn('test_fn/', result_op.op.name)
- self.assertIn('Multi-line docstring.', result.test_fn.__doc__)
- self.assertIn('Returns:', result.test_fn.__doc__)
+ self.assertIn('First sentence.', result.test_fn.__doc__)
+ self.assertIn('Second sentence.', result.test_fn.__doc__)
def test_nested_functions(self):
@@ -68,7 +68,7 @@ class FunctionNameScopeTransformer(converter_testing.TestCase):
l += 1
return l, inner_fn(l)
- with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
+ with self.converted(test_fn, function_scopes, {}, ops.name_scope) as result:
first, second = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', first.op.name)
self.assertNotIn('inner_fn', first.op.name)
@@ -88,7 +88,7 @@ class FunctionNameScopeTransformer(converter_testing.TestCase):
ns = {'TestClass': TestClass}
node, ctx = self.prepare(TestClass, ns, owner_type=TestClass)
- node = name_scopes.transform(node, ctx)
+ node = function_scopes.transform(node, ctx)
with self.compiled(node, {}, ops.name_scope) as result:
first, second = result.TestClass().test_fn(constant_op.constant(1))
diff --git a/tensorflow/python/autograph/converters/return_statements.py b/tensorflow/python/autograph/converters/return_statements.py
index 62da045d6a..496c99e3b5 100644
--- a/tensorflow/python/autograph/converters/return_statements.py
+++ b/tensorflow/python/autograph/converters/return_statements.py
@@ -212,6 +212,7 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
def __init__(self):
self.cant_return = False
+ self.function_level = 0
super(DetectReturnInUnsupportedControlFlow, self).__init__()
def visit_While(self, node):
@@ -229,6 +230,12 @@ class DetectReturnInUnsupportedControlFlow(gast.NodeVisitor):
self.generic_visit(node)
self.cant_return = False
+ def visit_FunctionDef(self, node):
+ if not self.function_level:
+ self.function_level += 1
+ self.generic_visit(node)
+ self.function_level -= 1
+
def visit_Return(self, node):
if self.cant_return:
raise ValueError(
@@ -242,6 +249,7 @@ class DetectReturnInConditional(gast.NodeVisitor):
def __init__(self):
self.cant_return = False
+ self.function_level = 0
super(DetectReturnInConditional, self).__init__()
def visit_If(self, node):
@@ -249,6 +257,12 @@ class DetectReturnInConditional(gast.NodeVisitor):
self.generic_visit(node)
self.cant_return = False
+ def visit_FunctionDef(self, node):
+ if not self.function_level:
+ self.function_level += 1
+ self.generic_visit(node)
+ self.function_level -= 1
+
def visit_Return(self, node):
if self.cant_return:
raise ValueError(
diff --git a/tensorflow/python/autograph/converters/return_statements_test.py b/tensorflow/python/autograph/converters/return_statements_test.py
index 01dd03da0b..762fbc6f60 100644
--- a/tensorflow/python/autograph/converters/return_statements_test.py
+++ b/tensorflow/python/autograph/converters/return_statements_test.py
@@ -151,6 +151,18 @@ class SingleReturnTest(converter_testing.TestCase):
self.assertTransformedEquivalent(test_fn, 2)
self.assertTransformedEquivalent(test_fn, -2)
+ def test_nested_functions_in_control_flow(self):
+
+ def test_fn(x):
+
+ if x:
+ def inner_fn(y):
+ return y
+ inner_fn(x)
+
+ self.assertTransformedEquivalent(test_fn, 2)
+ self.assertTransformedEquivalent(test_fn, -2)
+
def test_loop(self):
def test_fn(x):
diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD
index 85fecf084d..3ab2e7b1bc 100644
--- a/tensorflow/python/autograph/core/BUILD
+++ b/tensorflow/python/autograph/core/BUILD
@@ -20,17 +20,48 @@ py_library(
"config.py",
"converter.py",
"errors.py",
+ "function_wrapping.py",
"naming.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
+ "//tensorflow/python:framework_ops",
"//tensorflow/python/autograph/pyct",
"//tensorflow/python/autograph/pyct/static_analysis",
"//tensorflow/python/autograph/utils",
],
)
+py_library(
+ name = "test_lib",
+ srcs = [
+ "converter_testing.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":core",
+ "//tensorflow/python/autograph/operators",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
+ "//tensorflow/python/autograph/utils",
+ "@gast_archive//:gast",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "converter_test",
+ srcs = ["converter_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ ":test_lib",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
py_test(
name = "errors_test",
srcs = ["errors_test.py"],
@@ -47,8 +78,8 @@ py_test(
)
py_test(
- name = "naming_test",
- srcs = ["naming_test.py"],
+ name = "function_wrapping_test",
+ srcs = ["function_wrapping_test.py"],
srcs_version = "PY2AND3",
deps = [
":core",
@@ -56,20 +87,12 @@ py_test(
],
)
-py_library(
- name = "test_lib",
- srcs = [
- "converter_testing.py",
- ],
+py_test(
+ name = "naming_test",
+ srcs = ["naming_test.py"],
srcs_version = "PY2AND3",
- visibility = ["//tensorflow:__subpackages__"],
deps = [
":core",
- "//tensorflow/python/autograph/operators",
- "//tensorflow/python/autograph/pyct",
- "//tensorflow/python/autograph/pyct/static_analysis",
- "//tensorflow/python/autograph/utils",
- "@gast_archive//:gast",
- "@six_archive//:six",
+ "//tensorflow/python:client_testlib",
],
)
diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py
index 80928ae7f4..408a573ad0 100644
--- a/tensorflow/python/autograph/core/converter.py
+++ b/tensorflow/python/autograph/core/converter.py
@@ -210,14 +210,22 @@ class Base(transformer.Base):
self._ast_depth = 0
def get_definition_directive(self, node, directive, arg, default):
- """Returns the unique directive for a symbol, or a default if none exist.
+ """Returns the unique directive argument for a symbol.
See lang/directives.py for details on directives.
+ Example:
+ # Given a directive in the code:
+ ag.foo_directive(bar, baz=1)
+
+ # One can write for an AST node Name(id='bar'):
+ get_definition_directive(node, ag.foo_directive, 'baz')
+
Args:
- node: ast.AST
- directive: Callable[..., Any]
- arg: str
+ node: ast.AST, the node representing the symbol for which the directive
+ argument is needed.
+ directive: Callable[..., Any], the directive to search.
+ arg: str, the directive argument to return.
default: Any
Raises:
@@ -227,27 +235,28 @@ class Base(transformer.Base):
if not defs:
return default
- # TODO(mdan): Simplify this.
- arg_values = []
+ arg_values_found = []
for def_ in defs:
- if (directive not in def_.directives or
- arg not in def_.directives[directive]):
- continue
- arg_value = def_.directives[directive][arg]
- for prev_value in arg_values:
- if not ast_util.matches(arg_value, prev_value):
- qn = anno.getanno(node, anno.Basic.QN)
- raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
- (qn, directive.__name__, arg,
- compiler.ast_to_source(arg_value).strip(),
- compiler.ast_to_source(prev_value).strip()))
- arg_values.append(arg_value)
-
- if not arg_values:
+ if (directive in def_.directives and arg in def_.directives[directive]):
+ arg_values_found.append(def_.directives[directive][arg])
+
+ if not arg_values_found:
return default
- arg_value, = arg_values
- return arg_value
+ if len(arg_values_found) == 1:
+ return arg_values_found[0]
+
+ # If multiple annotations reach the symbol, they must all match. If they do,
+ # return any of them.
+ first_value = arg_values_found[0]
+ for other_value in arg_values_found[1:]:
+ if not ast_util.matches(first_value, other_value):
+ qn = anno.getanno(node, anno.Basic.QN)
+ raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
+ (qn, directive.__name__, arg,
+ compiler.ast_to_source(other_value).strip(),
+ compiler.ast_to_source(first_value).strip()))
+ return first_value
def visit(self, node):
if not self._ast_depth:
diff --git a/tensorflow/python/autograph/core/converter_test.py b/tensorflow/python/autograph/core/converter_test.py
new file mode 100644
index 0000000000..b73c67e337
--- /dev/null
+++ b/tensorflow/python/autograph/core/converter_test.py
@@ -0,0 +1,124 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for lists module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.platform import test
+
+
+class TestConverter(converter.Base):
+ pass
+
+
+class ConverterBaseTest(converter_testing.TestCase):
+
+ def test_get_definition_directive_basic(self):
+
+ directive_key = object
+
+ def test_fn():
+ a = 1
+ return a
+
+ ns = {}
+ node, ctx = self.prepare(test_fn, ns)
+ symbol_a = node.body[1].value
+ defs, = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
+ defs.directives[directive_key] = {
+ 'test_arg': parser.parse_expression('foo'),
+ 'other_arg': parser.parse_expression('bar'),
+ }
+ c = TestConverter(ctx)
+ value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
+ None)
+ self.assertEqual(value.id, 'foo')
+
+ def test_get_definition_directive_default(self):
+
+ directive_key = object
+
+ def test_fn():
+ a = 1
+ return a
+
+ ns = {}
+ node, ctx = self.prepare(test_fn, ns)
+ symbol_a = node.body[1].value
+ c = TestConverter(ctx)
+ value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
+ parser.parse_expression('default'))
+ self.assertEqual(value.id, 'default')
+
+ def test_get_definition_directive_multiple_consistent(self):
+
+ directive_key = object
+
+ def test_fn():
+ a = 1
+ if a:
+ a = 2
+ return a
+
+ ns = {}
+ node, ctx = self.prepare(test_fn, ns)
+ symbol_a = node.body[2].value
+ defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
+ defs[0].directives[directive_key] = {
+ 'test_arg': parser.parse_expression('foo'),
+ 'other_arg': parser.parse_expression('bar'),
+ }
+ defs[1].directives[directive_key] = {
+ 'test_arg': parser.parse_expression('foo'),
+ 'other_arg': parser.parse_expression('baz'),
+ }
+ c = TestConverter(ctx)
+ value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
+ None)
+ self.assertEqual(value.id, 'foo')
+
+ def test_get_definition_directive_multiple_inconsistent(self):
+
+ directive_key = object
+
+ def test_fn():
+ a = 1
+ if a:
+ a = 2
+ return a
+
+ ns = {}
+ node, ctx = self.prepare(test_fn, ns)
+ symbol_a = node.body[2].value
+ defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
+ defs[0].directives[directive_key] = {
+ 'test_arg': parser.parse_expression('foo'),
+ }
+ defs[1].directives[directive_key] = {
+ 'test_arg': parser.parse_expression('bar'),
+ }
+ c = TestConverter(ctx)
+ with self.assertRaises(ValueError):
+ c.get_definition_directive(symbol_a, directive_key, 'test_arg', None)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
index 7ce1b7c4c5..dc2d419d34 100644
--- a/tensorflow/python/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -29,6 +29,7 @@ from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.core import function_wrapping
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
@@ -112,6 +113,7 @@ class TestCase(test.TestCase):
fake_ag.__dict__['utils'] = utils
fake_ag.__dict__['rewrite_graph_construction_error'] = (
errors.rewrite_graph_construction_error)
+ fake_ag.__dict__['function_scope'] = function_wrapping.function_scope
result.__dict__['ag__'] = fake_ag
for k, v in namespace.items():
result.__dict__[k] = v
diff --git a/tensorflow/python/autograph/core/function_wrapping.py b/tensorflow/python/autograph/core/function_wrapping.py
new file mode 100644
index 0000000000..21b66eff02
--- /dev/null
+++ b/tensorflow/python/autograph/core/function_wrapping.py
@@ -0,0 +1,30 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Support for wrapping converted functions bodies with auxiliary logic."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+
+from tensorflow.python.framework import ops
+
+
+@contextlib.contextmanager
+def function_scope(function_name):
+ """Returns a context manager for the converted body of a function."""
+ with ops.name_scope(function_name):
+ yield
diff --git a/tensorflow/python/autograph/core/function_wrapping_test.py b/tensorflow/python/autograph/core/function_wrapping_test.py
new file mode 100644
index 0000000000..5e217055c7
--- /dev/null
+++ b/tensorflow/python/autograph/core/function_wrapping_test.py
@@ -0,0 +1,34 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for function_wrapping module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import function_wrapping
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+
+
+class FunctionWrappingTest(test.TestCase):
+
+ def test_function_scope_name(self):
+ with function_wrapping.function_scope('test_name'):
+ t = constant_op.constant(1)
+ self.assertIn('test_name', t.name)
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index a0d13c82a8..52abd40626 100644
--- a/tensorflow/python/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -34,15 +34,16 @@ from tensorflow.python.autograph.converters import control_flow
from tensorflow.python.autograph.converters import decorators
from tensorflow.python.autograph.converters import directives
from tensorflow.python.autograph.converters import error_handlers
+from tensorflow.python.autograph.converters import function_scopes
from tensorflow.python.autograph.converters import lists
from tensorflow.python.autograph.converters import logical_expressions
-from tensorflow.python.autograph.converters import name_scopes
from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.converters import side_effect_guards
from tensorflow.python.autograph.converters import slices
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.core import function_wrapping
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import origin_info
@@ -257,6 +258,7 @@ def _add_self_references(namespace, autograph_module):
ag_internal.converted_call = autograph_module.converted_call
ag_internal.ConversionOptions = autograph_module.ConversionOptions
ag_internal.utils = utils
+ ag_internal.function_scope = function_wrapping.function_scope
ag_internal.rewrite_graph_construction_error = (
errors.rewrite_graph_construction_error)
# TODO(mdan): Add safeguards against name clashes.
@@ -346,7 +348,7 @@ def node_to_graph(node, context, rewrite_errors=True):
node = converter.apply_(node, context, conditional_expressions)
node = converter.apply_(node, context, logical_expressions)
node = converter.apply_(node, context, side_effect_guards)
- node = converter.apply_(node, context, name_scopes)
+ node = converter.apply_(node, context, function_scopes)
if rewrite_errors:
node = converter.apply_(node, context, error_handlers)
return node
diff --git a/tensorflow/python/autograph/lang/special_functions.py b/tensorflow/python/autograph/lang/special_functions.py
index e4838d1b6d..62ac018ac4 100644
--- a/tensorflow/python/autograph/lang/special_functions.py
+++ b/tensorflow/python/autograph/lang/special_functions.py
@@ -24,6 +24,26 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.operators import data_structures
+from tensorflow.python.framework import tensor_util
+
+
+def _validate_list_constructor(elements, element_dtype, element_shape):
+ """Validates the inputs of tensor_list."""
+ if element_dtype is not None and element_shape is not None:
+ return
+ if tensor_util.is_tensor(elements):
+ return
+ if isinstance(elements, (list, tuple)):
+ if elements:
+ return
+ else:
+ raise ValueError(
+ 'element_dtype and element_shape are required when elements are'
+ ' empty')
+
+ raise ValueError(
+ 'unknown type for elements: {}; only Tensor, list and tuple are'
+ ' allowed'.format(type(elements)))
def tensor_list(elements,
@@ -52,9 +72,7 @@ def tensor_list(elements,
Raises:
ValueError: for invalid arguments
"""
- if not (elements or (element_dtype and element_shape)):
- raise ValueError(
- 'element_dtype and element_shape are required for empty lists')
+ _validate_list_constructor(elements, element_dtype, element_shape)
if use_tensor_array:
return data_structures.tf_tensor_array_new(elements, element_dtype,
element_shape)
diff --git a/tensorflow/python/autograph/lang/special_functions_test.py b/tensorflow/python/autograph/lang/special_functions_test.py
index 545dd11729..206a32d07c 100644
--- a/tensorflow/python/autograph/lang/special_functions_test.py
+++ b/tensorflow/python/autograph/lang/special_functions_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -28,12 +30,43 @@ from tensorflow.python.platform import test
class SpecialFunctionsTest(test.TestCase):
+ def test_tensor_list_empty_list(self):
+ l = special_functions.tensor_list([],
+ element_dtype=dtypes.int32,
+ element_shape=())
+ sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(sl), [])
+
+ l = special_functions.tensor_list((),
+ element_dtype=dtypes.int32,
+ element_shape=())
+ sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(sl), [])
+
+ def test_tensor_list_tensor(self):
+ l = special_functions.tensor_list(
+ constant_op.constant([], dtype=dtypes.int32))
+ sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(sl), [])
+
+ def test_tensor_list_unsupported_initializer(self):
+ with self.assertRaisesRegexp(ValueError, 'unknown type'):
+ special_functions.tensor_list(np.array([1, 2, 3]))
+
+ def test_tensor_list_empty_list_no_type(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'element_dtype and element_shape are required'):
+ special_functions.tensor_list([])
+
def test_tensor_list_from_elements(self):
elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
l = special_functions.tensor_list(elements)
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
- with self.cached_session() as sess:
+ with self.test_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_tensor_list_array_from_elements(self):
@@ -41,7 +74,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements, use_tensor_array=True)
sl = l.stack()
- with self.cached_session() as sess:
+ with self.test_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_stack(self):
diff --git a/tensorflow/python/autograph/operators/data_structures.py b/tensorflow/python/autograph/operators/data_structures.py
index cc0a3c3544..b3a3851333 100644
--- a/tensorflow/python/autograph/operators/data_structures.py
+++ b/tensorflow/python/autograph/operators/data_structures.py
@@ -106,6 +106,14 @@ def tf_tensor_array_new(elements, element_dtype=None, element_shape=None):
def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
"""Overload of new_list that stages a Tensor list creation."""
+ if tensor_util.is_tensor(elements):
+ if element_shape is not None:
+ raise ValueError(
+ 'element shape may not be specified when creating list from tensor')
+ element_shape = array_ops.shape(elements)[1:]
+ l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape)
+ return l
+
elements = tuple(ops.convert_to_tensor(el) for el in elements)
all_dtypes = set(el.dtype for el in elements)
@@ -115,13 +123,15 @@ def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
raise ValueError(
'incompatible dtype; specified: {}, inferred from {}: {}'.format(
element_dtype, elements, inferred_dtype))
- else:
+ elif all_dtypes:
# Heterogeneous lists are ok.
if element_dtype is not None:
raise ValueError(
'specified dtype {} is inconsistent with that of elements {}'.format(
element_dtype, elements))
inferred_dtype = dtypes.variant
+ else:
+ inferred_dtype = dtypes.variant
all_shapes = set(tuple(el.shape.as_list()) for el in elements)
if len(all_shapes) == 1:
@@ -130,19 +140,22 @@ def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
raise ValueError(
'incompatible shape; specified: {}, inferred from {}: {}'.format(
element_shape, elements, inferred_shape))
- else:
+ elif all_shapes:
# Heterogeneous lists are ok.
if element_shape is not None:
raise ValueError(
'specified shape {} is inconsistent with that of elements {}'.format(
element_shape, elements))
inferred_shape = constant_op.constant(-1) # unknown shape, by convention
+ else:
+ inferred_shape = constant_op.constant(-1) # unknown shape, by convention
if element_dtype is None:
element_dtype = inferred_dtype
if element_shape is None:
element_shape = inferred_shape
+ element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32)
l = list_ops.empty_tensor_list(
element_shape=element_shape, element_dtype=element_dtype)
for el in elements:
diff --git a/tensorflow/python/autograph/operators/data_structures_test.py b/tensorflow/python/autograph/operators/data_structures_test.py
index 8532dbe466..6039b07982 100644
--- a/tensorflow/python/autograph/operators/data_structures_test.py
+++ b/tensorflow/python/autograph/operators/data_structures_test.py
@@ -45,6 +45,20 @@ class ListTest(test.TestCase):
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4, 5])
+ def test_tf_tensor_list_new_empty(self):
+ l = data_structures.tf_tensor_list_new([],
+ element_dtype=dtypes.int32,
+ element_shape=())
+ t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(t), [])
+
+ def test_tf_tensor_list_new_from_tensor(self):
+ l = data_structures.tf_tensor_list_new(constant_op.constant([3, 4, 5]))
+ t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(t), [3, 4, 5])
+
def test_tf_tensor_list_new_illegal_input(self):
with self.assertRaises(ValueError):
data_structures.tf_tensor_list_new([3, 4.0])
@@ -56,9 +70,8 @@ class ListTest(test.TestCase):
with self.assertRaises(ValueError):
data_structures.tf_tensor_list_new([3, 4], element_shape=(2,))
with self.assertRaises(ValueError):
- data_structures.tf_tensor_list_new([], element_shape=(2,))
- with self.assertRaises(ValueError):
- data_structures.tf_tensor_list_new([], element_dtype=dtypes.float32)
+ data_structures.tf_tensor_list_new(
+ constant_op.constant([1, 2, 3]), element_shape=[1])
def test_tf_tensor_array_new(self):
l = data_structures.tf_tensor_array_new([3, 4, 5])
@@ -141,6 +154,18 @@ class ListTest(test.TestCase):
t = data_structures.list_stack(l, opts)
self.assertAllEqual(sess.run(t), sess.run(initial_list))
+ def test_stack_tensor_list_empty(self):
+ l = list_ops.empty_tensor_list(
+ element_shape=-1,
+ element_dtype=dtypes.variant)
+
+ opts = data_structures.ListStackOpts(
+ element_dtype=dtypes.int32, original_call=None)
+
+ # TODO(mdan): Allow stacking empty lists if the dtype and shape are known.
+ with self.assertRaises(ValueError):
+ data_structures.list_stack(l, opts)
+
def test_stack_fallback(self):
def dummy_function(l):
diff --git a/tensorflow/python/autograph/operators/py_builtins.py b/tensorflow/python/autograph/operators/py_builtins.py
index 91a2a22cc2..70e59272a9 100644
--- a/tensorflow/python/autograph/operators/py_builtins.py
+++ b/tensorflow/python/autograph/operators/py_builtins.py
@@ -228,5 +228,6 @@ BUILTIN_FUINCTIONS_MAP = {
'len': len_,
'print': print_,
'range': range_,
+ # TODO(mdan): This might make more sense as tf.data.range.
'xrange': range_,
}
diff --git a/tensorflow/python/autograph/pyct/inspect_utils.py b/tensorflow/python/autograph/pyct/inspect_utils.py
index eef74599a7..29c406c248 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils.py
@@ -30,10 +30,14 @@ from tensorflow.python.util import tf_inspect
def isbuiltin(f):
+ """Returns True if the argument is a built-in function."""
# Note these return false for isinstance(f, types.BuiltinFunctionType) so we
# need to specifically check for them.
if f in (range, int, float):
return True
+ if six.PY2:
+ if f in (xrange,):
+ return True
if isinstance(f, types.BuiltinFunctionType):
return True
if tf_inspect.isbuiltin(f):
@@ -63,6 +67,40 @@ def getnamespace(f):
return namespace
+def getqualifiedname(namespace, object_, max_depth=2):
+ """Returns the name by which a value can be referred to in a given namespace.
+
+ This function will recurse inside modules, but it will not search objects for
+ attributes. The recursion depth is controlled by max_depth.
+
+ Args:
+ namespace: Dict[str, Any], the namespace to search into.
+ object_: Any, the value to search.
+ max_depth: Optional[int], a limit to the recursion depth when searching
+ inside modules.
+ Returns: Union[str, None], the fully-qualified name that resolves to the value
+ o, or None if it couldn't be found.
+ """
+ for name, value in namespace.items():
+ # The value may be referenced by more than one symbol, case in which
+ # any symbol will be fine. If the program contains symbol aliases that
+ # change over time, this may capture a symbol that will later point to
+ # something else.
+ # TODO(mdan): Prefer the symbol that matches the value type name.
+ if object_ is value:
+ return name
+
+ # TODO(mdan): Use breadth-first search and avoid visiting modules twice.
+ if max_depth:
+ for name, value in namespace.items():
+ if tf_inspect.ismodule(value):
+ name_in_module = getqualifiedname(value.__dict__, object_,
+ max_depth - 1)
+ if name_in_module is not None:
+ return '{}.{}'.format(name, name_in_module)
+ return None
+
+
def _get_unbound_function(m):
# TODO(mdan): Figure out why six.get_unbound_function fails in some cases.
# The failure case is for tf.keras.Model.
diff --git a/tensorflow/python/autograph/pyct/inspect_utils_test.py b/tensorflow/python/autograph/pyct/inspect_utils_test.py
index f3eb027822..11074debfc 100644
--- a/tensorflow/python/autograph/pyct/inspect_utils_test.py
+++ b/tensorflow/python/autograph/pyct/inspect_utils_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from functools import wraps
+import imp
import six
@@ -127,6 +128,24 @@ class InspectUtilsTest(test.TestCase):
self.assertEqual(ns['closed_over_primitive'], closed_over_primitive)
self.assertTrue('local_var' not in ns)
+ def test_getqualifiedname(self):
+ foo = object()
+ qux = imp.new_module('quxmodule')
+ bar = imp.new_module('barmodule')
+ baz = object()
+ bar.baz = baz
+
+ ns = {
+ 'foo': foo,
+ 'bar': bar,
+ 'qux': qux,
+ }
+
+ self.assertIsNone(inspect_utils.getqualifiedname(ns, inspect_utils))
+ self.assertEqual(inspect_utils.getqualifiedname(ns, foo), 'foo')
+ self.assertEqual(inspect_utils.getqualifiedname(ns, bar), 'bar')
+ self.assertEqual(inspect_utils.getqualifiedname(ns, baz), 'bar.baz')
+
def test_getmethodclass(self):
self.assertEqual(
diff --git a/tensorflow/python/autograph/pyct/static_analysis/live_values.py b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
index 36b9e7074d..4ceddce53b 100644
--- a/tensorflow/python/autograph/pyct/static_analysis/live_values.py
+++ b/tensorflow/python/autograph/pyct/static_analysis/live_values.py
@@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
import gast
+import six
from tensorflow.python.autograph.pyct import anno
from tensorflow.python.autograph.pyct import transformer
@@ -35,6 +36,9 @@ from tensorflow.python.autograph.pyct.static_analysis.annos import NodeAnno
# These symbols are legal in Python, but don't appear in the namespace.
_SPECIAL_SYMBOLS = {'range': range, 'print': print}
+if six.PY2:
+ _SPECIAL_SYMBOLS['xrange'] = xrange
+
class LiveValueResolver(transformer.Base):
"""Annotates nodes with live values."""
diff --git a/tensorflow/python/autograph/pyct/templates.py b/tensorflow/python/autograph/pyct/templates.py
index 1bf0515745..1af8fca599 100644
--- a/tensorflow/python/autograph/pyct/templates.py
+++ b/tensorflow/python/autograph/pyct/templates.py
@@ -123,6 +123,8 @@ class ReplaceTransformer(gast.NodeTransformer):
self._check_inner_children_have_context(e)
for e in node.values:
self._check_inner_children_have_context(e)
+ elif isinstance(node, gast.Index):
+ self._check_inner_children_have_context(node.value)
elif isinstance(node, gast.Subscript):
self._check_inner_children_have_context(node.value)
self._check_inner_children_have_context(node.slice)
diff --git a/tensorflow/python/autograph/pyct/templates_test.py b/tensorflow/python/autograph/pyct/templates_test.py
index 078d9a149b..3032241846 100644
--- a/tensorflow/python/autograph/pyct/templates_test.py
+++ b/tensorflow/python/autograph/pyct/templates_test.py
@@ -158,6 +158,18 @@ class TemplatesTest(test.TestCase):
self.assertIsInstance(function_call_arg.elts[0].elts[0].ctx, gast.Load)
self.assertIsInstance(function_call_arg.elts[0].elts[1].ctx, gast.Load)
+ def test_replace_index(self):
+ template = """
+ def test_fn(foo):
+ foo = 0
+ """
+
+ node = templates.replace(
+ template, foo=parser.parse_expression('foo(a[b]).bar'))[0]
+ function_call_arg = node.body[0].targets[0].value.args[0]
+ self.assertIsInstance(function_call_arg.ctx, gast.Load)
+ self.assertIsInstance(function_call_arg.slice.value.ctx, gast.Load)
+
def test_replace_call_keyword(self):
template = """
def test_fn():
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index f576435136..347833ce8f 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -120,11 +120,17 @@ class SessionTest(test_util.TensorFlowTestCase):
inp = constant_op.constant(10.0, name='W1')
self.assertAllEqual(inp.eval(), 10.0)
- devices = sess.list_devices()
- self.assertEqual(2, len(devices))
- for device in devices:
- self.assertEqual('CPU', framework_device_lib.DeviceSpec.from_string(
- device.name).device_type)
+ num_cpu_devices = 0
+ num_gpu_devices = 0
+ for device in sess.list_devices():
+ device_type = framework_device_lib.DeviceSpec.from_string(
+ device.name).device_type
+ if device_type == 'CPU':
+ num_cpu_devices += 1
+ elif device_type == 'GPU':
+ num_gpu_devices += 1
+ self.assertEqual(2, num_cpu_devices)
+ self.assertEqual(0, num_gpu_devices)
def testPerSessionThreads(self):
with session.Session(
@@ -1022,7 +1028,7 @@ class SessionTest(test_util.TensorFlowTestCase):
with session.Session():
a = constant_op.constant(1.0, shape=[1, 2])
b = constant_op.constant(2.0, shape=[1, 2], name='b')
- v = variables.Variable(a, a.dtype)
+ v = variables.VariableV1(a, a.dtype)
assign_a_to_v = state_ops.assign(v, a)
assign_a_to_v.eval()
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index 74fe1fe35c..349c84e13c 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 25)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 10, 8)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/BUILD b/tensorflow/python/data/BUILD
index 138141f4fc..e32eeecbb8 100644
--- a/tensorflow/python/data/BUILD
+++ b/tensorflow/python/data/BUILD
@@ -10,6 +10,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
+ "//tensorflow/python/data/experimental",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:multi_device_iterator_ops",
diff --git a/tensorflow/python/data/__init__.py b/tensorflow/python/data/__init__.py
index f8b561205e..7536ba668a 100644
--- a/tensorflow/python/data/__init__.py
+++ b/tensorflow/python/data/__init__.py
@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
+from tensorflow.python.data import experimental
from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.data.ops.iterator_ops import Iterator
from tensorflow.python.data.ops.readers import FixedLengthRecordDataset
diff --git a/tensorflow/python/data/experimental/BUILD b/tensorflow/python/data/experimental/BUILD
new file mode 100644
index 0000000000..84e761d376
--- /dev/null
+++ b/tensorflow/python/data/experimental/BUILD
@@ -0,0 +1,16 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "experimental",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/__init__.py b/tensorflow/python/data/experimental/__init__.py
new file mode 100644
index 0000000000..2ac159d38a
--- /dev/null
+++ b/tensorflow/python/data/experimental/__init__.py
@@ -0,0 +1,109 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for building input pipelines.
+
+This module contains experimental `Dataset` sources and transformations that can
+be used in conjunction with the `tf.data.Dataset` API. Note that the
+`tf.data.experimental` API is not subject to the same backwards compatibility
+guarantees as `tf.data`, but we will provide deprecation advice in advance of
+removing existing functionality.
+
+See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
+
+@@Counter
+@@CheckpointInputPipelineHook
+@@CsvDataset
+@@Optional
+@@RandomDataset
+@@Reducer
+@@SqlDataset
+@@TFRecordWriter
+
+@@bucket_by_sequence_length
+@@choose_from_datasets
+@@copy_to_device
+@@dense_to_sparse_batch
+@@enumerate_dataset
+@@get_next_as_optional
+@@get_single_element
+@@group_by_reducer
+@@group_by_window
+@@ignore_errors
+@@latency_stats
+@@make_batched_features_dataset
+@@make_csv_dataset
+@@make_saveable_from_iterator
+@@map_and_batch
+@@parallel_interleave
+@@parse_example_dataset
+@@prefetch_to_device
+@@rejection_resample
+@@sample_from_datasets
+@@scan
+@@set_stats_aggregator
+@@shuffle_and_repeat
+@@StatsAggregator
+@@unbatch
+@@unique
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+
+from tensorflow.python.data.experimental.ops.batching import dense_to_sparse_batch
+from tensorflow.python.data.experimental.ops.batching import map_and_batch
+from tensorflow.python.data.experimental.ops.batching import unbatch
+from tensorflow.python.data.experimental.ops.counter import Counter
+from tensorflow.python.data.experimental.ops.enumerate_ops import enumerate_dataset
+from tensorflow.python.data.experimental.ops.error_ops import ignore_errors
+from tensorflow.python.data.experimental.ops.get_single_element import get_single_element
+from tensorflow.python.data.experimental.ops.grouping import bucket_by_sequence_length
+from tensorflow.python.data.experimental.ops.grouping import group_by_reducer
+from tensorflow.python.data.experimental.ops.grouping import group_by_window
+from tensorflow.python.data.experimental.ops.grouping import Reducer
+from tensorflow.python.data.experimental.ops.interleave_ops import choose_from_datasets
+from tensorflow.python.data.experimental.ops.interleave_ops import parallel_interleave
+from tensorflow.python.data.experimental.ops.interleave_ops import sample_from_datasets
+from tensorflow.python.data.experimental.ops.iterator_ops import CheckpointInputPipelineHook
+from tensorflow.python.data.experimental.ops.iterator_ops import make_saveable_from_iterator
+
+# Optimization constant that can be used to enable auto-tuning.
+from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE
+
+from tensorflow.python.data.experimental.ops.parsing_ops import parse_example_dataset
+from tensorflow.python.data.experimental.ops.prefetching_ops import copy_to_device
+from tensorflow.python.data.experimental.ops.prefetching_ops import prefetch_to_device
+from tensorflow.python.data.experimental.ops.random_ops import RandomDataset
+from tensorflow.python.data.experimental.ops.readers import CsvDataset
+from tensorflow.python.data.experimental.ops.readers import make_batched_features_dataset
+from tensorflow.python.data.experimental.ops.readers import make_csv_dataset
+from tensorflow.python.data.experimental.ops.readers import SqlDataset
+from tensorflow.python.data.experimental.ops.resampling import rejection_resample
+from tensorflow.python.data.experimental.ops.scan_ops import scan
+from tensorflow.python.data.experimental.ops.shuffle_ops import shuffle_and_repeat
+from tensorflow.python.data.experimental.ops.stats_ops import latency_stats
+from tensorflow.python.data.experimental.ops.stats_ops import set_stats_aggregator
+from tensorflow.python.data.experimental.ops.stats_ops import StatsAggregator
+from tensorflow.python.data.experimental.ops.unique import unique
+from tensorflow.python.data.experimental.ops.writers import TFRecordWriter
+from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
+from tensorflow.python.data.ops.optional_ops import Optional
+# pylint: enable=unused-import
+
+from tensorflow.python.util.all_util import remove_undocumented
+remove_undocumented(__name__)
diff --git a/tensorflow/python/data/experimental/benchmarks/BUILD b/tensorflow/python/data/experimental/benchmarks/BUILD
new file mode 100644
index 0000000000..b9398aebe7
--- /dev/null
+++ b/tensorflow/python/data/experimental/benchmarks/BUILD
@@ -0,0 +1,25 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "map_benchmark",
+ size = "medium",
+ srcs = ["map_benchmark.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/benchmarks/map_benchmark.py b/tensorflow/python/data/experimental/benchmarks/map_benchmark.py
new file mode 100644
index 0000000000..ad253cffa5
--- /dev/null
+++ b/tensorflow/python/data/experimental/benchmarks/map_benchmark.py
@@ -0,0 +1,245 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import hashlib
+import itertools
+import time
+
+import numpy as np
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+_NUMPY_RANDOM_SEED = 42
+
+
+class MapDatasetBenchmark(test.Benchmark):
+
+ # The purpose of this benchmark is to compare the performance of chaining vs
+ # fusing of the map and batch transformations across various configurations.
+ #
+ # NOTE: It is recommended to build the benchmark with
+ # `-c opt --copt=-mavx --copt=-mavx2 --copt=-mfma --copt=-gmlt`
+ # and execute it on a machine with at least 32 CPU cores.
+ def benchmarkMapAndBatch(self):
+
+ # Sequential pipeline configurations.
+ seq_elem_size_series = itertools.product([1], [1], [1, 2, 4, 8], [16])
+ seq_batch_size_series = itertools.product([1], [1], [1], [8, 16, 32, 64])
+
+ # Parallel pipeline configuration.
+ par_elem_size_series = itertools.product([32], [32], [1, 2, 4, 8], [256])
+ par_batch_size_series = itertools.product([32], [32], [1],
+ [128, 256, 512, 1024])
+ par_num_calls_series = itertools.product([8, 16, 32, 64], [32], [1], [512])
+ par_inter_op_series = itertools.product([32], [8, 16, 32, 64], [1], [512])
+
+ def name(method, label, num_calls, inter_op, element_size, batch_size):
+ return ("%s_id_%s_num_calls_%d_inter_op_%d_elem_size_%d_batch_size_%d" % (
+ method,
+ hashlib.sha1(label).hexdigest(),
+ num_calls,
+ inter_op,
+ element_size,
+ batch_size,
+ ))
+
+ def benchmark(label, series):
+
+ print("%s:" % label)
+ for num_calls, inter_op, element_size, batch_size in series:
+
+ num_iters = 1024 // (
+ (element_size * batch_size) // min(num_calls, inter_op))
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(
+ element_size, 4 * k), np.random.rand(4 * k, 1))).repeat()
+
+ chained_dataset = dataset.map(
+ math_ops.matmul,
+ num_parallel_calls=num_calls).batch(batch_size=batch_size)
+ chained_iterator = chained_dataset.make_one_shot_iterator()
+ chained_get_next = chained_iterator.get_next()
+
+ chained_deltas = []
+ with session.Session(
+ config=config_pb2.ConfigProto(
+ inter_op_parallelism_threads=inter_op,
+ use_per_session_threads=True)) as sess:
+ for _ in range(5):
+ sess.run(chained_get_next.op)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(chained_get_next.op)
+ end = time.time()
+ chained_deltas.append(end - start)
+
+ fused_dataset = dataset.apply(
+ batching.map_and_batch(
+ math_ops.matmul,
+ num_parallel_calls=num_calls,
+ batch_size=batch_size))
+ fused_iterator = fused_dataset.make_one_shot_iterator()
+ fused_get_next = fused_iterator.get_next()
+
+ fused_deltas = []
+ with session.Session(
+ config=config_pb2.ConfigProto(
+ inter_op_parallelism_threads=inter_op,
+ use_per_session_threads=True)) as sess:
+
+ for _ in range(5):
+ sess.run(fused_get_next.op)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(fused_get_next.op)
+ end = time.time()
+ fused_deltas.append(end - start)
+
+ print(
+ "batch size: %d, num parallel calls: %d, inter-op parallelism: %d, "
+ "element size: %d, num iters: %d\nchained wall time: %f (median), "
+ "%f (mean), %f (stddev), %f (min), %f (max)\n fused wall time: "
+ "%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n "
+ "chained/fused: %.2fx (median), %.2fx (mean)" %
+ (batch_size, num_calls, inter_op, element_size, num_iters,
+ np.median(chained_deltas), np.mean(chained_deltas),
+ np.std(chained_deltas), np.min(chained_deltas),
+ np.max(chained_deltas), np.median(fused_deltas),
+ np.mean(fused_deltas), np.std(fused_deltas), np.min(fused_deltas),
+ np.max(fused_deltas),
+ np.median(chained_deltas) / np.median(fused_deltas),
+ np.mean(chained_deltas) / np.mean(fused_deltas)))
+
+ self.report_benchmark(
+ iters=num_iters,
+ wall_time=np.median(chained_deltas),
+ name=name("chained", label, num_calls, inter_op, element_size,
+ batch_size))
+
+ self.report_benchmark(
+ iters=num_iters,
+ wall_time=np.median(fused_deltas),
+ name=name("fused", label, num_calls, inter_op, element_size,
+ batch_size))
+
+ print("")
+
+ np.random.seed(_NUMPY_RANDOM_SEED)
+ benchmark("Sequential element size evaluation", seq_elem_size_series)
+ benchmark("Sequential batch size evaluation", seq_batch_size_series)
+ benchmark("Parallel element size evaluation", par_elem_size_series)
+ benchmark("Parallel batch size evaluation", par_batch_size_series)
+ benchmark("Transformation parallelism evaluation", par_num_calls_series)
+ benchmark("Threadpool size evaluation", par_inter_op_series)
+
+ # This benchmark compares the performance of pipeline with multiple chained
+ # maps with and without map fusion.
+ def benchmarkChainOfMaps(self):
+ chain_lengths = [0, 1, 2, 5, 10, 20, 50]
+ for chain_length in chain_lengths:
+ self._benchmarkChainOfMaps(chain_length, False)
+ self._benchmarkChainOfMaps(chain_length, True)
+
+ def _benchmarkChainOfMaps(self, chain_length, optimize_dataset):
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset.map(lambda x: x)
+ if optimize_dataset:
+ dataset = dataset.apply(optimization.optimize(["map_fusion"]))
+
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(5):
+ sess.run(next_element.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ opt_mark = "opt" if optimize_dataset else "no-opt"
+ print("Map dataset {} chain length: {} Median wall time: {}".format(
+ opt_mark, chain_length, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="benchmark_map_dataset_chain_latency_{}_{}".format(
+ opt_mark, chain_length))
+
+
+class MapAndFilterBenchmark(test.Benchmark):
+
+ # This benchmark compares the performance of pipeline with multiple chained
+ # map + filter with and without map fusion.
+ def benchmarkMapAndFilter(self):
+ chain_lengths = [0, 1, 2, 5, 10, 20, 50]
+ for chain_length in chain_lengths:
+ self._benchmarkMapAndFilter(chain_length, False)
+ self._benchmarkMapAndFilter(chain_length, True)
+
+ def _benchmarkMapAndFilter(self, chain_length, optimize_dataset):
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset.map(lambda x: x + 5).filter(
+ lambda x: math_ops.greater_equal(x - 5, 0))
+ if optimize_dataset:
+ dataset = dataset.apply(
+ optimization.optimize(["map_and_filter_fusion"]))
+
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(10):
+ sess.run(next_element.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ opt_mark = "opt" if optimize_dataset else "no-opt"
+ print("Map and filter dataset {} chain length: {} Median wall time: {}".
+ format(opt_mark, chain_length, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="benchmark_map_and_filter_dataset_chain_latency_{}_{}".format(
+ opt_mark, chain_length))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
new file mode 100644
index 0000000000..4eef9580ad
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -0,0 +1,687 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "bucket_by_sequence_length_test",
+ size = "medium",
+ srcs = ["bucket_by_sequence_length_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "copy_to_device_test",
+ size = "small",
+ srcs = ["copy_to_device_test.py"],
+ additional_deps = [
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/compat:compat",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
+py_test(
+ name = "counter_test",
+ size = "small",
+ srcs = ["counter_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/ops:counter",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ ],
+)
+
+py_test(
+ name = "csv_dataset_test",
+ size = "medium",
+ srcs = ["csv_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:error_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/eager:context",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "dense_to_sparse_batch_test",
+ srcs = ["dense_to_sparse_batch_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "directed_interleave_dataset_test",
+ size = "medium",
+ srcs = ["directed_interleave_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "enumerate_dataset_test",
+ size = "small",
+ srcs = ["enumerate_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:enumerate_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "filter_dataset_op_test",
+ size = "medium",
+ srcs = ["filter_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "function_buffering_resource_test",
+ size = "small",
+ srcs = ["function_buffering_resource_test.py"],
+ additional_deps = [
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
+py_test(
+ name = "get_single_element_test",
+ size = "small",
+ srcs = ["get_single_element_test.py"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:get_single_element",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "group_by_reducer_test",
+ size = "medium",
+ srcs = ["group_by_reducer_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "group_by_window_test",
+ size = "medium",
+ srcs = ["group_by_window_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "ignore_errors_test",
+ srcs = ["ignore_errors_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:error_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "indexed_dataset_ops_test",
+ srcs = ["indexed_dataset_ops_test.py"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/experimental/ops:indexed_dataset_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "make_batched_features_dataset_test",
+ size = "medium",
+ srcs = ["make_batched_features_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "make_csv_dataset_test",
+ size = "medium",
+ srcs = ["make_csv_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "make_tf_record_dataset_test",
+ size = "medium",
+ srcs = ["make_tf_record_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_test(
+ name = "map_and_batch_test",
+ size = "medium",
+ srcs = ["map_and_batch_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "map_defun_op_test",
+ size = "small",
+ srcs = ["map_defun_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:map_defun",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ ],
+)
+
+py_test(
+ name = "override_threadpool_test",
+ size = "small",
+ srcs = ["override_threadpool_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python/data/experimental/ops:threadpool",
+ "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "parallel_interleave_test",
+ size = "medium",
+ srcs = ["parallel_interleave_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "notap",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "parse_example_dataset_test",
+ size = "small",
+ srcs = ["parse_example_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:parsing_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "prefetch_to_device_test",
+ size = "small",
+ srcs = ["prefetch_to_device_test.py"],
+ additional_deps = [
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+ tags = ["no_windows_gpu"],
+)
+
+py_library(
+ name = "reader_dataset_ops_test_base",
+ testonly = 1,
+ srcs = [
+ "reader_dataset_ops_test_base.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow/python/data/experimental/kernel_tests:__pkg__",
+ "//tensorflow/python/data/experimental/kernel_tests/serialization:__pkg__",
+ ],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "rejection_resample_test",
+ size = "medium",
+ srcs = ["rejection_resample_test.py"],
+ shard_count = 2,
+ srcs_version = "PY2AND3",
+ tags = [
+ "noasan",
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:resampling",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "restructured_dataset_test",
+ size = "medium",
+ srcs = ["restructured_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_test(
+ name = "scan_test",
+ size = "small",
+ srcs = ["scan_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:scan_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/eager:context",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "shuffle_and_repeat_test",
+ size = "medium",
+ srcs = ["shuffle_and_repeat_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_pip",
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:shuffle_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "sql_dataset_test_base",
+ srcs = ["sql_dataset_test_base.py"],
+ srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow/python/data/experimental/kernel_tests:__pkg__",
+ "//tensorflow/python/data/experimental/kernel_tests/serialization:__pkg__",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "@org_sqlite//:python",
+ ],
+)
+
+py_test(
+ name = "sql_dataset_test",
+ size = "small",
+ srcs = ["sql_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":sql_dataset_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ ],
+)
+
+py_test(
+ name = "stats_dataset_ops_test",
+ size = "medium",
+ srcs = ["stats_dataset_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ ":reader_dataset_ops_test_base",
+ ":stats_dataset_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:stats_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "stats_dataset_test_base",
+ srcs = ["stats_dataset_test_base.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ ],
+)
+
+py_test(
+ name = "tf_record_writer_test",
+ size = "small",
+ srcs = ["tf_record_writer_test.py"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:writers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "unbatch_test",
+ size = "medium",
+ srcs = ["unbatch_test.py"],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "unique_test",
+ size = "small",
+ srcs = ["unique_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py
new file mode 100644
index 0000000000..3903ec49b9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/bucket_by_sequence_length_test.py
@@ -0,0 +1,322 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.bucket_by_sequence_length()."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import random
+
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+def _element_length_fn(x, y=None):
+ del y
+ return array_ops.shape(x)[0]
+
+
+def _to_sparse_tensor(record):
+ return sparse_tensor.SparseTensor(**record)
+
+
+def _format_record(array, sparse):
+ if sparse:
+ return {
+ "values": array,
+ "indices": [[i] for i in range(len(array))],
+ "dense_shape": (len(array),)
+ }
+ return array
+
+
+def _get_record_type(sparse):
+ if sparse:
+ return {
+ "values": dtypes.int64,
+ "indices": dtypes.int64,
+ "dense_shape": dtypes.int64
+ }
+ return dtypes.int32
+
+
+def _get_record_shape(sparse):
+ if sparse:
+ return {
+ "values": tensor_shape.TensorShape([None,]),
+ "indices": tensor_shape.TensorShape([None, 1]),
+ "dense_shape": tensor_shape.TensorShape([1,])
+ }
+ return tensor_shape.TensorShape([None])
+
+
+class BucketBySequenceLengthTest(test_base.DatasetTestBase):
+
+ def testBucket(self):
+
+ boundaries = [10, 20, 30]
+ batch_sizes = [10, 8, 4, 2]
+ lengths = [8, 13, 25, 35]
+
+ def build_dataset(sparse):
+ def _generator():
+ # Produce 1 batch for each bucket
+ elements = []
+ for batch_size, length in zip(batch_sizes, lengths):
+ record_len = length - 1
+ for _ in range(batch_size):
+ elements.append([1] * record_len)
+ record_len = length
+ random.shuffle(elements)
+ for el in elements:
+ yield (_format_record(el, sparse),)
+ dataset = dataset_ops.Dataset.from_generator(
+ _generator,
+ (_get_record_type(sparse),),
+ (_get_record_shape(sparse),))
+ if sparse:
+ dataset = dataset.map(lambda x: (_to_sparse_tensor(x),))
+ return dataset
+
+ def _test_bucket_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(
+ grouping.bucket_by_sequence_length(
+ _element_length_fn,
+ boundaries,
+ batch_sizes,
+ no_padding=no_padding))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(4):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+ batch_sizes_val = []
+ lengths_val = []
+ for batch in batches:
+ shape = batch.dense_shape if no_padding else batch.shape
+ batch_size = shape[0]
+ length = shape[1]
+ batch_sizes_val.append(batch_size)
+ lengths_val.append(length)
+ sum_check = batch.values.sum() if no_padding else batch.sum()
+ self.assertEqual(sum_check, batch_size * length - 1)
+ self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
+ self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
+ self.assertEqual(sorted(lengths), sorted(lengths_val))
+
+ for no_padding in (True, False):
+ _test_bucket_by_padding(no_padding)
+
+ def testPadToBoundary(self):
+
+ boundaries = [10, 20, 30]
+ batch_sizes = [10, 8, 4, 2]
+ lengths = [8, 13, 25]
+
+ def element_gen():
+ # Produce 1 batch for each bucket
+ elements = []
+ for batch_size, length in zip(batch_sizes[:-1], lengths):
+ for _ in range(batch_size):
+ elements.append([1] * length)
+ random.shuffle(elements)
+ for el in elements:
+ yield (el,)
+ for _ in range(batch_sizes[-1]):
+ el = [1] * (boundaries[-1] + 5)
+ yield (el,)
+
+ element_len = lambda el: array_ops.shape(el)[0]
+ dataset = dataset_ops.Dataset.from_generator(
+ element_gen, (dtypes.int64,), ([None],)).apply(
+ grouping.bucket_by_sequence_length(
+ element_len, boundaries, batch_sizes,
+ pad_to_bucket_boundary=True))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(3):
+ batches.append(sess.run(batch))
+ with self.assertRaisesOpError("bucket_boundaries"):
+ sess.run(batch)
+ batch_sizes_val = []
+ lengths_val = []
+ for batch in batches:
+ batch_size = batch.shape[0]
+ length = batch.shape[1]
+ batch_sizes_val.append(batch_size)
+ lengths_val.append(length)
+ batch_sizes = batch_sizes[:-1]
+ self.assertEqual(sum(batch_sizes_val), sum(batch_sizes))
+ self.assertEqual(sorted(batch_sizes), sorted(batch_sizes_val))
+ self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
+ sorted(lengths_val))
+
+ def testPadToBoundaryNoExtraneousPadding(self):
+
+ boundaries = [3, 7, 11]
+ batch_sizes = [2, 2, 2, 2]
+ lengths = range(1, 11)
+
+ def element_gen():
+ for length in lengths:
+ yield ([1] * length,)
+
+ element_len = lambda element: array_ops.shape(element)[0]
+ dataset = dataset_ops.Dataset.from_generator(
+ element_gen, (dtypes.int64,), ([None],)).apply(
+ grouping.bucket_by_sequence_length(
+ element_len, boundaries, batch_sizes,
+ pad_to_bucket_boundary=True))
+ batch, = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ batches = []
+ for _ in range(5):
+ batches.append(sess.run(batch))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(batch)
+
+ self.assertAllEqual(batches[0], [[1, 0],
+ [1, 1]])
+ self.assertAllEqual(batches[1], [[1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 1, 0, 0]])
+ self.assertAllEqual(batches[2], [[1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 1]])
+ self.assertAllEqual(batches[3], [[1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
+ [1, 1, 1, 1, 1, 1, 1, 1, 0, 0]])
+ self.assertAllEqual(batches[4], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
+ [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
+
+ def testTupleElements(self):
+
+ def build_dataset(sparse):
+ def _generator():
+ text = [[1, 2, 3], [3, 4, 5, 6, 7], [1, 2], [8, 9, 0, 2, 3]]
+ label = [1, 2, 1, 2]
+ for x, y in zip(text, label):
+ yield (_format_record(x, sparse), y)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=_generator,
+ output_types=(_get_record_type(sparse), dtypes.int32),
+ output_shapes=(_get_record_shape(sparse),
+ tensor_shape.TensorShape([])))
+ if sparse:
+ dataset = dataset.map(lambda x, y: (_to_sparse_tensor(x), y))
+ return dataset
+
+ def _test_tuple_elements_by_padding(no_padding):
+ dataset = build_dataset(sparse=no_padding)
+ dataset = dataset.apply(grouping.bucket_by_sequence_length(
+ element_length_func=_element_length_fn,
+ bucket_batch_sizes=[2, 2, 2],
+ bucket_boundaries=[0, 8],
+ no_padding=no_padding))
+ shapes = dataset.output_shapes
+ self.assertEqual([None, None], shapes[0].as_list())
+ self.assertEqual([None], shapes[1].as_list())
+
+ for no_padding in (True, False):
+ _test_tuple_elements_by_padding(no_padding)
+
+ def testBucketSparse(self):
+ """Tests bucketing of sparse tensors (case where `no_padding` == True).
+
+ Test runs on following dataset:
+ [
+ [0],
+ [0, 1],
+ [0, 1, 2]
+ ...
+ [0, ..., max_len - 1]
+ ]
+ Sequences are bucketed by length and batched with
+ `batch_size` < `bucket_size`.
+ """
+
+ min_len = 0
+ max_len = 100
+ batch_size = 7
+ bucket_size = 10
+
+ def _build_dataset():
+ input_data = [range(i+1) for i in range(min_len, max_len)]
+ def generator_fn():
+ for record in input_data:
+ yield _format_record(record, sparse=True)
+ dataset = dataset_ops.Dataset.from_generator(
+ generator=generator_fn,
+ output_types=_get_record_type(sparse=True))
+ dataset = dataset.map(_to_sparse_tensor)
+ return dataset
+
+ def _compute_expected_batches():
+ """Computes expected batch outputs and stores in a set."""
+ all_expected_sparse_tensors = set()
+ for bucket_start_len in range(min_len, max_len, bucket_size):
+ for batch_offset in range(0, bucket_size, batch_size):
+ batch_start_len = bucket_start_len + batch_offset
+ batch_end_len = min(batch_start_len + batch_size,
+ bucket_start_len + bucket_size)
+ expected_indices = []
+ expected_values = []
+ for length in range(batch_start_len, batch_end_len):
+ for val in range(length + 1):
+ expected_indices.append((length - batch_start_len, val))
+ expected_values.append(val)
+ expected_sprs_tensor = (tuple(expected_indices),
+ tuple(expected_values))
+ all_expected_sparse_tensors.add(expected_sprs_tensor)
+ return all_expected_sparse_tensors
+
+ def _compute_batches(dataset):
+ """Computes actual batch outputs of dataset and stores in a set."""
+ batch = dataset.make_one_shot_iterator().get_next()
+ all_sparse_tensors = set()
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ output = sess.run(batch)
+ sprs_tensor = (tuple([tuple(idx) for idx in output.indices]),
+ tuple(output.values))
+ all_sparse_tensors.add(sprs_tensor)
+ return all_sparse_tensors
+
+ dataset = _build_dataset()
+ boundaries = range(min_len + bucket_size + 1, max_len, bucket_size)
+ dataset = dataset.apply(grouping.bucket_by_sequence_length(
+ _element_length_fn,
+ boundaries,
+ [batch_size] * (len(boundaries) + 1),
+ no_padding=True))
+ batches = _compute_batches(dataset)
+ expected_batches = _compute_expected_batches()
+ self.assertEqual(batches, expected_batches)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py
new file mode 100644
index 0000000000..adfacf1c9f
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/copy_to_device_test.py
@@ -0,0 +1,533 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.copy_to_device()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.compat import compat
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class CopyToDeviceTest(test_base.DatasetTestBase):
+
+ def testCopyToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceInt32(self):
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int32, next_element.dtype)
+ self.assertEqual((4,), next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToSameDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:0"))
+
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyDictToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyDictToDeviceWithPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopySparseTensorsToDevice(self):
+
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
+
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopySparseTensorsToDeviceWithPrefetch(self):
+
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i * [1]), dense_shape=[2, 2])
+
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuInt32(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuInt32AndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors([0, 1, 2, 3])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([0, 1, 2, 3], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuStrings(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuStringsAndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.from_tensors(["a", "b", "c"])
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertAllEqual([b"a", b"b", b"c"], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDevicePingPongCPUGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ with compat.forward_compatibility_horizon(2018, 8, 4):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0", source_device="/cpu:0"))
+ back_to_cpu_dataset = device_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:0", source_device="/gpu:0"))
+
+ with ops.device("/cpu:0"):
+ iterator = back_to_cpu_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithReInit(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1"))
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceWithReInitAndPrefetch(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/cpu:1")).prefetch(1)
+
+ with ops.device("/cpu:1"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithReInit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testCopyToDeviceGpuWithReInitAndPrefetch(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0")).prefetch(1)
+
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testIteratorGetNextAsOptionalOnGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(3)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.copy_to_device("/gpu:0"))
+ with ops.device("/gpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+ next_elem = iterator_ops.get_next_as_optional(iterator)
+ elem_has_value_t = next_elem.has_value()
+ elem_value_t = next_elem.get_value()
+
+ with self.cached_session() as sess:
+ # Before initializing the iterator, evaluating the optional fails with
+ # a FailedPreconditionError.
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess.run(elem_has_value_t)
+ with self.assertRaises(errors.FailedPreconditionError):
+ sess.run(elem_value_t)
+
+ # For each element of the dataset, assert that the optional evaluates to
+ # the expected value.
+ sess.run(iterator.initializer)
+ for i in range(3):
+ elem_has_value, elem_value = sess.run([elem_has_value_t, elem_value_t])
+ self.assertTrue(elem_has_value)
+ self.assertEqual(i, elem_value)
+
+ # After exhausting the iterator, `next_elem.has_value()` will evaluate to
+ # false, and attempting to get the value will fail.
+ for _ in range(2):
+ self.assertFalse(sess.run(elem_has_value_t))
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(elem_value_t)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/counter_test.py b/tensorflow/python/data/experimental/kernel_tests/counter_test.py
new file mode 100644
index 0000000000..4e114ac479
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/counter_test.py
@@ -0,0 +1,51 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.Counter`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import counter
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.framework import dtypes
+from tensorflow.python.platform import test
+
+
+class CounterTest(test_base.DatasetTestBase):
+
+ def testCounter(self):
+ """Test dataset construction using `count`."""
+ iterator = (counter.Counter(start=3, step=4)
+ .make_one_shot_iterator())
+ get_next = iterator.get_next()
+ self.assertEqual([], get_next.shape.as_list())
+ self.assertEqual(dtypes.int64, get_next.dtype)
+
+ negative_iterator = (counter.Counter(start=0, step=-1)
+ .make_one_shot_iterator())
+ negative_get_next = negative_iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertEqual(3, sess.run(get_next))
+ self.assertEqual(3 + 4, sess.run(get_next))
+ self.assertEqual(3 + 2 * 4, sess.run(get_next))
+
+ self.assertEqual(0, sess.run(negative_get_next))
+ self.assertEqual(-1, sess.run(negative_get_next))
+ self.assertEqual(-2, sess.run(negative_get_next))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py
new file mode 100644
index 0000000000..fb75be1fbc
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_test.py
@@ -0,0 +1,632 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.CsvDataset`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import string
+import tempfile
+import time
+import zlib
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import googletest
+from tensorflow.python.platform import test
+
+
+@test_util.run_all_in_graph_and_eager_modes
+class CsvDatasetTest(test_base.DatasetTestBase):
+
+ def _setup_files(self, inputs, linebreak='\n', compression_type=None):
+ filenames = []
+ for i, ip in enumerate(inputs):
+ fn = os.path.join(self.get_temp_dir(), 'temp_%d.csv' % i)
+ contents = linebreak.join(ip).encode('utf-8')
+ if compression_type is None:
+ with open(fn, 'wb') as f:
+ f.write(contents)
+ elif compression_type == 'GZIP':
+ with gzip.GzipFile(fn, 'wb') as f:
+ f.write(contents)
+ elif compression_type == 'ZLIB':
+ contents = zlib.compress(contents)
+ with open(fn, 'wb') as f:
+ f.write(contents)
+ else:
+ raise ValueError('Unsupported compression_type', compression_type)
+ filenames.append(fn)
+ return filenames
+
+ def _make_test_datasets(self, inputs, **kwargs):
+ # Test by comparing its output to what we could get with map->decode_csv
+ filenames = self._setup_files(inputs)
+ dataset_expected = core_readers.TextLineDataset(filenames)
+ dataset_expected = dataset_expected.map(
+ lambda l: parsing_ops.decode_csv(l, **kwargs))
+ dataset_actual = readers.CsvDataset(filenames, **kwargs)
+ return (dataset_actual, dataset_expected)
+
+ def _test_by_comparison(self, inputs, **kwargs):
+ """Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
+ dataset_actual, dataset_expected = self._make_test_datasets(
+ inputs, **kwargs)
+ self.assertDatasetsEqual(dataset_actual, dataset_expected)
+
+ def _verify_output_or_err(self,
+ dataset,
+ expected_output=None,
+ expected_err_re=None):
+ if expected_err_re is None:
+ # Verify that output is expected, without errors
+ nxt = self.getNext(dataset)
+ expected_output = [[
+ v.encode('utf-8') if isinstance(v, str) else v for v in op
+ ] for op in expected_output]
+ for value in expected_output:
+ op = self.evaluate(nxt())
+ self.assertAllEqual(op, value)
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(nxt())
+ else:
+ # Verify that OpError is produced as expected
+ with self.assertRaisesOpError(expected_err_re):
+ nxt = self.getNext(dataset)
+ while True:
+ try:
+ self.evaluate(nxt())
+ except errors.OutOfRangeError:
+ break
+
+ def _test_dataset(
+ self,
+ inputs,
+ expected_output=None,
+ expected_err_re=None,
+ linebreak='\n',
+ compression_type=None, # Used for both setup and parsing
+ **kwargs):
+ """Checks that elements produced by CsvDataset match expected output."""
+ # Convert str type because py3 tf strings are bytestrings
+ filenames = self._setup_files(inputs, linebreak, compression_type)
+ kwargs['compression_type'] = compression_type
+ dataset = readers.CsvDataset(filenames, **kwargs)
+ self._verify_output_or_err(dataset, expected_output, expected_err_re)
+
+ def testCsvDataset_requiredFields(self):
+ record_defaults = [[]] * 4
+ inputs = [['1,2,3,4']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_int(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1,2,3,4', '5,6,7,8']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_float(self):
+ record_defaults = [[0.0]] * 4
+ inputs = [['1.0,2.1,3.2,4.3', '5.4,6.5,7.6,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_string(self):
+ record_defaults = [['']] * 4
+ inputs = [['1.0,2.1,hello,4.3', '5.4,6.5,goodbye,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_withEmptyFields(self):
+ record_defaults = [[0]] * 4
+ inputs = [[',,,', '1,1,1,', ',2,2,2']]
+ self._test_dataset(
+ inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+ record_defaults=record_defaults)
+
+ def testCsvDataset_errWithUnquotedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['1,2"3,4']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Unquoted fields cannot have quotes inside',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_errWithUnescapedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['"a"b","c","d"']]
+ self._test_dataset(
+ inputs,
+ expected_err_re=
+ 'Quote inside a string has to be escaped by another quote',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_ignoreErrWithUnescapedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['1,"2"3",4', '1,"2"3",4",5,5', 'a,b,"c"d"', 'e,f,g']]
+ filenames = self._setup_files(inputs)
+ dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+ dataset = dataset.apply(error_ops.ignore_errors())
+ self._verify_output_or_err(dataset, [['e', 'f', 'g']])
+
+ def testCsvDataset_ignoreErrWithUnquotedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['1,2"3,4', 'a,b,c"d', '9,8"7,6,5', 'e,f,g']]
+ filenames = self._setup_files(inputs)
+ dataset = readers.CsvDataset(filenames, record_defaults=record_defaults)
+ dataset = dataset.apply(error_ops.ignore_errors())
+ self._verify_output_or_err(dataset, [['e', 'f', 'g']])
+
+ def testCsvDataset_withNoQuoteDelimAndUnquotedQuotes(self):
+ record_defaults = [['']] * 3
+ inputs = [['1,2"3,4']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, use_quote_delim=False)
+
+ def testCsvDataset_mixedTypes(self):
+ record_defaults = [
+ constant_op.constant([], dtype=dtypes.int32),
+ constant_op.constant([], dtype=dtypes.float32),
+ constant_op.constant([], dtype=dtypes.string),
+ constant_op.constant([], dtype=dtypes.float64)
+ ]
+ inputs = [['1,2.1,3.2,4.3', '5,6.5,7.6,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_withUseQuoteDelimFalse(self):
+ record_defaults = [['']] * 4
+ inputs = [['1,2,"3,4"', '"5,6",7,8']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, use_quote_delim=False)
+
+ def testCsvDataset_withFieldDelim(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1:2:3:4', '5:6:7:8']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, field_delim=':')
+
+ def testCsvDataset_withNaValue(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1,NA,3,4', 'NA,6,7,8']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, na_value='NA')
+
+ def testCsvDataset_withSelectCols(self):
+ record_defaults = [['']] * 2
+ inputs = [['1,2,3,4', '"5","6","7","8"']]
+ self._test_by_comparison(
+ inputs, record_defaults=record_defaults, select_cols=[1, 2])
+
+ def testCsvDataset_withSelectColsTooHigh(self):
+ record_defaults = [[0]] * 2
+ inputs = [['1,2,3,4', '5,6,7,8']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Expect 2 fields but have 1 in record',
+ record_defaults=record_defaults,
+ select_cols=[3, 4])
+
+ def testCsvDataset_withOneCol(self):
+ record_defaults = [['NA']]
+ inputs = [['0', '', '2']]
+ self._test_dataset(
+ inputs, [['0'], ['NA'], ['2']], record_defaults=record_defaults)
+
+ def testCsvDataset_withMultipleFiles(self):
+ record_defaults = [[0]] * 4
+ inputs = [['1,2,3,4', '5,6,7,8'], ['5,6,7,8']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_withLeadingAndTrailingSpaces(self):
+ record_defaults = [[0.0]] * 4
+ inputs = [['0, 1, 2, 3']]
+ expected = [[0.0, 1.0, 2.0, 3.0]]
+ self._test_dataset(inputs, expected, record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithMissingDefault(self):
+ record_defaults = [[]] * 2
+ inputs = [['0,']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Field 1 is required but missing in record!',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithFewerDefaultsThanFields(self):
+ record_defaults = [[0.0]] * 2
+ inputs = [['0,1,2,3']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Expect 2 fields but have more in record',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithMoreDefaultsThanFields(self):
+ record_defaults = [[0.0]] * 5
+ inputs = [['0,1,2,3']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='Expect 5 fields but have 4 in record',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withHeader(self):
+ record_defaults = [[0]] * 2
+ inputs = [['col1,col2', '1,2']]
+ expected = [[1, 2]]
+ self._test_dataset(
+ inputs,
+ expected,
+ record_defaults=record_defaults,
+ header=True,
+ )
+
+ def testCsvDataset_withHeaderAndNoRecords(self):
+ record_defaults = [[0]] * 2
+ inputs = [['col1,col2']]
+ expected = []
+ self._test_dataset(
+ inputs,
+ expected,
+ record_defaults=record_defaults,
+ header=True,
+ )
+
+ def testCsvDataset_errorWithHeaderEmptyFile(self):
+ record_defaults = [[0]] * 2
+ inputs = [[]]
+ expected_err_re = "Can't read header of file"
+ self._test_dataset(
+ inputs,
+ expected_err_re=expected_err_re,
+ record_defaults=record_defaults,
+ header=True,
+ )
+
+ def testCsvDataset_withEmptyFile(self):
+ record_defaults = [['']] * 2
+ inputs = [['']] # Empty file
+ self._test_dataset(
+ inputs, expected_output=[], record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithEmptyRecord(self):
+ record_defaults = [['']] * 2
+ inputs = [['', '1,2']] # First record is empty
+ self._test_dataset(
+ inputs,
+ expected_err_re='Expect 2 fields but have 1 in record',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withChainedOps(self):
+ # Testing that one dataset can create multiple iterators fine.
+ # `repeat` creates multiple iterators from the same C++ Dataset.
+ record_defaults = [[0]] * 4
+ inputs = [['1,,3,4', '5,6,,8']]
+ ds_actual, ds_expected = self._make_test_datasets(
+ inputs, record_defaults=record_defaults)
+ self.assertDatasetsEqual(
+ ds_actual.repeat(5).prefetch(1),
+ ds_expected.repeat(5).prefetch(1))
+
+ def testCsvDataset_withTypeDefaults(self):
+ # Testing using dtypes as record_defaults for required fields
+ record_defaults = [dtypes.float32, [0.0]]
+ inputs = [['1.0,2.0', '3.0,4.0']]
+ self._test_dataset(
+ inputs,
+ [[1.0, 2.0], [3.0, 4.0]],
+ record_defaults=record_defaults,
+ )
+
+ def testMakeCsvDataset_fieldOrder(self):
+ data = [[
+ '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19',
+ '1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19'
+ ]]
+ file_path = self._setup_files(data)
+
+ ds = readers.make_csv_dataset(
+ file_path, batch_size=1, shuffle=False, num_epochs=1)
+ nxt = self.getNext(ds)
+
+ result = list(self.evaluate(nxt()).values())
+
+ self.assertEqual(result, sorted(result))
+
+## The following tests exercise parsing logic for quoted fields
+
+ def testCsvDataset_withQuoted(self):
+ record_defaults = [['']] * 4
+ inputs = [['"a","b","c :)","d"', '"e","f","g :(","h"']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+ def testCsvDataset_withOneColAndQuotes(self):
+ record_defaults = [['']]
+ inputs = [['"0"', '"1"', '"2"']]
+ self._test_dataset(
+ inputs, [['0'], ['1'], ['2']], record_defaults=record_defaults)
+
+ def testCsvDataset_withNewLine(self):
+ # In this case, we expect it to behave differently from
+ # TextLineDataset->map(decode_csv) since that flow has bugs
+ record_defaults = [['']] * 4
+ inputs = [['a,b,"""c""\n0","d\ne"', 'f,g,h,i']]
+ expected = [['a', 'b', '"c"\n0', 'd\ne'], ['f', 'g', 'h', 'i']]
+ self._test_dataset(inputs, expected, record_defaults=record_defaults)
+
+ def testCsvDataset_withNewLineInUnselectedCol(self):
+ record_defaults = [['']]
+ inputs = [['1,"2\n3",4', '5,6,7']]
+ self._test_dataset(
+ inputs,
+ expected_output=[['1'], ['5']],
+ record_defaults=record_defaults,
+ select_cols=[0])
+
+ def testCsvDataset_withMultipleNewLines(self):
+ # In this case, we expect it to behave differently from
+ # TextLineDataset->map(decode_csv) since that flow has bugs
+ record_defaults = [['']] * 4
+ inputs = [['a,"b\n\nx","""c""\n \n0","d\ne"', 'f,g,h,i']]
+ expected = [['a', 'b\n\nx', '"c"\n \n0', 'd\ne'], ['f', 'g', 'h', 'i']]
+ self._test_dataset(inputs, expected, record_defaults=record_defaults)
+
+ def testCsvDataset_errorWithTerminateMidRecord(self):
+ record_defaults = [['']] * 4
+ inputs = [['a,b,c,"a']]
+ self._test_dataset(
+ inputs,
+ expected_err_re=
+ 'Reached end of file without closing quoted field in record',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withEscapedQuotes(self):
+ record_defaults = [['']] * 4
+ inputs = [['1.0,2.1,"she said: ""hello""",4.3', '5.4,6.5,goodbye,8.7']]
+ self._test_by_comparison(inputs, record_defaults=record_defaults)
+
+
+## Testing that parsing works with all buffer sizes, quoted/unquoted fields,
+## and different types of line breaks
+
+ def testCsvDataset_withInvalidBufferSize(self):
+ record_defaults = [['']] * 4
+ inputs = [['a,b,c,d']]
+ self._test_dataset(
+ inputs,
+ expected_err_re='buffer_size should be positive',
+ record_defaults=record_defaults,
+ buffer_size=0)
+
+ def _test_dataset_on_buffer_sizes(self,
+ inputs,
+ expected,
+ linebreak,
+ record_defaults,
+ compression_type=None,
+ num_sizes_to_test=20):
+ # Testing reading with a range of buffer sizes that should all work.
+ for i in list(range(1, 1 + num_sizes_to_test)) + [None]:
+ self._test_dataset(
+ inputs,
+ expected,
+ linebreak=linebreak,
+ compression_type=compression_type,
+ record_defaults=record_defaults,
+ buffer_size=i)
+
+ def testCsvDataset_withLF(self):
+ record_defaults = [['NA']] * 3
+ inputs = [['abc,def,ghi', '0,1,2', ',,']]
+ expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\n', record_defaults=record_defaults)
+
+ def testCsvDataset_withCR(self):
+ # Test that when the line separator is '\r', parsing works with all buffer
+ # sizes
+ record_defaults = [['NA']] * 3
+ inputs = [['abc,def,ghi', '0,1,2', ',,']]
+ expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\r', record_defaults=record_defaults)
+
+ def testCsvDataset_withCRLF(self):
+ # Test that when the line separator is '\r\n', parsing works with all buffer
+ # sizes
+ record_defaults = [['NA']] * 3
+ inputs = [['abc,def,ghi', '0,1,2', ',,']]
+ expected = [['abc', 'def', 'ghi'], ['0', '1', '2'], ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
+
+ def testCsvDataset_withBufferSizeAndQuoted(self):
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\n', record_defaults=record_defaults)
+
+ def testCsvDataset_withCRAndQuoted(self):
+ # Test that when the line separator is '\r', parsing works with all buffer
+ # sizes
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\r', record_defaults=record_defaults)
+
+ def testCsvDataset_withCRLFAndQuoted(self):
+ # Test that when the line separator is '\r\n', parsing works with all buffer
+ # sizes
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs, expected, linebreak='\r\n', record_defaults=record_defaults)
+
+ def testCsvDataset_withGzipCompressionType(self):
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs,
+ expected,
+ linebreak='\r\n',
+ compression_type='GZIP',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withZlibCompressionType(self):
+ record_defaults = [['NA']] * 3
+ inputs = [['"\n\n\n","\r\r\r","abc"', '"0","1","2"', '"","",""']]
+ expected = [['\n\n\n', '\r\r\r', 'abc'], ['0', '1', '2'],
+ ['NA', 'NA', 'NA']]
+ self._test_dataset_on_buffer_sizes(
+ inputs,
+ expected,
+ linebreak='\r\n',
+ compression_type='ZLIB',
+ record_defaults=record_defaults)
+
+ def testCsvDataset_withScalarDefaults(self):
+ record_defaults = [constant_op.constant(0, dtype=dtypes.int64)] * 4
+ inputs = [[',,,', '1,1,1,', ',2,2,2']]
+ self._test_dataset(
+ inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+ record_defaults=record_defaults)
+
+ def testCsvDataset_with2DDefaults(self):
+ record_defaults = [constant_op.constant([[0]], dtype=dtypes.int64)] * 4
+ inputs = [[',,,', '1,1,1,', ',2,2,2']]
+
+ if context.executing_eagerly():
+ err_spec = errors.InvalidArgumentError, (
+ 'Each record default should be at '
+ 'most rank 1.')
+ else:
+ err_spec = ValueError, 'Shape must be at most rank 1 but is rank 2'
+
+ with self.assertRaisesWithPredicateMatch(*err_spec):
+ self._test_dataset(
+ inputs, [[0, 0, 0, 0], [1, 1, 1, 0], [0, 2, 2, 2]],
+ record_defaults=record_defaults)
+
+
+class CsvDatasetBenchmark(test.Benchmark):
+ """Benchmarks for the various ways of creating a dataset from CSV files.
+ """
+ FLOAT_VAL = '1.23456E12'
+ STR_VAL = string.ascii_letters * 10
+
+ def _setUp(self, str_val):
+ # Since this isn't test.TestCase, have to manually create a test dir
+ gfile.MakeDirs(googletest.GetTempDir())
+ self._temp_dir = tempfile.mkdtemp(dir=googletest.GetTempDir())
+
+ self._num_cols = [4, 64, 256]
+ self._num_per_iter = 5000
+ self._filenames = []
+ for n in self._num_cols:
+ fn = os.path.join(self._temp_dir, 'file%d.csv' % n)
+ with open(fn, 'wb') as f:
+ # Just write 100 rows and use `repeat`... Assumes the cost
+ # of creating an iterator is not significant
+ row = ','.join([str_val for _ in range(n)])
+ f.write('\n'.join([row for _ in range(100)]))
+ self._filenames.append(fn)
+
+ def _tearDown(self):
+ gfile.DeleteRecursively(self._temp_dir)
+
+ def _runBenchmark(self, dataset, num_cols, prefix):
+ dataset = dataset.skip(self._num_per_iter - 1)
+ deltas = []
+ for _ in range(10):
+ next_element = dataset.make_one_shot_iterator().get_next()
+ with session.Session() as sess:
+ start = time.time()
+ # NOTE: This depends on the underlying implementation of skip, to have
+ # the net effect of calling `GetNext` num_per_iter times on the
+ # input dataset. We do it this way (instead of a python for loop, or
+ # batching N inputs in one iter) so that the overhead from session.run
+ # or batch doesn't dominate. If we eventually optimize skip, this has
+ # to change.
+ sess.run(next_element)
+ end = time.time()
+ deltas.append(end - start)
+ # Median wall time per CSV record read and decoded
+ median_wall_time = np.median(deltas) / self._num_per_iter
+ print('%s num_cols: %d Median wall time: %f' % (prefix, num_cols,
+ median_wall_time))
+ self.report_benchmark(
+ iters=self._num_per_iter,
+ wall_time=median_wall_time,
+ name='%s_with_cols_%d' % (prefix, num_cols))
+
+ def benchmarkMapWithFloats(self):
+ self._setUp(self.FLOAT_VAL)
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [[0.0]] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
+ self._runBenchmark(dataset, num_cols, 'csv_float_map_decode_csv')
+ self._tearDown()
+
+ def benchmarkMapWithStrings(self):
+ self._setUp(self.STR_VAL)
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [['']] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = dataset.map(lambda l: parsing_ops.decode_csv(l, **kwargs)) # pylint: disable=cell-var-from-loop
+ self._runBenchmark(dataset, num_cols, 'csv_strings_map_decode_csv')
+ self._tearDown()
+
+ def benchmarkCsvDatasetWithFloats(self):
+ self._setUp(self.FLOAT_VAL)
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [[0.0]] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
+ self._runBenchmark(dataset, num_cols, 'csv_float_fused_dataset')
+ self._tearDown()
+
+ def benchmarkCsvDatasetWithStrings(self):
+ self._setUp(self.STR_VAL)
+ for i in range(len(self._filenames)):
+ num_cols = self._num_cols[i]
+ kwargs = {'record_defaults': [['']] * num_cols}
+ dataset = core_readers.TextLineDataset(self._filenames[i]).repeat()
+ dataset = readers.CsvDataset(self._filenames[i], **kwargs).repeat() # pylint: disable=cell-var-from-loop
+ self._runBenchmark(dataset, num_cols, 'csv_strings_fused_dataset')
+ self._tearDown()
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py
new file mode 100644
index 0000000000..73be6cbcca
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/dense_to_sparse_batch_test.py
@@ -0,0 +1,124 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.dense_to_sparse_batch()."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class DenseToSparseBatchTest(test_base.DatasetTestBase):
+
+ def testDenseToSparseBatchDataset(self):
+ components = np.random.randint(12, size=(100,)).astype(np.int32)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([x], x)).apply(
+ batching.dense_to_sparse_batch(4, [12]))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ for start in range(0, len(components), 4):
+ results = sess.run(get_next)
+ self.assertAllEqual([[i, j]
+ for i, c in enumerate(components[start:start + 4])
+ for j in range(c)], results.indices)
+ self.assertAllEqual(
+ [c for c in components[start:start + 4] for _ in range(c)],
+ results.values)
+ self.assertAllEqual([min(4,
+ len(components) - start), 12],
+ results.dense_shape)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testDenseToSparseBatchDatasetWithUnknownShape(self):
+ components = np.random.randint(5, size=(40,)).astype(np.int32)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([x, x], x)).apply(
+ batching.dense_to_sparse_batch(
+ 4, [5, None])).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ for start in range(0, len(components), 4):
+ results = sess.run(get_next)
+ self.assertAllEqual([[i, j, z]
+ for i, c in enumerate(components[start:start + 4])
+ for j in range(c)
+ for z in range(c)], results.indices)
+ self.assertAllEqual([
+ c
+ for c in components[start:start + 4] for _ in range(c)
+ for _ in range(c)
+ ], results.values)
+ self.assertAllEqual([
+ min(4,
+ len(components) - start), 5,
+ np.max(components[start:start + 4])
+ ], results.dense_shape)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testDenseToSparseBatchDatasetWithInvalidShape(self):
+ input_tensor = array_ops.constant([[1]])
+ with self.assertRaisesRegexp(ValueError, "Dimension -2 must be >= 0"):
+ dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ batching.dense_to_sparse_batch(4, [-2])).make_initializable_iterator()
+
+ def testDenseToSparseBatchDatasetShapeErrors(self):
+ input_tensor = array_ops.placeholder(dtypes.int32)
+ iterator = (
+ dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ batching.dense_to_sparse_batch(4, [12]))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ # Initialize with an input tensor of incompatible rank.
+ sess.run(init_op, feed_dict={input_tensor: [[1]]})
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "incompatible with the row shape"):
+ sess.run(get_next)
+
+ # Initialize with an input tensor that is larger than `row_shape`.
+ sess.run(init_op, feed_dict={input_tensor: range(13)})
+ with self.assertRaisesRegexp(errors.DataLossError,
+ "larger than the row shape"):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py
new file mode 100644
index 0000000000..796a692c56
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py
@@ -0,0 +1,148 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import random_seed
+from tensorflow.python.platform import test
+
+
+class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
+
+ def testBasic(self):
+ selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
+ input_datasets = [
+ dataset_ops.Dataset.from_tensors(i).repeat(100) for i in range(10)
+ ]
+ dataset = interleave_ops._DirectedInterleaveDataset(selector_dataset,
+ input_datasets)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for _ in range(100):
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def _normalize(self, vec):
+ return vec / vec.sum()
+
+ def _chi2(self, expected, actual):
+ actual = np.asarray(actual)
+ expected = np.asarray(expected)
+ diff = actual - expected
+ chi2 = np.sum(diff * diff / expected, axis=0)
+ return chi2
+
+ def _testSampleFromDatasetsHelper(self, weights, num_datasets, num_samples):
+ # Create a dataset that samples each integer in `[0, num_datasets)`
+ # with probability given by `weights[i]`.
+ dataset = interleave_ops.sample_from_datasets([
+ dataset_ops.Dataset.from_tensors(i).repeat(None)
+ for i in range(num_datasets)
+ ], weights)
+ dataset = dataset.take(num_samples)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ freqs = np.zeros([num_datasets])
+ for _ in range(num_samples):
+ freqs[sess.run(next_element)] += 1
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ return freqs
+
+ def testSampleFromDatasets(self):
+ random_seed.set_random_seed(1619)
+ num_samples = 5000
+ rand_probs = self._normalize(np.random.random_sample((15,)))
+
+ # Use chi-squared test to assert that the observed distribution matches the
+ # expected distribution. Based on the implementation in
+ # "third_party/tensorflow/python/kernel_tests/multinomial_op_test.py".
+ for probs in [[.85, .05, .1], rand_probs, [1.]]:
+ probs = np.asarray(probs)
+ classes = len(probs)
+ freqs = self._testSampleFromDatasetsHelper(probs, classes, num_samples)
+ self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
+
+ # Also check that `weights` as a dataset samples correctly.
+ probs_ds = dataset_ops.Dataset.from_tensors(probs).repeat()
+ freqs = self._testSampleFromDatasetsHelper(probs_ds, classes, num_samples)
+ self.assertLess(self._chi2(probs, freqs / num_samples), 1e-2)
+
+ def testSelectFromDatasets(self):
+ words = [b"foo", b"bar", b"baz"]
+ datasets = [dataset_ops.Dataset.from_tensors(w).repeat() for w in words]
+ choice_array = np.random.randint(3, size=(15,), dtype=np.int64)
+ choice_dataset = dataset_ops.Dataset.from_tensor_slices(choice_array)
+ dataset = interleave_ops.choose_from_datasets(datasets, choice_dataset)
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in choice_array:
+ self.assertEqual(words[i], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testErrors(self):
+ with self.assertRaisesRegexp(ValueError,
+ r"vector of length `len\(datasets\)`"):
+ interleave_ops.sample_from_datasets(
+ [dataset_ops.Dataset.range(10),
+ dataset_ops.Dataset.range(20)],
+ weights=[0.25, 0.25, 0.25, 0.25])
+
+ with self.assertRaisesRegexp(TypeError, "`tf.float32` or `tf.float64`"):
+ interleave_ops.sample_from_datasets(
+ [dataset_ops.Dataset.range(10),
+ dataset_ops.Dataset.range(20)],
+ weights=[1, 1])
+
+ with self.assertRaisesRegexp(TypeError, "must have the same type"):
+ interleave_ops.sample_from_datasets([
+ dataset_ops.Dataset.from_tensors(0),
+ dataset_ops.Dataset.from_tensors(0.0)
+ ])
+
+ with self.assertRaisesRegexp(TypeError, "tf.int64"):
+ interleave_ops.choose_from_datasets([
+ dataset_ops.Dataset.from_tensors(0),
+ dataset_ops.Dataset.from_tensors(1)
+ ], choice_dataset=dataset_ops.Dataset.from_tensors(1.0))
+
+ with self.assertRaisesRegexp(TypeError, "scalar"):
+ interleave_ops.choose_from_datasets([
+ dataset_ops.Dataset.from_tensors(0),
+ dataset_ops.Dataset.from_tensors(1)
+ ], choice_dataset=dataset_ops.Dataset.from_tensors([1.0]))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py
new file mode 100644
index 0000000000..e54235d9f8
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/enumerate_dataset_test.py
@@ -0,0 +1,56 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.enumerate_dataset()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import enumerate_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.platform import test
+
+
+class EnumerateDatasetTest(test_base.DatasetTestBase):
+
+ def testEnumerateDataset(self):
+ components = (["a", "b"], [1, 2], [37.0, 38])
+ start = constant_op.constant(20, dtype=dtypes.int64)
+
+ iterator = (dataset_ops.Dataset.from_tensor_slices(components).apply(
+ enumerate_ops.enumerate_dataset(start)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ self.assertEqual(dtypes.int64, get_next[0].dtype)
+ self.assertEqual((), get_next[0].shape)
+ self.assertEqual([tensor_shape.TensorShape([])] * 3,
+ [t.shape for t in get_next[1]])
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ self.assertEqual((20, (b"a", 1, 37.0)), sess.run(get_next))
+ self.assertEqual((21, (b"b", 2, 38.0)), sess.run(get_next))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py
new file mode 100644
index 0000000000..c6ee88c676
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py
@@ -0,0 +1,76 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmarks FilterDataset input pipeline op."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class FilterBenchmark(test.Benchmark):
+
+ # This benchmark compares the performance of pipeline with multiple chained
+ # filter with and without filter fusion.
+ def benchmarkFilters(self):
+ chain_lengths = [0, 1, 2, 5, 10, 20, 50]
+ for chain_length in chain_lengths:
+ self._benchmarkFilters(chain_length, False)
+ self._benchmarkFilters(chain_length, True)
+
+ def _benchmarkFilters(self, chain_length, optimize_dataset):
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors(5).repeat(None)
+ for _ in range(chain_length):
+ dataset = dataset.filter(lambda x: math_ops.greater_equal(x - 5, 0))
+ if optimize_dataset:
+ dataset = dataset.apply(optimization.optimize(["filter_fusion"]))
+
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for _ in range(10):
+ sess.run(next_element.op)
+ deltas = []
+ for _ in range(100):
+ start = time.time()
+ for _ in range(100):
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ median_wall_time = np.median(deltas) / 100
+ opt_mark = "opt" if optimize_dataset else "no-opt"
+ print("Filter dataset {} chain length: {} Median wall time: {}".format(
+ opt_mark, chain_length, median_wall_time))
+ self.report_benchmark(
+ iters=1000,
+ wall_time=median_wall_time,
+ name="benchmark_filter_dataset_chain_latency_{}_{}".format(
+ opt_mark, chain_length))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/function_buffering_resource_test.py b/tensorflow/python/data/experimental/kernel_tests/function_buffering_resource_test.py
new file mode 100644
index 0000000000..399fd284f4
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/function_buffering_resource_test.py
@@ -0,0 +1,247 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the private `FunctionBufferingResource` used in prefetching."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import test
+
+
+class FunctionBufferingResourceTest(test_base.DatasetTestBase):
+
+ def setUp(self):
+ self._event = threading.Event()
+
+ def _create_ds_and_iterator(self, device0, initializable=False):
+
+ def gen():
+ for i in range(1, 10):
+ yield [float(i)]
+ if i == 6:
+ self._event.set()
+
+ with ops.device(device0):
+ ds = dataset_ops.Dataset.from_generator(gen, (dtypes.float32))
+ if initializable:
+ ds_iterator = ds.make_initializable_iterator()
+ else:
+ ds_iterator = ds.make_one_shot_iterator()
+ return (ds, ds_iterator)
+
+ def _create_ops(self, ds, ds_iterator, buffer_name, device0, device1):
+ ds_iterator_handle = ds_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _remote_fn(h):
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ h, ds.output_types, ds.output_shapes)
+ return remote_iterator.get_next()
+
+ target = constant_op.constant(device0)
+ with ops.device(device1):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_remote_fn,
+ output_types=[dtypes.float32],
+ target_device=target,
+ string_arg=ds_iterator_handle,
+ buffer_size=3,
+ shared_name=buffer_name)
+
+ with ops.device(device1):
+ prefetch_op = prefetching_ops.function_buffering_resource_get_next(
+ function_buffer_resource=buffer_resource_handle,
+ output_types=[dtypes.float32])
+ reset_op = prefetching_ops.function_buffering_resource_reset(
+ function_buffer_resource=buffer_resource_handle)
+ destroy_op = resource_variable_ops.destroy_resource_op(
+ buffer_resource_handle, ignore_lookup_error=True)
+
+ return (prefetch_op, reset_op, destroy_op)
+
+ def _prefetch_fn_helper_one_shot(self, buffer_name, device0, device1):
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+
+ ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=False)
+ prefetch_op, _, destroy_op = self._create_ops(ds, ds_iterator, buffer_name,
+ device0, device1)
+
+ with self.test_session(config=worker_config) as sess:
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [1.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [2.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [3.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [4.0])
+ self._event.wait()
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [5.0])
+ sess.run(destroy_op)
+
+ def testSameDeviceCPU(self):
+ self._prefetch_fn_helper_one_shot("same_device_cpu",
+ "/job:localhost/replica:0/task:0/cpu:0",
+ "/job:localhost/replica:0/task:0/cpu:0")
+
+ def testDifferentDeviceCPU(self):
+ self._prefetch_fn_helper_one_shot("diff_device_cpu",
+ "/job:localhost/replica:0/task:0/cpu:0",
+ "/job:localhost/replica:0/task:0/cpu:1")
+
+ def testDifferentDeviceCPUGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ self._prefetch_fn_helper_one_shot("cpu_gpu",
+ "/job:localhost/replica:0/task:0/cpu:0",
+ "/job:localhost/replica:0/task:0/gpu:0")
+
+ def testReinitialization(self):
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/cpu:1"
+ ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
+ prefetch_op, reset_op, destroy_op = self._create_ops(
+ ds, ds_iterator, "reinit", device0, device1)
+
+ with self.test_session(config=worker_config) as sess:
+ sess.run(ds_iterator.initializer)
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [1.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [2.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [3.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [4.0])
+ self._event.wait()
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [5.0])
+ # Lets reset the function buffering resource and reinitialize the
+ # iterator. Should be able to go through this again.
+ self._event.clear()
+ sess.run(reset_op)
+ sess.run(ds_iterator.initializer)
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [1.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [2.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [3.0])
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [4.0])
+ self._event.wait()
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [5.0])
+ sess.run(destroy_op)
+
+ def testReinitializationOutOfRange(self):
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/cpu:1"
+ ds, ds_iterator = self._create_ds_and_iterator(device0, initializable=True)
+ prefetch_op, reset_op, destroy_op = self._create_ops(
+ ds, ds_iterator, "reinit", device0, device1)
+
+ with self.test_session(config=worker_config) as sess:
+ sess.run(ds_iterator.initializer)
+ for i in range(1, 10):
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [float(i)])
+ # Try fetching after its over twice to test out end of sequence.
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ # Now reset everything and try it out again.
+ self._event.clear()
+ sess.run(reset_op)
+ sess.run(ds_iterator.initializer)
+ for i in range(1, 10):
+ elem = sess.run(prefetch_op)
+ self.assertEqual(elem, [float(i)])
+ # Try fetching after its over twice to test out end of sequence.
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ sess.run(destroy_op)
+
+ def testStringsGPU(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ device0 = "/job:localhost/replica:0/task:0/cpu:0"
+ device1 = "/job:localhost/replica:0/task:0/gpu:0"
+
+ ds = dataset_ops.Dataset.from_tensor_slices(["a", "b", "c"])
+ ds_iterator = ds.make_one_shot_iterator()
+ ds_iterator_handle = ds_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _remote_fn(h):
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ h, ds.output_types, ds.output_shapes)
+ return remote_iterator.get_next()
+
+ target = constant_op.constant(device0)
+ with ops.device(device1):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_remote_fn,
+ output_types=[dtypes.string],
+ target_device=target,
+ string_arg=ds_iterator_handle,
+ buffer_size=3,
+ shared_name="strings")
+
+ with ops.device(device1):
+ prefetch_op = prefetching_ops.function_buffering_resource_get_next(
+ function_buffer_resource=buffer_resource_handle,
+ output_types=[dtypes.string])
+ destroy_op = resource_variable_ops.destroy_resource_op(
+ buffer_resource_handle, ignore_lookup_error=True)
+
+ with self.cached_session() as sess:
+ self.assertEqual([b"a"], sess.run(prefetch_op))
+ self.assertEqual([b"b"], sess.run(prefetch_op))
+ self.assertEqual([b"c"], sess.run(prefetch_op))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(prefetch_op)
+
+ sess.run(destroy_op)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py
new file mode 100644
index 0000000000..8c07afbac5
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py
@@ -0,0 +1,72 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.experimental.ops import get_single_element
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("Zero", 0, 1),
+ ("Five", 5, 1),
+ ("Ten", 10, 1),
+ ("Empty", 100, 1, errors.InvalidArgumentError, "Dataset was empty."),
+ ("MoreThanOne", 0, 2, errors.InvalidArgumentError,
+ "Dataset had more than one element."),
+ )
+ def testGetSingleElement(self, skip, take, error=None, error_msg=None):
+ skip_t = array_ops.placeholder(dtypes.int64, shape=[])
+ take_t = array_ops.placeholder(dtypes.int64, shape=[])
+
+ def make_sparse(x):
+ x_1d = array_ops.reshape(x, [1])
+ x_2d = array_ops.reshape(x, [1, 1])
+ return sparse_tensor.SparseTensor(x_2d, x_1d, x_1d)
+
+ dataset = dataset_ops.Dataset.range(100).skip(skip_t).map(
+ lambda x: (x * x, make_sparse(x))).take(take_t)
+ element = get_single_element.get_single_element(dataset)
+
+ with self.cached_session() as sess:
+ if error is None:
+ dense_val, sparse_val = sess.run(
+ element, feed_dict={
+ skip_t: skip,
+ take_t: take
+ })
+ self.assertEqual(skip * skip, dense_val)
+ self.assertAllEqual([[skip]], sparse_val.indices)
+ self.assertAllEqual([skip], sparse_val.values)
+ self.assertAllEqual([skip], sparse_val.dense_shape)
+ else:
+ with self.assertRaisesRegexp(error, error_msg):
+ sess.run(element, feed_dict={skip_t: skip, take_t: take})
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py b/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py
new file mode 100644
index 0000000000..9030328593
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/group_by_reducer_test.py
@@ -0,0 +1,199 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.group_by_reducer()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class GroupByReducerTest(test_base.DatasetTestBase):
+
+ def checkResults(self, dataset, shapes, values):
+ self.assertEqual(shapes, dataset.output_shapes)
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.cached_session() as sess:
+ for expected in values:
+ got = sess.run(get_next)
+ self.assertEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testSum(self):
+ reducer = grouping.Reducer(
+ init_func=lambda _: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).apply(
+ grouping.group_by_reducer(lambda x: x % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+ def testAverage(self):
+
+ def reduce_fn(x, y):
+ return (x[0] * x[1] + math_ops.cast(y, dtypes.float32)) / (
+ x[1] + 1), x[1] + 1
+
+ reducer = grouping.Reducer(
+ init_func=lambda _: (0.0, 0.0),
+ reduce_func=reduce_fn,
+ finalize_func=lambda x, _: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).apply(
+ grouping.group_by_reducer(
+ lambda x: math_ops.cast(x, dtypes.int64) % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[i - 1, i])
+
+ def testConcat(self):
+ components = np.array(list("abcdefghijklmnopqrst")).view(np.chararray)
+ reducer = grouping.Reducer(
+ init_func=lambda x: "",
+ reduce_func=lambda x, y: x + y[0],
+ finalize_func=lambda x: x)
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensor_slices(components),
+ dataset_ops.Dataset.range(2 * i))).apply(
+ grouping.group_by_reducer(lambda x, y: y % 2, reducer))
+ self.checkResults(
+ dataset,
+ shapes=tensor_shape.scalar(),
+ values=[b"acegikmoqs" [:i], b"bdfhjlnprt" [:i]])
+
+ def testSparseSum(self):
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1], dtype=np.int64)),
+ dense_shape=np.array([1, 1]))
+
+ reducer = grouping.Reducer(
+ init_func=lambda _: _sparse(np.int64(0)),
+ reduce_func=lambda x, y: _sparse(x.values[0] + y.values[0]),
+ finalize_func=lambda x: x.values[0])
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.range(2 * i).map(_sparse).apply(
+ grouping.group_by_reducer(lambda x: x.values[0] % 2, reducer))
+ self.checkResults(
+ dataset, shapes=tensor_shape.scalar(), values=[(i - 1) * i, i * i])
+
+ def testChangingStateShape(self):
+
+ def reduce_fn(x, _):
+ # Statically known rank, but dynamic length.
+ larger_dim = array_ops.concat([x[0], x[0]], 0)
+ # Statically unknown rank.
+ larger_rank = array_ops.expand_dims(x[1], 0)
+ return larger_dim, larger_rank
+
+ reducer = grouping.Reducer(
+ init_func=lambda x: ([0], 1),
+ reduce_func=reduce_fn,
+ finalize_func=lambda x, y: (x, y))
+
+ for i in range(1, 11):
+ dataset = dataset_ops.Dataset.from_tensors(np.int64(0)).repeat(i).apply(
+ grouping.group_by_reducer(lambda x: x, reducer))
+ self.assertEqual([None], dataset.output_shapes[0].as_list())
+ self.assertIs(None, dataset.output_shapes[1].ndims)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ x, y = sess.run(get_next)
+ self.assertAllEqual([0] * (2**i), x)
+ self.assertAllEqual(np.array(1, ndmin=i), y)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testTypeMismatch(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: constant_op.constant(1, dtype=dtypes.int32),
+ reduce_func=lambda x, y: constant_op.constant(1, dtype=dtypes.int64),
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ TypeError,
+ "The element types for the new state must match the initial state."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: np.int64(0), reducer))
+
+ # TODO(b/78665031): Remove once non-scalar keys are supported.
+ def testInvalidKeyShape(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ ValueError, "`key_func` must return a single tf.int64 tensor."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: np.int64((0, 0)), reducer))
+
+ # TODO(b/78665031): Remove once non-int64 keys are supported.
+ def testInvalidKeyType(self):
+ reducer = grouping.Reducer(
+ init_func=lambda x: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ ValueError, "`key_func` must return a single tf.int64 tensor."):
+ dataset.apply(
+ grouping.group_by_reducer(lambda _: "wrong", reducer))
+
+ def testTuple(self):
+ def init_fn(_):
+ return np.array([], dtype=np.int64), np.int64(0)
+
+ def reduce_fn(state, value):
+ s1, s2 = state
+ v1, v2 = value
+ return array_ops.concat([s1, [v1]], 0), s2 + v2
+
+ def finalize_fn(s1, s2):
+ return s1, s2
+
+ reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
+ dataset = dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.range(10), dataset_ops.Dataset.range(10))).apply(
+ grouping.group_by_reducer(lambda x, y: np.int64(0), reducer))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ with self.cached_session() as sess:
+ x, y = sess.run(get_next)
+ self.assertAllEqual(x, np.asarray([x for x in range(10)]))
+ self.assertEqual(y, 45)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py b/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py
new file mode 100644
index 0000000000..557d56e8b9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/group_by_window_test.py
@@ -0,0 +1,367 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.group_by_window()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
+# Currently, they use a constant batch size, though should be made to use a
+# different batch size per key.
+class GroupByWindowTest(test_base.DatasetTestBase):
+
+ def _dynamicPad(self, bucket, window, window_size):
+ # TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
+ # generic form of padded_batch that pads every component
+ # dynamically and does not rely on static shape information about
+ # the arguments.
+ return dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(bucket),
+ window.padded_batch(
+ 32, (tensor_shape.TensorShape([]), tensor_shape.TensorShape(
+ [None]), tensor_shape.TensorShape([3])))))
+
+ def testSingleBucket(self):
+
+ def _map_fn(v):
+ return (v, array_ops.fill([v], v),
+ array_ops.fill([3], string_ops.as_string(v)))
+
+ input_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(math_ops.range(32)).map(_map_fn))
+
+ bucketed_dataset = input_dataset.apply(
+ grouping.group_by_window(
+ lambda x, y, z: 0,
+ lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
+
+ iterator = bucketed_dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ which_bucket, bucketed_values = sess.run(get_next)
+
+ self.assertEqual(0, which_bucket)
+
+ expected_scalar_int = np.arange(32, dtype=np.int64)
+ expected_unk_int64 = np.zeros((32, 31)).astype(np.int64)
+ for i in range(32):
+ expected_unk_int64[i, :i] = i
+ expected_vec3_str = np.vstack(3 * [np.arange(32).astype(bytes)]).T
+
+ self.assertAllEqual(expected_scalar_int, bucketed_values[0])
+ self.assertAllEqual(expected_unk_int64, bucketed_values[1])
+ self.assertAllEqual(expected_vec3_str, bucketed_values[2])
+
+ def testEvenOddBuckets(self):
+
+ def _map_fn(v):
+ return (v, array_ops.fill([v], v),
+ array_ops.fill([3], string_ops.as_string(v)))
+
+ input_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(math_ops.range(64)).map(_map_fn))
+
+ bucketed_dataset = input_dataset.apply(
+ grouping.group_by_window(
+ lambda x, y, z: math_ops.cast(x % 2, dtypes.int64),
+ lambda k, bucket: self._dynamicPad(k, bucket, 32), 32))
+
+ iterator = bucketed_dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ # Get two minibatches (one containing even values, one containing odds)
+ which_bucket_even, bucketed_values_even = sess.run(get_next)
+ which_bucket_odd, bucketed_values_odd = sess.run(get_next)
+
+ # Count number of bucket_tensors.
+ self.assertEqual(3, len(bucketed_values_even))
+ self.assertEqual(3, len(bucketed_values_odd))
+
+ # Ensure bucket 0 was used for all minibatch entries.
+ self.assertAllEqual(0, which_bucket_even)
+ self.assertAllEqual(1, which_bucket_odd)
+
+ # Test the first bucket outputted, the events starting at 0
+ expected_scalar_int = np.arange(0, 32 * 2, 2, dtype=np.int64)
+ expected_unk_int64 = np.zeros((32, 31 * 2)).astype(np.int64)
+ for i in range(0, 32):
+ expected_unk_int64[i, :2 * i] = 2 * i
+ expected_vec3_str = np.vstack(
+ 3 * [np.arange(0, 32 * 2, 2).astype(bytes)]).T
+
+ self.assertAllEqual(expected_scalar_int, bucketed_values_even[0])
+ self.assertAllEqual(expected_unk_int64, bucketed_values_even[1])
+ self.assertAllEqual(expected_vec3_str, bucketed_values_even[2])
+
+ # Test the second bucket outputted, the odds starting at 1
+ expected_scalar_int = np.arange(1, 32 * 2 + 1, 2, dtype=np.int64)
+ expected_unk_int64 = np.zeros((32, 31 * 2 + 1)).astype(np.int64)
+ for i in range(0, 32):
+ expected_unk_int64[i, :2 * i + 1] = 2 * i + 1
+ expected_vec3_str = np.vstack(
+ 3 * [np.arange(1, 32 * 2 + 1, 2).astype(bytes)]).T
+
+ self.assertAllEqual(expected_scalar_int, bucketed_values_odd[0])
+ self.assertAllEqual(expected_unk_int64, bucketed_values_odd[1])
+ self.assertAllEqual(expected_vec3_str, bucketed_values_odd[2])
+
+ def testEvenOddBucketsFilterOutAllOdd(self):
+
+ def _map_fn(v):
+ return {
+ "x": v,
+ "y": array_ops.fill([v], v),
+ "z": array_ops.fill([3], string_ops.as_string(v))
+ }
+
+ def _dynamic_pad_fn(bucket, window, _):
+ return dataset_ops.Dataset.zip(
+ (dataset_ops.Dataset.from_tensors(bucket),
+ window.padded_batch(
+ 32, {
+ "x": tensor_shape.TensorShape([]),
+ "y": tensor_shape.TensorShape([None]),
+ "z": tensor_shape.TensorShape([3])
+ })))
+
+ input_dataset = (
+ dataset_ops.Dataset.from_tensor_slices(math_ops.range(128)).map(_map_fn)
+ .filter(lambda d: math_ops.equal(d["x"] % 2, 0)))
+
+ bucketed_dataset = input_dataset.apply(
+ grouping.group_by_window(
+ lambda d: math_ops.cast(d["x"] % 2, dtypes.int64),
+ lambda k, bucket: _dynamic_pad_fn(k, bucket, 32), 32))
+
+ iterator = bucketed_dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+
+ # Get two minibatches ([0, 2, ...] and [64, 66, ...])
+ which_bucket0, bucketed_values_even0 = sess.run(get_next)
+ which_bucket1, bucketed_values_even1 = sess.run(get_next)
+
+ # Ensure that bucket 1 was completely filtered out
+ self.assertAllEqual(0, which_bucket0)
+ self.assertAllEqual(0, which_bucket1)
+ self.assertAllEqual(
+ np.arange(0, 64, 2, dtype=np.int64), bucketed_values_even0["x"])
+ self.assertAllEqual(
+ np.arange(64, 128, 2, dtype=np.int64), bucketed_values_even1["x"])
+
+ def testDynamicWindowSize(self):
+ components = np.arange(100).astype(np.int64)
+
+ # Key fn: even/odd
+ # Reduce fn: batches of 5
+ # Window size fn: even=5, odd=10
+
+ def window_size_func(key):
+ window_sizes = constant_op.constant([5, 10], dtype=dtypes.int64)
+ return window_sizes[key]
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).apply(
+ grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(20),
+ None, window_size_func))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.OutOfRangeError):
+ batches = 0
+ while True:
+ result = sess.run(get_next)
+ is_even = all(x % 2 == 0 for x in result)
+ is_odd = all(x % 2 == 1 for x in result)
+ self.assertTrue(is_even or is_odd)
+ expected_batch_size = 5 if is_even else 10
+ self.assertEqual(expected_batch_size, result.shape[0])
+ batches += 1
+
+ self.assertEqual(batches, 15)
+
+ def testSimple(self):
+ components = np.random.randint(100, size=(200,)).astype(np.int64)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(lambda x: x * x)
+ .apply(
+ grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
+ 4)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ counts = []
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ result = sess.run(get_next)
+ self.assertTrue(
+ all(x % 2 == 0
+ for x in result) or all(x % 2 == 1)
+ for x in result)
+ counts.append(result.shape[0])
+
+ self.assertEqual(len(components), sum(counts))
+ num_full_batches = len([c for c in counts if c == 4])
+ self.assertGreaterEqual(num_full_batches, 24)
+ self.assertTrue(all(c == 4 for c in counts[:num_full_batches]))
+
+ def testImmediateOutput(self):
+ components = np.array(
+ [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
+ grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4),
+ 4)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ # The input is infinite, so this test demonstrates that:
+ # 1. We produce output without having to consume the entire input,
+ # 2. Different buckets can produce output at different rates, and
+ # 3. For deterministic input, the output is deterministic.
+ for _ in range(3):
+ self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
+ self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
+ self.assertAllEqual([2, 2, 2, 2], sess.run(get_next))
+ self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
+
+ def testSmallGroups(self):
+ components = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0], dtype=np.int64)
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).apply(
+ grouping.group_by_window(lambda x: x % 2, lambda _, xs: xs.batch(4),
+ 4)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ self.assertAllEqual([0, 0, 0, 0], sess.run(get_next))
+ self.assertAllEqual([1, 1, 1, 1], sess.run(get_next))
+ # The small outputs at the end are deterministically produced in key
+ # order.
+ self.assertAllEqual([0, 0, 0], sess.run(get_next))
+ self.assertAllEqual([1], sess.run(get_next))
+
+ def testEmpty(self):
+ iterator = (
+ dataset_ops.Dataset.range(4).apply(
+ grouping.group_by_window(lambda _: 0, lambda _, xs: xs, 0))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Window size must be greater than zero, but got 0."):
+ print(sess.run(get_next))
+
+ def testReduceFuncError(self):
+ components = np.random.randint(100, size=(200,)).astype(np.int64)
+
+ def reduce_func(_, xs):
+ # Introduce an incorrect padded shape that cannot (currently) be
+ # detected at graph construction time.
+ return xs.padded_batch(
+ 4,
+ padded_shapes=(tensor_shape.TensorShape([]),
+ constant_op.constant([5], dtype=dtypes.int64) * -1))
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: (x, ops.convert_to_tensor([x * x]))).apply(
+ grouping.group_by_window(lambda x, _: x % 2, reduce_func,
+ 32)).make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ def testConsumeWindowDatasetMoreThanOnce(self):
+ components = np.random.randint(50, size=(200,)).astype(np.int64)
+
+ def reduce_func(key, window):
+ # Apply two different kinds of padding to the input: tight
+ # padding, and quantized (to a multiple of 10) padding.
+ return dataset_ops.Dataset.zip((
+ window.padded_batch(
+ 4, padded_shapes=tensor_shape.TensorShape([None])),
+ window.padded_batch(
+ 4, padded_shapes=ops.convert_to_tensor([(key + 1) * 10])),
+ ))
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.fill([math_ops.cast(x, dtypes.int32)], x))
+ .apply(grouping.group_by_window(
+ lambda x: math_ops.cast(array_ops.shape(x)[0] // 10, dtypes.int64),
+ reduce_func, 4))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ counts = []
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ tight_result, multiple_of_10_result = sess.run(get_next)
+ self.assertEqual(0, multiple_of_10_result.shape[1] % 10)
+ self.assertAllEqual(tight_result,
+ multiple_of_10_result[:, :tight_result.shape[1]])
+ counts.append(tight_result.shape[0])
+ self.assertEqual(len(components), sum(counts))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py b/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py
new file mode 100644
index 0000000000..c0ec1486ab
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/ignore_errors_test.py
@@ -0,0 +1,115 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.ignore_errors()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+_NUMPY_RANDOM_SEED = 42
+
+
+class IgnoreErrorsTest(test_base.DatasetTestBase):
+
+ def testMapIgnoreError(self):
+ components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components)
+ .map(lambda x: array_ops.check_numerics(x, "message")).apply(
+ error_ops.ignore_errors()))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for x in [1., 2., 3., 5.]:
+ self.assertEqual(x, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testParallelMapIgnoreError(self):
+ components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(components).map(
+ lambda x: array_ops.check_numerics(x, "message"),
+ num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for x in [1., 2., 3., 5.]:
+ self.assertEqual(x, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testReadFileIgnoreError(self):
+
+ def write_string_to_file(value, filename):
+ with open(filename, "w") as f:
+ f.write(value)
+
+ filenames = [
+ os.path.join(self.get_temp_dir(), "file_%d.txt" % i) for i in range(5)
+ ]
+ for filename in filenames:
+ write_string_to_file(filename, filename)
+
+ dataset = (
+ dataset_ops.Dataset.from_tensor_slices(filenames).map(
+ io_ops.read_file,
+ num_parallel_calls=2).prefetch(2).apply(error_ops.ignore_errors()))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ # All of the files are present.
+ sess.run(init_op)
+ for filename in filenames:
+ self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Delete one of the files.
+ os.remove(filenames[0])
+
+ # Attempting to read filenames[0] will fail, but ignore_errors()
+ # will catch the error.
+ sess.run(init_op)
+ for filename in filenames[1:]:
+ self.assertEqual(compat.as_bytes(filename), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py
new file mode 100644
index 0000000000..c93a8353ce
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py
@@ -0,0 +1,79 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for experimental indexed dataset ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import unittest
+
+from tensorflow.python.data.experimental.ops import indexed_dataset_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.platform import test
+
+
+class IndexedDatasetOpsTest(test_base.DatasetTestBase):
+
+ def testLowLevelIndexedDatasetOps(self):
+ identity = ged_ops.experimental_identity_indexed_dataset(
+ ops.convert_to_tensor(16, dtype=dtypes.uint64))
+ handle = ged_ops.experimental_materialized_index_dataset_handle(
+ container="",
+ shared_name="",
+ output_types=[dtypes.uint64],
+ output_shapes=[[]])
+ materialize = ged_ops.experimental_indexed_dataset_materialize(
+ identity, handle)
+ index = array_ops.placeholder(dtypes.uint64)
+ get_op = ged_ops.experimental_indexed_dataset_get(
+ handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
+
+ with self.cached_session() as sess:
+ sess.run(materialize)
+ self.assertEqual([3], sess.run(get_op, feed_dict={index: 3}))
+
+ def testIdentityIndexedDataset(self):
+ ds = indexed_dataset_ops.IdentityIndexedDataset(16)
+ materialized = ds.materialize()
+ with self.cached_session() as sess:
+ sess.run(materialized.initializer)
+ placeholder = array_ops.placeholder(dtypes.uint64, shape=[])
+ for i in range(16):
+ output = sess.run(
+ materialized.get(placeholder), feed_dict={placeholder: i})
+ self.assertEqual([i], output)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(materialized.get(placeholder), feed_dict={placeholder: 16})
+
+ @unittest.skip("Requisite functionality currently unimplemented.")
+ def testIdentityIndexedDatasetIterator(self):
+ ds = indexed_dataset_ops.IdentityIndexedDataset(16)
+ itr = ds.make_initializable_iterator()
+ n = itr.get_next()
+ with self.cached_session() as sess:
+ sess.run(itr.initializer)
+ for i in range(16):
+ output = sess.run(n)
+ self.assertEqual(i, output)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(n)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py
new file mode 100644
index 0000000000..5ee94e14dc
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/make_batched_features_dataset_test.py
@@ -0,0 +1,239 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.make_batched_features_dataset()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import test
+
+
+class MakeBatchedFeaturesDatasetTest(
+ reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
+
+ def testRead(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 10]:
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from file 0.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ 0,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from file 1.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames[1],
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ 1,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from both files.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Basic test: read from both files.
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ num_epochs=num_epochs,
+ batch_size=batch_size).make_one_shot_iterator().get_next()
+ self.verify_records(sess, batch_size, num_epochs=num_epochs)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess)
+
+ def testReadWithEquivalentDataset(self):
+ features = {
+ "file": parsing_ops.FixedLenFeature([], dtypes.int64),
+ "record": parsing_ops.FixedLenFeature([], dtypes.int64),
+ }
+ dataset = (
+ core_readers.TFRecordDataset(self.test_filenames)
+ .map(lambda x: parsing_ops.parse_single_example(x, features))
+ .repeat(10).batch(2))
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for file_batch, _, _, _, record_batch, _ in self._next_expected_batch(
+ range(self._num_files), 2, 10):
+ actual_batch = sess.run(next_element)
+ self.assertAllEqual(file_batch, actual_batch["file"])
+ self.assertAllEqual(record_batch, actual_batch["record"])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testReadWithFusedShuffleRepeatDataset(self):
+ num_epochs = 5
+ total_records = num_epochs * self._num_records
+ for batch_size in [1, 2]:
+ # Test that shuffling with same seed produces the same result.
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ outputs1 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ outputs2 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ for _ in range(total_records // batch_size):
+ batch1 = self._run_actual_batch(outputs1, sess)
+ batch2 = self._run_actual_batch(outputs2, sess)
+ for i in range(len(batch1)):
+ self.assertAllEqual(batch1[i], batch2[i])
+
+ # Test that shuffling with different seeds produces a different order.
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ outputs1 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5).make_one_shot_iterator().get_next()
+ outputs2 = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=15).make_one_shot_iterator().get_next()
+ all_equal = True
+ for _ in range(total_records // batch_size):
+ batch1 = self._run_actual_batch(outputs1, sess)
+ batch2 = self._run_actual_batch(outputs2, sess)
+ for i in range(len(batch1)):
+ all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
+ self.assertFalse(all_equal)
+
+ def testParallelReadersAndParsers(self):
+ num_epochs = 5
+ for batch_size in [1, 2]:
+ for reader_num_threads in [2, 4]:
+ for parser_num_threads in [2, 4]:
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ label_key_provided=True,
+ interleave_cycle_length=reader_num_threads)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess, label_key_provided=True)
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ self.outputs = self.make_batch_feature(
+ filenames=self.test_filenames,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads).make_one_shot_iterator(
+ ).get_next()
+ self.verify_records(
+ sess,
+ batch_size,
+ num_epochs=num_epochs,
+ interleave_cycle_length=reader_num_threads)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._next_actual_batch(sess)
+
+ def testDropFinalBatch(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 10]:
+ with ops.Graph().as_default():
+ # Basic test: read from file 0.
+ outputs = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ drop_final_batch=True).make_one_shot_iterator().get_next()
+ for tensor in nest.flatten(outputs):
+ if isinstance(tensor, ops.Tensor): # Guard against SparseTensor.
+ self.assertEqual(tensor.shape[0], batch_size)
+
+ def testIndefiniteRepeatShapeInference(self):
+ dataset = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ label_key="label",
+ num_epochs=None,
+ batch_size=32)
+ for shape, clazz in zip(nest.flatten(dataset.output_shapes),
+ nest.flatten(dataset.output_classes)):
+ if issubclass(clazz, ops.Tensor):
+ self.assertEqual(32, shape[0])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py
new file mode 100644
index 0000000000..e4bf089184
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/make_csv_dataset_test.py
@@ -0,0 +1,660 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.make_csv_dataset()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import zlib
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+
+
+class MakeCsvDatasetTest(test_base.DatasetTestBase):
+
+ def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
+ return readers.make_csv_dataset(
+ filenames, batch_size=batch_size, num_epochs=num_epochs, **kwargs)
+
+ def _setup_files(self, inputs, linebreak="\n", compression_type=None):
+ filenames = []
+ for i, ip in enumerate(inputs):
+ fn = os.path.join(self.get_temp_dir(), "temp_%d.csv" % i)
+ contents = linebreak.join(ip).encode("utf-8")
+ if compression_type is None:
+ with open(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "GZIP":
+ with gzip.GzipFile(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "ZLIB":
+ contents = zlib.compress(contents)
+ with open(fn, "wb") as f:
+ f.write(contents)
+ else:
+ raise ValueError("Unsupported compression_type", compression_type)
+ filenames.append(fn)
+ return filenames
+
+ def _next_expected_batch(self, expected_output, expected_keys, batch_size,
+ num_epochs):
+ features = {k: [] for k in expected_keys}
+ for _ in range(num_epochs):
+ for values in expected_output:
+ for n, key in enumerate(expected_keys):
+ features[key].append(values[n])
+ if len(features[expected_keys[0]]) == batch_size:
+ yield features
+ features = {k: [] for k in expected_keys}
+ if features[expected_keys[0]]: # Leftover from the last batch
+ yield features
+
+ def _verify_output(
+ self,
+ sess,
+ dataset,
+ batch_size,
+ num_epochs,
+ label_name,
+ expected_output,
+ expected_keys,
+ ):
+ nxt = dataset.make_one_shot_iterator().get_next()
+
+ for expected_features in self._next_expected_batch(
+ expected_output,
+ expected_keys,
+ batch_size,
+ num_epochs,
+ ):
+ actual_features = sess.run(nxt)
+
+ if label_name is not None:
+ expected_labels = expected_features.pop(label_name)
+ self.assertAllEqual(expected_labels, actual_features[1])
+ actual_features = actual_features[0]
+
+ for k in expected_features.keys():
+ # Compare features
+ self.assertAllEqual(expected_features[k], actual_features[k])
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(nxt)
+
+ def _test_dataset(self,
+ inputs,
+ expected_output,
+ expected_keys,
+ batch_size=1,
+ num_epochs=1,
+ label_name=None,
+ **kwargs):
+ """Checks that elements produced by CsvDataset match expected output."""
+ # Convert str type because py3 tf strings are bytestrings
+ filenames = self._setup_files(
+ inputs, compression_type=kwargs.get("compression_type", None))
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ dataset = self._make_csv_dataset(
+ filenames,
+ batch_size=batch_size,
+ num_epochs=num_epochs,
+ label_name=label_name,
+ **kwargs)
+ self._verify_output(sess, dataset, batch_size, num_epochs, label_name,
+ expected_output, expected_keys)
+
+ def testMakeCSVDataset(self):
+ """Tests making a CSV dataset with keys and defaults provided."""
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
+
+ def testMakeCSVDataset_withBatchSizeAndEpochs(self):
+ """Tests making a CSV dataset with keys and defaults provided."""
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=3,
+ num_epochs=10,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
+
+ def testMakeCSVDataset_withCompressionType(self):
+ """Tests `compression_type` argument."""
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ for compression_type in ("GZIP", "ZLIB"):
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ compression_type=compression_type,
+ )
+
+ def testMakeCSVDataset_withBadInputs(self):
+ """Tests that exception is raised when input is malformed.
+ """
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ filenames = self._setup_files(inputs)
+
+ # Duplicate column names
+ with self.assertRaises(ValueError):
+ self._make_csv_dataset(
+ filenames,
+ batch_size=1,
+ column_defaults=record_defaults,
+ label_name="col0",
+ column_names=column_names * 2)
+
+ # Label key not one of column names
+ with self.assertRaises(ValueError):
+ self._make_csv_dataset(
+ filenames,
+ batch_size=1,
+ column_defaults=record_defaults,
+ label_name="not_a_real_label",
+ column_names=column_names)
+
+ def testMakeCSVDataset_withNoLabel(self):
+ """Tests making a CSV dataset with no label provided."""
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
+
+ def testMakeCSVDataset_withNoHeader(self):
+ """Tests that datasets can be created from CSV files with no header line.
+ """
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [["0,1,2,3,4", "5,6,7,8,9"], ["10,11,12,13,14", "15,16,17,18,19"]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=False,
+ column_defaults=record_defaults,
+ )
+
+ def testMakeCSVDataset_withTypes(self):
+ """Tests that defaults can be a dtype instead of a Tensor for required vals.
+ """
+ record_defaults = [
+ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64,
+ dtypes.string
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x[0] for x in column_names), "0,1,2,3,4", "5,6,7,8,9"],
+ [
+ ",".join(x[0] for x in column_names), "10,11,12,13,14",
+ "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
+
+ def testMakeCSVDataset_withNoColNames(self):
+ """Tests that datasets can be created when column names are not specified.
+
+ In that case, we should infer the column names from the header lines.
+ """
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ expected_output = [[0, 1, 2, 3, b"4"], [5, 6, 7, 8, b"9"],
+ [10, 11, 12, 13, b"14"], [15, 16, 17, 18, b"19"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ column_defaults=record_defaults,
+ )
+
+ def testMakeCSVDataset_withTypeInferenceMismatch(self):
+ # Test that error is thrown when num fields doesn't match columns
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ filenames = self._setup_files(inputs)
+ with self.assertRaises(ValueError):
+ self._make_csv_dataset(
+ filenames,
+ column_names=column_names + ["extra_name"],
+ column_defaults=None,
+ batch_size=2,
+ num_epochs=10)
+
+ def testMakeCSVDataset_withTypeInference(self):
+ """Tests that datasets can be created when no defaults are specified.
+
+ In that case, we should infer the types from the first N records.
+ """
+ column_names = ["col%d" % i for i in range(5)]
+ str_int32_max = str(2**33)
+ inputs = [[
+ ",".join(x for x in column_names),
+ "0,%s,2.0,3e50,rabbit" % str_int32_max
+ ]]
+ expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ )
+
+ def testMakeCSVDataset_withTypeInferenceFallthrough(self):
+ """Tests that datasets can be created when no defaults are specified.
+
+ Tests on a deliberately tricky file.
+ """
+ column_names = ["col%d" % i for i in range(5)]
+ str_int32_max = str(2**33)
+ inputs = [[
+ ",".join(x for x in column_names),
+ ",,,,",
+ "0,0,0.0,0.0,0.0",
+ "0,%s,2.0,3e50,rabbit" % str_int32_max,
+ ",,,,",
+ ]]
+ expected_output = [[0, 0, 0, 0, b""], [0, 0, 0, 0, b"0.0"],
+ [0, 2**33, 2.0, 3e50, b"rabbit"], [0, 0, 0, 0, b""]]
+ label = "col0"
+
+ self._test_dataset(
+ inputs,
+ expected_output=expected_output,
+ expected_keys=column_names,
+ column_names=column_names,
+ label_name=label,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ )
+
+ def testMakeCSVDataset_withSelectCols(self):
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+ column_names = ["col%d" % i for i in range(5)]
+ str_int32_max = str(2**33)
+ inputs = [[
+ ",".join(x for x in column_names),
+ "0,%s,2.0,3e50,rabbit" % str_int32_max
+ ]]
+ expected_output = [[0, 2**33, 2.0, 3e50, b"rabbit"]]
+
+ select_cols = [1, 3, 4]
+ self._test_dataset(
+ inputs,
+ expected_output=[[x[i] for i in select_cols] for x in expected_output],
+ expected_keys=[column_names[i] for i in select_cols],
+ column_names=column_names,
+ column_defaults=[record_defaults[i] for i in select_cols],
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ select_columns=select_cols,
+ )
+
+ # Can still do inference without provided defaults
+ self._test_dataset(
+ inputs,
+ expected_output=[[x[i] for i in select_cols] for x in expected_output],
+ expected_keys=[column_names[i] for i in select_cols],
+ column_names=column_names,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ select_columns=select_cols,
+ )
+
+ # Can still do column name inference
+ self._test_dataset(
+ inputs,
+ expected_output=[[x[i] for i in select_cols] for x in expected_output],
+ expected_keys=[column_names[i] for i in select_cols],
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ select_columns=select_cols,
+ )
+
+ # Can specify column names instead of indices
+ self._test_dataset(
+ inputs,
+ expected_output=[[x[i] for i in select_cols] for x in expected_output],
+ expected_keys=[column_names[i] for i in select_cols],
+ column_names=column_names,
+ batch_size=1,
+ num_epochs=1,
+ shuffle=False,
+ header=True,
+ select_columns=[column_names[i] for i in select_cols],
+ )
+
+ def testMakeCSVDataset_withSelectColsError(self):
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+ column_names = ["col%d" % i for i in range(5)]
+ str_int32_max = str(2**33)
+ inputs = [[
+ ",".join(x for x in column_names),
+ "0,%s,2.0,3e50,rabbit" % str_int32_max
+ ]]
+
+ select_cols = [1, 3, 4]
+ filenames = self._setup_files(inputs)
+
+ with self.assertRaises(ValueError):
+ # Mismatch in number of defaults and number of columns selected,
+ # should raise an error
+ self._make_csv_dataset(
+ filenames,
+ batch_size=1,
+ column_defaults=record_defaults,
+ column_names=column_names,
+ select_columns=select_cols)
+
+ with self.assertRaises(ValueError):
+ # Invalid column name should raise an error
+ self._make_csv_dataset(
+ filenames,
+ batch_size=1,
+ column_defaults=[[0]],
+ column_names=column_names,
+ label_name=None,
+ select_columns=["invalid_col_name"])
+
+ def testMakeCSVDataset_withShuffle(self):
+ record_defaults = [
+ constant_op.constant([], dtypes.int32),
+ constant_op.constant([], dtypes.int64),
+ constant_op.constant([], dtypes.float32),
+ constant_op.constant([], dtypes.float64),
+ constant_op.constant([], dtypes.string)
+ ]
+
+ def str_series(st):
+ return ",".join(str(i) for i in range(st, st + 5))
+
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [
+ [",".join(x for x in column_names)
+ ] + [str_series(5 * i) for i in range(15)],
+ [",".join(x for x in column_names)] +
+ [str_series(5 * i) for i in range(15, 20)],
+ ]
+
+ filenames = self._setup_files(inputs)
+
+ total_records = 20
+ for batch_size in [1, 2]:
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Test that shuffling with the same seed produces the same result
+ dataset1 = self._make_csv_dataset(
+ filenames,
+ column_defaults=record_defaults,
+ column_names=column_names,
+ batch_size=batch_size,
+ header=True,
+ shuffle=True,
+ shuffle_seed=5,
+ num_epochs=2,
+ )
+ dataset2 = self._make_csv_dataset(
+ filenames,
+ column_defaults=record_defaults,
+ column_names=column_names,
+ batch_size=batch_size,
+ header=True,
+ shuffle=True,
+ shuffle_seed=5,
+ num_epochs=2,
+ )
+ outputs1 = dataset1.make_one_shot_iterator().get_next()
+ outputs2 = dataset2.make_one_shot_iterator().get_next()
+ for _ in range(total_records // batch_size):
+ batch1 = nest.flatten(sess.run(outputs1))
+ batch2 = nest.flatten(sess.run(outputs2))
+ for i in range(len(batch1)):
+ self.assertAllEqual(batch1[i], batch2[i])
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ # Test that shuffling with a different seed produces different results
+ dataset1 = self._make_csv_dataset(
+ filenames,
+ column_defaults=record_defaults,
+ column_names=column_names,
+ batch_size=batch_size,
+ header=True,
+ shuffle=True,
+ shuffle_seed=5,
+ num_epochs=2,
+ )
+ dataset2 = self._make_csv_dataset(
+ filenames,
+ column_defaults=record_defaults,
+ column_names=column_names,
+ batch_size=batch_size,
+ header=True,
+ shuffle=True,
+ shuffle_seed=6,
+ num_epochs=2,
+ )
+ outputs1 = dataset1.make_one_shot_iterator().get_next()
+ outputs2 = dataset2.make_one_shot_iterator().get_next()
+ all_equal = False
+ for _ in range(total_records // batch_size):
+ batch1 = nest.flatten(sess.run(outputs1))
+ batch2 = nest.flatten(sess.run(outputs2))
+ for i in range(len(batch1)):
+ all_equal = all_equal and np.array_equal(batch1[i], batch2[i])
+ self.assertFalse(all_equal)
+
+ def testIndefiniteRepeatShapeInference(self):
+ column_names = ["col%d" % i for i in range(5)]
+ inputs = [[",".join(x for x in column_names), "0,1,2,3,4", "5,6,7,8,9"], [
+ ",".join(x for x in column_names), "10,11,12,13,14", "15,16,17,18,19"
+ ]]
+ filenames = self._setup_files(inputs)
+ dataset = self._make_csv_dataset(filenames, batch_size=32, num_epochs=None)
+ for shape in nest.flatten(dataset.output_shapes):
+ self.assertEqual(32, shape[0])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py
new file mode 100644
index 0000000000..657cf3c00e
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/make_tf_record_dataset_test.py
@@ -0,0 +1,243 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.make_tf_record_dataset()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class MakeTFRecordDatasetTest(
+ reader_dataset_ops_test_base.TFRecordDatasetTestBase):
+
+ def _interleave(self, iterators, cycle_length):
+ pending_iterators = iterators
+ open_iterators = []
+ num_open = 0
+ for i in range(cycle_length):
+ if pending_iterators:
+ open_iterators.append(pending_iterators.pop(0))
+ num_open += 1
+
+ while num_open:
+ for i in range(min(cycle_length, len(open_iterators))):
+ if open_iterators[i] is None:
+ continue
+ try:
+ yield next(open_iterators[i])
+ except StopIteration:
+ if pending_iterators:
+ open_iterators[i] = pending_iterators.pop(0)
+ else:
+ open_iterators[i] = None
+ num_open -= 1
+
+ def _next_expected_batch(self,
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length,
+ drop_final_batch,
+ use_parser_fn):
+
+ def _next_record(file_indices):
+ for j in file_indices:
+ for i in range(self._num_records):
+ yield j, i
+
+ def _next_record_interleaved(file_indices, cycle_length):
+ return self._interleave([_next_record([i]) for i in file_indices],
+ cycle_length)
+
+ record_batch = []
+ batch_index = 0
+ for _ in range(num_epochs):
+ if cycle_length == 1:
+ next_records = _next_record(file_indices)
+ else:
+ next_records = _next_record_interleaved(file_indices, cycle_length)
+ for f, r in next_records:
+ record = self._record(f, r)
+ if use_parser_fn:
+ record = record[1:]
+ record_batch.append(record)
+ batch_index += 1
+ if len(record_batch) == batch_size:
+ yield record_batch
+ record_batch = []
+ batch_index = 0
+ if record_batch and not drop_final_batch:
+ yield record_batch
+
+ def _verify_records(self,
+ sess,
+ outputs,
+ batch_size,
+ file_index,
+ num_epochs,
+ interleave_cycle_length,
+ drop_final_batch,
+ use_parser_fn):
+ if file_index is not None:
+ file_indices = [file_index]
+ else:
+ file_indices = range(self._num_files)
+
+ for expected_batch in self._next_expected_batch(
+ file_indices, batch_size, num_epochs, interleave_cycle_length,
+ drop_final_batch, use_parser_fn):
+ actual_batch = sess.run(outputs)
+ self.assertAllEqual(expected_batch, actual_batch)
+
+ def _read_test(self, batch_size, num_epochs, file_index=None,
+ num_parallel_reads=1, drop_final_batch=False, parser_fn=False):
+ if file_index is None:
+ file_pattern = self.test_filenames
+ else:
+ file_pattern = self.test_filenames[file_index]
+
+ if parser_fn:
+ fn = lambda x: string_ops.substr(x, 1, 999)
+ else:
+ fn = None
+
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ outputs = readers.make_tf_record_dataset(
+ file_pattern=file_pattern,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ parser_fn=fn,
+ num_parallel_reads=num_parallel_reads,
+ drop_final_batch=drop_final_batch,
+ shuffle=False).make_one_shot_iterator().get_next()
+ self._verify_records(
+ sess, outputs, batch_size, file_index, num_epochs=num_epochs,
+ interleave_cycle_length=num_parallel_reads,
+ drop_final_batch=drop_final_batch, use_parser_fn=parser_fn)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(outputs)
+
+ def testRead(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 3]:
+ # Basic test: read from file 0.
+ self._read_test(batch_size, num_epochs, 0)
+
+ # Basic test: read from file 1.
+ self._read_test(batch_size, num_epochs, 1)
+
+ # Basic test: read from both files.
+ self._read_test(batch_size, num_epochs)
+
+ # Basic test: read from both files, with parallel reads.
+ self._read_test(batch_size, num_epochs, num_parallel_reads=8)
+
+ def testDropFinalBatch(self):
+ for batch_size in [1, 2, 10]:
+ for num_epochs in [1, 3]:
+ # Read from file 0.
+ self._read_test(batch_size, num_epochs, 0, drop_final_batch=True)
+
+ # Read from both files.
+ self._read_test(batch_size, num_epochs, drop_final_batch=True)
+
+ # Read from both files, with parallel reads.
+ self._read_test(batch_size, num_epochs, num_parallel_reads=8,
+ drop_final_batch=True)
+
+ def testParserFn(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 3]:
+ for drop_final_batch in [False, True]:
+ self._read_test(batch_size, num_epochs, parser_fn=True,
+ drop_final_batch=drop_final_batch)
+ self._read_test(batch_size, num_epochs, num_parallel_reads=8,
+ parser_fn=True, drop_final_batch=drop_final_batch)
+
+ def _shuffle_test(self, batch_size, num_epochs, num_parallel_reads=1,
+ seed=None):
+ with ops.Graph().as_default() as g:
+ with self.session(graph=g) as sess:
+ dataset = readers.make_tf_record_dataset(
+ file_pattern=self.test_filenames,
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ num_parallel_reads=num_parallel_reads,
+ shuffle=True,
+ shuffle_seed=seed)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ sess.run(iterator.initializer)
+ first_batches = []
+ try:
+ while True:
+ first_batches.append(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+
+ sess.run(iterator.initializer)
+ second_batches = []
+ try:
+ while True:
+ second_batches.append(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+
+ self.assertEqual(len(first_batches), len(second_batches))
+ if seed is not None:
+ # if you set a seed, should get the same results
+ for i in range(len(first_batches)):
+ self.assertAllEqual(first_batches[i], second_batches[i])
+
+ expected = []
+ for f in range(self._num_files):
+ for r in range(self._num_records):
+ expected.extend([self._record(f, r)] * num_epochs)
+
+ for batches in (first_batches, second_batches):
+ actual = []
+ for b in batches:
+ actual.extend(b)
+ self.assertAllEqual(sorted(expected), sorted(actual))
+
+ def testShuffle(self):
+ for batch_size in [1, 2]:
+ for num_epochs in [1, 3]:
+ for num_parallel_reads in [1, 2]:
+ # Test that all expected elements are produced
+ self._shuffle_test(batch_size, num_epochs, num_parallel_reads)
+ # Test that elements are produced in a consistent order if
+ # you specify a seed.
+ self._shuffle_test(batch_size, num_epochs, num_parallel_reads,
+ seed=21345)
+
+ def testIndefiniteRepeatShapeInference(self):
+ dataset = readers.make_tf_record_dataset(
+ file_pattern=self.test_filenames, num_epochs=None, batch_size=32)
+ for shape in nest.flatten(dataset.output_shapes):
+ self.assertEqual(32, shape[0])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
new file mode 100644
index 0000000000..d444c4082e
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/map_and_batch_test.py
@@ -0,0 +1,368 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.map_and_batch()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import script_ops
+from tensorflow.python.platform import test
+
+
+class MapAndBatchTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("Default", None, None),
+ ("SequentialCalls", 1, None),
+ ("ParallelCalls", 2, None),
+ ("ParallelBatches", None, 10),
+ )
+ def testMapAndBatch(self, num_parallel_calls, num_parallel_batches):
+ """Test a dataset that maps a TF function across its input elements."""
+ # The pipeline is TensorSliceDataset ->
+ # RepeatDataset(count) -> MapAndBatchDataset(square_3, batch_size).
+ components = (np.arange(7),
+ np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
+ np.array(37.0) * np.arange(7))
+
+ count = array_ops.placeholder(dtypes.int64, shape=[])
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ iterator = (
+ dataset_ops.Dataset.from_tensor_slices(components).repeat(count).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_calls=num_parallel_calls,
+ num_parallel_batches=num_parallel_batches))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ self.assertEqual([[None] + list(c.shape[1:]) for c in components],
+ [t.shape.as_list() for t in get_next])
+
+ with self.cached_session() as sess:
+ # Batch of a finite input, where the batch_size divides the
+ # total number of elements.
+ sess.run(init_op, feed_dict={count: 28, batch_size: 14})
+ num_batches = (28 * 7) // 14
+ for i in range(num_batches):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(14):
+ self.assertAllEqual(component[(i * 14 + j) % 7]**2,
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Batch of a finite input, where the batch_size does not
+ # divide the total number of elements.
+ sess.run(init_op, feed_dict={count: 14, batch_size: 8})
+
+ # We expect (num_batches - 1) full-sized batches.
+ num_batches = int(math.ceil((14 * 7) / 8))
+ for i in range(num_batches - 1):
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range(8):
+ self.assertAllEqual(component[(i * 8 + j) % 7]**2,
+ result_component[j])
+ result = sess.run(get_next)
+ for component, result_component in zip(components, result):
+ for j in range((14 * 7) % 8):
+ self.assertAllEqual(component[((num_batches - 1) * 8 + j) % 7]**2,
+ result_component[j])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Batch of an empty input should fail straight away.
+ sess.run(init_op, feed_dict={count: 0, batch_size: 8})
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Empty batch should be an initialization time error.
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(init_op, feed_dict={count: 14, batch_size: 0})
+
+ @parameterized.named_parameters(
+ ("Even", False),
+ ("Uneven", True),
+ )
+ def testMapAndBatchPartialBatch(self, drop_remainder):
+ iterator = (
+ dataset_ops.Dataset.range(10).apply(
+ batching.map_and_batch(
+ lambda x: array_ops.reshape(x * x, [1]),
+ batch_size=4,
+ drop_remainder=drop_remainder)).make_one_shot_iterator())
+ if drop_remainder:
+ self.assertEqual([4, 1], iterator.output_shapes.as_list())
+ else:
+ self.assertEqual([None, 1], iterator.output_shapes.as_list())
+ next_element = iterator.get_next()
+ with self.cached_session() as sess:
+ self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
+ self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
+ if not drop_remainder:
+ self.assertAllEqual([[64], [81]], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testMapAndBatchYieldsPartialBatch(self):
+ iterator = (dataset_ops.Dataset.range(10)
+ .apply(batching.map_and_batch(
+ lambda x: array_ops.reshape(x * x, [1]), 4))
+ .make_one_shot_iterator())
+ self.assertEqual([None, 1], iterator.output_shapes.as_list())
+ next_element = iterator.get_next()
+ with self.cached_session() as sess:
+ self.assertAllEqual([[0], [1], [4], [9]], sess.run(next_element))
+ self.assertAllEqual([[16], [25], [36], [49]], sess.run(next_element))
+ self.assertAllEqual([[64], [81]], sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testMapAndBatchParallelGetNext(self):
+ iterator = (dataset_ops.Dataset.range(50000)
+ .apply(batching.map_and_batch(lambda x: x, batch_size=100))
+ .make_one_shot_iterator())
+ elements = []
+ for _ in range(100):
+ elements.append(iterator.get_next())
+ with self.cached_session() as sess:
+ for i in range(5):
+ got = sess.run(elements)
+ got.sort(key=lambda x: x[0])
+ expected = []
+ for j in range(100):
+ expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
+ self.assertAllEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elements)
+
+ def testMapAndBatchParallelGetNextDropRemainder(self):
+ iterator = (
+ dataset_ops.Dataset.range(49999).apply(
+ batching.map_and_batch(
+ lambda x: x, batch_size=100, drop_remainder=True))
+ .make_one_shot_iterator())
+ elements = []
+ for _ in range(100):
+ elements.append(iterator.get_next())
+ with self.cached_session() as sess:
+ for i in range(4):
+ got = sess.run(elements)
+ got.sort(key=lambda x: x[0])
+ expected = []
+ for j in range(100):
+ expected.append(range(i*10000+j*100, i*10000+(j+1)*100))
+ self.assertAllEqual(got, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(elements)
+
+ def testMapAndBatchSparse(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ iterator = dataset_ops.Dataset.range(10).apply(
+ batching.map_and_batch(_sparse, 5)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for i in range(2):
+ actual = sess.run(get_next)
+ expected = sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
+ values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
+ dense_shape=[5, 1])
+ self.assertTrue(sparse_tensor.is_sparse(actual))
+ self.assertSparseValuesEqual(actual, expected)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testMapAndBatchFails(self):
+ """Test a dataset that maps a TF function across its input elements."""
+ dataset = dataset_ops.Dataset.from_tensors(
+ array_ops.check_numerics(
+ constant_op.constant(1.0) / constant_op.constant(0.0), "oops"))
+ batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+ iterator = (
+ dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ with self.cached_session() as sess:
+ with self.assertRaisesRegexp(errors.InvalidArgumentError, "oops"):
+ sess.run(init_op, feed_dict={batch_size: 14})
+
+ def testMapAndBatchShapeMismatch(self):
+ """Test a dataset that maps a TF function across its input elements."""
+
+ def generator():
+ yield [1]
+ yield [2]
+ yield [3]
+ yield [[4, 5, 6]]
+
+ dataset = dataset_ops.Dataset.from_generator(
+ generator, output_types=dtypes.int32)
+ batch_size = 4
+ iterator = (
+ dataset.apply(batching.map_and_batch(lambda x: x, batch_size))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ "number of elements does not match"):
+ sess.run(get_next)
+
+ def testMapAndBatchImplicitDispose(self):
+ # Tests whether a map and batch dataset will be cleaned up correctly when
+ # the pipeline does not run it until exhaustion.
+ # The pipeline is TensorSliceDataset -> RepeatDataset(1000) ->
+ # MapAndBatchDataset(f=square_3, batch_size=100).
+ components = (np.arange(1000),
+ np.array([[1, 2, 3]]) * np.arange(1000)[:, np.newaxis],
+ np.array(37.0) * np.arange(1000))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(components).repeat(
+ 1000).apply(batching.map_and_batch(_map_fn, batch_size=100))
+ dataset = dataset.prefetch(5)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for _ in range(3):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", 0),
+ ("2", 5),
+ ("3", 10),
+ ("4", 90),
+ ("5", 95),
+ ("6", 99),
+ )
+ def testMapAndBatchOutOfRangeError(self, threshold):
+
+ def raising_py_fn(i):
+ if i >= threshold:
+ raise StopIteration()
+ else:
+ return i
+
+ iterator = (
+ dataset_ops.Dataset.range(100).apply(
+ batching.map_and_batch(
+ lambda x: script_ops.py_func(raising_py_fn, [x], dtypes.int64),
+ batch_size=10)).make_one_shot_iterator())
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(threshold // 10):
+ self.assertAllEqual([i * 10 + j for j in range(10)], sess.run(get_next))
+ if threshold % 10 != 0:
+ self.assertAllEqual(
+ [threshold // 10 * 10 + j for j in range(threshold % 10)],
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @parameterized.named_parameters(
+ ("1", False, dtypes.bool),
+ ("2", -42, dtypes.int8),
+ ("3", -42, dtypes.int16),
+ ("4", -42, dtypes.int32),
+ ("5", -42, dtypes.int64),
+ ("6", 42, dtypes.uint8),
+ ("7", 42, dtypes.uint16),
+ ("8", 42.0, dtypes.float16),
+ ("9", 42.0, dtypes.float32),
+ ("10", 42.0, dtypes.float64),
+ ("11", b"hello", dtypes.string),
+ )
+ def testMapAndBatchTypes(self, element, dtype):
+ def gen():
+ yield element
+
+ dataset = dataset_ops.Dataset.from_generator(gen, dtype).repeat(100).apply(
+ batching.map_and_batch(lambda x: x, batch_size=10))
+
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ for _ in range(10):
+ self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
+
+ @parameterized.named_parameters(
+ ("Identity", None, lambda x: x, None),
+ ("Replicate", None, lambda x: (x, x), None),
+ ("Swap", (None, None), lambda x, y: (y, x), None),
+ ("Project", (None, None), lambda x, y: x, None),
+ )
+ def testShortCircuit(self, structure, map_fn, num_parallel_calls):
+ dataset = self.structuredDataset(structure).repeat().apply(
+ batching.map_and_batch(map_fn, batch_size=10))
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ if isinstance(structure, tuple):
+ expected = map_fn(
+ *sess.run(self.structuredElement(structure, shape=[10])))
+ else:
+ expected = map_fn(
+ sess.run(self.structuredElement(structure, shape=[10])))
+ self.assertAllEqual(expected, sess.run(get_next))
+
+ def testShortCircuitCapturedInput(self):
+ captured_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = self.structuredDataset(None).repeat().apply(
+ batching.map_and_batch(lambda x: captured_t, batch_size=10))
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={captured_t: 42})
+ self.assertAllEqual([42] * 10, sess.run(get_next))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
new file mode 100644
index 0000000000..ae9dedb0ab
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
@@ -0,0 +1,293 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for MapDefunOp."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import map_defun
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import data_flow_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapDefunTest(test_base.DatasetTestBase):
+
+ def testMapDefunSimple(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(2,)])[0]
+ expected = elems * 2 + 3
+ self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
+
+ def testMapDefunMismatchedTypes(self):
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return math_ops.cast(x, dtypes.float64)
+
+ nums = [1, 2, 3, 4, 5, 6]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(r)
+
+ def testMapDefunReduceDim(self):
+ # Tests where the output has a different rank from the input
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return array_ops.gather(x, 0)
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])[0]
+ expected = constant_op.constant([1, 3, 5])
+ self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
+
+ def testMapDefunMultipleOutputs(self):
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return (x, math_ops.cast(x * 2 + 3, dtypes.float64))
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(fn, [elems], [dtypes.int32, dtypes.float64], [(2,),
+ (2,)])
+ expected = [elems, elems * 2 + 3]
+ self.assertAllEqual(self.evaluate(r), self.evaluate(expected))
+
+ def testMapDefunShapeInference(self):
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return x
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])[0]
+ self.assertEqual(result.get_shape(), (3, 2))
+
+ def testMapDefunPartialShapeInference(self):
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return x
+
+ elems = array_ops.placeholder(dtypes.int64, (None, 2))
+ result = map_defun.map_defun(fn, [elems], [dtypes.int32], [(2,)])
+ self.assertEqual(result[0].get_shape().as_list(), [None, 2])
+
+ def testMapDefunRaisesErrorOnRuntimeShapeMismatch(self):
+
+ @function.Defun(dtypes.int32, dtypes.int32)
+ def fn(x, y):
+ return x, y
+
+ elems1 = array_ops.placeholder(dtypes.int32)
+ elems2 = array_ops.placeholder(dtypes.int32)
+ result = map_defun.map_defun(fn, [elems1, elems2],
+ [dtypes.int32, dtypes.int32], [(), ()])
+ with self.cached_session() as sess:
+ with self.assertRaisesWithPredicateMatch(
+ errors.InvalidArgumentError,
+ "All inputs must have the same dimension 0."):
+ sess.run(result, feed_dict={elems1: [1, 2, 3, 4, 5], elems2: [1, 2, 3]})
+
+ def testMapDefunRaisesDefunError(self):
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
+ return array_ops.identity(x)
+
+ elems = constant_op.constant([0, 0, 0, 37, 0])
+ result = map_defun.map_defun(fn, [elems], [dtypes.int32], [()])
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(result)
+
+ def testMapDefunCancelledCorrectly(self):
+
+ @function.Defun(dtypes.int64)
+ def defun(x):
+ # x has leading dimension 5, this will raise an error
+ return array_ops.gather(x, 10)
+
+ c = array_ops.tile(
+ array_ops.expand_dims(
+ constant_op.constant([1, 2, 3, 4, 5], dtype=dtypes.int64), 0),
+ [100, 1])
+ map_defun_op = map_defun.map_defun(defun, [c], [dtypes.int64], [()])[0]
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ r"indices = 10 is not in \[0, 5\)"):
+ self.evaluate(map_defun_op)
+
+ def testMapDefunWithUnspecifiedOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ res = x * 2 + 3
+ return (res, res + 1, res + 2)
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems],
+ [dtypes.int32, dtypes.int32, dtypes.int32],
+ [None, (None,), (2,)])
+ expected = elems * 2 + 3
+ self.assertAllEqual(self.evaluate(r[0]), self.evaluate(expected))
+ self.assertAllEqual(self.evaluate(r[1]), self.evaluate(expected + 1))
+ self.assertAllEqual(self.evaluate(r[2]), self.evaluate(expected + 2))
+
+ def testMapDefunWithDifferentOutputShapeEachRun(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ elems = array_ops.placeholder(dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ self.assertAllEqual(sess.run(r, feed_dict={elems: [0]}), [3])
+ self.assertAllEqual(
+ sess.run(r, feed_dict={elems: [[0], [1]]}), [[3], [5]])
+
+ def testMapDefunWithWrongOutputShape(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2 + 3
+
+ nums = [[1, 2], [3, 4], [5, 6]]
+ elems = constant_op.constant(nums, dtype=dtypes.int32, name="data")
+ r = map_defun.map_defun(simple_fn, [elems], [dtypes.int32], [(1,)])[0]
+ with self.assertRaises(errors.InvalidArgumentError):
+ self.evaluate(r)
+
+ def testMapDefunWithInvalidInput(self):
+
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ return x * 2
+
+ c = constant_op.constant(2)
+ with self.assertRaises(ValueError):
+ # Fails at graph construction time for inputs with known shapes.
+ r = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [None])[0]
+ p = array_ops.placeholder(dtypes.int32)
+ r = map_defun.map_defun(simple_fn, [p], [dtypes.int32], [None])[0]
+ with session.Session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(r, feed_dict={p: 0})
+
+ def _assert_op_cancelled(self, sess, map_defun_op):
+ with self.assertRaisesRegexp(errors.CancelledError, "was cancelled"):
+ sess.run(map_defun_op)
+
+ def testMapDefunWithParentCancellation(self):
+ # Checks that a cancellation of the parent graph is threaded through to
+ # MapDefunOp correctly.
+ @function.Defun(dtypes.int32)
+ def simple_fn(x):
+ del x
+ queue = data_flow_ops.FIFOQueue(10, dtypes.int32, ())
+ # Blocking
+ return queue.dequeue_many(5)
+
+ c = constant_op.constant([1, 2, 3, 4, 5])
+ map_defun_op = map_defun.map_defun(simple_fn, [c], [dtypes.int32], [()])[0]
+
+ with self.cached_session() as sess:
+ thread = self.checkedThread(
+ self._assert_op_cancelled, args=(sess, map_defun_op))
+ thread.start()
+ time.sleep(0.1)
+ sess.close()
+ thread.join()
+
+ def testMapDefunWithCapturedInputs(self):
+ c = constant_op.constant(2)
+
+ @function.Defun(dtypes.int32)
+ def fn(x):
+ return x + c
+
+ x = constant_op.constant([1, 2, 3, 4])
+ map_defun_op = map_defun.map_defun(fn, [x], [dtypes.int32], [()])[0]
+ expected = x + c
+ self.assertAllEqual(self.evaluate(expected), self.evaluate(map_defun_op))
+
+
+class MapDefunBenchmark(test.Benchmark):
+
+ def _run(self, op, name=None, num_iters=3000):
+ with session.Session() as sess:
+ # Warm up the session
+ for _ in range(5):
+ sess.run(op)
+ start = time.time()
+ for _ in range(num_iters):
+ sess.run(op)
+ end = time.time()
+ mean_us = (end - start) * 1e6 / num_iters
+ self.report_benchmark(
+ name=name,
+ iters=num_iters,
+ wall_time=mean_us,
+ extras={"examples_per_sec": num_iters / (end - start)})
+
+ def benchmarkDefunVsMapFn(self):
+ """Benchmarks to compare the performance of MapDefun vs tf.map_fn."""
+
+ @function.Defun(dtypes.int32)
+ def defun(x):
+ return array_ops.identity(x)
+
+ def map_fn(x):
+ return array_ops.identity(x)
+
+ base = math_ops.range(100)
+ for input_size in [10, 100, 1000, 10000]:
+ num_iters = 100000 // input_size
+ map_defun_op = map_defun.map_defun(defun, [base], [dtypes.int32], [()])
+ map_fn_op = functional_ops.map_fn(map_fn, base)
+
+ self._run(
+ map_defun_op,
+ "benchmarkMapDefun_size_%d" % input_size,
+ num_iters=num_iters)
+ self._run(
+ map_fn_op, "benchmarkMapFn_size_%d" % input_size, num_iters=num_iters)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
new file mode 100644
index 0000000000..c92bb8b9bc
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
@@ -0,0 +1,207 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "assert_next_dataset_op_test",
+ size = "medium",
+ srcs = ["assert_next_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "hoist_random_uniform_test",
+ size = "small",
+ srcs = ["hoist_random_uniform_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "latency_all_edges_test",
+ size = "small",
+ srcs = ["latency_all_edges_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/experimental/ops:stats_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "map_vectorization_test",
+ size = "small",
+ srcs = ["map_vectorization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "map_and_filter_fusion_test",
+ size = "medium",
+ srcs = ["map_and_filter_fusion_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "map_parallelization_test",
+ size = "small",
+ srcs = ["map_parallelization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "model_dataset_op_test",
+ size = "medium",
+ srcs = ["model_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "noop_elimination_test",
+ size = "small",
+ srcs = ["noop_elimination_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "optimize_dataset_op_test",
+ size = "small",
+ srcs = ["optimize_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py
new file mode 100644
index 0000000000..45b77b5c20
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -0,0 +1,65 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class AssertNextDatasetTest(test_base.DatasetTestBase):
+
+ def testAssertNext(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertEqual(0, sess.run(get_next))
+
+ def testAssertNextInvalid(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted Whoops transformation at offset 0 but encountered "
+ "Map transformation instead."):
+ sess.run(get_next)
+
+ def testAssertNextShort(self):
+ dataset = dataset_ops.Dataset.from_tensors(0).apply(
+ optimization.assert_next(["Map", "Whoops"])).map(lambda x: x)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ with self.assertRaisesRegexp(
+ errors.InvalidArgumentError,
+ "Asserted next 2 transformations but encountered only 1."):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
new file mode 100644
index 0000000000..81437c0aec
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
@@ -0,0 +1,106 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for HostState optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @staticmethod
+ def map_functions():
+ plus_one = lambda x: x + 1
+
+ def random(_):
+ return random_ops.random_uniform([],
+ minval=1,
+ maxval=10,
+ dtype=dtypes.float32,
+ seed=42)
+
+ def random_with_assert(x):
+ y = random(x)
+ assert_op = control_flow_ops.Assert(math_ops.greater_equal(y, 1), [y])
+ with ops.control_dependencies([assert_op]):
+ return y
+
+ twice_random = lambda x: (random(x) + random(x)) / 2.
+
+ tests = [("PlusOne", plus_one, False), ("RandomUniform", random, True),
+ ("RandomWithAssert", random_with_assert, True),
+ ("TwiceRandom", twice_random, False)]
+ return tuple(tests)
+
+ @parameterized.named_parameters(*map_functions.__func__())
+ def testHoisting(self, function, will_optimize):
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(
+ ["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
+
+ options = dataset_ops.Options()
+ options.experimental_hoist_random_uniform = True
+ dataset = dataset.with_options(options)
+ self._testDataset(dataset)
+
+ def testAdditionalInputs(self):
+ a = constant_op.constant(1, dtype=dtypes.float32)
+ b = constant_op.constant(0, dtype=dtypes.float32)
+ some_tensor = math_ops.mul(a, b)
+
+ def random_with_capture(_):
+ return some_tensor + random_ops.random_uniform(
+ [], minval=1, maxval=10, dtype=dtypes.float32, seed=42)
+
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(["Zip[0]", "Map"])).map(random_with_capture)
+ options = dataset_ops.Options()
+ options.experimental_hoist_random_uniform = True
+ dataset = dataset.with_options(options)
+ self._testDataset(dataset)
+
+ def _testDataset(self, dataset):
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ previous_result = 0
+ with self.cached_session() as sess:
+ for _ in range(5):
+ result = sess.run(get_next)
+ self.assertLessEqual(1, result)
+ self.assertLessEqual(result, 10)
+ # This checks if the result is somehow random by checking if we are not
+ # generating the same values.
+ self.assertNotEqual(previous_result, result)
+ previous_result = result
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
new file mode 100644
index 0000000000..26fec0414e
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
@@ -0,0 +1,59 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the LatencyAllEdges optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.experimental.ops import stats_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
+
+ def testLatencyStatsOptimization(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.from_tensors(1).apply(
+ optimization.assert_next(
+ ["LatencyStats", "Map", "LatencyStats", "Prefetch",
+ "LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ options = dataset_ops.Options()
+ options.experimental_latency_all_edges = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ self.assertEqual(1 * 1, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str,
+ "record_latency_TensorDataset/_1", 1)
+ self._assertSummaryHasCount(summary_str, "record_latency_MapDataset/_4",
+ 1)
+ self._assertSummaryHasCount(summary_str,
+ "record_latency_PrefetchDataset/_6", 1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
new file mode 100644
index 0000000000..7f8a4e6406
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -0,0 +1,232 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapAndFilterFusion optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @staticmethod
+ def map_functions():
+ identity = lambda x: x
+ increment = lambda x: x + 1
+
+ def increment_and_square(x):
+ y = x + 1
+ return y * y
+
+ functions = [identity, increment, increment_and_square]
+ tests = []
+ for i, fun1 in enumerate(functions):
+ for j, fun2 in enumerate(functions):
+ tests.append((
+ "Test{}{}".format(i, j),
+ [fun1, fun2],
+ ))
+ for k, fun3 in enumerate(functions):
+ tests.append((
+ "Test{}{}{}".format(i, j, k),
+ [fun1, fun2, fun3],
+ ))
+
+ swap = lambda x, n: (n, x)
+ tests.append((
+ "Swap1",
+ [lambda x: (x, 42), swap],
+ ))
+ tests.append((
+ "Swap2",
+ [lambda x: (x, 42), swap, swap],
+ ))
+ return tuple(tests)
+
+ @parameterized.named_parameters(*map_functions.__func__())
+ def testMapFusion(self, functions):
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(["Map", "Prefetch"]))
+ for function in functions:
+ dataset = dataset.map(function)
+
+ dataset = dataset.prefetch(0)
+ options = dataset_ops.Options()
+ options.experimental_map_fusion = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ for x in range(5):
+ result = sess.run(get_next)
+ r = x
+ for function in functions:
+ if isinstance(r, tuple):
+ r = function(*r) # Pass tuple as multiple arguments.
+ else:
+ r = function(r)
+ self.assertAllEqual(r, result)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ @staticmethod
+ def map_and_filter_functions():
+ identity = lambda x: x
+ increment = lambda x: x + 1
+ minus_five = lambda x: x - 5
+
+ def increment_and_square(x):
+ y = x + 1
+ return y * y
+
+ take_all = lambda x: constant_op.constant(True)
+ is_zero = lambda x: math_ops.equal(x, 0)
+ is_odd = lambda x: math_ops.equal(x % 2, 0)
+ greater = lambda x: math_ops.greater(x + 5, 0)
+
+ functions = [identity, increment, minus_five, increment_and_square]
+ filters = [take_all, is_zero, is_odd, greater]
+ tests = []
+
+ for x, fun in enumerate(functions):
+ for y, predicate in enumerate(filters):
+ tests.append(("Mixed{}{}".format(x, y), fun, predicate))
+
+ # Multi output
+ tests.append(("Multi1", lambda x: (x, x),
+ lambda x, y: constant_op.constant(True)))
+ tests.append(
+ ("Multi2", lambda x: (x, 2),
+ lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)))
+ return tuple(tests)
+
+ @parameterized.named_parameters(*map_and_filter_functions.__func__())
+ def testMapFilterFusion(self, function, predicate):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["Map", "FilterByLastComponent"])).map(function).filter(predicate)
+ options = dataset_ops.Options()
+ options.experimental_map_and_filter_fusion = True
+ dataset = dataset.with_options(options)
+ self._testMapAndFilter(dataset, function, predicate)
+
+ def _testMapAndFilter(self, dataset, function, predicate):
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ for x in range(10):
+ r = function(x)
+ if isinstance(r, tuple):
+ b = predicate(*r) # Pass tuple as multiple arguments.
+ else:
+ b = predicate(r)
+ if sess.run(b):
+ result = sess.run(get_next)
+ self.assertAllEqual(r, result)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testAdditionalInputs(self):
+ a = constant_op.constant(3, dtype=dtypes.int64)
+ b = constant_op.constant(4, dtype=dtypes.int64)
+ some_tensor = math_ops.mul(a, b)
+ function = lambda x: x * x
+
+ def predicate(y):
+ return math_ops.less(math_ops.cast(y, dtypes.int64), some_tensor)
+
+ # We are currently not supporting functions with additional inputs.
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(["Map",
+ "Filter"])).map(function).filter(predicate)
+ options = dataset_ops.Options()
+ options.experimental_map_and_filter_fusion = True
+ dataset = dataset.with_options(options)
+ self._testMapAndFilter(dataset, function, predicate)
+
+ @staticmethod
+ def filter_functions():
+ take_all = lambda x: constant_op.constant(True)
+ is_zero = lambda x: math_ops.equal(x, 0)
+ greater = lambda x: math_ops.greater(x + 5, 0)
+
+ tests = []
+ filters = [take_all, is_zero, greater]
+ identity = lambda x: x
+ for x, predicate_1 in enumerate(filters):
+ for y, predicate_2 in enumerate(filters):
+ tests.append(("Mixed{}{}".format(x, y), identity,
+ [predicate_1, predicate_2]))
+ for z, predicate_3 in enumerate(filters):
+ tests.append(("Mixed{}{}{}".format(x, y, z), identity,
+ [predicate_1, predicate_2, predicate_3]))
+
+ take_all_multiple = lambda x, y: constant_op.constant(True)
+ # Multi output
+ tests.append(("Multi1", lambda x: (x, x),
+ [take_all_multiple, take_all_multiple]))
+ tests.append(("Multi2", lambda x: (x, 2), [
+ take_all_multiple,
+ lambda x, y: math_ops.equal(x * math_ops.cast(y, dtypes.int64), 0)
+ ]))
+ return tuple(tests)
+
+ @parameterized.named_parameters(*filter_functions.__func__())
+ def testFilterFusion(self, map_function, predicates):
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(["Map", "Filter",
+ "Prefetch"])).map(map_function)
+ for predicate in predicates:
+ dataset = dataset.filter(predicate)
+
+ dataset = dataset.prefetch(0)
+ options = dataset_ops.Options()
+ options.experimental_filter_fusion = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ for x in range(5):
+ r = map_function(x)
+ filtered = False
+ for predicate in predicates:
+ if isinstance(r, tuple):
+ b = predicate(*r) # Pass tuple as multiple arguments.
+ else:
+ b = predicate(r)
+ if not sess.run(b):
+ filtered = True
+ break
+
+ if not filtered:
+ result = sess.run(get_next)
+ self.assertAllEqual(r, result)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
new file mode 100644
index 0000000000..ce9c9bc47b
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
@@ -0,0 +1,87 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapParallelization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @staticmethod
+ def map_functions():
+ identity = lambda x: x
+ increment = lambda x: x + 1
+
+ def assert_greater(x):
+ assert_op = control_flow_ops.Assert(math_ops.greater(x, -1), [x])
+ with ops.control_dependencies([assert_op]):
+ return x
+
+ def random(_):
+ return random_ops.random_uniform([],
+ minval=0,
+ maxval=10,
+ dtype=dtypes.int64,
+ seed=42)
+
+ def assert_with_random(x):
+ x = assert_greater(x)
+ return random(x)
+
+ return (("Identity", identity, True), ("Increment", increment, True),
+ ("AssertGreater", assert_greater, True), ("Random", random, False),
+ ("AssertWithRandom", assert_with_random, False))
+
+ @parameterized.named_parameters(*map_functions.__func__())
+ def testMapParallelization(self, function, should_optimize):
+ next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
+ dataset = dataset_ops.Dataset.range(5).apply(
+ optimization.assert_next(next_nodes)).map(function)
+ options = dataset_ops.Options()
+ options.experimental_map_parallelization = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ for x in range(5):
+ result = sess.run(get_next)
+ # No need to run the pipeline if it was not optimized. Also the results
+ # might be hard to check because of random.
+ if not should_optimize:
+ return
+ r = function(x)
+ self.assertAllEqual(r, result)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
new file mode 100644
index 0000000000..971a2d94b9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
@@ -0,0 +1,231 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapVectorization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ def _get_test_datasets(self,
+ base_dataset,
+ map_fn,
+ num_parallel_calls=None,
+ expect_optimized=True):
+ """Given base dataset and map fn, creates test datasets.
+
+ Returns a tuple of (unoptimized, dataset, optimized dataset). The
+ unoptimized dataset has the assertion that Batch follows Map. The optimized
+ dataset has the assertion that Map follows Batch, and has the
+ "map_vectorization" optimization applied.
+
+ Args:
+ base_dataset: Input dataset to map->batch
+ map_fn: Map function to use
+ num_parallel_calls: (Optional.) num_parallel_calls argument for map
+ expect_optimized: (Optional.) Whether we expect the optimization to take
+ place, in which case we will assert that Batch is followed by Map,
+ otherwise Map followed by Batch. Defaults to True.
+
+ Returns:
+ Tuple of (unoptimized dataset, optimized dataset).
+ """
+ map_node_name = "Map" if num_parallel_calls is None else "ParallelMap"
+ batch_size = 100
+
+ def _make_dataset(node_names):
+ return base_dataset.apply(optimization.assert_next(node_names)).map(
+ map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)
+
+ unoptimized = _make_dataset([map_node_name, "Batch"])
+ optimized = _make_dataset(["Batch", map_node_name]
+ if expect_optimized else [map_node_name, "Batch"])
+ options = dataset_ops.Options()
+ options.experimental_map_vectorization = True
+ optimized = optimized.with_options(options)
+ return unoptimized, optimized
+
+ @parameterized.named_parameters(
+ ("Basic", lambda x: (x, x + 1), None),
+ ("Const", lambda x: 2, 12),
+ ("Parallel", lambda x: (x, x + 1), 12),
+ ("Gather", lambda x: array_ops.gather(x, 0), 12),
+ )
+ def testOptimization(self, map_fn, num_parallel_calls):
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
+ num_parallel_calls)
+ self.assertDatasetsEqual(unoptimized, optimized)
+
+ def testOptimizationBadMapFn(self):
+ # Test map functions that give an error
+ def map_fn(x):
+ # x has leading dimension 5, this will raise an error
+ return array_ops.gather(x, 10)
+
+ base_dataset = dataset_ops.Dataset.range(5).repeat(5).batch(
+ 5, drop_remainder=True)
+ _, optimized = self._get_test_datasets(base_dataset, map_fn)
+ nxt = optimized.make_one_shot_iterator().get_next()
+ with self.assertRaisesRegexp(errors.InvalidArgumentError,
+ r"indices = 10 is not in \[0, 5\)"):
+ self.evaluate(nxt)
+
+ def testOptimizationWithCapturedInputs(self):
+ # Tests that vectorization works with captured inputs
+ def map_fn(x):
+ return x + y
+
+ y = constant_op.constant(1, shape=(2,))
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ # TODO(rachelim): when this optimization works, turn on expect_optimized
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self.assertDatasetsEqual(optimized, unoptimized)
+
+ def testOptimizationIgnoreStateful(self):
+
+ def map_fn(x):
+ with ops.control_dependencies([check_ops.assert_equal(x, 0)]):
+ return array_ops.identity(x)
+
+ base_dataset = dataset_ops.Dataset.from_tensor_slices([[1, 2],
+ [3, 4]]).repeat(5)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self.assertDatasetsRaiseSameError(
+ unoptimized, optimized, errors.InvalidArgumentError,
+ [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)])
+
+ def testOptimizationIgnoreRagged(self):
+ # Make sure we ignore inputs that might not be uniformly sized
+ def map_fn(x):
+ return array_ops.gather(x, 0)
+
+ # output_shape = (?,)
+ base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self.assertDatasetsEqual(unoptimized, optimized)
+
+ def testOptimizationIgnoreRaggedMap(self):
+ # Don't optimize when the output of the map fn shapes are unknown.
+ def map_fn(x):
+ return array_ops.tile(x, x)
+
+ base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
+ unoptimized, optimized = self._get_test_datasets(
+ base_dataset, map_fn, expect_optimized=False)
+ self.assertDatasetsRaiseSameError(
+ unoptimized, optimized, errors.InvalidArgumentError,
+ [("OneShotIterator", "OneShotIterator_1", 1),
+ ("IteratorGetNext", "IteratorGetNext_1", 1)])
+
+
+class MapVectorizationBenchmark(test.Benchmark):
+ # TODO(rachelim): Add a benchmark for more expensive transformations, such as
+ # vgg_preprocessing.
+
+ def _run(self, x, num_iters=100, name=None):
+ deltas = []
+ with session.Session() as sess:
+ for _ in range(5):
+ # Warm up session...
+ sess.run(x)
+ for _ in range(num_iters):
+ start = time.time()
+ sess.run(x)
+ end = time.time()
+ deltas.append(end - start)
+ median_time = np.median(deltas)
+ self.report_benchmark(iters=num_iters, wall_time=median_time, name=name)
+ return median_time
+
+ def _compare(self, input_dataset, map_fn, batch_size, input_size, str_id):
+ num_elems = np.prod(input_size)
+ name_template = "{}__batch_size_{}_input_size_{}_{}"
+ unoptimized = input_dataset.map(map_fn).batch(batch_size)
+ unoptimized_op = unoptimized.make_one_shot_iterator().get_next()
+
+ optimized = input_dataset.map(map_fn).batch(batch_size)
+ options = dataset_ops.Options()
+ options.experimental_map_vectorization = True
+ optimized = optimized.with_options(options)
+ optimized_op = optimized.make_one_shot_iterator().get_next()
+
+ unoptimized_time = self._run(
+ unoptimized_op,
+ name=name_template.format(str_id, batch_size, num_elems, "unoptimized"))
+ optimized_time = self._run(
+ optimized_op,
+ name=name_template.format(str_id, batch_size, num_elems, "optimized"))
+
+ print("Batch size: {}\n"
+ "Input size: {}\n"
+ "Transformation: {}\n"
+ "Speedup: {}\n".format(batch_size, input_size, str_id,
+ (unoptimized_time / optimized_time)))
+
+ # Known cheap functions
+ def benchmarkIdentity(self):
+ self._benchmark_helper(lambda *args: [array_ops.identity(x) for x in args],
+ "identity")
+
+ def benchmarkAddConst(self):
+ self._benchmark_helper(lambda *args: [x + 1 for x in args], "add_const")
+
+ def benchmarkReturnConst(self):
+ self._benchmark_helper(lambda *args: [constant_op.constant(2)], "ret_const")
+
+ def benchmarkSelect(self):
+ self._benchmark_helper(lambda *args: args[0], "select")
+
+ def benchmarkCast(self):
+ self._benchmark_helper(
+ lambda *args: [math_ops.cast(x, dtypes.float64) for x in args], "cast")
+
+ def _benchmark_helper(self, map_fn, str_id):
+ input_sizes = [(10, 10, 3), (10, 100, 300)]
+ batch_size = 1000
+ for input_size in input_sizes:
+ input_dataset = dataset_ops.Dataset.from_tensor_slices(
+ (np.random.rand(*input_size), np.random.rand(*input_size))).repeat()
+ self._compare(input_dataset, map_fn, batch_size, input_size, str_id)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
new file mode 100644
index 0000000000..82516356df
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
@@ -0,0 +1,193 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ModelDatasetTest(test_base.DatasetTestBase):
+
+ def testModelMap(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(math_ops.matmul)
+ options = dataset_ops.Options()
+ options.experimental_autotune = True
+ iterator = dataset.with_options(options).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(100):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelParallelMap(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(
+ math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE)
+ options = dataset_ops.Options()
+ options.experimental_autotune = True
+ iterator = dataset.with_options(options).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(1000):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelMapAndBatch(self):
+ batch_size = 16
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.apply(
+ batching.map_and_batch(
+ math_ops.matmul,
+ num_parallel_calls=optimization.AUTOTUNE,
+ batch_size=batch_size))
+ options = dataset_ops.Options()
+ options.experimental_autotune = True
+ iterator = dataset.with_options(options).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(10):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelParallelInterleave(self):
+ k = 1024 * 1024
+ dataset = dataset_ops.Dataset.from_tensors((np.random.rand(1, 4 * k),
+ np.random.rand(4 * k,
+ 1))).repeat()
+ dataset = dataset.map(math_ops.matmul)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset,
+ cycle_length=10,
+ num_parallel_calls=optimization.AUTOTUNE)
+ options = dataset_ops.Options()
+ options.experimental_autotune = True
+ iterator = dataset.with_options(options).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next.op)
+ for _ in range(1000):
+ start = time.time()
+ sess.run(get_next.op)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+ def testModelNested(self):
+ k = 1024 * 1024
+ a = (np.random.rand(1, 8 * k), np.random.rand(8 * k, 1))
+ b = (np.random.rand(1, 4 * k), np.random.rand(4 * k, 1))
+ c = (np.random.rand(1, 2 * k), np.random.rand(2 * k, 1))
+ dataset = dataset_ops.Dataset.from_tensors((a, b, c)).repeat()
+
+ def f1(a, b, c):
+ x, y = a
+ return math_ops.matmul(x, y), b, c
+
+ def f2(a, b, c):
+ x, y = b
+ return a, math_ops.matmul(x, y), c
+
+ def f3(a, b, c):
+ x, y = c
+ return a, b, math_ops.matmul(x, y)
+
+ dataset = dataset.map(f1, num_parallel_calls=optimization.AUTOTUNE)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset, cycle_length=2)
+
+ dataset = dataset.map(f2, num_parallel_calls=optimization.AUTOTUNE)
+ dataset = dataset_ops.Dataset.range(1).repeat().interleave(
+ lambda _: dataset, cycle_length=2)
+
+ dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE)
+ options = dataset_ops.Options()
+ options.experimental_autotune = True
+ iterator = dataset.with_options(options).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ deltas = []
+ with self.cached_session() as sess:
+ for _ in range(5):
+ sess.run(get_next)
+ for _ in range(100):
+ start = time.time()
+ sess.run(get_next)
+ end = time.time()
+ deltas.append(end - start)
+
+ print("%f (median), %f (mean), %f (stddev), %f (min), %f (max)\n" %
+ (np.median(deltas), np.mean(deltas), np.std(deltas), np.min(deltas),
+ np.max(deltas)))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py
new file mode 100644
index 0000000000..fb0640fe9f
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py
@@ -0,0 +1,60 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapParallelization optimization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class NoopEliminationTest(test_base.DatasetTestBase):
+
+ def testNoopElimination(self):
+ a = constant_op.constant(1, dtype=dtypes.int64)
+ b = constant_op.constant(2, dtype=dtypes.int64)
+ some_tensor = math_ops.mul(a, b)
+
+ dataset = dataset_ops.Dataset.range(5)
+ dataset = dataset.apply(
+ optimization.assert_next(
+ ["FiniteRepeat", "FiniteSkip", "Prefetch", "Prefetch"]))
+ dataset = dataset.repeat(some_tensor).skip(5).prefetch(0).take(-1).skip(
+ 0).repeat(1).prefetch(0)
+ options = dataset_ops.Options()
+ options.experimental_noop_elimination = True
+ dataset = dataset.with_options(options)
+
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.test_session() as sess:
+ for x in range(5):
+ result = sess.run(get_next)
+ self.assertAllEqual(result, x)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
new file mode 100644
index 0000000000..760cd8cc4e
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -0,0 +1,104 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class OptimizeDatasetTest(test_base.DatasetTestBase):
+
+ def testOptimizationDefault(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(["Map",
+ "Batch"])).map(lambda x: x * x).batch(10)
+ iterator = dataset.with_options(
+ dataset_ops.Options()).make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testOptimizationFusion(self):
+ dataset = dataset_ops.Dataset.range(10).apply(
+ optimization.assert_next(
+ ["MapAndBatch"])).map(lambda x: x * x).batch(10)
+ options = dataset_ops.Options()
+ options.experimental_map_and_batch_fusion = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testOptimizationStatefulFunction(self):
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda _: random_ops.random_uniform([])).batch(10)
+ options = dataset_ops.Options()
+ options.experimental_map_and_batch_fusion = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_one_shot_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(get_next)
+
+ def testOptimizationLargeInputFromTensor(self):
+ input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
+ dataset = dataset_ops.Dataset.from_tensors(input_t)
+ options = dataset_ops.Options()
+ options.experimental_map_and_batch_fusion = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op, {input_t: np.ones([512, 1024, 1025], np.int32)})
+ sess.run(get_next)
+
+ def testOptimizationLargeInputFromTensorSlices(self):
+ input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_t)
+ options = dataset_ops.Options()
+ options.experimental_map_and_batch_fusion = True
+ dataset = dataset.with_options(options)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op, {input_t: np.ones([1, 512, 1024, 1025], np.int32)})
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py b/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py
new file mode 100644
index 0000000000..5e419a9b2f
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/override_threadpool_test.py
@@ -0,0 +1,91 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the private `override_threadpool()` transformation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import threadpool
+from tensorflow.python.data.experimental.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import script_ops
+from tensorflow.python.platform import test
+
+
+class OverrideThreadpoolTest(test_base.DatasetTestBase,
+ parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("1", 1, None),
+ ("2", 2, None),
+ ("3", 4, None),
+ ("4", 8, None),
+ ("5", 16, None),
+ ("6", 4, -1),
+ ("7", 4, 0),
+ ("8", 4, 1),
+ ("9", 4, 4),
+ )
+ def testNumThreads(self, num_threads, max_intra_op_parallelism):
+
+ def get_thread_id(_):
+ # Python creates a dummy thread object to represent the current
+ # thread when called from an "alien" thread (such as a
+ # `PrivateThreadPool` thread in this case). It does not include
+ # the TensorFlow-given display name, but it has a unique
+ # identifier that maps one-to-one with the underlying OS thread.
+ return np.array(threading.current_thread().ident).astype(np.int64)
+
+ dataset = (
+ dataset_ops.Dataset.range(1000).map(
+ lambda x: script_ops.py_func(get_thread_id, [x], dtypes.int64),
+ num_parallel_calls=32).apply(unique.unique()))
+
+ dataset = threadpool.override_threadpool(
+ dataset,
+ threadpool.PrivateThreadPool(
+ num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name="private_thread_pool_%d" % num_threads))
+
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ thread_ids = []
+ try:
+ while True:
+ thread_ids.append(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+ self.assertEqual(len(thread_ids), len(set(thread_ids)))
+ self.assertGreater(len(thread_ids), 0)
+ # NOTE(mrry): We don't control the thread pool scheduling, and
+ # so cannot guarantee that all of the threads in the pool will
+ # perform work.
+ self.assertLessEqual(len(thread_ids), num_threads)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
new file mode 100644
index 0000000000..90ac250df7
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/parallel_interleave_test.py
@@ -0,0 +1,811 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.parallel_interleave()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+import math
+import threading
+import time
+
+from six.moves import zip_longest
+
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import script_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import test
+
+
+class ParallelInterleaveTest(test_base.DatasetTestBase):
+
+ def setUp(self):
+
+ self.input_values = array_ops.placeholder(dtypes.int64, shape=[None])
+ self.cycle_length = array_ops.placeholder(dtypes.int64, shape=[])
+ self.block_length = array_ops.placeholder(dtypes.int64, shape=[])
+ self.sloppy = array_ops.placeholder(dtypes.bool, shape=[])
+ self.buffer_output_elements = array_ops.placeholder(dtypes.int64, shape=[])
+ self.prefetch_input_elements = array_ops.placeholder(dtypes.int64, shape=[])
+
+ self.error = None
+ self.repeat_count = 2
+
+ # Set up threading events used to sequence when items are produced that
+ # are subsequently interleaved. These events allow us to deterministically
+ # simulate slowdowns and force sloppiness.
+ self.read_coordination_events = {}
+ self.write_coordination_events = {}
+ # input values [4, 5, 6] are the common case for the tests; set defaults
+ for i in range(4, 7):
+ self.read_coordination_events[i] = threading.Semaphore(0)
+ self.write_coordination_events[i] = threading.Event()
+
+ def map_py_fn(x):
+ self.write_coordination_events[x].wait()
+ self.write_coordination_events[x].clear()
+ self.read_coordination_events[x].release()
+ if self.error:
+ err = self.error
+ self.error = None
+ raise err # pylint: disable=raising-bad-type
+ return x * x
+
+ def map_fn(x):
+ return script_ops.py_func(map_py_fn, [x], x.dtype)
+
+ def interleave_fn(x):
+ dataset = dataset_ops.Dataset.from_tensors(x)
+ dataset = dataset.repeat(x)
+ return dataset.map(map_fn)
+
+ self.dataset = (
+ dataset_ops.Dataset.from_tensor_slices(self.input_values)
+ .repeat(self.repeat_count).apply(
+ interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
+ self.block_length, self.sloppy,
+ self.buffer_output_elements,
+ self.prefetch_input_elements)))
+ self.iterator = self.dataset.make_initializable_iterator()
+ self.init_op = self.iterator.initializer
+ self.next_element = self.iterator.get_next()
+
+ def _interleave(self, lists, cycle_length, block_length):
+ """Python implementation of interleave used for testing."""
+ num_open = 0
+
+ # `all_iterators` acts as a queue of iterators over each element of `lists`.
+ all_iterators = [iter(l) for l in lists]
+
+ # `open_iterators` are the iterators whose elements are currently being
+ # interleaved.
+ open_iterators = []
+ for i in range(cycle_length):
+ if all_iterators:
+ open_iterators.append(all_iterators.pop(0))
+ num_open += 1
+ else:
+ open_iterators.append(None)
+
+ while num_open or all_iterators:
+ for i in range(cycle_length):
+ if open_iterators[i] is None:
+ if all_iterators:
+ open_iterators[i] = all_iterators.pop(0)
+ num_open += 1
+ else:
+ continue
+ for _ in range(block_length):
+ try:
+ yield next(open_iterators[i])
+ except StopIteration:
+ open_iterators[i] = None
+ num_open -= 1
+ break
+
+ def testPythonImplementation(self):
+ input_lists = [[4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6],
+ [4, 4, 4, 4], [5, 5, 5, 5, 5], [6, 6, 6, 6, 6, 6]]
+
+ # Cycle length 1 acts like `Dataset.flat_map()`.
+ expected_elements = itertools.chain(*input_lists)
+ for expected, produced in zip(expected_elements,
+ self._interleave(input_lists, 1, 1)):
+ self.assertEqual(expected, produced)
+
+ # Cycle length > 1.
+ expected_elements = [
+ 4, 5, 4, 5, 4, 5, 4, 5, 5, 6, 6, 4, 6, 4, 6, 4, 6, 4, 6, 5, 6, 5, 6, 5,
+ 6, 5, 6, 5, 6, 6
+ ]
+ for index, (expected, produced) in enumerate(
+ zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
+ self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
+ (index, expected, produced))
+
+ def testPythonImplementationBlockLength(self):
+ input_lists = [[4] * 4, [5] * 5, [6] * 6] * 2
+ expected_elements = [
+ 4, 4, 5, 5, 4, 4, 5, 5, 5, 6, 6, 4, 4, 6, 6, 4, 4, 6, 6, 5, 5, 6, 6, 5,
+ 5, 6, 6, 5, 6, 6
+ ]
+ for index, (expected, produced) in enumerate(
+ zip_longest(expected_elements, self._interleave(input_lists, 2, 2))):
+ self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
+ (index, expected, produced))
+
+ def testPythonImplementationEmptyLists(self):
+ input_lists = [[4, 4, 4, 4], [], [6, 6, 6, 6, 6, 6], [4, 4, 4, 4], [],
+ [6, 6, 6, 6, 6, 6]]
+
+ expected_elements = [
+ 4, 4, 6, 4, 6, 4, 6, 6, 4, 6, 4, 6, 4, 4, 6, 6, 6, 6, 6, 6
+ ]
+ for index, (expected, produced) in enumerate(
+ zip_longest(expected_elements, self._interleave(input_lists, 2, 1))):
+ self.assertEqual(expected, produced, "Values differ at %s. %s != %s" %
+ (index, expected, produced))
+
+ def _clear_coordination_events(self):
+ for i in range(4, 7):
+ self.read_coordination_events[i] = threading.Semaphore(0)
+ self.write_coordination_events[i].clear()
+
+ def _allow_all_map_threads(self):
+ for i in range(4, 7):
+ self.write_coordination_events[i].set()
+
+ def _testSingleThreaded(self, sloppy=False, prefetch_input_elements=0):
+ # cycle_length=1,block_length=1 acts like `Dataset.interleave()` and
+ # `Dataset.flat_map()` and is single-threaded. No synchronization required.
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 1,
+ self.block_length: 1,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: prefetch_input_elements,
+ })
+
+ for expected_element in self._interleave(
+ [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 1):
+ self.write_coordination_events[expected_element].set()
+ self.assertEqual(expected_element * expected_element,
+ sess.run(self.next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testSingleThreaded(self):
+ self._testSingleThreaded()
+
+ def testSingleThreadedSloppy(self):
+ self._testSingleThreaded(sloppy=True)
+
+ def testSingleThreadedPrefetch1Itr(self):
+ self._testSingleThreaded(prefetch_input_elements=1)
+
+ def testSingleThreadedPrefetch1ItrSloppy(self):
+ self._testSingleThreaded(prefetch_input_elements=1, sloppy=True)
+
+ def testSingleThreadedRagged(self):
+ # Tests a sequence with wildly different elements per iterator.
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [3, 7, 4],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: False,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 1,
+ })
+
+ # Add coordination values for 3 and 7
+ self.read_coordination_events[3] = threading.Semaphore(0)
+ self.write_coordination_events[3] = threading.Event()
+ self.read_coordination_events[7] = threading.Semaphore(0)
+ self.write_coordination_events[7] = threading.Event()
+
+ for expected_element in self._interleave(
+ [[3] * 3, [7] * 7, [4] * 4] * self.repeat_count, 2, 1):
+ self.write_coordination_events[expected_element].set()
+ output = sess.run(self.next_element)
+ self.assertEqual(expected_element * expected_element, output)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def _testTwoThreadsNoContention(self, sloppy=False):
+ # num_threads > 1.
+ # Explicit coordination should result in `Dataset.interleave()` behavior
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 1,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 1)):
+ self.write_coordination_events[expected_element].set()
+ if done_first_event: # First event starts the worker threads.
+ self.read_coordination_events[expected_element].acquire()
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ self.read_coordination_events[expected_element].acquire()
+ done_first_event = True
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testTwoThreadsNoContention(self):
+ self._testTwoThreadsNoContention()
+
+ def testTwoThreadsNoContentionSloppy(self):
+ self._testTwoThreadsNoContention(sloppy=True)
+
+ def _testTwoThreadsNoContentionWithRaces(self, sloppy=False):
+ """Tests where all the workers race in producing elements.
+
+ Note: this is in contrast with the previous test which carefully sequences
+ the execution of the map functions.
+
+ Args:
+ sloppy: Whether to be sloppy or not.
+ """
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 1,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 1)):
+ if done_first_event: # First event starts the worker threads.
+ self._allow_all_map_threads()
+ self.read_coordination_events[expected_element].acquire()
+ else:
+ self.write_coordination_events[expected_element].set()
+ time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ done_first_event = True
+ self.assertTrue(
+ self.read_coordination_events[expected_element].acquire(False))
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testTwoThreadsNoContentionWithRaces(self):
+ self._testTwoThreadsNoContentionWithRaces()
+
+ def testTwoThreadsNoContentionWithRacesSloppy(self):
+ self._testTwoThreadsNoContentionWithRaces(sloppy=True)
+
+ def _testTwoThreadsNoContentionBlockLength(self, sloppy=False):
+ # num_threads > 1.
+ # Explicit coordination should result in `Dataset.interleave()` behavior
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 2,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 1,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 2)):
+ self.write_coordination_events[expected_element].set()
+ if done_first_event: # First event starts the worker threads.
+ self.read_coordination_events[expected_element].acquire()
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ done_first_event = True
+ self.read_coordination_events[expected_element].acquire()
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testTwoThreadsNoContentionBlockLength(self):
+ self._testTwoThreadsNoContentionBlockLength()
+
+ def testTwoThreadsNoContentionBlockLengthSloppy(self):
+ self._testTwoThreadsNoContentionBlockLength(sloppy=True)
+
+ def _testTwoThreadsNoContentionWithRacesAndBlocking(self, sloppy=False):
+ """Tests where all the workers race in producing elements.
+
+ Note: this is in contrast with the previous test which carefully sequences
+ the execution of the map functions.
+
+
+ Args:
+ sloppy: Whether to be sloppy or not.
+ """
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 2,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 1,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 2)):
+ if done_first_event: # First event starts the worker threads.
+ self._allow_all_map_threads()
+ self.read_coordination_events[expected_element].acquire()
+ else:
+ self.write_coordination_events[expected_element].set()
+ time.sleep(0.5) # Sleep to consistently "avoid" the race condition.
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ done_first_event = True
+ self.assertTrue(
+ self.read_coordination_events[expected_element].acquire(False))
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testTwoThreadsNoContentionWithRacesAndBlocking(self):
+ self._testTwoThreadsNoContentionWithRacesAndBlocking()
+
+ def testTwoThreadsNoContentionWithRacesAndBlockingSloppy(self):
+ self._testTwoThreadsNoContentionWithRacesAndBlocking(sloppy=True)
+
+ def _testEmptyInput(self, sloppy=False):
+ with self.cached_session() as sess:
+ # Empty input.
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [],
+ self.cycle_length: 2,
+ self.block_length: 3,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testEmptyInput(self):
+ self._testEmptyInput()
+
+ def testEmptyInputSloppy(self):
+ self._testEmptyInput(sloppy=True)
+
+ def _testNonEmptyInputIntoEmptyOutputs(self, sloppy=False):
+ # Non-empty input leading to empty output.
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [0, 0, 0],
+ self.cycle_length: 2,
+ self.block_length: 3,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testNonEmptyInputIntoEmptyOutputs(self):
+ self._testNonEmptyInputIntoEmptyOutputs()
+
+ def testNonEmptyInputIntoEmptyOutputsSloppy(self):
+ self._testNonEmptyInputIntoEmptyOutputs(sloppy=True)
+
+ def _testPartiallyEmptyOutputs(self, sloppy=False, prefetch_input_elements=1):
+ race_indices = {2, 8, 14} # Sequence points when sloppy mode has race conds
+ # Mixture of non-empty and empty interleaved datasets.
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 0, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: prefetch_input_elements,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [], [6] * 6] * self.repeat_count, 2, 1)):
+ self.write_coordination_events[expected_element].set()
+ # First event starts the worker threads. Additionally, when running the
+ # sloppy case with prefetch_input_elements=0, we get stuck if we wait
+ # for the read coordination event for certain event orderings in the
+ # presence of finishing iterators.
+ if done_first_event and not (sloppy and (i in race_indices)):
+ self.read_coordination_events[expected_element].acquire()
+ actual_element = sess.run(self.next_element)
+ if not done_first_event or (sloppy and (i in race_indices)):
+ done_first_event = True
+ self.read_coordination_events[expected_element].acquire()
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+
+ def testPartiallyEmptyOutputs(self):
+ self._testPartiallyEmptyOutputs()
+
+ def testPartiallyEmptyOutputsSloppy(self):
+ self._testPartiallyEmptyOutputs(sloppy=True, prefetch_input_elements=0)
+
+ def testDelayedOutputSloppy(self):
+ # Explicitly control the sequence of events to ensure we correctly avoid
+ # head-of-line blocking.
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: True,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+
+ mis_ordering = [
+ 4, 4, 5, 4, 5, 5, 4, 5, 6, 6, 6, 5, 4, 4, 6, 6, 4, 4, 6, 5, 6, 6, 6,
+ 6, 5, 5, 5, 5, 6, 6
+ ]
+ for element in mis_ordering:
+ self.write_coordination_events[element].set()
+ self.assertEqual(element * element, sess.run(self.next_element))
+ self.assertTrue(self.read_coordination_events[element].acquire(False))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testBlockLengthWithContentionSloppy(self):
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ done_first_event = False
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: True,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 1,
+ })
+ # Test against a generating sequence that differs from the uncontended
+ # case, in order to prove sloppy correctness.
+ for i, expected_element in enumerate(
+ self._interleave(
+ [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count,
+ cycle_length=2,
+ block_length=3)):
+ self.write_coordination_events[expected_element].set()
+ if done_first_event: # First event starts the worker threads.
+ self.read_coordination_events[expected_element].acquire()
+ actual_element = sess.run(self.next_element)
+ if not done_first_event:
+ self.read_coordination_events[expected_element].acquire()
+ done_first_event = True
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def _testEarlyExit(self, sloppy=False):
+ # Exiting without consuming all input should not block
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 3,
+ self.block_length: 2,
+ self.sloppy: sloppy,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+ for i in range(4, 7):
+ self.write_coordination_events[i].set()
+ elem = sess.run(self.next_element) # Start all workers
+ # Allow the one successful worker to progress beyond the py_func again.
+ elem = int(math.sqrt(elem))
+ self.write_coordination_events[elem].set()
+ self.read_coordination_events[elem].acquire()
+ # Allow the prefetch to succeed
+ for i in range(4, 7):
+ self.read_coordination_events[i].acquire()
+ self.write_coordination_events[i].set()
+
+ def testEarlyExit(self):
+ self._testEarlyExit()
+
+ def testEarlyExitSloppy(self):
+ self._testEarlyExit(sloppy=True)
+
+ def _testTooManyReaders(self, sloppy=False):
+
+ def interleave_fn(x):
+ dataset = dataset_ops.Dataset.from_tensors(x)
+ dataset = dataset.repeat(math_ops.cast(x, dtype=dtypes.int64))
+ return dataset
+
+ dataset = dataset_ops.Dataset.from_tensor_slices([4, 5, 6])
+ dataset = dataset.repeat(self.repeat_count)
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ interleave_fn, cycle_length=16, block_length=2, sloppy=sloppy))
+ iterator = dataset.make_one_shot_iterator()
+
+ with self.cached_session() as sess:
+ output_values = []
+ for _ in range(30):
+ output_values.append(sess.run(iterator.get_next()))
+
+ expected_values = self._interleave(
+ [[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 1, 2)
+ self.assertItemsEqual(output_values, expected_values)
+
+ def testTooManyReaders(self):
+ self._testTooManyReaders()
+
+ def testTooManyReadersSloppy(self):
+ self._testTooManyReaders(sloppy=True)
+
+ def testSparse(self):
+ def _map_fn(i):
+ return sparse_tensor.SparseTensor(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _interleave_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ dataset = dataset_ops.Dataset.range(10).map(_map_fn)
+ iterator = dataset.apply(
+ interleave_ops.parallel_interleave(
+ _interleave_fn, cycle_length=1)).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for i in range(10):
+ for j in range(2):
+ expected = [i, 0] if j % 2 == 0 else [0, -i]
+ self.assertAllEqual(expected, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def testErrorsInOutputFn(self):
+ with self.cached_session() as sess:
+ self._clear_coordination_events()
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: False,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+
+ except_on_element_indices = set([3])
+
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5] * 5, [6] * 6] * self.repeat_count, 2,
+ 1)):
+ if i in except_on_element_indices:
+ self.error = ValueError()
+ self.write_coordination_events[expected_element].set()
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(self.next_element)
+ else:
+ self.write_coordination_events[expected_element].set()
+ actual_element = sess.run(self.next_element)
+ self.assertEqual(expected_element * expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testErrorsInInputFn(self):
+
+ def map_py_fn(x):
+ if x == 5:
+ raise ValueError()
+ return x
+
+ def map_fn(x):
+ return script_ops.py_func(map_py_fn, [x], x.dtype)
+
+ def interleave_fn(x):
+ dataset = dataset_ops.Dataset.from_tensors(x)
+ dataset = dataset.repeat(x)
+ return dataset
+
+ self.dataset = (
+ dataset_ops.Dataset.from_tensor_slices(self.input_values).map(map_fn)
+ .repeat(self.repeat_count).apply(
+ interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
+ self.block_length, self.sloppy,
+ self.buffer_output_elements,
+ self.prefetch_input_elements)))
+
+ self.iterator = self.dataset.make_initializable_iterator()
+ self.init_op = self.iterator.initializer
+ self.next_element = self.iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: False,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
+ if expected_element == 5:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(self.next_element)
+ else:
+ actual_element = sess.run(self.next_element)
+ self.assertEqual(expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testErrorsInInterleaveFn(self):
+
+ def map_py_fn(x):
+ if x == 5:
+ raise ValueError()
+ return x
+
+ def interleave_fn(x):
+ dataset = dataset_ops.Dataset.from_tensors(x)
+ y = script_ops.py_func(map_py_fn, [x], x.dtype)
+ dataset = dataset.repeat(y)
+ return dataset
+
+ self.dataset = (
+ dataset_ops.Dataset.from_tensor_slices(self.input_values)
+ .repeat(self.repeat_count).apply(
+ interleave_ops.parallel_interleave(interleave_fn, self.cycle_length,
+ self.block_length, self.sloppy,
+ self.buffer_output_elements,
+ self.prefetch_input_elements)))
+
+ self.iterator = self.dataset.make_initializable_iterator()
+ self.init_op = self.iterator.initializer
+ self.next_element = self.iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(
+ self.init_op,
+ feed_dict={
+ self.input_values: [4, 5, 6],
+ self.cycle_length: 2,
+ self.block_length: 1,
+ self.sloppy: False,
+ self.buffer_output_elements: 1,
+ self.prefetch_input_elements: 0,
+ })
+ for i, expected_element in enumerate(
+ self._interleave([[4] * 4, [5], [6] * 6] * self.repeat_count, 2, 1)):
+ if expected_element == 5:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(self.next_element)
+ else:
+ actual_element = sess.run(self.next_element)
+ self.assertEqual(expected_element, actual_element,
+ "At index %s: %s expected, got: %s" %
+ (i, expected_element, actual_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(self.next_element)
+
+ def testShutdownRace(self):
+ dataset = dataset_ops.Dataset.range(20)
+ map_fn = lambda x: dataset_ops.Dataset.range(20 * x, 20 * (x + 1))
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ map_fn,
+ cycle_length=3,
+ sloppy=False,
+ buffer_output_elements=1,
+ prefetch_input_elements=0))
+ dataset = dataset.batch(32)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ results = []
+ with self.cached_session() as sess:
+ for _ in range(2):
+ elements = []
+ sess.run(iterator.initializer)
+ try:
+ while True:
+ elements.extend(sess.run(next_element))
+ except errors.OutOfRangeError:
+ pass
+ results.append(elements)
+
+ self.assertAllEqual(results[0], results[1])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py
new file mode 100644
index 0000000000..723e709ae8
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/parse_example_dataset_test.py
@@ -0,0 +1,850 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.parse_example_dataset()."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+
+import numpy as np
+
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.experimental.ops import parsing_ops as contrib_parsing_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.platform import test
+from tensorflow.python.platform import tf_logging
+
+# Helpers for creating Example objects
+example = example_pb2.Example
+feature = feature_pb2.Feature
+features = lambda d: feature_pb2.Features(feature=d)
+bytes_feature = lambda v: feature(bytes_list=feature_pb2.BytesList(value=v))
+int64_feature = lambda v: feature(int64_list=feature_pb2.Int64List(value=v))
+float_feature = lambda v: feature(float_list=feature_pb2.FloatList(value=v))
+# Helpers for creating SequenceExample objects
+feature_list = lambda l: feature_pb2.FeatureList(feature=l)
+feature_lists = lambda d: feature_pb2.FeatureLists(feature_list=d)
+sequence_example = example_pb2.SequenceExample
+
+
+def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
+ flat_output):
+ tester.assertEqual(set(dict_tensors.keys()), set(expected_tensors.keys()))
+
+ i = 0 # Index into the flattened output of session.run()
+ for k, v in sorted(dict_tensors.items()):
+ # TODO(shivaniagrawal): flat_output is same as v.
+ expected_v = expected_tensors[k]
+ tf_logging.info("Comparing key: %s", k)
+ print("i", i, "flat_output", flat_output[i], "expected_v", expected_v)
+ if sparse_tensor.is_sparse(v):
+ # Three outputs for SparseTensor : indices, values, shape.
+ tester.assertEqual([k, len(expected_v)], [k, 3])
+ print("i", i, "flat_output", flat_output[i].indices, "expected_v",
+ expected_v[0])
+ tester.assertAllEqual(expected_v[0], flat_output[i].indices)
+ tester.assertAllEqual(expected_v[1], flat_output[i].values)
+ tester.assertAllEqual(expected_v[2], flat_output[i].dense_shape)
+ else:
+ # One output for standard Tensor.
+ tester.assertAllEqual(expected_v, flat_output[i])
+ i += 1
+
+
+class ParseExampleDatasetTest(test_base.DatasetTestBase):
+
+ def _test(self,
+ input_tensor,
+ feature_val,
+ expected_values=None,
+ expected_err=None):
+
+ with self.cached_session() as sess:
+ if expected_err:
+ with self.assertRaisesWithPredicateMatch(expected_err[0],
+ expected_err[1]):
+ dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ contrib_parsing_ops.parse_example_dataset(feature_val))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ sess.run(get_next)
+ return
+ else:
+ # Returns dict w/ Tensors and SparseTensors.
+ # Check values.
+ dataset = dataset_ops.Dataset.from_tensors(input_tensor).apply(
+ contrib_parsing_ops.parse_example_dataset(feature_val))
+ get_next = dataset.make_one_shot_iterator().get_next()
+ result = sess.run(get_next)
+ flattened = nest.flatten(result)
+ print("result", result, "expected_values", expected_values)
+ _compare_output_to_expected(self, result, expected_values, flattened)
+
+ # Check shapes; if serialized is a Tensor we need its size to
+ # properly check.
+ batch_size = (
+ input_tensor.eval().size if isinstance(input_tensor, ops.Tensor) else
+ np.asarray(input_tensor).size)
+ for k, f in feature_val.items():
+ print("output_shapes as list ",
+ tuple(dataset.output_shapes[k].as_list()))
+ if isinstance(f, parsing_ops.FixedLenFeature) and f.shape is not None:
+ self.assertEqual(dataset.output_shapes[k].as_list()[0], batch_size)
+ elif isinstance(f, parsing_ops.VarLenFeature):
+ self.assertEqual(dataset.output_shapes[k].as_list()[1], None)
+
+ def testEmptySerializedWithAllDefaults(self):
+ sparse_name = "st_a"
+ a_name = "a"
+ b_name = "b"
+ c_name = "c:has_a_tricky_name"
+ a_default = [0, 42, 0]
+ b_default = np.random.rand(3, 3).astype(bytes)
+ c_default = np.random.rand(2).astype(np.float32)
+
+ expected_st_a = ( # indices, values, shape
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array(
+ [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+
+ expected_output = {
+ sparse_name: expected_st_a,
+ a_name: np.array(2 * [[a_default]]),
+ b_name: np.array(2 * [b_default]),
+ c_name: np.array(2 * [c_default]),
+ }
+
+ self._test(
+ ops.convert_to_tensor(["", ""]), {
+ sparse_name:
+ parsing_ops.VarLenFeature(dtypes.int64),
+ a_name:
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=a_default),
+ b_name:
+ parsing_ops.FixedLenFeature(
+ (3, 3), dtypes.string, default_value=b_default),
+ c_name:
+ parsing_ops.FixedLenFeature(
+ (2,), dtypes.float32, default_value=c_default),
+ },
+ expected_values=expected_output)
+
+ def testEmptySerializedWithoutDefaultsShouldFail(self):
+ input_features = {
+ "st_a":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=[0, 42, 0]),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (3, 3),
+ dtypes.string,
+ default_value=np.random.rand(3, 3).astype(bytes)),
+ # Feature "c" is missing a default, this gap will cause failure.
+ "c":
+ parsing_ops.FixedLenFeature(
+ (2,), dtype=dtypes.float32),
+ }
+
+ # Edge case where the key is there but the feature value is empty
+ original = example(features=features({"c": feature()}))
+ self._test(
+ [original.SerializeToString()],
+ input_features,
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Feature: c \\(data type: float\\) is required"))
+
+ # Standard case of missing key and value.
+ self._test(
+ ["", ""],
+ input_features,
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Feature: c \\(data type: float\\) is required"))
+
+ def testDenseNotMatchingShapeShouldFail(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1, 3]),
+ })), example(features=features({
+ "a": float_feature([-1, -1]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"a": parsing_ops.FixedLenFeature((1, 3), dtypes.float32)},
+ expected_err=(errors_impl.InvalidArgumentError,
+ "Key: a, Index: 1. Number of float values"))
+
+ def testDenseDefaultNoShapeShouldFail(self):
+ original = [example(features=features({"a": float_feature([1, 1, 3]),})),]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"a": parsing_ops.FixedLenFeature(None, dtypes.float32)},
+ expected_err=(ValueError, "Missing shape for feature a"))
+
+ def testSerializedContainingSparse(self):
+ original = [
+ example(features=features({
+ "st_c": float_feature([3, 4])
+ })),
+ example(features=features({
+ "st_c": float_feature([]), # empty float list
+ })),
+ example(features=features({
+ "st_d": feature(), # feature with nothing in it
+ })),
+ example(features=features({
+ "st_c": float_feature([1, 2, -1]),
+ "st_d": bytes_feature([b"hi"])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_st_c = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 1], [3, 0], [3, 1], [3, 2]], dtype=np.int64), np.array(
+ [3.0, 4.0, 1.0, 2.0, -1.0], dtype=np.float32), np.array(
+ [4, 3], dtype=np.int64)) # batch == 2, max_elems = 3
+
+ expected_st_d = ( # indices, values, shape
+ np.array(
+ [[3, 0]], dtype=np.int64), np.array(
+ ["hi"], dtype=bytes), np.array(
+ [4, 1], dtype=np.int64)) # batch == 2, max_elems = 1
+
+ expected_output = {
+ "st_c": expected_st_c,
+ "st_d": expected_st_d,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "st_c": parsing_ops.VarLenFeature(dtypes.float32),
+ "st_d": parsing_ops.VarLenFeature(dtypes.string)
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseFeature(self):
+ original = [
+ example(features=features({
+ "val": float_feature([3, 4]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(features=features({
+ "val": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ example(features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx":
+ int64_feature([0, 9, 3]) # unsorted
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10], [3, 0], [3, 3], [3, 9]], dtype=np.int64),
+ np.array(
+ [3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32), np.array(
+ [4, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ expected_output = {"sp": expected_sp,}
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {"sp": parsing_ops.SparseFeature(["idx"], "val", dtypes.float32, [13])},
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseFeatureReuse(self):
+ original = [
+ example(features=features({
+ "val1": float_feature([3, 4]),
+ "val2": float_feature([5, 6]),
+ "idx": int64_feature([5, 10])
+ })),
+ example(features=features({
+ "val1": float_feature([]), # empty float list
+ "idx": int64_feature([])
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp1 = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10]], dtype=np.int64), np.array(
+ [3.0, 4.0], dtype=np.float32), np.array(
+ [2, 13], dtype=np.int64)) # batch == 2, max_elems = 13
+
+ expected_sp2 = ( # indices, values, shape
+ np.array(
+ [[0, 5], [0, 10]], dtype=np.int64), np.array(
+ [5.0, 6.0], dtype=np.float32), np.array(
+ [2, 7], dtype=np.int64)) # batch == 2, max_elems = 13
+
+ expected_output = {
+ "sp1": expected_sp1,
+ "sp2": expected_sp2,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "sp1":
+ parsing_ops.SparseFeature("idx", "val1", dtypes.float32, 13),
+ "sp2":
+ parsing_ops.SparseFeature(
+ "idx", "val2", dtypes.float32, size=7, already_sorted=True)
+ },
+ expected_values=expected_output)
+
+ def testSerializedContaining3DSparseFeature(self):
+ original = [
+ example(features=features({
+ "val": float_feature([3, 4]),
+ "idx0": int64_feature([5, 10]),
+ "idx1": int64_feature([0, 2]),
+ })),
+ example(features=features({
+ "val": float_feature([]), # empty float list
+ "idx0": int64_feature([]),
+ "idx1": int64_feature([]),
+ })),
+ example(features=features({
+ "val": feature(), # feature with nothing in it
+ # missing idx feature
+ })),
+ example(features=features({
+ "val": float_feature([1, 2, -1]),
+ "idx0": int64_feature([0, 9, 3]), # unsorted
+ "idx1": int64_feature([1, 0, 2]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_sp = (
+ # indices
+ np.array(
+ [[0, 5, 0], [0, 10, 2], [3, 0, 1], [3, 3, 2], [3, 9, 0]],
+ dtype=np.int64),
+ # values
+ np.array([3.0, 4.0, 1.0, -1.0, 2.0], dtype=np.float32),
+ # shape batch == 4, max_elems = 13
+ np.array([4, 13, 3], dtype=np.int64))
+
+ expected_output = {"sp": expected_sp,}
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "sp":
+ parsing_ops.SparseFeature(["idx0", "idx1"], "val",
+ dtypes.float32, [13, 3])
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDense(self):
+ aname = "a"
+ bname = "b*has+a:tricky_name"
+ original = [
+ example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ })), example(features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b""]),
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ aname:
+ np.array(
+ [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ bname:
+ np.array(
+ ["b0_str", ""], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ }
+
+ # No defaults, values required
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ },
+ expected_values=expected_output)
+
+ # This test is identical as the previous one except
+ # for the creation of 'serialized'.
+ def testSerializedContainingDenseWithConcat(self):
+ aname = "a"
+ bname = "b*has+a:tricky_name"
+ # TODO(lew): Feature appearing twice should be an error in future.
+ original = [
+ (example(features=features({
+ aname: float_feature([10, 10]),
+ })), example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str"]),
+ }))),
+ (
+ example(features=features({
+ bname: bytes_feature([b"b100"]),
+ })),
+ example(features=features({
+ aname: float_feature([-1, -1]),
+ bname: bytes_feature([b"b1"]),
+ })),),
+ ]
+
+ serialized = [
+ m.SerializeToString() + n.SerializeToString() for (m, n) in original
+ ]
+
+ expected_output = {
+ aname:
+ np.array(
+ [[1, 1], [-1, -1]], dtype=np.float32).reshape(2, 1, 2, 1),
+ bname:
+ np.array(
+ ["b0_str", "b1"], dtype=bytes).reshape(2, 1, 1, 1, 1),
+ }
+
+ # No defaults, values required
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((1, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenFeature((1, 1, 1, 1), dtype=dtypes.string),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDenseScalar(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1]),
+ })), example(features=features({}))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "a":
+ np.array(
+ [[1], [-1]], dtype=np.float32) # 2x1 (column vector)
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1,), dtype=dtypes.float32, default_value=-1),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingDenseWithDefaults(self):
+ original = [
+ example(features=features({
+ "a": float_feature([1, 1]),
+ })),
+ example(features=features({
+ "b": bytes_feature([b"b1"]),
+ })),
+ example(features=features({
+ "b": feature()
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "a":
+ np.array(
+ [[1, 1], [3, -3], [3, -3]], dtype=np.float32).reshape(3, 1, 2,
+ 1),
+ "b":
+ np.array(
+ ["tmp_str", "b1", "tmp_str"], dtype=bytes).reshape(3, 1, 1, 1,
+ 1),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 2, 1), dtype=dtypes.float32, default_value=[3.0, -3.0]),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (1, 1, 1, 1), dtype=dtypes.string, default_value="tmp_str"),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseAndSparseFeatureAndDenseWithNoDefault(self):
+ expected_st_a = ( # indices, values, shape
+ np.empty(
+ (0, 2), dtype=np.int64), # indices
+ np.empty(
+ (0,), dtype=np.int64), # sp_a is DT_INT64
+ np.array(
+ [2, 0], dtype=np.int64)) # batch == 2, max_elems = 0
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 3], [1, 7]], dtype=np.int64), np.array(
+ ["a", "b", "c"], dtype="|S"), np.array(
+ [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ original = [
+ example(features=features({
+ "c": float_feature([3, 4]),
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })), example(features=features({
+ "c": float_feature([1, 2]),
+ "val": bytes_feature([b"c"]),
+ "idx": int64_feature([7])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ a_default = [1, 2, 3]
+ b_default = np.random.rand(3, 3).astype(bytes)
+ expected_output = {
+ "st_a": expected_st_a,
+ "sp": expected_sp,
+ "a": np.array(2 * [[a_default]]),
+ "b": np.array(2 * [b_default]),
+ "c": np.array(
+ [[3, 4], [1, 2]], dtype=np.float32),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized),
+ {
+ "st_a":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "sp":
+ parsing_ops.SparseFeature("idx", "val", dtypes.string, 13),
+ "a":
+ parsing_ops.FixedLenFeature(
+ (1, 3), dtypes.int64, default_value=a_default),
+ "b":
+ parsing_ops.FixedLenFeature(
+ (3, 3), dtypes.string, default_value=b_default),
+ # Feature "c" must be provided, since it has no default_value.
+ "c":
+ parsing_ops.FixedLenFeature((2,), dtypes.float32),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingSparseAndSparseFeatureWithReuse(self):
+ expected_idx = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 1], [1, 0], [1, 1]], dtype=np.int64),
+ np.array([0, 3, 7, 1]), np.array(
+ [2, 2], dtype=np.int64)) # batch == 4, max_elems = 2
+
+ expected_sp = ( # indices, values, shape
+ np.array(
+ [[0, 0], [0, 3], [1, 1], [1, 7]], dtype=np.int64), np.array(
+ ["a", "b", "d", "c"], dtype="|S"), np.array(
+ [2, 13], dtype=np.int64)) # batch == 4, max_elems = 13
+
+ original = [
+ example(features=features({
+ "val": bytes_feature([b"a", b"b"]),
+ "idx": int64_feature([0, 3])
+ })), example(features=features({
+ "val": bytes_feature([b"c", b"d"]),
+ "idx": int64_feature([7, 1])
+ }))
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ "idx": expected_idx,
+ "sp": expected_sp,
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ "idx":
+ parsing_ops.VarLenFeature(dtypes.int64),
+ "sp":
+ parsing_ops.SparseFeature(["idx"], "val", dtypes.string, [13]),
+ },
+ expected_values=expected_output)
+
+ def _testSerializedContainingVarLenDenseLargerBatch(self, batch_size):
+ # During parsing, data read from the serialized proto is stored in buffers.
+ # For small batch sizes, a buffer will contain one minibatch entry.
+ # For larger batch sizes, a buffer may contain several minibatch
+ # entries. This test identified a bug where the code that copied
+ # data out of the buffers and into the output tensors assumed each
+ # buffer only contained one minibatch entry. The bug has since been fixed.
+ truth_int = [i for i in range(batch_size)]
+ truth_str = [[("foo%d" % i).encode(), ("bar%d" % i).encode()]
+ for i in range(batch_size)]
+
+ expected_str = copy.deepcopy(truth_str)
+
+ # Delete some intermediate entries
+ for i in range(batch_size):
+ col = 1
+ if np.random.rand() < 0.25:
+ # w.p. 25%, drop out the second entry
+ expected_str[i][col] = b"default"
+ col -= 1
+ truth_str[i].pop()
+ if np.random.rand() < 0.25:
+ # w.p. 25%, drop out the second entry (possibly again)
+ expected_str[i][col] = b"default"
+ truth_str[i].pop()
+
+ expected_output = {
+ # Batch size batch_size, 1 time step.
+ "a": np.array(truth_int, dtype=np.int64).reshape(batch_size, 1),
+ # Batch size batch_size, 2 time steps.
+ "b": np.array(expected_str, dtype="|S").reshape(batch_size, 2),
+ }
+
+ original = [
+ example(features=features(
+ {"a": int64_feature([truth_int[i]]),
+ "b": bytes_feature(truth_str[i])}))
+ for i in range(batch_size)
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ self._test(
+ ops.convert_to_tensor(serialized, dtype=dtypes.string), {
+ "a":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=(),
+ dtype=dtypes.int64,
+ allow_missing=True,
+ default_value=-1),
+ "b":
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[],
+ dtype=dtypes.string,
+ allow_missing=True,
+ default_value="default"),
+ },
+ expected_values=expected_output)
+
+ def testSerializedContainingVarLenDenseLargerBatch(self):
+ np.random.seed(3456)
+ for batch_size in (1, 10, 20, 100, 256):
+ self._testSerializedContainingVarLenDenseLargerBatch(batch_size)
+
+ def testSerializedContainingVarLenDense(self):
+ aname = "a"
+ bname = "b"
+ cname = "c"
+ dname = "d"
+ original = [
+ example(features=features({
+ cname: int64_feature([2]),
+ })),
+ example(features=features({
+ aname: float_feature([1, 1]),
+ bname: bytes_feature([b"b0_str", b"b1_str"]),
+ })),
+ example(features=features({
+ aname: float_feature([-1, -1, 2, 2]),
+ bname: bytes_feature([b"b1"]),
+ })),
+ example(features=features({
+ aname: float_feature([]),
+ cname: int64_feature([3]),
+ })),
+ ]
+
+ serialized = [m.SerializeToString() for m in original]
+
+ expected_output = {
+ aname:
+ np.array(
+ [
+ [0, 0, 0, 0],
+ [1, 1, 0, 0],
+ [-1, -1, 2, 2],
+ [0, 0, 0, 0],
+ ],
+ dtype=np.float32).reshape(4, 2, 2, 1),
+ bname:
+ np.array(
+ [["", ""], ["b0_str", "b1_str"], ["b1", ""], ["", ""]],
+ dtype=bytes).reshape(4, 2, 1, 1, 1),
+ cname:
+ np.array([2, 0, 0, 3], dtype=np.int64).reshape(4, 1),
+ dname:
+ np.empty(shape=(4, 0), dtype=bytes),
+ }
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=True),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ },
+ expected_values=expected_output)
+
+ # Test with padding values.
+ expected_output_custom_padding = dict(expected_output)
+ expected_output_custom_padding[aname] = np.array(
+ [
+ [-2, -2, -2, -2],
+ [1, 1, -2, -2],
+ [-1, -1, 2, 2],
+ [-2, -2, -2, -2],
+ ],
+ dtype=np.float32).reshape(4, 2, 2, 1)
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
+ default_value=-2.0),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=True),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ }, expected_output_custom_padding)
+
+ # Change number of required values so the inputs are not a
+ # multiple of this size.
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(
+ errors_impl.OpError, "Key: b, Index: 2. "
+ "Number of bytes values is not a multiple of stride length."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1),
+ dtype=dtypes.float32,
+ allow_missing=True,
+ default_value=[]),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "Cannot reshape a tensor with 0 elements to shape"))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenFeature((None, 2, 1), dtype=dtypes.float32),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1, 1), dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "First dimension of shape for feature a unknown. "
+ "Consider using FixedLenSequenceFeature."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ cname:
+ parsing_ops.FixedLenFeature(
+ (1, None), dtype=dtypes.int64, default_value=[[1]]),
+ },
+ expected_err=(ValueError,
+ "All dimensions of shape for feature c need to be known "
+ r"but received \(1, None\)."))
+
+ self._test(
+ ops.convert_to_tensor(serialized), {
+ aname:
+ parsing_ops.FixedLenSequenceFeature(
+ (2, 1), dtype=dtypes.float32, allow_missing=True),
+ bname:
+ parsing_ops.FixedLenSequenceFeature(
+ (1, 1, 1), dtype=dtypes.string, allow_missing=True),
+ cname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.int64, allow_missing=False),
+ dname:
+ parsing_ops.FixedLenSequenceFeature(
+ shape=[], dtype=dtypes.string, allow_missing=True),
+ },
+ expected_err=(ValueError,
+ "Unsupported: FixedLenSequenceFeature requires "
+ "allow_missing to be True."))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
new file mode 100644
index 0000000000..f73725366c
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py
@@ -0,0 +1,234 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.prefetch_to_device()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class PrefetchToDeviceTest(test_base.DatasetTestBase):
+
+ def testPrefetchToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToSameDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device(
+ "/job:localhost/replica:0/task:0/device:CPU:0"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchDictToDevice(self):
+ host_dataset = dataset_ops.Dataset.range(10).map(lambda x: {"a": x})
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element["a"].dtype)
+ self.assertEqual([], next_element["a"].shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ self.assertEqual({"a": i}, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchSparseTensorsToDevice(self):
+ def make_tensor(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0]], values=(i*[1]), dense_shape=[2, 2])
+ host_dataset = dataset_ops.Dataset.range(10).map(make_tensor)
+
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_one_shot_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ for i in range(10):
+ actual = sess.run(next_element)
+ self.assertAllEqual([i], actual.values)
+ self.assertAllEqual([[0, 0]], actual.indices)
+ self.assertAllEqual([2, 2], actual.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToDeviceGpu(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/gpu:0"))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToDeviceWithReInit(self):
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/cpu:1"))
+
+ # NOTE(mrry): This device block creates the "host" dataset and iterator on
+ # /cpu:0, and ensures that the prefetching is across devices. In typical use
+ # this would not be necessary, because the GPU device would not support any
+ # of the dataset-related ops.
+ with ops.device("/cpu:0"):
+ iterator = device_dataset.make_initializable_iterator()
+
+ self.assertEqual(host_dataset.output_types, device_dataset.output_types)
+ self.assertEqual(host_dataset.output_types, iterator.output_types)
+ self.assertEqual(host_dataset.output_shapes, device_dataset.output_shapes)
+ self.assertEqual(host_dataset.output_shapes, iterator.output_shapes)
+ self.assertEqual(host_dataset.output_classes, device_dataset.output_classes)
+ self.assertEqual(host_dataset.output_classes, iterator.output_classes)
+
+ next_element = iterator.get_next()
+ self.assertEqual(dtypes.int64, next_element.dtype)
+ self.assertEqual([], next_element.shape)
+
+ worker_config = config_pb2.ConfigProto(device_count={"CPU": 2})
+ with self.test_session(config=worker_config) as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToDeviceGpuWithReInit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops.prefetch_to_device("/gpu:0"))
+
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(5):
+ self.assertEqual(i, sess.run(next_element))
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
new file mode 100644
index 0000000000..77df8310d4
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
@@ -0,0 +1,353 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing reader datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import zlib
+
+from tensorflow.core.example import example_pb2
+from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.lib.io import python_io
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.util import compat
+
+
+class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
+ """Base class for setting up and testing FixedLengthRecordDataset."""
+
+ def setUp(self):
+ super(FixedLengthRecordDatasetTestBase, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+ self._header_bytes = 5
+ self._record_bytes = 3
+ self._footer_bytes = 2
+
+ def _record(self, f, r):
+ return compat.as_bytes(str(f * 2 + r) * self._record_bytes)
+
+ def _createFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "fixed_length_record.%d.txt" % i)
+ filenames.append(fn)
+ with open(fn, "wb") as f:
+ f.write(b"H" * self._header_bytes)
+ for j in range(self._num_records):
+ f.write(self._record(i, j))
+ f.write(b"F" * self._footer_bytes)
+ return filenames
+
+
+class MakeBatchedFeaturesDatasetTestBase(test_base.DatasetTestBase):
+ """Base class for setting up and testing `make_batched_features_dataset`."""
+
+ def setUp(self):
+ super(MakeBatchedFeaturesDatasetTestBase, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+ self.test_filenames = self._createFiles()
+
+ def make_batch_feature(self,
+ filenames,
+ num_epochs,
+ batch_size,
+ label_key=None,
+ reader_num_threads=1,
+ parser_num_threads=1,
+ shuffle=False,
+ shuffle_seed=None,
+ drop_final_batch=False):
+ self.filenames = filenames
+ self.num_epochs = num_epochs
+ self.batch_size = batch_size
+
+ return readers.make_batched_features_dataset(
+ file_pattern=self.filenames,
+ batch_size=self.batch_size,
+ features={
+ "file": parsing_ops.FixedLenFeature([], dtypes.int64),
+ "record": parsing_ops.FixedLenFeature([], dtypes.int64),
+ "keywords": parsing_ops.VarLenFeature(dtypes.string),
+ "label": parsing_ops.FixedLenFeature([], dtypes.string),
+ },
+ label_key=label_key,
+ reader=core_readers.TFRecordDataset,
+ num_epochs=self.num_epochs,
+ shuffle=shuffle,
+ shuffle_seed=shuffle_seed,
+ reader_num_threads=reader_num_threads,
+ parser_num_threads=parser_num_threads,
+ drop_final_batch=drop_final_batch)
+
+ def _record(self, f, r, l):
+ example = example_pb2.Example(
+ features=feature_pb2.Features(
+ feature={
+ "file":
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=[f])),
+ "record":
+ feature_pb2.Feature(
+ int64_list=feature_pb2.Int64List(value=[r])),
+ "keywords":
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=self._get_keywords(f, r))),
+ "label":
+ feature_pb2.Feature(
+ bytes_list=feature_pb2.BytesList(
+ value=[compat.as_bytes(l)]))
+ }))
+ return example.SerializeToString()
+
+ def _get_keywords(self, f, r):
+ num_keywords = 1 + (f + r) % 2
+ keywords = []
+ for index in range(num_keywords):
+ keywords.append(compat.as_bytes("keyword%d" % index))
+ return keywords
+
+ def _sum_keywords(self, num_files):
+ sum_keywords = 0
+ for i in range(num_files):
+ for j in range(self._num_records):
+ sum_keywords += 1 + (i + j) % 2
+ return sum_keywords
+
+ def _createFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
+ filenames.append(fn)
+ writer = python_io.TFRecordWriter(fn)
+ for j in range(self._num_records):
+ writer.write(self._record(i, j, "fake-label"))
+ writer.close()
+ return filenames
+
+ def _run_actual_batch(self, outputs, sess, label_key_provided=False):
+ if label_key_provided:
+ # outputs would be a tuple of (feature dict, label)
+ label_op = outputs[1]
+ features_op = outputs[0]
+ else:
+ features_op = outputs
+ label_op = features_op["label"]
+ file_op = features_op["file"]
+ keywords_indices_op = features_op["keywords"].indices
+ keywords_values_op = features_op["keywords"].values
+ keywords_dense_shape_op = features_op["keywords"].dense_shape
+ record_op = features_op["record"]
+ return sess.run([
+ file_op, keywords_indices_op, keywords_values_op,
+ keywords_dense_shape_op, record_op, label_op
+ ])
+
+ def _next_actual_batch(self, sess, label_key_provided=False):
+ return self._run_actual_batch(self.outputs, sess, label_key_provided)
+
+ def _interleave(self, iterators, cycle_length):
+ pending_iterators = iterators
+ open_iterators = []
+ num_open = 0
+ for i in range(cycle_length):
+ if pending_iterators:
+ open_iterators.append(pending_iterators.pop(0))
+ num_open += 1
+
+ while num_open:
+ for i in range(min(cycle_length, len(open_iterators))):
+ if open_iterators[i] is None:
+ continue
+ try:
+ yield next(open_iterators[i])
+ except StopIteration:
+ if pending_iterators:
+ open_iterators[i] = pending_iterators.pop(0)
+ else:
+ open_iterators[i] = None
+ num_open -= 1
+
+ def _next_expected_batch(self,
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length=1):
+
+ def _next_record(file_indices):
+ for j in file_indices:
+ for i in range(self._num_records):
+ yield j, i, compat.as_bytes("fake-label")
+
+ def _next_record_interleaved(file_indices, cycle_length):
+ return self._interleave([_next_record([i]) for i in file_indices],
+ cycle_length)
+
+ file_batch = []
+ keywords_batch_indices = []
+ keywords_batch_values = []
+ keywords_batch_max_len = 0
+ record_batch = []
+ batch_index = 0
+ label_batch = []
+ for _ in range(num_epochs):
+ if cycle_length == 1:
+ next_records = _next_record(file_indices)
+ else:
+ next_records = _next_record_interleaved(file_indices, cycle_length)
+ for record in next_records:
+ f = record[0]
+ r = record[1]
+ label_batch.append(record[2])
+ file_batch.append(f)
+ record_batch.append(r)
+ keywords = self._get_keywords(f, r)
+ keywords_batch_values.extend(keywords)
+ keywords_batch_indices.extend(
+ [[batch_index, i] for i in range(len(keywords))])
+ batch_index += 1
+ keywords_batch_max_len = max(keywords_batch_max_len, len(keywords))
+ if len(file_batch) == batch_size:
+ yield [
+ file_batch, keywords_batch_indices, keywords_batch_values,
+ [batch_size, keywords_batch_max_len], record_batch, label_batch
+ ]
+ file_batch = []
+ keywords_batch_indices = []
+ keywords_batch_values = []
+ keywords_batch_max_len = 0
+ record_batch = []
+ batch_index = 0
+ label_batch = []
+ if file_batch:
+ yield [
+ file_batch, keywords_batch_indices, keywords_batch_values,
+ [len(file_batch), keywords_batch_max_len], record_batch, label_batch
+ ]
+
+ def verify_records(self,
+ sess,
+ batch_size,
+ file_index=None,
+ num_epochs=1,
+ label_key_provided=False,
+ interleave_cycle_length=1):
+ if file_index is not None:
+ file_indices = [file_index]
+ else:
+ file_indices = range(self._num_files)
+
+ for expected_batch in self._next_expected_batch(
+ file_indices,
+ batch_size,
+ num_epochs,
+ cycle_length=interleave_cycle_length):
+ actual_batch = self._next_actual_batch(
+ sess, label_key_provided=label_key_provided)
+ for i in range(len(expected_batch)):
+ self.assertAllEqual(expected_batch[i], actual_batch[i])
+
+
+class TextLineDatasetTestBase(test_base.DatasetTestBase):
+ """Base class for setting up and testing TextLineDataset."""
+
+ def _lineText(self, f, l):
+ return compat.as_bytes("%d: %d" % (f, l))
+
+ def _createFiles(self,
+ num_files,
+ num_lines,
+ crlf=False,
+ compression_type=None):
+ filenames = []
+ for i in range(num_files):
+ fn = os.path.join(self.get_temp_dir(), "text_line.%d.txt" % i)
+ filenames.append(fn)
+ contents = []
+ for j in range(num_lines):
+ contents.append(self._lineText(i, j))
+ # Always include a newline after the record unless it is
+ # at the end of the file, in which case we include it
+ if j + 1 != num_lines or i == 0:
+ contents.append(b"\r\n" if crlf else b"\n")
+ contents = b"".join(contents)
+
+ if not compression_type:
+ with open(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "GZIP":
+ with gzip.GzipFile(fn, "wb") as f:
+ f.write(contents)
+ elif compression_type == "ZLIB":
+ contents = zlib.compress(contents)
+ with open(fn, "wb") as f:
+ f.write(contents)
+ else:
+ raise ValueError("Unsupported compression_type", compression_type)
+
+ return filenames
+
+
+class TFRecordDatasetTestBase(test_base.DatasetTestBase):
+ """Base class for setting up and testing TFRecordDataset."""
+
+ def setUp(self):
+ super(TFRecordDatasetTestBase, self).setUp()
+ self._num_files = 2
+ self._num_records = 7
+
+ self.test_filenames = self._createFiles()
+
+ self.filenames = array_ops.placeholder(dtypes.string, shape=[None])
+ self.num_epochs = array_ops.placeholder_with_default(
+ constant_op.constant(1, dtypes.int64), shape=[])
+ self.compression_type = array_ops.placeholder_with_default("", shape=[])
+ self.batch_size = array_ops.placeholder(dtypes.int64, shape=[])
+
+ repeat_dataset = core_readers.TFRecordDataset(
+ self.filenames, self.compression_type).repeat(self.num_epochs)
+ batch_dataset = repeat_dataset.batch(self.batch_size)
+
+ iterator = iterator_ops.Iterator.from_structure(batch_dataset.output_types)
+ self.init_op = iterator.make_initializer(repeat_dataset)
+ self.init_batch_op = iterator.make_initializer(batch_dataset)
+ self.get_next = iterator.get_next()
+
+ def _record(self, f, r):
+ return compat.as_bytes("Record %d of file %d" % (r, f))
+
+ def _createFiles(self):
+ filenames = []
+ for i in range(self._num_files):
+ fn = os.path.join(self.get_temp_dir(), "tf_record.%d.txt" % i)
+ filenames.append(fn)
+ writer = python_io.TFRecordWriter(fn)
+ for j in range(self._num_records):
+ writer.write(self._record(i, j))
+ writer.close()
+ return filenames
diff --git a/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
new file mode 100644
index 0000000000..4c879dbae6
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/rejection_resample_test.py
@@ -0,0 +1,182 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.rejection_resample()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl.testing import parameterized
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.python.data.experimental.ops import resampling
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+def _time_resampling(
+ test_obj, data_np, target_dist, init_dist, num_to_sample):
+ dataset = dataset_ops.Dataset.from_tensor_slices(data_np).repeat()
+
+ # Reshape distribution via rejection sampling.
+ dataset = dataset.apply(
+ resampling.rejection_resample(
+ class_func=lambda x: x,
+ target_dist=target_dist,
+ initial_dist=init_dist,
+ seed=142))
+
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with test_obj.test_session() as sess:
+ start_time = time.time()
+ for _ in xrange(num_to_sample):
+ sess.run(get_next)
+ end_time = time.time()
+
+ return end_time - start_time
+
+
+class RejectionResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("InitialDistributionKnown", True),
+ ("InitialDistributionUnknown", False))
+ def testDistribution(self, initial_known):
+ classes = np.random.randint(5, size=(20000,)) # Uniformly sampled
+ target_dist = [0.9, 0.05, 0.05, 0.0, 0.0]
+ initial_dist = [0.2] * 5 if initial_known else None
+ classes = math_ops.to_int64(classes) # needed for Windows build.
+ dataset = dataset_ops.Dataset.from_tensor_slices(classes).shuffle(
+ 200, seed=21).map(lambda c: (c, string_ops.as_string(c))).repeat()
+
+ get_next = dataset.apply(
+ resampling.rejection_resample(
+ target_dist=target_dist,
+ initial_dist=initial_dist,
+ class_func=lambda c, _: c,
+ seed=27)).make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ returned = []
+ while len(returned) < 4000:
+ returned.append(sess.run(get_next))
+
+ returned_classes, returned_classes_and_data = zip(*returned)
+ _, returned_data = zip(*returned_classes_and_data)
+ self.assertAllEqual([compat.as_bytes(str(c))
+ for c in returned_classes], returned_data)
+ total_returned = len(returned_classes)
+ class_counts = np.array([
+ len([True for v in returned_classes if v == c])
+ for c in range(5)])
+ returned_dist = class_counts / total_returned
+ self.assertAllClose(target_dist, returned_dist, atol=1e-2)
+
+ @parameterized.named_parameters(
+ ("OnlyInitial", True),
+ ("NotInitial", False))
+ def testEdgeCasesSampleFromInitialDataset(self, only_initial_dist):
+ init_dist = [0.5, 0.5]
+ target_dist = [0.5, 0.5] if only_initial_dist else [0.0, 1.0]
+ num_classes = len(init_dist)
+ # We don't need many samples to test that this works.
+ num_samples = 100
+ data_np = np.random.choice(num_classes, num_samples, p=init_dist)
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(data_np)
+
+ # Reshape distribution.
+ dataset = dataset.apply(
+ resampling.rejection_resample(
+ class_func=lambda x: x,
+ target_dist=target_dist,
+ initial_dist=init_dist))
+
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ returned = []
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ returned.append(sess.run(get_next))
+
+ def testRandomClasses(self):
+ init_dist = [0.25, 0.25, 0.25, 0.25]
+ target_dist = [0.0, 0.0, 0.0, 1.0]
+ num_classes = len(init_dist)
+ # We don't need many samples to test a dirac-delta target distribution.
+ num_samples = 100
+ data_np = np.random.choice(num_classes, num_samples, p=init_dist)
+
+ dataset = dataset_ops.Dataset.from_tensor_slices(data_np)
+
+ # Apply a random mapping that preserves the data distribution.
+ def _remap_fn(_):
+ return math_ops.cast(random_ops.random_uniform([1]) * num_classes,
+ dtypes.int32)[0]
+ dataset = dataset.map(_remap_fn)
+
+ # Reshape distribution.
+ dataset = dataset.apply(
+ resampling.rejection_resample(
+ class_func=lambda x: x,
+ target_dist=target_dist,
+ initial_dist=init_dist))
+
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ returned = []
+ with self.assertRaises(errors.OutOfRangeError):
+ while True:
+ returned.append(sess.run(get_next))
+
+ classes, _ = zip(*returned)
+ bincount = np.bincount(
+ np.array(classes),
+ minlength=num_classes).astype(np.float32) / len(classes)
+
+ self.assertAllClose(target_dist, bincount, atol=1e-2)
+
+
+class ResampleDatasetBenchmark(test.Benchmark):
+
+ def benchmarkResamplePerformance(self):
+ init_dist = [0.25, 0.25, 0.25, 0.25]
+ target_dist = [0.0, 0.0, 0.0, 1.0]
+ num_classes = len(init_dist)
+ # We don't need many samples to test a dirac-delta target distribution
+ num_samples = 1000
+ data_np = np.random.choice(num_classes, num_samples, p=init_dist)
+
+ resample_time = _time_resampling(
+ self, data_np, target_dist, init_dist, num_to_sample=1000)
+
+ self.report_benchmark(
+ iters=1000, wall_time=resample_time, name="benchmark_resample")
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py
new file mode 100644
index 0000000000..516e489d04
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/restructured_dataset_test.py
@@ -0,0 +1,71 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the private `_RestructuredDataset` transformation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class RestructuredDatasetTest(test_base.DatasetTestBase):
+
+ def testRestructureDataset(self):
+ components = (array_ops.placeholder(dtypes.int32),
+ (array_ops.placeholder(dtypes.int32, shape=[None]),
+ array_ops.placeholder(dtypes.int32, shape=[20, 30])))
+ dataset = dataset_ops.Dataset.from_tensors(components)
+
+ i32 = dtypes.int32
+
+ test_cases = [((i32, i32, i32), None),
+ (((i32, i32), i32), None),
+ ((i32, i32, i32), (None, None, None)),
+ ((i32, i32, i32), ([17], [17], [20, 30]))]
+
+ for new_types, new_shape_lists in test_cases:
+ # pylint: disable=protected-access
+ new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
+ # pylint: enable=protected-access
+ self.assertEqual(new_types, new.output_types)
+ if new_shape_lists is not None:
+ for expected_shape_list, shape in zip(
+ nest.flatten(new_shape_lists), nest.flatten(new.output_shapes)):
+ if expected_shape_list is None:
+ self.assertIs(None, shape.ndims)
+ else:
+ self.assertEqual(expected_shape_list, shape.as_list())
+
+ fail_cases = [((i32, dtypes.int64, i32), None),
+ ((i32, i32, i32, i32), None),
+ ((i32, i32, i32), ((None, None), None)),
+ ((i32, i32, i32), (None, None, None, None)),
+ ((i32, i32, i32), (None, [None], [21, 30]))]
+
+ for new_types, new_shape_lists in fail_cases:
+ with self.assertRaises(ValueError):
+ # pylint: disable=protected-access
+ new = batching._RestructuredDataset(dataset, new_types, new_shape_lists)
+ # pylint: enable=protected-access
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/scan_test.py b/tensorflow/python/data/experimental/kernel_tests/scan_test.py
new file mode 100644
index 0000000000..0730455431
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/scan_test.py
@@ -0,0 +1,172 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.scan()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import itertools
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import scan_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class ScanTest(test_base.DatasetTestBase):
+
+ def _counting_dataset(self, start, scan_fn):
+ return dataset_ops.Dataset.from_tensors(0).repeat().apply(
+ scan_ops.scan(start, scan_fn))
+
+ def testCount(self):
+ def make_scan_fn(step):
+ return lambda state, _: (state + step, state)
+
+ start = array_ops.placeholder(dtypes.int32, shape=[])
+ step = array_ops.placeholder(dtypes.int32, shape=[])
+ take = array_ops.placeholder(dtypes.int64, shape=[])
+ iterator = self._counting_dataset(
+ start, make_scan_fn(step)).take(take).make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+
+ for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
+ (10, 2, 10), (10, -1, 10),
+ (10, -2, 10)]:
+ sess.run(iterator.initializer,
+ feed_dict={start: start_val, step: step_val, take: take_val})
+ for expected, _ in zip(
+ itertools.count(start_val, step_val), range(take_val)):
+ self.assertEqual(expected, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testFibonacci(self):
+ iterator = dataset_ops.Dataset.from_tensors(1).repeat(None).apply(
+ scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1]))
+ ).make_one_shot_iterator()
+
+ if context.executing_eagerly():
+ next_element = iterator.get_next
+ else:
+ get_next = iterator.get_next()
+ next_element = lambda: get_next
+
+ self.assertEqual(1, self.evaluate(next_element()))
+ self.assertEqual(1, self.evaluate(next_element()))
+ self.assertEqual(2, self.evaluate(next_element()))
+ self.assertEqual(3, self.evaluate(next_element()))
+ self.assertEqual(5, self.evaluate(next_element()))
+ self.assertEqual(8, self.evaluate(next_element()))
+
+ def testSparseCount(self):
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ def make_scan_fn(step):
+ return lambda state, _: (_sparse(state.values[0] + step), state)
+
+ start = array_ops.placeholder(dtypes.int32, shape=[])
+ step = array_ops.placeholder(dtypes.int32, shape=[])
+ take = array_ops.placeholder(dtypes.int64, shape=[])
+ iterator = self._counting_dataset(
+ _sparse(start),
+ make_scan_fn(step)).take(take).make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+
+ for start_val, step_val, take_val in [(0, 1, 10), (0, 1, 0), (10, 1, 10),
+ (10, 2, 10), (10, -1, 10),
+ (10, -2, 10)]:
+ sess.run(iterator.initializer,
+ feed_dict={start: start_val, step: step_val, take: take_val})
+ for expected, _ in zip(
+ itertools.count(start_val, step_val), range(take_val)):
+ self.assertEqual(expected, sess.run(next_element).values[0])
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testChangingStateShape(self):
+ # Test the fixed-point shape invariant calculations: start with
+ # initial values with known shapes, and use a scan function that
+ # changes the size of the state on each element.
+ def _scan_fn(state, input_value):
+ # Statically known rank, but dynamic length.
+ ret_longer_vector = array_ops.concat([state[0], state[0]], 0)
+ # Statically unknown rank.
+ ret_larger_rank = array_ops.expand_dims(state[1], 0)
+ return (ret_longer_vector, ret_larger_rank), (state, input_value)
+
+ dataset = dataset_ops.Dataset.from_tensors(0).repeat(5).apply(
+ scan_ops.scan(([0], 1), _scan_fn))
+ self.assertEqual([None], dataset.output_shapes[0][0].as_list())
+ self.assertIs(None, dataset.output_shapes[0][1].ndims)
+ self.assertEqual([], dataset.output_shapes[1].as_list())
+
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(5):
+ (longer_vector_val, larger_rank_val), _ = sess.run(next_element)
+ self.assertAllEqual([0] * (2**i), longer_vector_val)
+ self.assertAllEqual(np.array(1, ndmin=i), larger_rank_val)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testIncorrectStateType(self):
+
+ def _scan_fn(state, _):
+ return constant_op.constant(1, dtype=dtypes.int64), state
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ TypeError,
+ "The element types for the new state must match the initial state."):
+ dataset.apply(
+ scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
+
+ def testIncorrectReturnType(self):
+
+ def _scan_fn(unused_state, unused_input_value):
+ return constant_op.constant(1, dtype=dtypes.int64)
+
+ dataset = dataset_ops.Dataset.range(10)
+ with self.assertRaisesRegexp(
+ TypeError,
+ "The scan function must return a pair comprising the new state and the "
+ "output value."):
+ dataset.apply(
+ scan_ops.scan(constant_op.constant(1, dtype=dtypes.int32), _scan_fn))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
new file mode 100644
index 0000000000..e556b65b7c
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
@@ -0,0 +1,719 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_library(
+ name = "dataset_serialization_test_base",
+ srcs = [
+ "dataset_serialization_test_base.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:lookup_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:training",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "batch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["batch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "cache_dataset_serialization_test",
+ size = "small",
+ srcs = ["cache_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "checkpoint_input_pipeline_hook_test",
+ size = "small",
+ srcs = ["checkpoint_input_pipeline_hook_test.py"],
+ srcs_version = "PY2AND3",
+ tags = ["no_pip"],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_test(
+ name = "concatenate_dataset_serialization_test",
+ size = "small",
+ srcs = ["concatenate_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "csv_dataset_serialization_test",
+ size = "small",
+ srcs = ["csv_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ ],
+)
+
+py_test(
+ name = "dataset_constructor_serialization_test",
+ size = "medium",
+ srcs = ["dataset_constructor_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "filter_dataset_serialization_test",
+ size = "medium",
+ srcs = ["filter_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "fixed_length_record_dataset_serialization_test",
+ size = "medium",
+ srcs = ["fixed_length_record_dataset_serialization_test.py"],
+ shard_count = 4,
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "flat_map_dataset_serialization_test",
+ size = "medium",
+ srcs = ["flat_map_dataset_serialization_test.py"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:function",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "group_by_reducer_serialization_test",
+ size = "medium",
+ srcs = ["group_by_reducer_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "group_by_window_serialization_test",
+ size = "medium",
+ srcs = ["group_by_window_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "ignore_errors_serialization_test",
+ size = "small",
+ srcs = ["ignore_errors_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:error_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "interleave_dataset_serialization_test",
+ size = "medium",
+ srcs = ["interleave_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "map_and_batch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["map_and_batch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "map_dataset_serialization_test",
+ size = "medium",
+ srcs = ["map_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:function",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "optimize_dataset_serialization_test",
+ size = "small",
+ srcs = ["optimize_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "padded_batch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["padded_batch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "parallel_interleave_dataset_serialization_test",
+ size = "medium",
+ srcs = ["parallel_interleave_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "parallel_map_dataset_serialization_test",
+ size = "medium",
+ srcs = ["parallel_map_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:function",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "parse_example_dataset_serialization_test",
+ size = "medium",
+ srcs = ["parse_example_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
+ ],
+)
+
+py_test(
+ name = "prefetch_dataset_serialization_test",
+ size = "small",
+ srcs = ["prefetch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "range_dataset_serialization_test",
+ size = "small",
+ srcs = ["range_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "sample_from_datasets_serialization_test",
+ size = "medium",
+ srcs = ["sample_from_datasets_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "scan_dataset_serialization_test",
+ size = "small",
+ srcs = ["scan_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:scan_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "sequence_dataset_serialization_test",
+ size = "medium",
+ srcs = ["sequence_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "serialization_integration_test",
+ size = "small",
+ srcs = ["serialization_integration_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "shuffle_and_repeat_dataset_serialization_test",
+ size = "medium",
+ srcs = ["shuffle_and_repeat_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:shuffle_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "shuffle_dataset_serialization_test",
+ size = "medium",
+ srcs = ["shuffle_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "sql_dataset_serialization_test",
+ size = "small",
+ srcs = ["sql_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/kernel_tests:sql_dataset_test_base",
+ "//tensorflow/python/data/experimental/ops:readers",
+ ],
+)
+
+py_test(
+ name = "stats_dataset_serialization_test",
+ size = "medium",
+ srcs = ["stats_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:stats_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "textline_dataset_serialization_test",
+ size = "medium",
+ srcs = ["textline_dataset_serialization_test.py"],
+ shard_count = 4,
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "tf_record_dataset_serialization_test",
+ size = "medium",
+ srcs = ["tf_record_dataset_serialization_test.py"],
+ shard_count = 4,
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "unbatch_dataset_serialization_test",
+ size = "medium",
+ srcs = ["unbatch_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "unique_dataset_serialization_test",
+ size = "small",
+ srcs = ["unique_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "zip_dataset_serialization_test",
+ size = "small",
+ srcs = ["zip_dataset_serialization_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":dataset_serialization_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py
new file mode 100644
index 0000000000..d72a6df14c
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py
@@ -0,0 +1,83 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the BatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class BatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
+ components = (
+ np.arange(tensor_slice_len),
+ np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(tensor_slice_len))
+
+ return dataset_ops.Dataset.from_tensor_slices(components).batch(batch_size)
+
+ def testCore(self):
+ tensor_slice_len = 8
+ batch_size = 2
+ num_outputs = tensor_slice_len // batch_size
+ self.run_core_tests(
+ lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
+ lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
+ num_outputs)
+
+ def _build_dataset_dense_to_sparse(self, components):
+ return dataset_ops.Dataset.from_tensor_slices(components).map(
+ lambda x: array_ops.fill([x], x)).apply(
+ batching.dense_to_sparse_batch(4, [12]))
+
+ def testDenseToSparseBatchDatasetCore(self):
+ components = np.random.randint(5, size=(40,)).astype(np.int32)
+ diff_comp = np.random.randint(2, size=(100,)).astype(np.int32)
+
+ num_outputs = len(components) // 4
+ self.run_core_tests(lambda: self._build_dataset_dense_to_sparse(components),
+ lambda: self._build_dataset_dense_to_sparse(diff_comp),
+ num_outputs)
+
+ def _sparse(self, i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0]], values=(i * [1]), dense_shape=[1])
+
+ def _build_dataset_sparse(self, batch_size=5):
+ return dataset_ops.Dataset.range(10).map(self._sparse).batch(batch_size)
+
+ def testSparseCore(self):
+ self.run_core_tests(self._build_dataset_sparse,
+ lambda: self._build_dataset_sparse(2), 2)
+
+ def _build_dataset_nested_sparse(self):
+ return dataset_ops.Dataset.range(10).map(self._sparse).batch(5).batch(2)
+
+ def testNestedSparseCore(self):
+ self.run_core_tests(self._build_dataset_nested_sparse, None, 1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py
new file mode 100644
index 0000000000..2bcf77f5d8
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py
@@ -0,0 +1,253 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the CacheDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from absl.testing import parameterized
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class CacheDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase,
+ parameterized.TestCase):
+
+ def setUp(self):
+ self.range_size = 10
+ self.num_repeats = 3
+ self.num_outputs = self.range_size * self.num_repeats
+ self.cache_file_prefix = 'test'
+
+ def make_dataset_fn(self, is_memory):
+ if is_memory:
+ filename = ''
+ else:
+ filename = os.path.join(self.get_temp_dir(), self.cache_file_prefix)
+
+ def ds_fn():
+ return dataset_ops.Dataset.range(self.range_size).cache(filename).repeat(
+ self.num_repeats)
+
+ return ds_fn
+
+ def expected_outputs(self):
+ return list(range(self.range_size)) * self.num_repeats
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointBeforeOneEpoch(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Generate 5 entries from iterator and save checkpoint.
+ outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, range(5))
+
+ # Restore from checkpoint and produce the rest of the elements from the
+ # iterator.
+ outputs.extend(
+ self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False))
+ self.assertSequenceEqual(outputs, self.expected_outputs())
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointBeforeOneEpochThenRunFewSteps(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Generate 8 entries from iterator but save checkpoint after producing 5.
+ outputs = self.gen_outputs(
+ ds_fn, [5], 8, verify_exhausted=False, save_checkpoint_at_end=False)
+ self.assertSequenceEqual(outputs, range(8))
+
+ if is_memory:
+ outputs = outputs[:5]
+ outputs.extend(
+ self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False))
+ self.assertSequenceEqual(outputs, self.expected_outputs())
+ else:
+ # Restoring from checkpoint and running GetNext should return
+ # `AlreadExistsError` now because the lockfile already exists.
+ with self.assertRaises(errors.AlreadyExistsError):
+ self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointAfterOneEpoch(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Generate 15 entries from iterator and save checkpoint.
+ outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
+
+ # Restore from checkpoint and produce the rest of the elements from the
+ # iterator.
+ outputs.extend(
+ self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 15,
+ ckpt_saved=True,
+ verify_exhausted=False))
+ self.assertSequenceEqual(outputs, self.expected_outputs())
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointAfterOneEpochThenRunFewSteps(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Generate 18 entries from iterator but save checkpoint after producing 15.
+ outputs = self.gen_outputs(
+ ds_fn, [15], 18, verify_exhausted=False, save_checkpoint_at_end=False)
+ self.assertSequenceEqual(outputs, list(range(10)) + list(range(8)))
+
+ outputs = list(range(10)) + list(range(5)) + self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 15,
+ ckpt_saved=True,
+ verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointBeforeOneEpochButRunCompleteEpoch(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Generate 13 entries from iterator but save checkpoint after producing 5.
+ outputs = self.gen_outputs(
+ ds_fn, [5], 13, verify_exhausted=False, save_checkpoint_at_end=False)
+ self.assertSequenceEqual(outputs, list(range(10)) + list(range(3)))
+
+ # Since we ran for more than one epoch, the cache was completely written.
+ # The ckpt was saved when the iterator was in cache-write mode. Test that
+ # the iterator falls back to read mode after restoring if the cache has
+ # been completely written.
+
+ outputs = list(range(5)) + self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointUnusedWriterIterator(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Checkpoint before get_next is called even once.
+ outputs = self.gen_outputs(ds_fn, [], 0, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, [])
+
+ outputs = self.gen_outputs(
+ ds_fn, [], self.num_outputs, ckpt_saved=True, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testCheckpointUnusedMidwayWriterIterator(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Produce 5 elements and checkpoint.
+ outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, range(5))
+
+ # Restore from checkpoint, then produce no elements and checkpoint.
+ outputs.extend(
+ self.gen_outputs(ds_fn, [], 0, ckpt_saved=True, verify_exhausted=False))
+ self.assertSequenceEqual(outputs, range(5))
+
+ # Restore from checkpoint and produce rest of the elements.
+ outputs.extend(
+ self.gen_outputs(
+ ds_fn, [],
+ self.num_outputs - 5,
+ ckpt_saved=True,
+ verify_exhausted=False))
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testUnusedCheckpointError(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Produce 5 elements and save ckpt.
+ outputs = self.gen_outputs(ds_fn, [], 5, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, range(5))
+
+ if is_memory:
+ outputs = self.gen_outputs(
+ ds_fn, [], self.num_outputs, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, self.expected_outputs())
+ else:
+ # Since the complete cache has not been written, a new iterator which does
+ # not restore the checkpoint will throw an error since there is a partial
+ # cache shard.
+ with self.assertRaises(errors.AlreadyExistsError):
+ outputs = self.gen_outputs(
+ ds_fn, [], self.num_outputs, verify_exhausted=False)
+
+ @parameterized.named_parameters(
+ ('Memory', True),
+ ('File', False),
+ )
+ def testIgnoreCheckpointIfCacheWritten(self, is_memory):
+ ds_fn = self.make_dataset_fn(is_memory)
+
+ # Produce 15 elements and save ckpt. This will write the complete cache.
+ outputs = self.gen_outputs(ds_fn, [], 15, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) + list(range(5)))
+
+ # Build the iterator again but do not restore from ckpt. Since the cache
+ # has already been written we should be able to use it.
+ outputs = self.gen_outputs(
+ ds_fn, [], self.num_outputs, verify_exhausted=False)
+ self.assertSequenceEqual(outputs, list(range(10)) * 3)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py
new file mode 100644
index 0000000000..94393d6d4b
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/checkpoint_input_pipeline_hook_test.py
@@ -0,0 +1,125 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for experimental iterator_ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import iterator_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.estimator import estimator
+from tensorflow.python.estimator import model_fn
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import training_util
+
+
+class CheckpointInputPipelineHookTest(test_base.DatasetTestBase):
+
+ @staticmethod
+ def _model_fn(features, labels, mode, config):
+ del labels
+ del mode
+ del config
+ global_step = training_util.get_or_create_global_step()
+ update_global_step_op = global_step.assign_add(1)
+ latest_feature = variables.VariableV1(
+ 0, name='latest_feature', dtype=dtypes.int64)
+ store_latest_feature_op = latest_feature.assign(features)
+ ops.add_to_collection('my_vars', global_step)
+ ops.add_to_collection('my_vars', latest_feature)
+ return model_fn.EstimatorSpec(
+ mode='train',
+ train_op=control_flow_ops.group(
+ [update_global_step_op, store_latest_feature_op]),
+ loss=constant_op.constant(2.0))
+
+ def _read_vars(self, model_dir):
+ """Returns (global_step, latest_feature)."""
+ with ops.Graph().as_default() as g:
+ ckpt_path = checkpoint_management.latest_checkpoint(model_dir)
+ meta_filename = ckpt_path + '.meta'
+ saver_lib.import_meta_graph(meta_filename)
+ saver = saver_lib.Saver()
+ with self.session(graph=g) as sess:
+ saver.restore(sess, ckpt_path)
+ return sess.run(ops.get_collection('my_vars'))
+
+ def _build_iterator_saver_hook(self, est):
+ return iterator_ops.CheckpointInputPipelineHook(est)
+
+ def testReturnDatasetFromInputFn(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.range(10)
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+
+ def testBuildIteratorInInputFn(self):
+
+ def _input_fn():
+ ds = dataset_ops.Dataset.range(10)
+ iterator = ds.make_one_shot_iterator()
+ return iterator.get_next()
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+
+ def testDoNotRestore(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.range(10)
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (2, 1))
+ est.train(_input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (4, 3))
+ # Hook not provided, input pipeline was not restored.
+ est.train(_input_fn, steps=2)
+ self.assertSequenceEqual(self._read_vars(est.model_dir), (6, 1))
+
+ def testRaiseErrorIfNoIterator(self):
+
+ def _input_fn():
+ return constant_op.constant(1, dtype=dtypes.int64)
+
+ est = estimator.Estimator(model_fn=self._model_fn)
+
+ with self.assertRaises(ValueError):
+ est.train(
+ _input_fn, steps=2, hooks=[self._build_iterator_saver_hook(est)])
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py
new file mode 100644
index 0000000000..c075dff8cb
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py
@@ -0,0 +1,49 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ConcatenateDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class ConcatenateDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_concatenate_dataset(self, var_array):
+ input_components = (np.tile(np.array([[1], [2], [3], [4]]), 20),
+ np.tile(np.array([[12], [13], [14], [15]]), 4))
+ to_concatenate_components = (np.tile(
+ np.array([[5], [6], [7], [8], [9]]), 20), var_array)
+
+ return dataset_ops.Dataset.from_tensor_slices(input_components).concatenate(
+ dataset_ops.Dataset.from_tensor_slices(to_concatenate_components))
+
+ def testConcatenateCore(self):
+ num_outputs = 9
+ array = np.tile(np.array([[16], [17], [18], [19], [20]]), 15)
+ diff_array = np.array([[1], [2], [3], [4], [5]])
+ self.run_core_tests(lambda: self._build_concatenate_dataset(array),
+ lambda: self._build_concatenate_dataset(diff_array),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py
new file mode 100644
index 0000000000..d4983492e7
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py
@@ -0,0 +1,73 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the CsvDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.platform import test
+
+
+class CsvDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self._num_cols = 7
+ self._num_rows = 10
+ self._num_epochs = 14
+ self._num_outputs = self._num_rows * self._num_epochs
+
+ inputs = [
+ ",".join(str(self._num_cols * j + i)
+ for i in range(self._num_cols))
+ for j in range(self._num_rows)
+ ]
+ contents = "\n".join(inputs).encode("utf-8")
+
+ self._filename = os.path.join(self.get_temp_dir(), "file.csv")
+ self._compressed = os.path.join(self.get_temp_dir(),
+ "comp.csv") # GZip compressed
+
+ with open(self._filename, "wb") as f:
+ f.write(contents)
+ with gzip.GzipFile(self._compressed, "wb") as f:
+ f.write(contents)
+
+ def ds_func(self, **kwargs):
+ compression_type = kwargs.get("compression_type", None)
+ if compression_type == "GZIP":
+ filename = self._compressed
+ elif compression_type is None:
+ filename = self._filename
+ else:
+ raise ValueError("Invalid compression type:", compression_type)
+
+ return readers.CsvDataset(filename, **kwargs).repeat(self._num_epochs)
+
+ def testSerializationCore(self):
+ defs = [[0]] * self._num_cols
+ self.run_core_tests(
+ lambda: self.ds_func(record_defaults=defs, buffer_size=2),
+ lambda: self.ds_func(record_defaults=defs, buffer_size=12),
+ self._num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py
new file mode 100644
index 0000000000..41a095fb1a
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py
@@ -0,0 +1,95 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the dataset constructors serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.platform import test
+
+
+class FromTensorsSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_tensor_dataset(self, variable_array):
+ components = (variable_array, np.array([1, 2, 3]), np.array(37.0))
+
+ return dataset_ops.Dataset.from_tensors(components)
+
+ def testFromTensorsCore(self):
+ # Equal length components
+ arr = np.array(1)
+ num_outputs = 1
+ diff_arr = np.array(2)
+ self.run_core_tests(lambda: self._build_tensor_dataset(arr),
+ lambda: self._build_tensor_dataset(diff_arr),
+ num_outputs)
+
+
+class FromTensorSlicesSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_tensor_slices_dataset(self, components):
+ return dataset_ops.Dataset.from_tensor_slices(components)
+
+ def testFromTensorSlicesCore(self):
+ # Equal length components
+ components = (np.tile(np.array([[1], [2], [3], [4]]), 20),
+ np.tile(np.array([[12], [13], [14], [15]]), 22),
+ np.array([37.0, 38.0, 39.0, 40.0]))
+
+ diff_comp = (np.tile(np.array([[1], [2], [3], [4]]), 20),
+ np.tile(np.array([[5], [6], [7], [8]]), 22),
+ np.array([1.0, 2.0, 3.0, 4.0]))
+
+ dict_components = {"foo": [1, 2, 3], "bar": [[4.0], [5.0], [6.0]]}
+
+ self.run_core_tests(lambda: self._build_tensor_slices_dataset(components),
+ lambda: self._build_tensor_slices_dataset(diff_comp), 4)
+ self.run_core_tests(
+ lambda: self._build_tensor_slices_dataset(dict_components), None, 3)
+
+
+class FromSparseTensorSlicesSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_sparse_tensor_slice_dataset(self, slices):
+ indices = np.array(
+ [[i, j] for i in range(len(slices)) for j in range(len(slices[i]))],
+ dtype=np.int64)
+ values = np.array([val for s in slices for val in s], dtype=np.float64)
+ dense_shape = np.array(
+ [len(slices), max(len(s) for s in slices) + 1], dtype=np.int64)
+ sparse_components = sparse_tensor.SparseTensor(indices, values, dense_shape)
+ return dataset_ops.Dataset.from_sparse_tensor_slices(sparse_components)
+
+ def testFromSparseTensorSlicesCore(self):
+ slices = [[1., 2., 3.], [1.], [1.], [1., 2.], [], [1., 2.], [], [], []]
+ diff_slices = [[1., 2.], [2.], [2., 3., 4.], [], [], []]
+
+ self.run_core_tests(
+ lambda: self._build_sparse_tensor_slice_dataset(slices),
+ lambda: self._build_sparse_tensor_slice_dataset(diff_slices),
+ 9,
+ sparse_tensors=True)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
new file mode 100644
index 0000000000..7f435b8239
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
@@ -0,0 +1,692 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing serializable datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.util import nest
+
+
+def remove_variants(get_next_op):
+ # TODO(b/72408568): Remove this once session.run can get
+ # variant tensors.
+ """Remove variants from a nest structure, so sess.run will execute."""
+
+ def _remove_variant(x):
+ if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
+ return ()
+ else:
+ return x
+
+ return nest.map_structure(_remove_variant, get_next_op)
+
+
+class DatasetSerializationTestBase(test.TestCase):
+ """Base class for testing serializable datasets."""
+
+ def tearDown(self):
+ self._delete_ckpt()
+
+ # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
+ # (deprecated) saveable `SparseTensorSliceDataset`, once the API
+ # `from_sparse_tensor_slices()`and related tests are deleted.
+ def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
+ """Runs the core tests.
+
+ Args:
+ ds_fn1: 0-argument function that returns a Dataset.
+ ds_fn2: 0-argument function that returns a Dataset different from
+ ds_fn1. If None, verify_restore_in_modified_graph test is not run.
+ num_outputs: Total number of outputs expected from this Dataset.
+ sparse_tensors: Whether dataset is built from SparseTensor(s).
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_unused_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_fully_used_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_exhausted_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_init_before_restore(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_multiple_breaks(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_reset_restored_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_restore_in_empty_graph(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ if ds_fn2:
+ self.verify_restore_in_modified_graph(
+ ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors)
+
+ def verify_unused_iterator(self,
+ ds_fn,
+ num_outputs,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that saving and restoring an unused iterator works.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn, [0],
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_fully_used_iterator(self, ds_fn, num_outputs,
+ sparse_tensors=False):
+ """Verifies that saving and restoring a fully used iterator works.
+
+ Note that this only checks saving and restoring an iterator from which
+ `num_outputs` items have been produced but does not check for an
+ exhausted iterator, i.e., one from which an OutOfRange error has been
+ returned.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
+
+ def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
+ """Verifies that saving and restoring an exhausted iterator works.
+
+ An exhausted iterator is one which has returned an OutOfRange error.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ verify_exhausted=True,
+ sparse_tensors=sparse_tensors)
+ actual = self.gen_outputs(
+ ds_fn, [],
+ 0,
+ ckpt_saved=True,
+ verify_exhausted=True,
+ sparse_tensors=sparse_tensors)
+ self.assertEqual(len(actual), 0)
+
+ def verify_init_before_restore(self,
+ ds_fn,
+ num_outputs,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that restoring into an already initialized iterator works.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn,
+ self.gen_break_points(num_outputs),
+ num_outputs,
+ init_before_restore=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_multiple_breaks(self,
+ ds_fn,
+ num_outputs,
+ num_breaks=10,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to save/restore at multiple break points.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ num_breaks: The number of break points. These are uniformly spread in
+ [0, num_outputs] both inclusive.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn,
+ self.gen_break_points(num_outputs, num_breaks),
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_reset_restored_iterator(self,
+ ds_fn,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to re-initialize a restored iterator.
+
+ This is useful when restoring a training checkpoint during validation.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Collect ground truth containing all outputs.
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Skip some items and save checkpoint.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Restore from checkpoint and then run init_op.
+ with ops.Graph().as_default() as g:
+ saver = self._import_meta_graph()
+ init_op, get_next_op = self._get_iterator_ops_from_collection(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ self._initialize(init_op, sess)
+ for _ in range(num_outputs):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ self.match(expected, actual)
+
+ def verify_restore_in_modified_graph(self,
+ ds_fn1,
+ ds_fn2,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to restore an iterator in a modified graph.
+
+ Builds an input pipeline using ds_fn1, runs it for `break_point` steps
+ and saves a checkpoint. Then builds a new graph using ds_fn2, restores
+ the checkpoint from ds_fn1 and verifies that the restore is successful.
+
+ Args:
+ ds_fn1: See `run_core_tests`.
+ ds_fn2: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Skip `break_point` items and store the remaining produced from ds_fn1
+ # in `expected`.
+ self.gen_outputs(
+ ds_fn1, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+ expected = self.gen_outputs(
+ ds_fn1, [],
+ num_outputs - break_point,
+ ckpt_saved=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Generate `break_point` items from ds_fn1 and save checkpoint.
+ self.gen_outputs(
+ ds_fn1, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Build graph for ds_fn2 but load checkpoint for ds_fn1.
+ with ops.Graph().as_default() as g:
+ _, get_next_op, saver = self._build_graph(
+ ds_fn2, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ for _ in range(num_outputs - break_point):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ self.match(expected, actual)
+
+ def verify_restore_in_empty_graph(self,
+ ds_fn,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to restore an iterator in an empty graph.
+
+ Builds an input pipeline using ds_fn, runs it for `break_point` steps
+ and saves a checkpoint. Then builds a new empty graph, restores
+ the checkpoint from ds_fn and verifies that the restore is successful.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Skip `break_point` items and store the remaining produced from ds_fn
+ # in `expected`.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs - break_point,
+ ckpt_saved=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Generate `break_point` items from ds_fn and save checkpoint.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Build an empty graph but load checkpoint for ds_fn.
+ with ops.Graph().as_default() as g:
+ get_next_op, saver = self._build_empty_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ for _ in range(num_outputs - break_point):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ self.match(expected, actual)
+
+ def verify_error_on_save(self,
+ ds_fn,
+ num_outputs,
+ error,
+ break_point=None,
+ sparse_tensors=False):
+ """Attempts to save a non-saveable iterator.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ error: Declared error when trying to save iterator.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+
+ break_point = num_outputs // 2 if not break_point else break_point
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, saver = self._build_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._initialize(init_op, sess)
+ for _ in range(break_point):
+ sess.run(get_next_op)
+ with self.assertRaises(error):
+ self._save(sess, saver)
+
+ def verify_run_with_breaks(self,
+ ds_fn,
+ break_points,
+ num_outputs,
+ init_before_restore=False,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that ds_fn() produces the same outputs with and without breaks.
+
+ 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
+ *without* stopping at break points.
+ 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
+ with stopping at break points.
+
+ Deep matches outputs from 1 and 2.
+
+ Args:
+ ds_fn: See `gen_outputs`.
+ break_points: See `gen_outputs`.
+ num_outputs: See `gen_outputs`.
+ init_before_restore: See `gen_outputs`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ init_before_restore=init_before_restore,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ actual = self.gen_outputs(
+ ds_fn,
+ break_points,
+ num_outputs,
+ init_before_restore=init_before_restore,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ self.match(expected, actual)
+
+ def gen_outputs(self,
+ ds_fn,
+ break_points,
+ num_outputs,
+ ckpt_saved=False,
+ init_before_restore=False,
+ sparse_tensors=False,
+ verify_exhausted=True,
+ save_checkpoint_at_end=True):
+ """Generates elements from input dataset while stopping at break points.
+
+ Produces `num_outputs` outputs and saves the state of the iterator in the
+ Saver checkpoint.
+
+ Args:
+ ds_fn: 0-argument function that returns the dataset.
+ break_points: A list of integers. For each `break_point` in
+ `break_points`, we produce outputs till `break_point` number of items
+ have been produced and then checkpoint the state. The current graph
+ and session are destroyed and a new graph and session are used to
+ produce outputs till next checkpoint or till `num_outputs` elements
+ have been produced. `break_point` must be <= `num_outputs`.
+ num_outputs: The total number of outputs to produce from the iterator.
+ ckpt_saved: Whether a checkpoint already exists. If False, we build the
+ graph from ds_fn.
+ init_before_restore: Whether init should be called before saver.restore.
+ This is just so that we can verify that restoring an already initialized
+ iterator works.
+ sparse_tensors: Whether dataset is built from SparseTensor(s).
+ verify_exhausted: Whether to verify that the iterator has been exhausted
+ after producing `num_outputs` elements.
+ save_checkpoint_at_end: Whether to save a checkpoint after producing all
+ outputs. If False, checkpoints are saved each break point but not at the
+ end. Note that checkpoints overwrite each other so there is always only
+ a single checkpoint available. Defaults to True.
+
+ Returns:
+ A list of `num_outputs` items.
+ """
+ outputs = []
+
+ def get_ops():
+ if ckpt_saved:
+ saver = self._import_meta_graph()
+ init_op, get_next_op = self._get_iterator_ops_from_collection(
+ ds_fn, sparse_tensors=sparse_tensors)
+ else:
+ init_op, get_next_op, saver = self._build_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ return init_op, get_next_op, saver
+
+ for i in range(len(break_points) + 1):
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, saver = get_ops()
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ if ckpt_saved:
+ if init_before_restore:
+ self._initialize(init_op, sess)
+ self._restore(saver, sess)
+ else:
+ self._initialize(init_op, sess)
+ start = break_points[i - 1] if i > 0 else 0
+ end = break_points[i] if i < len(break_points) else num_outputs
+ num_iters = end - start
+ for _ in range(num_iters):
+ outputs.append(sess.run(get_next_op))
+ if i == len(break_points) and verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ if save_checkpoint_at_end or i < len(break_points):
+ self._save(sess, saver)
+ ckpt_saved = True
+
+ return outputs
+
+ def match(self, expected, actual):
+ """Matches nested structures.
+
+ Recursively matches shape and values of `expected` and `actual`.
+ Handles scalars, numpy arrays and other python sequence containers
+ e.g. list, dict.
+
+ Args:
+ expected: Nested structure 1.
+ actual: Nested structure 2.
+
+ Raises:
+ AssertionError if matching fails.
+ """
+ if isinstance(expected, np.ndarray):
+ expected = expected.tolist()
+ if isinstance(actual, np.ndarray):
+ actual = actual.tolist()
+ self.assertEqual(type(expected), type(actual))
+
+ if nest.is_sequence(expected):
+ self.assertEqual(len(expected), len(actual))
+ if isinstance(expected, dict):
+ for key1, key2 in zip(sorted(expected), sorted(actual)):
+ self.assertEqual(key1, key2)
+ self.match(expected[key1], actual[key2])
+ else:
+ for item1, item2 in zip(expected, actual):
+ self.match(item1, item2)
+ else:
+ self.assertEqual(expected, actual)
+
+ def does_not_match(self, expected, actual):
+ with self.assertRaises(AssertionError):
+ self.match(expected, actual)
+
+ def gen_break_points(self, num_outputs, num_samples=10):
+ """Generates `num_samples` breaks points in [0, num_outputs]."""
+ return np.linspace(0, num_outputs, num_samples, dtype=int)
+
+ def _build_graph(self, ds_fn, sparse_tensors=False):
+ iterator = ds_fn().make_initializable_iterator()
+
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ init_op = iterator.initializer
+ if sparse_tensors:
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ else:
+ get_next = iterator.get_next()
+ self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
+ sparse_tensors)
+ saver = saver_lib.Saver(allow_empty=True)
+ return init_op, get_next, saver
+
+ def _build_empty_graph(self, ds_fn, sparse_tensors=False):
+ iterator = iterator_ops.Iterator.from_structure(
+ self._get_output_types(ds_fn),
+ output_shapes=self._get_output_shapes(ds_fn),
+ output_classes=self._get_output_classes(ds_fn))
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ if sparse_tensors:
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ else:
+ get_next = iterator.get_next()
+ saver = saver_lib.Saver(allow_empty=True)
+ return get_next, saver
+
+ def _add_iterator_ops_to_collection(self,
+ init_op,
+ get_next,
+ ds_fn,
+ sparse_tensors=False):
+ ops.add_to_collection("iterator_ops", init_op)
+ # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
+ # do not support tuples we flatten the tensors and restore the shape in
+ # `_get_iterator_ops_from_collection`.
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
+ ops.add_to_collection("iterator_ops", get_next.indices)
+ ops.add_to_collection("iterator_ops", get_next.values)
+ ops.add_to_collection("iterator_ops", get_next.dense_shape)
+ return
+
+ get_next_list = nest.flatten(get_next)
+ for i, output_class in enumerate(
+ nest.flatten(self._get_output_classes(ds_fn))):
+ if output_class is sparse_tensor.SparseTensor:
+ ops.add_to_collection("iterator_ops", get_next_list[i].indices)
+ ops.add_to_collection("iterator_ops", get_next_list[i].values)
+ ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
+ else:
+ ops.add_to_collection("iterator_ops", get_next_list[i])
+
+ def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
+ all_ops = ops.get_collection("iterator_ops")
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
+ init_op, indices, values, dense_shape = all_ops
+ return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
+ get_next_list = []
+ i = 1
+ for output_class in nest.flatten(self._get_output_classes(ds_fn)):
+ if output_class is sparse_tensor.SparseTensor:
+ indices, values, dense_shape = all_ops[i:i + 3]
+ i += 3
+ get_next_list.append(
+ sparse_tensor.SparseTensor(indices, values, dense_shape))
+ else:
+ get_next_list.append(all_ops[i])
+ i += 1
+ return all_ops[0], nest.pack_sequence_as(
+ self._get_output_types(ds_fn), get_next_list)
+
+ def _get_output_types(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_types
+
+ def _get_output_shapes(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_shapes
+
+ def _get_output_classes(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_classes
+
+ def _ckpt_path(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _latest_ckpt(self):
+ return checkpoint_management.latest_checkpoint(self.get_temp_dir())
+
+ def _save(self, sess, saver):
+ saver.save(sess, self._ckpt_path())
+
+ def _restore(self, saver, sess):
+ sess.run(lookup_ops.tables_initializer())
+ saver.restore(sess, self._latest_ckpt())
+
+ def _initialize(self, init_op, sess):
+ sess.run(variables.global_variables_initializer())
+ sess.run(lookup_ops.tables_initializer())
+ sess.run(init_op)
+
+ def _import_meta_graph(self):
+ meta_file_path = self._ckpt_path() + ".meta"
+ return saver_lib.import_meta_graph(meta_file_path)
+
+ def _delete_ckpt(self):
+ # Remove all checkpoint files.
+ prefix = self._ckpt_path()
+ pattern = prefix + "*"
+ files = gfile.Glob(pattern)
+ map(gfile.Remove, files)
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
new file mode 100644
index 0000000000..225f6cbac0
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
@@ -0,0 +1,71 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the FilterDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class FilterDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_filter_range_graph(self, div):
+ return dataset_ops.Dataset.range(100).filter(
+ lambda x: math_ops.not_equal(math_ops.mod(x, div), 2))
+
+ def testFilterCore(self):
+ div = 3
+ num_outputs = np.sum([x % 3 != 2 for x in range(100)])
+ self.run_core_tests(lambda: self._build_filter_range_graph(div),
+ lambda: self._build_filter_range_graph(div * 2),
+ num_outputs)
+
+ def _build_filter_dict_graph(self):
+ return dataset_ops.Dataset.range(10).map(
+ lambda x: {"foo": x * 2, "bar": x ** 2}).filter(
+ lambda d: math_ops.equal(d["bar"] % 2, 0)).map(
+ lambda d: d["foo"] + d["bar"])
+
+ def testFilterDictCore(self):
+ num_outputs = np.sum([(x**2) % 2 == 0 for x in range(10)])
+ self.run_core_tests(self._build_filter_dict_graph, None, num_outputs)
+
+ def _build_sparse_filter(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensor(
+ indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
+
+ def _filter_fn(_, i):
+ return math_ops.equal(i % 2, 0)
+
+ return dataset_ops.Dataset.range(10).map(_map_fn).filter(_filter_fn).map(
+ lambda x, i: x)
+
+ def testSparseCore(self):
+ num_outputs = 5
+ self.run_core_tests(self._build_sparse_filter, None, num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
new file mode 100644
index 0000000000..70caf3e0d5
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
@@ -0,0 +1,45 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the FixedLengthRecordDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.platform import test
+
+
+class FixedLengthRecordDatasetSerializationTest(
+ reader_dataset_ops_test_base.FixedLengthRecordDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_iterator_graph(self, num_epochs, compression_type=None):
+ filenames = self._createFiles()
+ return core_readers.FixedLengthRecordDataset(
+ filenames, self._record_bytes, self._header_bytes,
+ self._footer_bytes).repeat(num_epochs)
+
+ def testFixedLengthRecordCore(self):
+ num_epochs = 5
+ num_outputs = num_epochs * self._num_files * self._num_records
+ self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
+ lambda: self._build_iterator_graph(num_epochs * 2),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py
new file mode 100644
index 0000000000..c30534a9e9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py
@@ -0,0 +1,122 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the FlatMapDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class FlatMapDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testCore(self):
+ # Complicated way of saying range(start, start+25).
+ def build_ds(start):
+
+ def map_fn(x):
+ return dataset_ops.Dataset.range(x, x + 5)
+
+ return dataset_ops.Dataset.range(start, start + 5 * 5, 5).flat_map(map_fn)
+
+ self.run_core_tests(lambda: build_ds(0), lambda: build_ds(10), 25)
+
+ def testMapThenFlatMap(self):
+
+ def build_ds():
+
+ def flat_map_fn(_):
+
+ def map_fn(y):
+ return 10 * math_ops.to_int32(y)
+
+ return dataset_ops.Dataset.range(100).map(map_fn)
+
+ return dataset_ops.Dataset.range(5).flat_map(flat_map_fn)
+
+ self.run_core_tests(build_ds, None, 500)
+
+ def testCaptureDefunInMapFn(self):
+
+ def build_ds():
+
+ def map_fn(x):
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.from_tensor_slices([defun_fn(x)])
+
+ return dataset_ops.Dataset.range(100).flat_map(map_fn)
+
+ self.run_core_tests(build_ds, None, 100)
+
+ def testDisallowVariableCapture(self):
+
+ def build_ds():
+ test_var = variable_scope.get_variable(
+ name="test_var", shape=(), use_resource=True)
+ return dataset_ops.Dataset.range(5).flat_map(
+ lambda _: dataset_ops.Dataset.from_tensor_slices([test_var]))
+
+ self.verify_error_on_save(build_ds, 5, errors.InvalidArgumentError)
+
+ def testDisallowCapturingStatefulOps(self):
+
+ def build_ds():
+
+ def flat_map_fn(_):
+
+ def map_fn(x):
+ return random_ops.random_uniform(
+ (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(100).map(map_fn)
+
+ return dataset_ops.Dataset.range(5).flat_map(flat_map_fn)
+
+ self.verify_error_on_save(build_ds, 500, errors.InvalidArgumentError)
+
+ def testSparseCore(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _flat_map_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ def _build_ds():
+ return dataset_ops.Dataset.range(10).map(_map_fn).flat_map(_flat_map_fn)
+
+ self.run_core_tests(_build_ds, None, 20)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py
new file mode 100644
index 0000000000..169c8845d0
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py
@@ -0,0 +1,61 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the GroupByReducer serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class GroupByReducerSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, components):
+ reducer = grouping.Reducer(
+ init_func=lambda _: np.int64(0),
+ reduce_func=lambda x, y: x + y,
+ finalize_func=lambda x: x)
+
+ return dataset_ops.Dataset.from_tensor_slices(components).apply(
+ grouping.group_by_reducer(lambda x: x % 5, reducer))
+
+ def testCoreGroupByReducer(self):
+ components = np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=np.int64)
+ self.verify_unused_iterator(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_init_before_restore(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_multiple_breaks(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_reset_restored_iterator(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ self.verify_restore_in_empty_graph(
+ lambda: self._build_dataset(components), 5, verify_exhausted=True)
+ diff_components = np.array([5, 4, 3, 2, 1, 0], dtype=np.int64)
+ self.verify_restore_in_modified_graph(
+ lambda: self._build_dataset(components),
+ lambda: self._build_dataset(diff_components),
+ 5,
+ verify_exhausted=True)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py
new file mode 100644
index 0000000000..e5bc76288e
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the GroupByWindow serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class GroupByWindowSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, components):
+ return dataset_ops.Dataset.from_tensor_slices(components).repeat(-1).apply(
+ grouping.group_by_window(lambda x: x % 3, lambda _, xs: xs.batch(4), 4))
+
+ def testCoreGroupByWindow(self):
+ components = np.array(
+ [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0], dtype=np.int64)
+ self.verify_unused_iterator(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ self.verify_init_before_restore(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ self.verify_multiple_breaks(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ self.verify_reset_restored_iterator(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ self.verify_restore_in_empty_graph(
+ lambda: self._build_dataset(components), 12, verify_exhausted=False)
+ diff_components = np.array([0, 0, 0, 1, 1, 1], dtype=np.int64)
+ self.verify_restore_in_modified_graph(
+ lambda: self._build_dataset(components),
+ lambda: self._build_dataset(diff_components),
+ 12,
+ verify_exhausted=False)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py
new file mode 100644
index 0000000000..df1f43129a
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py
@@ -0,0 +1,46 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the IgnoreErrors input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class IgnoreErrorsSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_ds(self, components):
+ return dataset_ops.Dataset.from_tensor_slices(components).map(
+ lambda x: array_ops.check_numerics(x, "message")).apply(
+ error_ops.ignore_errors())
+
+ def testIgnoreErrorsCore(self):
+ components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
+ diff_components = np.array([1., 2., 3., np.nan]).astype(np.float32)
+ num_outputs = 4
+ self.run_core_tests(lambda: self._build_ds(components),
+ lambda: self._build_ds(diff_components), num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py
new file mode 100644
index 0000000000..0c1d40ce39
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py
@@ -0,0 +1,83 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the InterleaveDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import test
+
+
+class InterleaveDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase,
+ parameterized.TestCase):
+
+ def _build_iterator_graph(self, input_values, cycle_length, block_length,
+ num_parallel_calls):
+ repeat_count = 2
+ return dataset_ops.Dataset.from_tensor_slices(input_values).repeat(
+ repeat_count).interleave(
+ lambda x: dataset_ops.Dataset.from_tensors(x).repeat(x),
+ cycle_length, block_length, num_parallel_calls)
+
+ @parameterized.named_parameters(
+ ("1", 2, 3, None),
+ ("2", 2, 3, 1),
+ ("3", 2, 3, 2),
+ ("4", 1, 3, None),
+ ("5", 1, 3, 1),
+ ("6", 2, 1, None),
+ ("7", 2, 1, 1),
+ ("8", 2, 1, 2),
+ )
+ def testSerializationCore(self, cycle_length, block_length,
+ num_parallel_calls):
+ input_values = np.array([4, 5, 6], dtype=np.int64)
+ num_outputs = np.sum(input_values) * 2
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(
+ input_values, cycle_length, block_length, num_parallel_calls),
+ lambda: self._build_iterator_graph(
+ input_values, cycle_length * 2, block_length, num_parallel_calls),
+ num_outputs)
+ # pylint: enable=g-long-lambda
+
+ def testSparseCore(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _interleave_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ def _build_dataset():
+ return dataset_ops.Dataset.range(10).map(_map_fn).interleave(
+ _interleave_fn, cycle_length=1)
+
+ self.run_core_tests(_build_dataset, None, 20)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
new file mode 100644
index 0000000000..166ffa99ca
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
@@ -0,0 +1,88 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapAndBatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class MapAndBatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testNumParallelBatches(self):
+ range_size = 11
+ num_repeats = 2
+ batch_size = 5
+ total_outputs = range_size * num_repeats
+ num_outputs_drop_remainder = total_outputs // batch_size
+ num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
+ num_parallel_batches = 2
+
+ def build_ds(range_start, drop_remainder=False):
+
+ def _map_fn(x):
+ return math_ops.square(x)
+
+ return dataset_ops.Dataset.range(
+ range_start, range_start + range_size).repeat(num_repeats).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_batches=num_parallel_batches,
+ drop_remainder=drop_remainder))
+
+ self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
+ num_outputs_keep_remainder)
+ self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
+ num_outputs_drop_remainder)
+
+ def testNumParallelCalls(self):
+ range_size = 11
+ num_repeats = 2
+ batch_size = 5
+ total_outputs = range_size * num_repeats
+ num_outputs_drop_remainder = total_outputs // batch_size
+ num_outputs_keep_remainder = int(math.ceil(total_outputs / batch_size))
+ num_parallel_calls = 7
+
+ def build_ds(range_start, drop_remainder=False):
+
+ def _map_fn(x):
+ return math_ops.square(x)
+
+ return dataset_ops.Dataset.range(
+ range_start, range_start + range_size).repeat(num_repeats).apply(
+ batching.map_and_batch(
+ map_func=_map_fn,
+ batch_size=batch_size,
+ num_parallel_calls=num_parallel_calls,
+ drop_remainder=drop_remainder))
+
+ self.run_core_tests(lambda: build_ds(10), lambda: build_ds(15),
+ num_outputs_keep_remainder)
+ self.run_core_tests(lambda: build_ds(10, True), lambda: build_ds(15, True),
+ num_outputs_drop_remainder)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py
new file mode 100644
index 0000000000..b93156a96c
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py
@@ -0,0 +1,140 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the MapDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class MapDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self._tensor_slice_len = 7
+ self._num_epochs = 14
+ self._num_outputs = self._tensor_slice_len * self._num_epochs
+
+ def _build_ds(self, multiplier=37.0):
+ components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
+ np.arange(self._tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(self._tensor_slice_len))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ return (
+ dataset_ops.Dataset.from_tensor_slices(components).map(_map_fn)
+ .repeat(self._num_epochs))
+
+ def testSaveRestoreCore(self):
+ self.run_core_tests(
+ self._build_ds,
+ lambda: self._build_ds(multiplier=15.0),
+ self._num_outputs)
+
+ def testSaveStatefulFunction(self):
+
+ def _build_ds():
+
+ def _map_fn(x):
+ return random_ops.random_uniform(
+ (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(100).map(_map_fn)
+
+ self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+
+ def testCaptureVariableInMapFn(self):
+
+ def _build_ds():
+ counter_var = variable_scope.get_variable(
+ "counter", (), dtypes.int32, use_resource=True)
+ return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda _: counter_var.assign_add(1)))
+
+ self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+
+ def testCaptureConstantInMapFn(self):
+
+ def _build_ds():
+ constant_var = constant_op.constant(5)
+ return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda x: x + constant_var))
+
+ self.run_core_tests(_build_ds, None, 10)
+
+ def testCaptureDefunInMapFn(self):
+ num_outputs = 100
+
+ def _build_ds():
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
+
+ self.run_core_tests(_build_ds, None, num_outputs)
+
+ def testBuildDefunInMapFn(self):
+ num_outputs = 100
+
+ def _build_ds():
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+
+ @function.Defun(dtypes.int32)
+ def defun_fn_deep(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
+
+ return dataset_ops.Dataset.range(num_outputs).map(defun_fn)
+
+ self.run_core_tests(_build_ds, None, num_outputs)
+
+ def testSparseCore(self):
+
+ def _sparse(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ def _build_ds(num_outputs):
+ return dataset_ops.Dataset.range(num_outputs).map(_sparse)
+
+ num_outputs = 10
+ self.run_core_tests(lambda: _build_ds(num_outputs),
+ lambda: _build_ds(int(num_outputs / 2)), num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py
new file mode 100644
index 0000000000..ed4a1da596
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py
@@ -0,0 +1,39 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the OptimizeDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class OptimizeDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testCore(self):
+
+ def build_dataset(num_elements, batch_size):
+ return dataset_ops.Dataset.range(num_elements).map(lambda x: x * x).batch(
+ batch_size).apply(optimization.optimize(["map_and_batch_fusion"]))
+
+ self.run_core_tests(lambda: build_dataset(200, 10), None, 20)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
new file mode 100644
index 0000000000..6f72b24673
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
@@ -0,0 +1,66 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the PaddedBatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class PaddedBatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testPaddedBatch(self):
+
+ def build_dataset(seq_lens):
+ return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
+ lambda x: array_ops.fill([x], x)).padded_batch(
+ 4, padded_shapes=[-1])
+
+ seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
+ seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
+ self.run_core_tests(lambda: build_dataset(seq_lens1),
+ lambda: build_dataset(seq_lens2), 8)
+
+ def testPaddedBatchNonDefaultPadding(self):
+
+ def build_dataset(seq_lens):
+
+ def fill_tuple(x):
+ filled = array_ops.fill([x], x)
+ return (filled, string_ops.as_string(filled))
+
+ padded_shape = [-1]
+ return dataset_ops.Dataset.from_tensor_slices(seq_lens).map(
+ fill_tuple).padded_batch(
+ 4,
+ padded_shapes=(padded_shape, padded_shape),
+ padding_values=(-1, "<end>"))
+
+ seq_lens1 = np.random.randint(1, 20, size=(32,)).astype(np.int32)
+ seq_lens2 = np.random.randint(21, 40, size=(32,)).astype(np.int32)
+ self.run_core_tests(lambda: build_dataset(seq_lens1),
+ lambda: build_dataset(seq_lens2), 8)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
new file mode 100644
index 0000000000..b8f38e8a28
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
@@ -0,0 +1,101 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ParallelInterleaveDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.platform import test
+
+
+class ParallelInterleaveDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self.input_values = np.array([4, 5, 6], dtype=np.int64)
+ self.num_repeats = 2
+ self.num_outputs = np.sum(self.input_values) * 2
+
+ def _build_ds(self, cycle_length, block_length, sloppy=False):
+ return (dataset_ops.Dataset.from_tensor_slices(
+ self.input_values).repeat(self.num_repeats).apply(
+ interleave_ops.parallel_interleave(
+ lambda x: dataset_ops.Dataset.range(10 * x, 11 * x),
+ cycle_length, block_length, sloppy)))
+
+ def testSerializationCore(self):
+ # cycle_length > 1, block_length > 1
+ cycle_length = 2
+ block_length = 3
+ self.run_core_tests(
+ lambda: self._build_ds(cycle_length, block_length),
+ lambda: self._build_ds(cycle_length * 2, block_length * 1),
+ self.num_outputs)
+ # cycle_length = 1
+ cycle_length = 1
+ block_length = 3
+ self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
+ None, self.num_outputs)
+ # block_length = 1
+ cycle_length = 2
+ block_length = 1
+ self.run_core_tests(lambda: self._build_ds(cycle_length, block_length),
+ None, self.num_outputs)
+
+ def testSerializationWithSloppy(self):
+ break_points = self.gen_break_points(self.num_outputs, 10)
+ expected_outputs = np.repeat(
+ np.concatenate([np.arange(10 * x, 11 * x) for x in self.input_values]),
+ self.num_repeats).tolist()
+
+ def run_test(cycle_length, block_length):
+ actual = self.gen_outputs(
+ lambda: self._build_ds(cycle_length, block_length, True),
+ break_points, self.num_outputs)
+ self.assertSequenceEqual(sorted(actual), expected_outputs)
+
+ # cycle_length > 1, block_length > 1
+ run_test(2, 3)
+ # cycle_length = 1
+ run_test(1, 3)
+ # block_length = 1
+ run_test(2, 1)
+
+ def testSparseCore(self):
+
+ def _map_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=[[0, 0], [1, 1]], values=(i * [1, -1]), dense_shape=[2, 2])
+
+ def _interleave_fn(x):
+ return dataset_ops.Dataset.from_tensor_slices(
+ sparse_ops.sparse_to_dense(x.indices, x.dense_shape, x.values))
+
+ def _build_dataset():
+ return dataset_ops.Dataset.range(10).map(_map_fn).apply(
+ interleave_ops.parallel_interleave(_interleave_fn, 1))
+
+ self.run_core_tests(_build_dataset, None, 20)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
new file mode 100644
index 0000000000..a0bdd4fa59
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
@@ -0,0 +1,139 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ParallelMapDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import function
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.platform import test
+
+
+class ParallelMapDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def setUp(self):
+ self._tensor_slice_len = 7
+ self._num_epochs = 1
+ self._num_outputs = self._tensor_slice_len * self._num_epochs
+
+ def _build_ds(self, multiplier=37.0):
+ components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
+ np.arange(self._tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(self._tensor_slice_len))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ return (dataset_ops.Dataset.from_tensor_slices(components).map(
+ _map_fn, num_parallel_calls=3).repeat(self._num_epochs))
+
+ def _build_ds_with_prefetch(self, multiplier=37.0):
+ components = (np.arange(self._tensor_slice_len), np.array([[1, 2, 3]]) *
+ np.arange(self._tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(self._tensor_slice_len))
+
+ def _map_fn(x, y, z):
+ return math_ops.square(x), math_ops.square(y), math_ops.square(z)
+
+ return (dataset_ops.Dataset.from_tensor_slices(components).map(
+ _map_fn, num_parallel_calls=3).repeat(self._num_epochs).prefetch(5))
+
+ def testSaveRestoreCore(self):
+ for ds_fn in [self._build_ds, self._build_ds_with_prefetch]:
+ self.run_core_tests(
+ ds_fn,
+ lambda: ds_fn(multiplier=15.0), # pylint: disable=cell-var-from-loop
+ self._num_outputs)
+
+ def testSaveStatefulFunction(self):
+
+ def _build_ds():
+
+ def _map_fn(x):
+ return random_ops.random_uniform(
+ (), 0, 10, dtype=dtypes.int32) * math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(100).map(
+ _map_fn, num_parallel_calls=2).prefetch(2)
+
+ self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+
+ def testCaptureVariableInMapFn(self):
+
+ def _build_ds():
+ counter_var = variable_scope.get_variable(
+ "counter", (), dtypes.int32, use_resource=True)
+ return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda _: counter_var.assign_add(1),
+ num_parallel_calls=2).prefetch(2))
+
+ self.verify_error_on_save(_build_ds, 15, errors.InvalidArgumentError)
+
+ def testCaptureConstantInMapFn(self):
+
+ def _build_ds():
+ constant_var = constant_op.constant(5)
+ return (dataset_ops.Dataset.from_tensors(0).repeat(10).map(
+ lambda x: x + constant_var, num_parallel_calls=2).prefetch(2))
+
+ self.run_core_tests(_build_ds, None, 10)
+
+ def testCaptureDefunInMapFn(self):
+ num_outputs = 100
+
+ def _build_ds():
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return dataset_ops.Dataset.range(num_outputs).map(
+ defun_fn, num_parallel_calls=2).prefetch(2)
+
+ self.run_core_tests(_build_ds, None, num_outputs)
+
+ def testBuildDefunInMapFn(self):
+ num_outputs = 100
+
+ def _build_ds():
+
+ @function.Defun(dtypes.int64)
+ def defun_fn(x):
+
+ @function.Defun(dtypes.int32)
+ def defun_fn_deep(x):
+ return constant_op.constant(1000) + math_ops.to_int32(x)
+
+ return constant_op.constant(11000) + defun_fn_deep(math_ops.to_int32(x))
+
+ return dataset_ops.Dataset.range(num_outputs).map(
+ defun_fn, num_parallel_calls=2).prefetch(2)
+
+ self.run_core_tests(_build_ds, None, num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
new file mode 100644
index 0000000000..b3dfe21486
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
@@ -0,0 +1,50 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ParseExampleDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.platform import test
+
+
+class ParseExampleDatasetSerializationTest(
+ reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def ParseExampleDataset(self, num_repeat, batch_size):
+ return self.make_batch_feature(
+ filenames=self.test_filenames,
+ num_epochs=num_repeat,
+ batch_size=batch_size,
+ reader_num_threads=5,
+ parser_num_threads=10)
+
+ def testSerializationCore(self):
+ num_repeat = 5
+ batch_size = 2
+ num_outputs = self._num_records * self._num_files * num_repeat // batch_size
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: self.ParseExampleDataset(
+ num_repeat=num_repeat, batch_size=batch_size),
+ lambda: self.ParseExampleDataset(num_repeat=10, batch_size=4),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py
new file mode 100644
index 0000000000..00d74c0025
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py
@@ -0,0 +1,39 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the PrefetchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class PrefetchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def build_dataset(self, seed):
+ return dataset_ops.Dataset.range(100).prefetch(10).shuffle(
+ buffer_size=10, seed=seed, reshuffle_each_iteration=False)
+
+ def testCore(self):
+ num_outputs = 100
+ self.run_core_tests(lambda: self.build_dataset(10),
+ lambda: self.build_dataset(20), num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
new file mode 100644
index 0000000000..ef99d01c73
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
@@ -0,0 +1,118 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the RangeDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import io_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class RangeDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _iterator_checkpoint_prefix_local(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _save_op(self, iterator_resource):
+ iterator_state_variant = gen_dataset_ops.serialize_iterator(
+ iterator_resource)
+ save_op = io_ops.write_file(
+ self._iterator_checkpoint_prefix_local(),
+ parsing_ops.serialize_tensor(iterator_state_variant))
+ return save_op
+
+ def _restore_op(self, iterator_resource):
+ iterator_state_variant = parsing_ops.parse_tensor(
+ io_ops.read_file(self._iterator_checkpoint_prefix_local()),
+ dtypes.variant)
+ restore_op = gen_dataset_ops.deserialize_iterator(iterator_resource,
+ iterator_state_variant)
+ return restore_op
+
+ def testSaveRestore(self):
+
+ def _build_graph(start, stop):
+ iterator = dataset_ops.Dataset.range(start,
+ stop).make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ save_op = self._save_op(iterator._iterator_resource)
+ restore_op = self._restore_op(iterator._iterator_resource)
+ return init_op, get_next, save_op, restore_op
+
+ # Saving and restoring in different sessions.
+ start = 2
+ stop = 10
+ break_point = 5
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, _ = _build_graph(start, stop)
+ with self.session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+
+ with ops.Graph().as_default() as g:
+ init_op, get_next, _, restore_op = _build_graph(start, stop)
+ with self.session(graph=g) as sess:
+ sess.run(init_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Saving and restoring in same session.
+ with ops.Graph().as_default() as g:
+ init_op, get_next, save_op, restore_op = _build_graph(start, stop)
+ with self.session(graph=g) as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(init_op)
+ for i in range(start, break_point):
+ self.assertEqual(i, sess.run(get_next))
+ sess.run(save_op)
+ sess.run(restore_op)
+ for i in range(break_point, stop):
+ self.assertEqual(i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def _build_range_dataset(self, start, stop):
+ return dataset_ops.Dataset.range(start, stop)
+
+ def testRangeCore(self):
+ start = 2
+ stop = 10
+ stop_1 = 8
+ self.run_core_tests(lambda: self._build_range_dataset(start, stop),
+ lambda: self._build_range_dataset(start, stop_1),
+ stop - start)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py
new file mode 100644
index 0000000000..c23c1ecdfb
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py
@@ -0,0 +1,46 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the SampleFromDatasets serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class SampleFromDatasetsSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, probs, num_samples):
+ dataset = interleave_ops.sample_from_datasets(
+ [
+ dataset_ops.Dataset.from_tensors(i).repeat(None)
+ for i in range(len(probs))
+ ],
+ probs,
+ seed=1813)
+ return dataset.take(num_samples)
+
+ def testSerializationCore(self):
+ self.run_core_tests(
+ lambda: self._build_dataset([0.5, 0.5], 100),
+ lambda: self._build_dataset([0.25, 0.25, 0.25, 0.25], 1000), 100)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py
new file mode 100644
index 0000000000..5f50160619
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py
@@ -0,0 +1,40 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ScanDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import scan_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class ScanDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, num_elements):
+ return dataset_ops.Dataset.from_tensors(1).repeat(num_elements).apply(
+ scan_ops.scan([0, 1], lambda a, _: ([a[1], a[0] + a[1]], a[1])))
+
+ def testScanCore(self):
+ num_output = 5
+ self.run_core_tests(lambda: self._build_dataset(num_output),
+ lambda: self._build_dataset(2), num_output)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py
new file mode 100644
index 0000000000..fe99a3d3d9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py
@@ -0,0 +1,129 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the sequence datasets serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class SkipDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_skip_dataset(self, count):
+ components = (np.arange(10),)
+ return dataset_ops.Dataset.from_tensor_slices(components).skip(count)
+
+ def testSkipFewerThanInputs(self):
+ count = 4
+ num_outputs = 10 - count
+ self.run_core_tests(lambda: self._build_skip_dataset(count),
+ lambda: self._build_skip_dataset(count + 2),
+ num_outputs)
+
+ def testSkipVarious(self):
+ # Skip more than inputs
+ self.run_core_tests(lambda: self._build_skip_dataset(20), None, 0)
+ # Skip exactly the input size
+ self.run_core_tests(lambda: self._build_skip_dataset(10), None, 0)
+ self.run_core_tests(lambda: self._build_skip_dataset(-1), None, 0)
+ # Skip nothing
+ self.run_core_tests(lambda: self._build_skip_dataset(0), None, 10)
+
+ def testInvalidSkip(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'Shape must be rank 0 but is rank 1'):
+ self.run_core_tests(lambda: self._build_skip_dataset([1, 2]), None, 0)
+
+
+class TakeDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_take_dataset(self, count):
+ components = (np.arange(10),)
+ return dataset_ops.Dataset.from_tensor_slices(components).take(count)
+
+ def testTakeFewerThanInputs(self):
+ count = 4
+ self.run_core_tests(
+ lambda: self._build_take_dataset(count),
+ lambda: self._build_take_dataset(count + 2),
+ count,
+ )
+
+ def testTakeVarious(self):
+ # Take more than inputs
+ self.run_core_tests(lambda: self._build_take_dataset(20), None, 10)
+ # Take exactly the input size
+ self.run_core_tests(lambda: self._build_take_dataset(10), None, 10)
+ # Take all
+ self.run_core_tests(lambda: self._build_take_dataset(-1), None, 10)
+ # Take nothing
+ self.run_core_tests(lambda: self._build_take_dataset(0), None, 0)
+
+ def testInvalidTake(self):
+ with self.assertRaisesRegexp(ValueError,
+ 'Shape must be rank 0 but is rank 1'):
+ self.run_core_tests(lambda: self._build_take_dataset([1, 2]), None, 0)
+
+
+class RepeatDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_repeat_dataset(self, count, take_count=3):
+ components = (np.arange(10),)
+ return dataset_ops.Dataset.from_tensor_slices(components).take(
+ take_count).repeat(count)
+
+ def testFiniteRepeat(self):
+ count = 10
+ self.run_core_tests(lambda: self._build_repeat_dataset(count),
+ lambda: self._build_repeat_dataset(count + 2),
+ 3 * count)
+
+ def testEmptyRepeat(self):
+ self.run_core_tests(lambda: self._build_repeat_dataset(0), None, 0)
+
+ def testInfiniteRepeat(self):
+ self.verify_unused_iterator(
+ lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
+ self.verify_init_before_restore(
+ lambda: self._build_repeat_dataset(-1), 10, verify_exhausted=False)
+ self.verify_multiple_breaks(
+ lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
+ self.verify_reset_restored_iterator(
+ lambda: self._build_repeat_dataset(-1), 20, verify_exhausted=False)
+ self.verify_restore_in_modified_graph(
+ lambda: self._build_repeat_dataset(-1),
+ lambda: self._build_repeat_dataset(2),
+ 20,
+ verify_exhausted=False)
+ # Test repeat empty dataset
+ self.run_core_tests(lambda: self._build_repeat_dataset(-1, 0), None, 0)
+
+ def testInvalidRepeat(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'Shape must be rank 0 but is rank 1'):
+ self.run_core_tests(lambda: self._build_repeat_dataset([1, 2], 0),
+ None, 0)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
new file mode 100644
index 0000000000..88d5c896c9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
@@ -0,0 +1,85 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Integration test for dataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
+
+
+class SerializationIntegrationTest(test.TestCase):
+
+ def _build_input_pipeline(self, name, num_outputs):
+ with ops.name_scope(name):
+ ds = dataset_ops.Dataset.range(num_outputs).shuffle(
+ 10, reshuffle_each_iteration=False).prefetch(10)
+ iterator = ds.make_initializable_iterator()
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ return iterator.initializer, iterator.get_next()
+
+ def _build_graph(self, num_pipelines, num_outputs):
+ init_ops = []
+ get_next_ops = []
+ for i in range(num_pipelines):
+ name = "input_pipeline_%d" % i
+ init_op, get_next_op = self._build_input_pipeline(name, num_outputs)
+ init_ops.append(init_op)
+ get_next_ops.append(get_next_op)
+ saver = saver_lib.Saver()
+ return init_ops, get_next_ops, saver
+
+ def _ckpt_path(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def testConcurrentSaves(self):
+ num_pipelines = 100
+ num_outputs = 100
+ break_point = 10
+ all_outputs = [[] for _ in range(num_pipelines)]
+ with ops.Graph().as_default() as g:
+ init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
+ num_outputs)
+ with self.session(graph=g) as sess:
+ sess.run(init_ops)
+ for _ in range(break_point):
+ output = sess.run(get_next_ops)
+ for i in range(num_pipelines):
+ all_outputs[i].append(output[i])
+ saver.save(sess, self._ckpt_path())
+
+ with ops.Graph().as_default() as g:
+ init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
+ num_outputs)
+ with self.session(graph=g) as sess:
+ saver.restore(sess, self._ckpt_path())
+ for _ in range(num_outputs - break_point):
+ output = sess.run(get_next_ops)
+ for i in range(num_pipelines):
+ all_outputs[i].append(output[i])
+
+ for output in all_outputs:
+ self.assertSequenceEqual(sorted(output), range(num_outputs))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
new file mode 100644
index 0000000000..f847ac19f9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
@@ -0,0 +1,39 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ShuffleAndRepeatDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class ShuffleAndRepeatSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_ds(self, seed):
+ return dataset_ops.Dataset.range(20).apply(
+ shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed))
+
+ def testCore(self):
+ self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20),
+ 100)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
new file mode 100644
index 0000000000..a04f1ddafc
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
@@ -0,0 +1,148 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ShuffleDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
+
+
+class ShuffleDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_shuffle_dataset(
+ self,
+ range_limit=10,
+ num_repeats=5,
+ buffer_size=5,
+ seed=None,
+ reshuffle_each_iteration=None,
+ ):
+ return dataset_ops.Dataset.range(range_limit).shuffle(
+ buffer_size,
+ seed=seed,
+ reshuffle_each_iteration=reshuffle_each_iteration).repeat(num_repeats)
+
+ def testShuffleCore(self):
+
+ seed = 55
+ range_limit = 5
+ num_repeats = 2
+ num_outputs = range_limit * num_repeats
+ buffer_sizes = [1, 3, 5, 8, 10]
+ # pylint: disable=cell-var-from-loop
+ # pylint: disable=g-long-lambda
+ for reshuffle_each_iteration in [True, False]:
+ for buffer_size in buffer_sizes:
+ self.run_core_tests(
+ lambda: self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=seed,
+ reshuffle_each_iteration=reshuffle_each_iteration),
+ lambda: self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=10,
+ reshuffle_each_iteration=reshuffle_each_iteration),
+ num_outputs)
+ # pylint: enable=cell-var-from-loop
+ # pylint: enable=g-long-lambda
+
+ def testNonDeterministicSeeding(self):
+
+ range_limit = 5
+ num_repeats = 2
+ num_outputs = range_limit * num_repeats
+ buffer_sizes = [1, 3, 5, 8, 10]
+ for reshuffle_each_iteration in [True, False]:
+ for buffer_size in buffer_sizes:
+
+ def ds_fn():
+ # pylint: disable=cell-var-from-loop
+ return self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=None, # Iterator seeds are generated non-deterministically.
+ reshuffle_each_iteration=reshuffle_each_iteration)
+ # pylint: enable=cell-var-from-loop
+
+ # We checkpoint the initial state of the Dataset so that we can restore
+ # the seeds in the next run. Since the seeding is non-deterministic
+ # the dataset gets initialized with different seeds each time.
+ expected = self.gen_outputs(
+ ds_fn,
+ break_points=[0],
+ num_outputs=num_outputs,
+ ckpt_saved=False,
+ verify_exhausted=False,
+ save_checkpoint_at_end=False)
+ actual = self.gen_outputs(
+ ds_fn,
+ break_points=self.gen_break_points(num_outputs),
+ num_outputs=num_outputs,
+ ckpt_saved=True,
+ verify_exhausted=False)
+ self.match(expected, actual)
+
+ def testMultipleIterators(self):
+ range_limit = 5
+ num_repeats = 2
+ num_outputs = range_limit * num_repeats
+ buffer_sizes = [1, 3, 5, 8, 10]
+
+ for reshuffle_each_iteration in [True, False]:
+ for buffer_size in buffer_sizes:
+
+ def ds_fn():
+ # pylint: disable=cell-var-from-loop
+ return self._build_shuffle_dataset(
+ range_limit=range_limit,
+ num_repeats=num_repeats,
+ buffer_size=buffer_size,
+ seed=None, # Iterator seeds are generated non-deterministically.
+ reshuffle_each_iteration=reshuffle_each_iteration)
+ # pylint: enable=cell-var-from-loop
+
+ with ops.Graph().as_default() as g:
+ ds = ds_fn()
+ iterators = [ds.make_one_shot_iterator(), ds.make_one_shot_iterator()]
+ get_next_ops = [it.get_next() for it in iterators]
+ saveables = [
+ contrib_iterator_ops.make_saveable_from_iterator(it)
+ for it in iterators
+ ]
+ for saveable in saveables:
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ saver = saver_lib.Saver(allow_empty=True)
+ with self.session(graph=g) as sess:
+ self._save(sess, saver)
+ expected = [sess.run(get_next_ops) for _ in range(num_outputs)]
+ self._restore(saver, sess)
+ actual = [sess.run(get_next_ops) for _ in range(num_outputs)]
+ self.match(expected, actual)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
new file mode 100644
index 0000000000..006279bbe1
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
@@ -0,0 +1,53 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the SqlDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class SqlDatasetSerializationTest(
+ sql_dataset_test_base.SqlDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, num_repeats):
+ data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
+ driver_name = array_ops.placeholder_with_default(
+ array_ops.constant("sqlite", dtypes.string), shape=[])
+ query = ("SELECT first_name, last_name, motto FROM students ORDER BY "
+ "first_name DESC")
+ output_types = (dtypes.string, dtypes.string, dtypes.string)
+ return readers.SqlDataset(driver_name, data_source_name, query,
+ output_types).repeat(num_repeats)
+
+ def testSQLSaveable(self):
+ num_repeats = 4
+ num_outputs = num_repeats * 2
+ self.run_core_tests(lambda: self._build_dataset(num_repeats),
+ lambda: self._build_dataset(num_repeats // 2),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py
new file mode 100644
index 0000000000..ef7061b190
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py
@@ -0,0 +1,106 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the StatsDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import stats_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+# TODO(shivaniagrawal): Can not checkpoint input_pipeline with the
+# transformation `stats_ops.set_stats_aggregator`, since we don't support
+# serializing StatsAggregator yet.
+class StatsDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset_bytes_stats(self, num_elements):
+ return dataset_ops.Dataset.range(num_elements).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
+ stats_ops.bytes_produced_stats("bytes_produced"))
+
+ def test_bytes_produced_stats_invalid_tag_shape(self):
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: dataset_ops.Dataset.range(100).apply(
+ stats_ops.bytes_produced_stats(["bytes_produced"])),
+ None, 100)
+ # pylint: enable=g-long-lambda
+
+ def testBytesStatsDatasetSaveableCore(self):
+ num_outputs = 100
+ self.run_core_tests(
+ lambda: self._build_dataset_bytes_stats(num_outputs),
+ lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs)
+
+ def _build_dataset_latency_stats(self, num_elements, tag="record_latency"):
+ return dataset_ops.Dataset.range(num_elements).apply(
+ stats_ops.latency_stats(tag))
+
+ def _build_dataset_multiple_tags(self,
+ num_elements,
+ tag1="record_latency",
+ tag2="record_latency_2"):
+ return dataset_ops.Dataset.range(num_elements).apply(
+ stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2))
+
+ def test_latency_stats_invalid_tag_shape(self):
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1"):
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats(["record_latency", "record_latency_2"])),
+ None, 100)
+ # pylint: enable=g-long-lambda
+
+ def testLatencyStatsDatasetSaveableCore(self):
+ num_outputs = 100
+
+ self.run_core_tests(
+ lambda: self._build_dataset_latency_stats(num_outputs),
+ lambda: self._build_dataset_latency_stats(num_outputs // 10),
+ num_outputs)
+
+ self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs),
+ None, num_outputs)
+
+ tag1 = "record_latency"
+ tag2 = "record_latency"
+ self.run_core_tests(
+ lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
+ None, num_outputs)
+
+ def _build_dataset_stats_aggregator(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ return dataset_ops.Dataset.range(10).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+
+ def test_set_stats_aggregator_not_support_checkpointing(self):
+ with self.assertRaisesRegexp(errors.UnimplementedError,
+ "does not support checkpointing"):
+ self.run_core_tests(self._build_dataset_stats_aggregator, None, 10)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py
new file mode 100644
index 0000000000..c87a7443a7
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py
@@ -0,0 +1,53 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the TextLineDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.platform import test
+
+
+class TextLineDatasetSerializationTest(
+ reader_dataset_ops_test_base.TextLineDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_iterator_graph(self, test_filenames, compression_type=None):
+ return core_readers.TextLineDataset(
+ test_filenames, compression_type=compression_type, buffer_size=10)
+
+ def testTextLineCore(self):
+ compression_types = [None, "GZIP", "ZLIB"]
+ num_files = 5
+ lines_per_file = 5
+ num_outputs = num_files * lines_per_file
+ for compression_type in compression_types:
+ test_filenames = self._createFiles(
+ num_files,
+ lines_per_file,
+ crlf=True,
+ compression_type=compression_type)
+ # pylint: disable=cell-var-from-loop
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(test_filenames, compression_type),
+ lambda: self._build_iterator_graph(test_filenames), num_outputs)
+ # pylint: enable=cell-var-from-loop
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py
new file mode 100644
index 0000000000..f0dcc131d4
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py
@@ -0,0 +1,99 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the TFRecordDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import gzip
+import os
+import zlib
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.platform import test
+
+
+class TFRecordDatasetSerializationTest(
+ reader_dataset_ops_test_base.TFRecordDatasetTestBase,
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_iterator_graph(self,
+ num_epochs,
+ batch_size=1,
+ compression_type=None,
+ buffer_size=None):
+ filenames = self._createFiles()
+ if compression_type == "ZLIB":
+ zlib_files = []
+ for i, fn in enumerate(filenames):
+ with open(fn, "rb") as f:
+ cdata = zlib.compress(f.read())
+ zfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.z" % i)
+ with open(zfn, "wb") as f:
+ f.write(cdata)
+ zlib_files.append(zfn)
+ filenames = zlib_files
+
+ elif compression_type == "GZIP":
+ gzip_files = []
+ for i, fn in enumerate(self.test_filenames):
+ with open(fn, "rb") as f:
+ gzfn = os.path.join(self.get_temp_dir(), "tfrecord_%s.gz" % i)
+ with gzip.GzipFile(gzfn, "wb") as gzf:
+ gzf.write(f.read())
+ gzip_files.append(gzfn)
+ filenames = gzip_files
+
+ return core_readers.TFRecordDataset(
+ filenames, compression_type,
+ buffer_size=buffer_size).repeat(num_epochs).batch(batch_size)
+
+ def testTFRecordWithoutBufferCore(self):
+ num_epochs = 5
+ batch_size = num_epochs
+ num_outputs = num_epochs * self._num_files * self._num_records // batch_size
+ # pylint: disable=g-long-lambda
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(num_epochs, batch_size,
+ buffer_size=0),
+ lambda: self._build_iterator_graph(num_epochs * 2, batch_size),
+ num_outputs)
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(num_epochs, buffer_size=0), None,
+ num_outputs * batch_size)
+ # pylint: enable=g-long-lambda
+
+ def testTFRecordWithBufferCore(self):
+ num_epochs = 5
+ num_outputs = num_epochs * self._num_files * self._num_records
+ self.run_core_tests(lambda: self._build_iterator_graph(num_epochs),
+ lambda: self._build_iterator_graph(num_epochs * 2),
+ num_outputs)
+
+ def testTFRecordWithCompressionCore(self):
+ num_epochs = 5
+ num_outputs = num_epochs * self._num_files * self._num_records
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(num_epochs, compression_type="ZLIB"),
+ lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
+ self.run_core_tests(
+ lambda: self._build_iterator_graph(num_epochs, compression_type="GZIP"),
+ lambda: self._build_iterator_graph(num_epochs * 2), num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py
new file mode 100644
index 0000000000..528598dfe4
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py
@@ -0,0 +1,51 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the UnbatchDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class UnbatchDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def build_dataset(self, multiplier=15.0, tensor_slice_len=2, batch_size=2):
+ components = (
+ np.arange(tensor_slice_len),
+ np.array([[1, 2, 3]]) * np.arange(tensor_slice_len)[:, np.newaxis],
+ np.array(multiplier) * np.arange(tensor_slice_len))
+
+ return dataset_ops.Dataset.from_tensor_slices(components).batch(
+ batch_size).apply(batching.unbatch())
+
+ def testCore(self):
+ tensor_slice_len = 8
+ batch_size = 2
+ num_outputs = tensor_slice_len
+ self.run_core_tests(
+ lambda: self.build_dataset(15.0, tensor_slice_len, batch_size),
+ lambda: self.build_dataset(20.0, tensor_slice_len, batch_size),
+ num_outputs)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py
new file mode 100644
index 0000000000..e2862af4d6
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py
@@ -0,0 +1,40 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the UniqueDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import unique
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class UniqueDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def testUnique(self):
+
+ def build_dataset(num_elements, unique_elem_range):
+ return dataset_ops.Dataset.range(num_elements).map(
+ lambda x: x % unique_elem_range).apply(unique.unique())
+
+ self.run_core_tests(lambda: build_dataset(200, 100),
+ lambda: build_dataset(40, 100), 100)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py
new file mode 100644
index 0000000000..4ea6131c22
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py
@@ -0,0 +1,54 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the ZipDataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.platform import test
+
+
+class ZipDatasetSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_dataset(self, arr):
+ components = [
+ np.tile(np.array([[1], [2], [3], [4]]), 20),
+ np.tile(np.array([[12], [13], [14], [15]]), 22),
+ np.array(arr)
+ ]
+ datasets = [
+ dataset_ops.Dataset.from_tensor_slices(component)
+ for component in components
+ ]
+ return dataset_ops.Dataset.zip((datasets[0], (datasets[1], datasets[2])))
+
+ def testCore(self):
+ # Equal length components
+ arr = [37.0, 38.0, 39.0, 40.0]
+ num_outputs = len(arr)
+ self.run_core_tests(lambda: self._build_dataset(arr), None, num_outputs)
+ # Variable length components
+ diff_size_arr = [1.0, 2.0]
+ self.run_core_tests(lambda: self._build_dataset(diff_size_arr),
+ lambda: self._build_dataset(arr), 2)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py b/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
new file mode 100644
index 0000000000..c208963a86
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/shuffle_and_repeat_test.py
@@ -0,0 +1,115 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.shuffle_and_repeat()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+
+
+class ShuffleAndRepeatTest(test_base.DatasetTestBase):
+
+ def _build_ds(self, seed, count=5, num_elements=20):
+ return dataset_ops.Dataset.range(num_elements).apply(
+ shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed))
+
+ def _gen_outputs(self, ds_fn, num_outputs, verify_exhausted=True):
+ get_next = ds_fn().make_one_shot_iterator().get_next()
+ outputs = []
+ with self.cached_session() as sess:
+ for _ in range(num_outputs):
+ outputs.append(sess.run(get_next))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ return outputs
+
+ def testCorrectOutput(self):
+ output = self._gen_outputs(lambda: self._build_ds(10), 100)
+ self.assertSequenceEqual(
+ sorted(output), sorted(
+ np.array([range(20) for _ in range(5)]).flatten()))
+ for i in range(5):
+ self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20))
+
+ def testReshuffling(self):
+ # Check that the output orders of different epochs are indeed different.
+ output = self._gen_outputs(lambda: self._build_ds(10), 100)
+ for i in range(4):
+ epoch1 = output[i * 20:(i + 1) * 20]
+ epoch2 = output[(i + 1) * 20:(i + 2) * 20]
+ self.assertNotEqual(epoch1, epoch2)
+
+ def testSameOrderForSameSeeds(self):
+ output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
+ output2 = self._gen_outputs(lambda: self._build_ds(10), 100)
+ self.assertEqual(output1, output2)
+
+ def testDifferentOrderForDifferentSeeds(self):
+ output1 = self._gen_outputs(lambda: self._build_ds(10), 100)
+ output2 = self._gen_outputs(lambda: self._build_ds(20), 100)
+ self.assertNotEqual(output1, output2)
+ self.assertEqual(sorted(output1), sorted(output2))
+
+ def testCountNone(self):
+ output1 = self._gen_outputs(
+ lambda: self._build_ds(10, count=None), 100, verify_exhausted=False)
+ output2 = self._gen_outputs(
+ lambda: self._build_ds(20, count=None), 100, verify_exhausted=False)
+ self.assertNotEqual(output1, output2)
+ self.assertEqual(sorted(output1), sorted(output2))
+
+ def testCountMinusOne(self):
+ output1 = self._gen_outputs(
+ lambda: self._build_ds(10, count=-1), 100, verify_exhausted=False)
+ output2 = self._gen_outputs(
+ lambda: self._build_ds(20, count=-1), 100, verify_exhausted=False)
+ self.assertNotEqual(output1, output2)
+ self.assertEqual(sorted(output1), sorted(output2))
+
+ def testInfiniteOutputs(self):
+ # Asserting the iterator is exhausted after producing 100 items should fail.
+ with self.assertRaises(AssertionError):
+ self._gen_outputs(lambda: self._build_ds(10, count=None), 100)
+ with self.assertRaises(AssertionError):
+ self._gen_outputs(lambda: self._build_ds(10, count=-1), 100)
+
+ def testInfiniteEmpty(self):
+ with self.assertRaises(errors.OutOfRangeError):
+ self._gen_outputs(lambda: self._build_ds(10, count=None, num_elements=0),
+ 100)
+ with self.assertRaises(errors.OutOfRangeError):
+ self._gen_outputs(lambda: self._build_ds(10, count=-1, num_elements=0),
+ 100)
+
+ def testLargeBufferSize(self):
+ with ops.Graph().as_default() as g:
+ ds = dataset_ops.Dataset.range(20).apply(
+ shuffle_ops.shuffle_and_repeat(buffer_size=21))
+ get_next_op = ds.make_one_shot_iterator().get_next()
+ with self.session(graph=g) as sess:
+ sess.run(get_next_op)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py
new file mode 100644
index 0000000000..a2c1169638
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test.py
@@ -0,0 +1,590 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.SqlDataset`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.kernel_tests import sql_dataset_test_base
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class SqlDatasetTest(sql_dataset_test_base.SqlDatasetTestBase):
+
+ # Test that SqlDataset can read from a database table.
+ def testReadResultSet(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string), 2)
+ with self.cached_session() as sess:
+ for _ in range(2): # Run twice to verify statelessness of db operations.
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, last_name, motto FROM students "
+ "ORDER BY first_name DESC"
+ })
+ for _ in range(2): # Dataset is repeated. See setUp.
+ self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
+ self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that SqlDataset works on a join query.
+ def testReadResultSetJoinQuery(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT students.first_name, state, motto FROM students "
+ "INNER JOIN people "
+ "ON students.first_name = people.first_name "
+ "AND students.last_name = people.last_name"
+ })
+ self.assertEqual((b"John", b"California", b"Hi!"), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that SqlDataset can read a database entry with a null-terminator
+ # in the middle of the text and place the entry in a `string` tensor.
+ def testReadResultSetNullTerminator(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, last_name, favorite_nonsense_word "
+ "FROM students ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", b"Doe", b"n\0nsense"), sess.run(get_next))
+ self.assertEqual((b"Jane", b"Moe", b"nonsense\0"), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that SqlDataset works when used on two different queries.
+ # Because the output types of the dataset must be determined at graph-creation
+ # time, the two queries must have the same number and types of columns.
+ def testReadResultSetReuseSqlDataset(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, last_name, motto FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", b"Doe", b"Hi!"), sess.run(get_next))
+ self.assertEqual((b"Jane", b"Moe", b"Hi again!"), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, last_name, state FROM people "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", b"Doe", b"California"), sess.run(get_next))
+ self.assertEqual((b"Benjamin", b"Franklin", b"Pennsylvania"),
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that an `OutOfRangeError` is raised on the first call to
+ # `get_next_str_only` if result set is empty.
+ def testReadEmptyResultSet(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, last_name, motto FROM students "
+ "WHERE first_name = 'Nonexistent'"
+ })
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that an error is raised when `driver_name` is invalid.
+ def testReadResultSetWithInvalidDriverName(self):
+ init_op = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))[0]
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(
+ init_op,
+ feed_dict={
+ self.driver_name: "sqlfake",
+ self.query: "SELECT first_name, last_name, motto FROM students "
+ "ORDER BY first_name DESC"
+ })
+
+ # Test that an error is raised when a column name in `query` is nonexistent
+ def testReadResultSetWithInvalidColumnName(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, last_name, fake_column FROM students "
+ "ORDER BY first_name DESC"
+ })
+ with self.assertRaises(errors.UnknownError):
+ sess.run(get_next)
+
+ # Test that an error is raised when there is a syntax error in `query`.
+ def testReadResultSetOfQueryWithSyntaxError(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELEmispellECT first_name, last_name, motto FROM students "
+ "ORDER BY first_name DESC"
+ })
+ with self.assertRaises(errors.UnknownError):
+ sess.run(get_next)
+
+ # Test that an error is raised when the number of columns in `query`
+ # does not match the length of `output_types`.
+ def testReadResultSetWithMismatchBetweenColumnsAndOutputTypes(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, last_name FROM students "
+ "ORDER BY first_name DESC"
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ # Test that no results are returned when `query` is an insert query rather
+ # than a select query. In particular, the error refers to the number of
+ # output types passed to the op not matching the number of columns in the
+ # result set of the query (namely, 0 for an insert statement.)
+ def testReadResultSetOfInsertQuery(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.string))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "INSERT INTO students (first_name, last_name, motto) "
+ "VALUES ('Foo', 'Bar', 'Baz'), ('Fizz', 'Buzz', 'Fizzbuzz')"
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer from a SQLite database table and
+ # place it in an `int8` tensor.
+ def testReadResultSetInt8(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, desk_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 9), sess.run(get_next))
+ self.assertEqual((b"Jane", 127), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a negative or 0-valued integer from a
+ # SQLite database table and place it in an `int8` tensor.
+ def testReadResultSetInt8NegativeAndZero(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int8,
+ dtypes.int8))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, income, favorite_negative_number "
+ "FROM students "
+ "WHERE first_name = 'John' ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 0, -2), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a large (positive or negative) integer from
+ # a SQLite database table and place it in an `int8` tensor.
+ def testReadResultSetInt8MaxValues(self):
+ init_op, get_next = self._createSqlDataset((dtypes.int8, dtypes.int8))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT desk_number, favorite_negative_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((9, -2), sess.run(get_next))
+ # Max and min values of int8
+ self.assertEqual((127, -128), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer from a SQLite database table and
+ # place it in an `int16` tensor.
+ def testReadResultSetInt16(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, desk_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 9), sess.run(get_next))
+ self.assertEqual((b"Jane", 127), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a negative or 0-valued integer from a
+ # SQLite database table and place it in an `int16` tensor.
+ def testReadResultSetInt16NegativeAndZero(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16,
+ dtypes.int16))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, income, favorite_negative_number "
+ "FROM students "
+ "WHERE first_name = 'John' ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 0, -2), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a large (positive or negative) integer from
+ # a SQLite database table and place it in an `int16` tensor.
+ def testReadResultSetInt16MaxValues(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int16))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, favorite_medium_sized_number "
+ "FROM students ORDER BY first_name DESC"
+ })
+ # Max value of int16
+ self.assertEqual((b"John", 32767), sess.run(get_next))
+ # Min value of int16
+ self.assertEqual((b"Jane", -32768), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer from a SQLite database table and
+ # place it in an `int32` tensor.
+ def testReadResultSetInt32(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, desk_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 9), sess.run(get_next))
+ self.assertEqual((b"Jane", 127), sess.run(get_next))
+
+ # Test that `SqlDataset` can read a negative or 0-valued integer from a
+ # SQLite database table and place it in an `int32` tensor.
+ def testReadResultSetInt32NegativeAndZero(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, income FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 0), sess.run(get_next))
+ self.assertEqual((b"Jane", -20000), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a large (positive or negative) integer from
+ # a SQLite database table and place it in an `int32` tensor.
+ def testReadResultSetInt32MaxValues(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, favorite_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ # Max value of int32
+ self.assertEqual((b"John", 2147483647), sess.run(get_next))
+ # Min value of int32
+ self.assertEqual((b"Jane", -2147483648), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a numeric `varchar` from a SQLite database
+ # table and place it in an `int32` tensor.
+ def testReadResultSetInt32VarCharColumnAsInt(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int32))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, school_id FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 123), sess.run(get_next))
+ self.assertEqual((b"Jane", 1000), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer from a SQLite database table
+ # and place it in an `int64` tensor.
+ def testReadResultSetInt64(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, desk_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 9), sess.run(get_next))
+ self.assertEqual((b"Jane", 127), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a negative or 0-valued integer from a
+ # SQLite database table and place it in an `int64` tensor.
+ def testReadResultSetInt64NegativeAndZero(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, income FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 0), sess.run(get_next))
+ self.assertEqual((b"Jane", -20000), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a large (positive or negative) integer from
+ # a SQLite database table and place it in an `int64` tensor.
+ def testReadResultSetInt64MaxValues(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.int64))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, favorite_big_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ # Max value of int64
+ self.assertEqual((b"John", 9223372036854775807), sess.run(get_next))
+ # Min value of int64
+ self.assertEqual((b"Jane", -9223372036854775808), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer from a SQLite database table and
+ # place it in a `uint8` tensor.
+ def testReadResultSetUInt8(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, desk_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 9), sess.run(get_next))
+ self.assertEqual((b"Jane", 127), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read the minimum and maximum uint8 values from a
+ # SQLite database table and place them in `uint8` tensors.
+ def testReadResultSetUInt8MinAndMaxValues(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint8))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, brownie_points FROM students "
+ "ORDER BY first_name DESC"
+ })
+ # Min value of uint8
+ self.assertEqual((b"John", 0), sess.run(get_next))
+ # Max value of uint8
+ self.assertEqual((b"Jane", 255), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer from a SQLite database table
+ # and place it in a `uint16` tensor.
+ def testReadResultSetUInt16(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, desk_number FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", 9), sess.run(get_next))
+ self.assertEqual((b"Jane", 127), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read the minimum and maximum uint16 values from a
+ # SQLite database table and place them in `uint16` tensors.
+ def testReadResultSetUInt16MinAndMaxValues(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.uint16))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, account_balance FROM students "
+ "ORDER BY first_name DESC"
+ })
+ # Min value of uint16
+ self.assertEqual((b"John", 0), sess.run(get_next))
+ # Max value of uint16
+ self.assertEqual((b"Jane", 65535), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a 0-valued and 1-valued integer from a
+ # SQLite database table and place them as `True` and `False` respectively
+ # in `bool` tensors.
+ def testReadResultSetBool(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, registration_complete FROM students "
+ "ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", True), sess.run(get_next))
+ self.assertEqual((b"Jane", False), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read an integer that is not 0-valued or 1-valued
+ # from a SQLite database table and place it as `True` in a `bool` tensor.
+ def testReadResultSetBoolNotZeroOrOne(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.bool))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query: "SELECT first_name, favorite_medium_sized_number "
+ "FROM students ORDER BY first_name DESC"
+ })
+ self.assertEqual((b"John", True), sess.run(get_next))
+ self.assertEqual((b"Jane", True), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a float from a SQLite database table
+ # and place it in a `float64` tensor.
+ def testReadResultSetFloat64(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.float64))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, last_name, victories FROM townspeople "
+ "ORDER BY first_name"
+ })
+ self.assertEqual((b"George", b"Washington", 20.0), sess.run(get_next))
+ self.assertEqual((b"John", b"Adams", -19.95), sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a float from a SQLite database table beyond
+ # the precision of 64-bit IEEE, without throwing an error. Test that
+ # `SqlDataset` identifies such a value as equal to itself.
+ def testReadResultSetFloat64OverlyPrecise(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.float64))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, last_name, accolades FROM townspeople "
+ "ORDER BY first_name"
+ })
+ self.assertEqual(
+ (b"George", b"Washington",
+ 1331241.321342132321324589798264627463827647382647382643874),
+ sess.run(get_next))
+ self.assertEqual(
+ (b"John", b"Adams",
+ 1331241321342132321324589798264627463827647382647382643874.0),
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ # Test that `SqlDataset` can read a float from a SQLite database table,
+ # representing the largest integer representable as a 64-bit IEEE float
+ # such that the previous integer is also representable as a 64-bit IEEE float.
+ # Test that `SqlDataset` can distinguish these two numbers.
+ def testReadResultSetFloat64LargestConsecutiveWholeNumbersNotEqual(self):
+ init_op, get_next = self._createSqlDataset((dtypes.string, dtypes.string,
+ dtypes.float64))
+ with self.cached_session() as sess:
+ sess.run(
+ init_op,
+ feed_dict={
+ self.query:
+ "SELECT first_name, last_name, triumphs FROM townspeople "
+ "ORDER BY first_name"
+ })
+ self.assertNotEqual((b"George", b"Washington", 9007199254740992.0),
+ sess.run(get_next))
+ self.assertNotEqual((b"John", b"Adams", 9007199254740991.0),
+ sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test_base.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test_base.py
new file mode 100644
index 0000000000..6aaaa90c65
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_test_base.py
@@ -0,0 +1,94 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing `tf.data.experimental.SqlDataset`."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import sqlite3
+
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class SqlDatasetTestBase(test_base.DatasetTestBase):
+ """Base class for setting up and testing SqlDataset."""
+
+ def _createSqlDataset(self, output_types, num_repeats=1):
+ dataset = readers.SqlDataset(self.driver_name, self.data_source_name,
+ self.query, output_types).repeat(num_repeats)
+ iterator = dataset.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ return init_op, get_next
+
+ def setUp(self):
+ self.data_source_name = os.path.join(test.get_temp_dir(), "tftest.sqlite")
+ self.driver_name = array_ops.placeholder_with_default(
+ array_ops.constant("sqlite", dtypes.string), shape=[])
+ self.query = array_ops.placeholder(dtypes.string, shape=[])
+
+ conn = sqlite3.connect(self.data_source_name)
+ c = conn.cursor()
+ c.execute("DROP TABLE IF EXISTS students")
+ c.execute("DROP TABLE IF EXISTS people")
+ c.execute("DROP TABLE IF EXISTS townspeople")
+ c.execute(
+ "CREATE TABLE IF NOT EXISTS students (id INTEGER NOT NULL PRIMARY KEY, "
+ "first_name VARCHAR(100), last_name VARCHAR(100), motto VARCHAR(100), "
+ "school_id VARCHAR(100), favorite_nonsense_word VARCHAR(100), "
+ "desk_number INTEGER, income INTEGER, favorite_number INTEGER, "
+ "favorite_big_number INTEGER, favorite_negative_number INTEGER, "
+ "favorite_medium_sized_number INTEGER, brownie_points INTEGER, "
+ "account_balance INTEGER, registration_complete INTEGER)")
+ c.executemany(
+ "INSERT INTO students (first_name, last_name, motto, school_id, "
+ "favorite_nonsense_word, desk_number, income, favorite_number, "
+ "favorite_big_number, favorite_negative_number, "
+ "favorite_medium_sized_number, brownie_points, account_balance, "
+ "registration_complete) "
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ [("John", "Doe", "Hi!", "123", "n\0nsense", 9, 0, 2147483647,
+ 9223372036854775807, -2, 32767, 0, 0, 1),
+ ("Jane", "Moe", "Hi again!", "1000", "nonsense\0", 127, -20000,
+ -2147483648, -9223372036854775808, -128, -32768, 255, 65535, 0)])
+ c.execute(
+ "CREATE TABLE IF NOT EXISTS people (id INTEGER NOT NULL PRIMARY KEY, "
+ "first_name VARCHAR(100), last_name VARCHAR(100), state VARCHAR(100))")
+ c.executemany(
+ "INSERT INTO PEOPLE (first_name, last_name, state) VALUES (?, ?, ?)",
+ [("Benjamin", "Franklin", "Pennsylvania"), ("John", "Doe",
+ "California")])
+ c.execute(
+ "CREATE TABLE IF NOT EXISTS townspeople (id INTEGER NOT NULL PRIMARY "
+ "KEY, first_name VARCHAR(100), last_name VARCHAR(100), victories "
+ "FLOAT, accolades FLOAT, triumphs FLOAT)")
+ c.executemany(
+ "INSERT INTO townspeople (first_name, last_name, victories, "
+ "accolades, triumphs) VALUES (?, ?, ?, ?, ?)",
+ [("George", "Washington", 20.00,
+ 1331241.321342132321324589798264627463827647382647382643874,
+ 9007199254740991.0),
+ ("John", "Adams", -19.95,
+ 1331241321342132321324589798264627463827647382647382643874.0,
+ 9007199254740992.0)])
+ conn.commit()
+ conn.close()
diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
new file mode 100644
index 0000000000..427654cd76
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
@@ -0,0 +1,322 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline statistics gathering ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
+from tensorflow.python.data.experimental.ops import stats_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class StatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
+
+ def testBytesProduced(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).apply(
+ stats_ops.bytes_produced_stats("bytes_produced")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ expected_sum = 0.0
+ for i in range(100):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "bytes_produced", float(i + 1))
+ expected_sum += i * 8.0
+ self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "bytes_produced", 100.0)
+ self._assertSummaryHasSum(summary_str, "bytes_produced", expected_sum)
+
+ def testLatencyStats(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertEqual(i, sess.run(next_element))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency", float(i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
+
+ def testPrefetchBufferUtilization(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+ -1).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+ float(i + 1))
+ self._assertSummaryContains(summary_str, "Prefetch::buffer_capacity")
+ self._assertSummaryContains(summary_str, "Prefetch::buffer_size")
+ self._assertSummaryHasRange(summary_str, "Prefetch::buffer_utilization",
+ 0, 1)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasCount(summary_str, "Prefetch::buffer_utilization",
+ 100)
+
+ def testPrefetchBufferScalars(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(10).map(
+ lambda x: array_ops.tile([x], ops.convert_to_tensor([x]))).prefetch(
+ 0).apply(stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(10):
+ self.assertAllEqual(
+ np.array([i] * i, dtype=np.int64), sess.run(next_element))
+ summary_str = sess.run(summary_t)
+ self._assertSummaryHasScalarValue(summary_str,
+ "Prefetch::buffer_capacity", 0)
+ self._assertSummaryHasScalarValue(summary_str, "Prefetch::buffer_size",
+ 0)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testFilteredElementsStats(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(101).filter(
+ lambda x: math_ops.equal(math_ops.mod(x, 3), 0)).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(34):
+ self.assertEqual(i * 3, sess.run(next_element))
+ if i is not 0:
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::dropped_elements", float(i * 2))
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::filtered_elements", float(i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::dropped_elements", 67.0)
+ self._assertSummaryHasScalarValue(
+ sess.run(summary_t), "Filter::filtered_elements", 34.0)
+
+ def testReinitialize(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ for j in range(5):
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertEqual(i, sess.run(next_element))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency", float((j * 100) + i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency", (j + 1) * 100.0)
+
+ def testNoAggregatorRegistered(self):
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency"))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testMultipleTags(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.latency_stats("record_latency_2")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertEqual(i, sess.run(next_element))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency", float(i + 1))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency_2", float(i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 100.0)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency_2", 100.0)
+
+ def testRepeatedTags(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for i in range(100):
+ self.assertEqual(i, sess.run(next_element))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency", float(2 * (i + 1)))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
+
+ def testMultipleIteratorsSameAggregator(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ iterator_0 = dataset.make_initializable_iterator()
+ iterator_1 = dataset.make_initializable_iterator()
+ next_element = iterator_0.get_next() + iterator_1.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.cached_session() as sess:
+ sess.run([iterator_0.initializer, iterator_1.initializer])
+ for i in range(100):
+ self.assertEqual(i * 2, sess.run(next_element))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_latency", float(2 * (i + 1)))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(sess.run(summary_t), "record_latency", 200.0)
+
+ def testMultipleDatasetWithTags(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator, "dataset1"))
+ dataset2 = dataset_ops.Dataset.range(100).apply(
+ stats_ops.latency_stats("record_latency")).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator, "dataset2"))
+ iterator_0 = dataset.make_initializable_iterator()
+ iterator_1 = dataset2.make_initializable_iterator()
+ next_element = iterator_0.get_next() + iterator_1.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run([iterator_0.initializer, iterator_1.initializer])
+ for i in range(100):
+ self.assertEqual(i * 2, sess.run(next_element))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "dataset1_record_latency", float(i + 1))
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "dataset2_record_latency", float(i + 1))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "dataset1_record_latency", 100.0)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "dataset2_record_latency", 100.0)
+
+
+class FeatureStatsDatasetTest(
+ stats_dataset_test_base.StatsDatasetTestBase,
+ reader_dataset_ops_test_base.MakeBatchedFeaturesDatasetTestBase):
+
+ def testFeaturesStats(self):
+ num_epochs = 5
+ total_records = num_epochs * self._num_records
+ batch_size = 2
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = self.make_batch_feature(
+ filenames=self.test_filenames[0],
+ num_epochs=num_epochs,
+ batch_size=batch_size,
+ shuffle=True,
+ shuffle_seed=5,
+ drop_final_batch=False).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator, "record_stats"))
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+ summary_t = stats_aggregator.get_summary()
+
+ with self.test_session() as sess:
+ sess.run(iterator.initializer)
+ for _ in range(total_records // batch_size + 1 if total_records %
+ batch_size else total_records // batch_size):
+ sess.run(next_element)
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_stats_features", total_records)
+ self._assertSummaryHasCount(
+ sess.run(summary_t), "record_stats_feature-values", total_records)
+ self._assertSummaryHasSum(
+ sess.run(summary_t), "record_stats_features", total_records * 4)
+ self._assertSummaryHasSum(
+ sess.run(summary_t), "record_stats_feature-values",
+ self._sum_keywords(1) * num_epochs + 3 * total_records)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py
new file mode 100644
index 0000000000..80f2625927
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py
@@ -0,0 +1,71 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing the input pipeline statistics gathering ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+
+from tensorflow.core.framework import summary_pb2
+from tensorflow.python.data.kernel_tests import test_base
+
+
+class StatsDatasetTestBase(test_base.DatasetTestBase):
+ """Base class for testing statistics gathered in `StatsAggregator`."""
+
+ def _assertSummaryContains(self, summary_str, tag):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+ def _assertSummaryHasCount(self, summary_str, tag, expected_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertEqual(expected_value, value.histo.num)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+ def _assertSummaryHasRange(self, summary_str, tag, min_value, max_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertLessEqual(min_value, value.histo.min)
+ self.assertGreaterEqual(max_value, value.histo.max)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+ def _assertSummaryHasSum(self, summary_str, tag, expected_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertEqual(expected_value, value.histo.sum)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
+
+ def _assertSummaryHasScalarValue(self, summary_str, tag, expected_value):
+ summary_proto = summary_pb2.Summary()
+ summary_proto.ParseFromString(summary_str)
+ for value in summary_proto.value:
+ if tag == value.tag:
+ self.assertEqual(expected_value, value.simple_value)
+ return
+ self.fail("Expected tag %r not found in summary %r" % (tag, summary_proto))
diff --git a/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py b/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py
new file mode 100644
index 0000000000..8fd0ad50c4
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/tf_record_writer_test.py
@@ -0,0 +1,118 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.TFRecordWriter`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.data.experimental.ops import writers
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.framework import dtypes
+from tensorflow.python.lib.io import python_io
+from tensorflow.python.lib.io import tf_record
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class TFRecordWriterTest(test_base.DatasetTestBase):
+
+ def setUp(self):
+ super(TFRecordWriterTest, self).setUp()
+ self._num_records = 7
+ self.filename = array_ops.placeholder(dtypes.string, shape=[])
+ self.compression_type = array_ops.placeholder_with_default("", shape=[])
+
+ input_dataset = readers.TFRecordDataset([self.filename],
+ self.compression_type)
+ self.writer = writers.TFRecordWriter(
+ self._outputFilename(), self.compression_type).write(input_dataset)
+
+ def _record(self, i):
+ return compat.as_bytes("Record %d" % (i))
+
+ def _createFile(self, options=None):
+ filename = self._inputFilename()
+ writer = python_io.TFRecordWriter(filename, options)
+ for i in range(self._num_records):
+ writer.write(self._record(i))
+ writer.close()
+ return filename
+
+ def _inputFilename(self):
+ return os.path.join(self.get_temp_dir(), "tf_record.in.txt")
+
+ def _outputFilename(self):
+ return os.path.join(self.get_temp_dir(), "tf_record.out.txt")
+
+ def testWrite(self):
+ with self.cached_session() as sess:
+ sess.run(
+ self.writer, feed_dict={
+ self.filename: self._createFile(),
+ })
+ for i, r in enumerate(tf_record.tf_record_iterator(self._outputFilename())):
+ self.assertAllEqual(self._record(i), r)
+
+ def testWriteZLIB(self):
+ options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.ZLIB)
+ with self.cached_session() as sess:
+ sess.run(
+ self.writer,
+ feed_dict={
+ self.filename: self._createFile(options),
+ self.compression_type: "ZLIB",
+ })
+ for i, r in enumerate(
+ tf_record.tf_record_iterator(self._outputFilename(), options=options)):
+ self.assertAllEqual(self._record(i), r)
+
+ def testWriteGZIP(self):
+ options = tf_record.TFRecordOptions(tf_record.TFRecordCompressionType.GZIP)
+ with self.cached_session() as sess:
+ sess.run(
+ self.writer,
+ feed_dict={
+ self.filename: self._createFile(options),
+ self.compression_type: "GZIP",
+ })
+ for i, r in enumerate(
+ tf_record.tf_record_iterator(self._outputFilename(), options=options)):
+ self.assertAllEqual(self._record(i), r)
+
+ def testFailDataset(self):
+ with self.assertRaises(TypeError):
+ writers.TFRecordWriter(self._outputFilename(),
+ self.compression_type).write("whoops")
+
+ def testFailDType(self):
+ input_dataset = dataset_ops.Dataset.from_tensors(10)
+ with self.assertRaises(TypeError):
+ writers.TFRecordWriter(self._outputFilename(),
+ self.compression_type).write(input_dataset)
+
+ def testFailShape(self):
+ input_dataset = dataset_ops.Dataset.from_tensors([["hello"], ["world"]])
+ with self.assertRaises(TypeError):
+ writers.TFRecordWriter(self._outputFilename(),
+ self.compression_type).write(input_dataset)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py b/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py
new file mode 100644
index 0000000000..0278a208cb
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/unbatch_test.py
@@ -0,0 +1,300 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.unbatch()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class UnbatchTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ def testUnbatchWithUnknownRankInput(self):
+ placeholder = array_ops.placeholder(dtypes.int32)
+ dataset = dataset_ops.Dataset.from_tensors(placeholder).apply(
+ batching.unbatch())
+ iterator = dataset.make_initializable_iterator()
+ next_elem = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={placeholder: [0, 1, 2, 3]})
+ for i in range(4):
+ self.assertEqual(i, sess.run(next_elem))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_elem)
+
+ def testUnbatchScalarDataset(self):
+ data = tuple([math_ops.range(10) for _ in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ expected_types = (dtypes.int32,) * 3
+ data = data.batch(2)
+ self.assertEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual((i,) * 3, sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchDatasetWithStrings(self):
+ data = tuple([math_ops.range(10) for _ in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ data = data.map(lambda x, y, z: (x, string_ops.as_string(y), z))
+ expected_types = (dtypes.int32, dtypes.string, dtypes.int32)
+ data = data.batch(2)
+ self.assertEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual((i, compat.as_bytes(str(i)), i), sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchDatasetWithSparseTensor(self):
+ st = sparse_tensor.SparseTensorValue(
+ indices=[[i, i] for i in range(10)],
+ values=list(range(10)),
+ dense_shape=[10, 10])
+ data = dataset_ops.Dataset.from_tensors(st)
+ data = data.apply(batching.unbatch())
+ data = data.batch(5)
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ st_row = sess.run(next_element)
+ self.assertEqual([i], st_row.indices)
+ self.assertEqual([i], st_row.values)
+ self.assertEqual([10], st_row.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchDatasetWithDenseAndSparseTensor(self):
+ st = sparse_tensor.SparseTensorValue(
+ indices=[[i, i] for i in range(10)],
+ values=list(range(10)),
+ dense_shape=[10, 10])
+ data = dataset_ops.Dataset.from_tensors((list(range(10)), st))
+ data = data.apply(batching.unbatch())
+ data = data.batch(5)
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ dense_elem, st_row = sess.run(next_element)
+ self.assertEqual(i, dense_elem)
+ self.assertEqual([i], st_row.indices)
+ self.assertEqual([i], st_row.values)
+ self.assertEqual([10], st_row.dense_shape)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchSingleElementTupleDataset(self):
+ data = tuple([(math_ops.range(10),) for _ in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ expected_types = ((dtypes.int32,),) * 3
+ data = data.batch(2)
+ self.assertEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(((i,),) * 3, sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchMultiElementTupleDataset(self):
+ data = tuple([(math_ops.range(10 * i, 10 * i + 10),
+ array_ops.fill([10], "hi")) for i in range(3)])
+ data = dataset_ops.Dataset.from_tensor_slices(data)
+ expected_types = ((dtypes.int32, dtypes.string),) * 3
+ data = data.batch(2)
+ self.assertAllEqual(expected_types, data.output_types)
+ data = data.apply(batching.unbatch())
+ self.assertAllEqual(expected_types, data.output_types)
+
+ iterator = data.make_one_shot_iterator()
+ op = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(((i, b"hi"), (10 + i, b"hi"), (20 + i, b"hi")),
+ sess.run(op))
+
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(op)
+
+ def testUnbatchEmpty(self):
+ data = dataset_ops.Dataset.from_tensors(
+ (constant_op.constant([]), constant_op.constant([], shape=[0, 4]),
+ constant_op.constant([], shape=[0, 4, 0])))
+ data = data.apply(batching.unbatch())
+ iterator = data.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testUnbatchStaticShapeMismatch(self):
+ data = dataset_ops.Dataset.from_tensors((np.arange(7), np.arange(8),
+ np.arange(9)))
+ with self.assertRaises(ValueError):
+ data.apply(batching.unbatch())
+
+ def testUnbatchDynamicShapeMismatch(self):
+ ph1 = array_ops.placeholder(dtypes.int32, shape=[None])
+ ph2 = array_ops.placeholder(dtypes.int32, shape=None)
+ data = dataset_ops.Dataset.from_tensors((ph1, ph2))
+ data = data.apply(batching.unbatch())
+ iterator = data.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ # Mismatch in the 0th dimension.
+ sess.run(
+ iterator.initializer,
+ feed_dict={
+ ph1: np.arange(7).astype(np.int32),
+ ph2: np.arange(8).astype(np.int32)
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(next_element)
+
+ # No 0th dimension (i.e. scalar value) for one component.
+ sess.run(
+ iterator.initializer,
+ feed_dict={
+ ph1: np.arange(7).astype(np.int32),
+ ph2: 7
+ })
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(next_element)
+
+
+class UnbatchBenchmark(test.Benchmark):
+
+ def benchmarkNativeUnbatch(self):
+ batch_sizes = [1, 2, 5, 10, 20, 50]
+ elems_per_trial = 10000
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+ batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset.batch(batch_size_placeholder)
+ dataset = dataset.apply(batching.unbatch())
+ dataset = dataset.skip(elems_per_trial)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for batch_size in batch_sizes:
+ deltas = []
+ for _ in range(5):
+ sess.run(
+ iterator.initializer,
+ feed_dict={batch_size_placeholder: batch_size})
+ start = time.time()
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append((end - start) / elems_per_trial)
+
+ median_wall_time = np.median(deltas)
+ print("Unbatch (native) batch size: %d Median wall time per element:"
+ " %f microseconds" % (batch_size, median_wall_time * 1e6))
+ self.report_benchmark(
+ iters=10000,
+ wall_time=median_wall_time,
+ name="benchmark_unbatch_dataset_native_batch_size_%d" %
+ batch_size)
+
+ # Include a benchmark of the previous `unbatch()` implementation that uses
+ # a composition of more primitive ops. Eventually we'd hope to generate code
+ # that is as good in both cases.
+ def benchmarkOldUnbatchImplementation(self):
+ batch_sizes = [1, 2, 5, 10, 20, 50]
+ elems_per_trial = 10000
+ with ops.Graph().as_default():
+ dataset = dataset_ops.Dataset.from_tensors("element").repeat(None)
+ batch_size_placeholder = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset.batch(batch_size_placeholder)
+ dataset = dataset.flat_map(dataset_ops.Dataset.from_tensor_slices)
+ dataset = dataset.skip(elems_per_trial)
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with session.Session() as sess:
+ for batch_size in batch_sizes:
+ deltas = []
+ for _ in range(5):
+ sess.run(
+ iterator.initializer,
+ feed_dict={batch_size_placeholder: batch_size})
+ start = time.time()
+ sess.run(next_element.op)
+ end = time.time()
+ deltas.append((end - start) / elems_per_trial)
+
+ median_wall_time = np.median(deltas)
+ print("Unbatch (unfused) batch size: %d Median wall time per element:"
+ " %f microseconds" % (batch_size, median_wall_time * 1e6))
+ self.report_benchmark(
+ iters=10000,
+ wall_time=median_wall_time,
+ name="benchmark_unbatch_dataset_unfused_batch_size_%d" %
+ batch_size)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/kernel_tests/unique_test.py b/tensorflow/python/data/experimental/kernel_tests/unique_test.py
new file mode 100644
index 0000000000..847cff26b0
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/unique_test.py
@@ -0,0 +1,83 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for `tf.data.experimental.unique()`."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+from tensorflow.python.util import compat
+
+
+class UniqueTest(test_base.DatasetTestBase):
+
+ def _testSimpleHelper(self, dtype, test_cases):
+ """Test the `unique()` transformation on a list of test cases.
+
+ Args:
+ dtype: The `dtype` of the elements in each test case.
+ test_cases: A list of pairs of lists. The first component is the test
+ input that will be passed to the transformation; the second component
+ is the expected sequence of outputs from the transformation.
+ """
+
+ # The `current_test_case` will be updated when we loop over `test_cases`
+ # below; declare it here so that the generator can capture it once.
+ current_test_case = []
+ dataset = dataset_ops.Dataset.from_generator(lambda: current_test_case,
+ dtype).apply(unique.unique())
+ iterator = dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for test_case, expected in test_cases:
+ current_test_case = test_case
+ sess.run(iterator.initializer)
+ for element in expected:
+ if dtype == dtypes.string:
+ element = compat.as_bytes(element)
+ self.assertAllEqual(element, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testSimpleInt(self):
+ for dtype in [dtypes.int32, dtypes.int64]:
+ self._testSimpleHelper(dtype, [
+ ([], []),
+ ([1], [1]),
+ ([1, 1, 1, 1, 1, 1, 1], [1]),
+ ([1, 2, 3, 4], [1, 2, 3, 4]),
+ ([1, 2, 4, 3, 2, 1, 2, 3, 4], [1, 2, 4, 3]),
+ ([[1], [1, 1], [1, 1, 1]], [[1], [1, 1], [1, 1, 1]]),
+ ([[1, 1], [1, 1], [2, 2], [3, 3], [1, 1]], [[1, 1], [2, 2], [3, 3]]),
+ ])
+
+ def testSimpleString(self):
+ self._testSimpleHelper(dtypes.string, [
+ ([], []),
+ (["hello"], ["hello"]),
+ (["hello", "hello", "hello"], ["hello"]),
+ (["hello", "world"], ["hello", "world"]),
+ (["foo", "bar", "baz", "baz", "bar", "foo"], ["foo", "bar", "baz"]),
+ ])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD
new file mode 100644
index 0000000000..915d399f1b
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/BUILD
@@ -0,0 +1,377 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+)
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+
+py_library(
+ name = "counter",
+ srcs = ["counter.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":scan_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "get_single_element",
+ srcs = ["get_single_element.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "iterator_ops",
+ srcs = [
+ "iterator_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:basic_session_run_hooks",
+ "//tensorflow/python:checkpoint_management",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:saver",
+ "//tensorflow/python:session_run_hook",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:optional_ops",
+ ],
+)
+
+py_library(
+ name = "random_ops",
+ srcs = [
+ "random_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "readers",
+ srcs = [
+ "readers.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":batching",
+ ":interleave_ops",
+ ":optimization",
+ ":parsing_ops",
+ ":shuffle_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:convert",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "shuffle_ops",
+ srcs = [
+ "shuffle_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "batching",
+ srcs = ["batching.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":get_single_element",
+ ":grouping",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:tensor_util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:convert",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "enumerate_ops",
+ srcs = ["enumerate_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "error_ops",
+ srcs = ["error_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "grouping",
+ srcs = ["grouping.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "interleave_ops",
+ srcs = ["interleave_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":random_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:stateless_random_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "optimization",
+ srcs = ["optimization.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "parsing_ops",
+ srcs = ["parsing_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_library(
+ name = "map_defun",
+ srcs = ["map_defun.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:tensor_shape",
+ ],
+)
+
+py_library(
+ name = "resampling",
+ srcs = ["resampling.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":batching",
+ ":interleave_ops",
+ ":scan_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:logging_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "scan_ops",
+ srcs = ["scan_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "stats_ops",
+ srcs = ["stats_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "threadpool",
+ srcs = ["threadpool.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+py_library(
+ name = "unique",
+ srcs = [
+ "unique.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "writers",
+ srcs = [
+ "writers.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "indexed_dataset_ops",
+ srcs = ["indexed_dataset_ops.py"],
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "prefetching_ops",
+ srcs = ["prefetching_ops.py"],
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "dataset_ops",
+ deps = [
+ ":batching",
+ ":counter",
+ ":enumerate_ops",
+ ":error_ops",
+ ":get_single_element",
+ ":grouping",
+ ":indexed_dataset_ops",
+ ":interleave_ops",
+ ":map_defun",
+ ":optimization",
+ ":prefetching_ops",
+ ":readers",
+ ":resampling",
+ ":scan_ops",
+ ":shuffle_ops",
+ ":stats_ops",
+ ":threadpool",
+ ":unique",
+ ":writers",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py
new file mode 100644
index 0000000000..d42af9e7e9
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/batching.py
@@ -0,0 +1,669 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Batching dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import get_single_element
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+def batch_window(dataset):
+ """Batches a window of tensors.
+
+ Args:
+ dataset: the input dataset.
+
+ Returns:
+ A `Tensor` representing the batch of the entire input dataset.
+ """
+ if isinstance(dataset.output_classes, tuple):
+ raise TypeError("Input dataset expected to have a single component")
+ if dataset.output_classes is ops.Tensor:
+ return _batch_dense_window(dataset)
+ elif dataset.output_classes is sparse_tensor.SparseTensor:
+ return _batch_sparse_window(dataset)
+ else:
+ raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
+
+
+def _batch_dense_window(dataset):
+ """Batches a window of dense tensors."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def shape_init_fn(_):
+ return array_ops.shape(first_element)
+
+ def shape_reduce_fn(state, value):
+ check_ops.assert_equal(state, array_ops.shape(value))
+ return state
+
+ def finalize_fn(state):
+ return state
+
+ if dataset.output_shapes.is_fully_defined():
+ shape = dataset.output_shapes
+ else:
+ first_element = get_single_element.get_single_element(dataset.take(1))
+ shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
+ finalize_fn)
+ shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
+
+ def batch_init_fn(_):
+ batch_shape = array_ops.concat([[0], shape], 0)
+ return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
+
+ def batch_reduce_fn(state, value):
+ return array_ops.concat([state, [value]], 0)
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+def _batch_sparse_window(dataset):
+ """Batches a window of sparse tensors."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def shape_init_fn(_):
+ return first_element.dense_shape
+
+ def shape_reduce_fn(state, value):
+ check_ops.assert_equal(state, value.dense_shape)
+ return state
+
+ def finalize_fn(state):
+ return state
+
+ if dataset.output_shapes.is_fully_defined():
+ shape = dataset.output_shapes
+ else:
+ first_element = get_single_element.get_single_element(dataset.take(1))
+ shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
+ finalize_fn)
+ shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
+
+ def batch_init_fn(_):
+ indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0)
+ return sparse_tensor.SparseTensor(
+ indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
+ values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
+ dense_shape=array_ops.concat(
+ [np.array([0], dtype=np.int64),
+ math_ops.cast(shape, dtypes.int64)], 0))
+
+ def batch_reduce_fn(state, value):
+ return sparse_ops.sparse_concat(0, [state, value])
+
+ def reshape_fn(value):
+ return sparse_ops.sparse_reshape(
+ value,
+ array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0))
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.map(reshape_fn).apply(
+ grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+@tf_export("data.experimental.dense_to_sparse_batch")
+def dense_to_sparse_batch(batch_size, row_shape):
+ """A transformation that batches ragged elements into `tf.SparseTensor`s.
+
+ Like `Dataset.padded_batch()`, this transformation combines multiple
+ consecutive elements of the dataset, which might have different
+ shapes, into a single element. The resulting element has three
+ components (`indices`, `values`, and `dense_shape`), which
+ comprise a `tf.SparseTensor` that represents the same data. The
+ `row_shape` represents the dense shape of each row in the
+ resulting `tf.SparseTensor`, to which the effective batch size is
+ prepended. For example:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset.
+ a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
+
+ a.apply(tf.data.experimental.dense_to_sparse_batch(
+ batch_size=2, row_shape=[6])) ==
+ {
+ ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices
+ ['a', 'b', 'c', 'a', 'b'], # values
+ [2, 6]), # dense_shape
+ ([[0, 0], [0, 1], [0, 2], [0, 3]],
+ ['a', 'b', 'c', 'd'],
+ [1, 6])
+ }
+ ```
+
+ Args:
+ batch_size: A `tf.int64` scalar `tf.Tensor`, representing the
+ number of consecutive elements of this dataset to combine in a
+ single batch.
+ row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like
+ object representing the equivalent dense shape of a row in the
+ resulting `tf.SparseTensor`. Each element of this dataset must
+ have the same rank as `row_shape`, and must have size less
+ than or equal to `row_shape` in each dimension.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
+
+ return _apply_fn
+
+
+def padded_batch_window(dataset, padded_shape, padding_value=None):
+ """Batches a window of tensors with padding.
+
+ Args:
+ dataset: the input dataset.
+ padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like
+ object representing the shape to which the input elements should be padded
+ prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a
+ `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the
+ maximum size of that dimension in each batch.
+ padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the
+ padding value to use. Defaults are `0` for numeric types and the empty
+ string for string types. If `dataset` contains `tf.SparseTensor`, this
+ value is ignored.
+
+ Returns:
+ A `Tensor` representing the batch of the entire input dataset.
+
+ Raises:
+ ValueError: if invalid arguments are provided.
+ """
+ if not issubclass(dataset.output_classes,
+ (ops.Tensor, sparse_tensor.SparseTensor)):
+ raise TypeError("Input dataset expected to have a single tensor component")
+ if issubclass(dataset.output_classes, (ops.Tensor)):
+ return _padded_batch_dense_window(dataset, padded_shape, padding_value)
+ elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)):
+ if padding_value is not None:
+ raise ValueError("Padding value not allowed for sparse tensors")
+ return _padded_batch_sparse_window(dataset, padded_shape)
+ else:
+ raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
+
+
+def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
+ """Batches a window of dense tensors with padding."""
+
+ padded_shape = math_ops.cast(
+ convert.partial_shape_to_tensor(padded_shape), dtypes.int32)
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def max_init_fn(_):
+ return padded_shape
+
+ def max_reduce_fn(state, value):
+ """Computes the maximum shape to pad to."""
+ condition = math_ops.reduce_all(
+ math_ops.logical_or(
+ math_ops.less_equal(array_ops.shape(value), padded_shape),
+ math_ops.equal(padded_shape, -1)))
+ assert_op = control_flow_ops.Assert(condition, [
+ "Actual shape greater than padded shape: ",
+ array_ops.shape(value), padded_shape
+ ])
+ with ops.control_dependencies([assert_op]):
+ return math_ops.maximum(state, array_ops.shape(value))
+
+ def finalize_fn(state):
+ return state
+
+ # Compute the padded shape.
+ max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
+ padded_shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
+
+ if padding_value is None:
+ if dataset.output_types == dtypes.string:
+ padding_value = ""
+ elif dataset.output_types == dtypes.bool:
+ padding_value = False
+ elif dataset.output_types == dtypes.variant:
+ raise TypeError("Unable to create padding for field of type 'variant'")
+ else:
+ padding_value = 0
+
+ def batch_init_fn(_):
+ batch_shape = array_ops.concat(
+ [np.array([0], dtype=np.int32), padded_shape], 0)
+ return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
+
+ def batch_reduce_fn(state, value):
+ return array_ops.concat([state, [value]], 0)
+
+ def pad_fn(value):
+ shape = array_ops.shape(value)
+ left = array_ops.zeros_like(shape)
+ right = padded_shape - shape
+ return array_ops.pad(
+ value, array_ops.stack([left, right], 1), constant_values=padding_value)
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.map(pad_fn).apply(
+ grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+def _padded_batch_sparse_window(dataset, padded_shape):
+ """Batches a window of sparse tensors with padding."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def max_init_fn(_):
+ return convert.partial_shape_to_tensor(padded_shape)
+
+ def max_reduce_fn(state, value):
+ """Computes the maximum shape to pad to."""
+ condition = math_ops.reduce_all(
+ math_ops.logical_or(
+ math_ops.less_equal(value.dense_shape, padded_shape),
+ math_ops.equal(padded_shape, -1)))
+ assert_op = control_flow_ops.Assert(condition, [
+ "Actual shape greater than padded shape: ", value.dense_shape,
+ padded_shape
+ ])
+ with ops.control_dependencies([assert_op]):
+ return math_ops.maximum(state, value.dense_shape)
+
+ def finalize_fn(state):
+ return state
+
+ # Compute the padded shape.
+ max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
+ padded_shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
+
+ def batch_init_fn(_):
+ indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]],
+ 0)
+ return sparse_tensor.SparseTensor(
+ indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
+ values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
+ dense_shape=array_ops.concat(
+ [np.array([0], dtype=np.int64), padded_shape], 0))
+
+ def batch_reduce_fn(state, value):
+ padded_value = sparse_tensor.SparseTensor(
+ indices=value.indices, values=value.values, dense_shape=padded_shape)
+ reshaped_value = sparse_ops.sparse_reshape(
+ padded_value,
+ array_ops.concat(
+ [np.array([1], dtype=np.int64), padded_value.dense_shape], 0))
+ return sparse_ops.sparse_concat(0, [state, reshaped_value])
+
+ reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
+
+
+class _UnbatchDataset(dataset_ops.UnaryDataset):
+ """A dataset that splits the elements of its input into multiple elements."""
+
+ def __init__(self, input_dataset):
+ """See `unbatch()` for more details."""
+ super(_UnbatchDataset, self).__init__(input_dataset)
+ flat_shapes = nest.flatten(input_dataset.output_shapes)
+ if any(s.ndims == 0 for s in flat_shapes):
+ raise ValueError("Cannot unbatch an input with scalar components.")
+ known_batch_dim = tensor_shape.Dimension(None)
+ for s in flat_shapes:
+ try:
+ known_batch_dim = known_batch_dim.merge_with(s[0])
+ except ValueError:
+ raise ValueError("Cannot unbatch an input whose components have "
+ "different batch sizes.")
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.unbatch_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return nest.map_structure(lambda s: s[1:],
+ self._input_dataset.output_shapes)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+@tf_export("data.experimental.unbatch")
+def unbatch():
+ """Splits elements of a dataset into multiple elements on the batch dimension.
+
+ For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
+ where `B` may vary for each input element, then for each element in the
+ dataset, the unbatched dataset will contain `B` consecutive elements
+ of shape `[a0, a1, ...]`.
+
+ ```python
+ # NOTE: The following example uses `{ ... }` to represent the contents
+ # of a dataset.
+ a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
+
+ a.apply(tf.data.experimental.unbatch()) == {
+ 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
+ ```
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ if not sparse.any_sparse(dataset.output_classes):
+ return _UnbatchDataset(dataset)
+
+ # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
+ # are normalized to the rank-1 dense representation, so that the
+ # sparse-oblivious unbatching logic will slice them
+ # appropriately. This leads to a somewhat inefficient re-encoding step
+ # for all SparseTensor components.
+ # TODO(mrry): Consider optimizing this in future
+ # if it turns out to be a bottleneck.
+ def normalize(arg, *rest):
+ if rest:
+ return sparse.serialize_many_sparse_tensors((arg,) + rest)
+ else:
+ return sparse.serialize_many_sparse_tensors(arg)
+
+ normalized_dataset = dataset.map(normalize)
+
+ # NOTE(mrry): Our `map()` has lost information about the sparseness
+ # of any SparseTensor components, so re-apply the structure of the
+ # original dataset.
+ restructured_dataset = _RestructuredDataset(
+ normalized_dataset,
+ dataset.output_types,
+ dataset.output_shapes,
+ dataset.output_classes,
+ allow_unsafe_cast=True)
+ return _UnbatchDataset(restructured_dataset)
+
+ return _apply_fn
+
+
+class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
+
+ def __init__(self, input_dataset, batch_size, row_shape):
+ """See `Dataset.dense_to_sparse_batch()` for more details."""
+ super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
+ if not isinstance(input_dataset.output_types, dtypes.DType):
+ raise TypeError("DenseToSparseDataset requires an input whose elements "
+ "have a single component, whereas the input has %r." %
+ input_dataset.output_types)
+ self._input_dataset = input_dataset
+ self._batch_size = batch_size
+ self._row_shape = row_shape
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.dense_to_sparse_batch_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._batch_size,
+ row_shape=convert.partial_shape_to_tensor(self._row_shape),
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return sparse_tensor.SparseTensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.vector(None).concatenate(self._row_shape)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _RestructuredDataset(dataset_ops.UnaryDataset):
+ """An internal helper for changing the structure and shape of a dataset."""
+
+ def __init__(self,
+ dataset,
+ output_types,
+ output_shapes=None,
+ output_classes=None,
+ allow_unsafe_cast=False):
+ """Creates a new dataset with the given output types and shapes.
+
+ The given `dataset` must have a structure that is convertible:
+ * `dataset.output_types` must be the same as `output_types` module nesting.
+ * Each shape in `dataset.output_shapes` must be compatible with each shape
+ in `output_shapes` (if given).
+
+ Note: This helper permits "unsafe casts" for shapes, equivalent to using
+ `tf.Tensor.set_shape()` where domain-specific knowledge is available.
+
+ Args:
+ dataset: A `Dataset` object.
+ output_types: A nested structure of `tf.DType` objects.
+ output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
+ If omitted, the shapes will be inherited from `dataset`.
+ output_classes: (Optional.) A nested structure of class types.
+ If omitted, the class types will be inherited from `dataset`.
+ allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
+ reported output types and shapes of the restructured dataset, e.g. to
+ switch a sparse tensor represented as `tf.variant` to its user-visible
+ type and shape.
+
+ Raises:
+ ValueError: If either `output_types` or `output_shapes` is not compatible
+ with the structure of `dataset`.
+ """
+ super(_RestructuredDataset, self).__init__(dataset)
+ self._input_dataset = dataset
+
+ if not allow_unsafe_cast:
+ # Validate that the types are compatible.
+ output_types = nest.map_structure(dtypes.as_dtype, output_types)
+ flat_original_types = nest.flatten(dataset.output_types)
+ flat_new_types = nest.flatten(output_types)
+ if flat_original_types != flat_new_types:
+ raise ValueError(
+ "Dataset with output types %r cannot be restructured to have "
+ "output types %r" % (dataset.output_types, output_types))
+
+ self._output_types = output_types
+
+ if output_shapes is None:
+ # Inherit shapes from the original `dataset`.
+ self._output_shapes = nest.pack_sequence_as(output_types,
+ nest.flatten(
+ dataset.output_shapes))
+ else:
+ if not allow_unsafe_cast:
+ # Validate that the shapes are compatible.
+ nest.assert_same_structure(output_types, output_shapes)
+ flat_original_shapes = nest.flatten(dataset.output_shapes)
+ flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
+
+ for original_shape, new_shape in zip(flat_original_shapes,
+ flat_new_shapes):
+ if not original_shape.is_compatible_with(new_shape):
+ raise ValueError(
+ "Dataset with output shapes %r cannot be restructured to have "
+ "incompatible output shapes %r" % (dataset.output_shapes,
+ output_shapes))
+ self._output_shapes = nest.map_structure_up_to(
+ output_types, tensor_shape.as_shape, output_shapes)
+ if output_classes is None:
+ # Inherit class types from the original `dataset`.
+ self._output_classes = nest.pack_sequence_as(output_types,
+ nest.flatten(
+ dataset.output_classes))
+ else:
+ self._output_classes = output_classes
+
+ def _as_variant_tensor(self):
+ return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+
+class _MapAndBatchDataset(dataset_ops.MapDataset):
+ """A `Dataset` that maps a function over a batch of elements."""
+
+ def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
+ drop_remainder):
+ """See `Dataset.map()` for details."""
+ super(_MapAndBatchDataset, self).__init__(input_dataset, map_func)
+ self._batch_size_t = ops.convert_to_tensor(
+ batch_size, dtype=dtypes.int64, name="batch_size")
+ self._num_parallel_calls_t = ops.convert_to_tensor(
+ num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
+ self._drop_remainder_t = ops.convert_to_tensor(
+ drop_remainder, dtype=dtypes.bool, name="drop_remainder")
+
+ self._batch_size = batch_size
+ self._drop_remainder = drop_remainder
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ input_resource = self._input_dataset._as_variant_tensor()
+ return gen_dataset_ops.map_and_batch_dataset_v2(
+ input_resource,
+ self._map_func.captured_inputs,
+ f=self._map_func,
+ batch_size=self._batch_size_t,
+ num_parallel_calls=self._num_parallel_calls_t,
+ drop_remainder=self._drop_remainder_t,
+ **dataset_ops.flat_structure(self))
+ # pylint: enable=protected-access
+
+ @property
+ def output_shapes(self):
+ dim = self._batch_size if self._drop_remainder else None
+ return nest.pack_sequence_as(self._output_shapes, [
+ tensor_shape.vector(dim).concatenate(s)
+ for s in nest.flatten(self._output_shapes)
+ ])
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+
+@tf_export("data.experimental.map_and_batch")
+def map_and_batch(map_func,
+ batch_size,
+ num_parallel_batches=None,
+ drop_remainder=False,
+ num_parallel_calls=None):
+ """Fused implementation of `map` and `batch`.
+
+ Maps `map_func` across `batch_size` consecutive elements of this dataset
+ and then combines them into a batch. Functionally, it is equivalent to `map`
+ followed by `batch`. However, by fusing the two transformations together, the
+ implementation can be more efficient. Surfacing this transformation in the API
+ is temporary. Once automatic input pipeline optimization is implemented,
+ the fusing of `map` and `batch` will happen automatically and this API will be
+ deprecated.
+
+ Args:
+ map_func: A function mapping a nested structure of tensors to another
+ nested structure of tensors.
+ batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ consecutive elements of this dataset to combine in a single batch.
+ num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`,
+ representing the number of batches to create in parallel. On one hand,
+ higher values can help mitigate the effect of stragglers. On the other
+ hand, higher values can increase contention if CPU is scarce.
+ drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+ whether the last batch should be dropped in case its size is smaller than
+ desired; the default behavior is not to drop the smaller batch.
+ num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
+ representing the number of elements to process in parallel. If not
+ specified, `batch_size * num_parallel_batches` elements will be
+ processed in parallel.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
+ specified.
+ """
+
+ if num_parallel_batches is None and num_parallel_calls is None:
+ num_parallel_calls = batch_size
+ elif num_parallel_batches is not None and num_parallel_calls is None:
+ num_parallel_calls = batch_size * num_parallel_batches
+ elif num_parallel_batches is not None and num_parallel_calls is not None:
+ raise ValueError("The `num_parallel_batches` and `num_parallel_calls` "
+ "arguments are mutually exclusive.")
+
+ def _apply_fn(dataset):
+ return _MapAndBatchDataset(dataset, map_func, batch_size,
+ num_parallel_calls, drop_remainder)
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/counter.py b/tensorflow/python/data/experimental/ops/counter.py
new file mode 100644
index 0000000000..42200eaef9
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/counter.py
@@ -0,0 +1,55 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Counter Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import scan_ops
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.Counter")
+def Counter(start=0, step=1, dtype=dtypes.int64):
+ """Creates a `Dataset` that counts from `start` in steps of size `step`.
+
+ For example:
+
+ ```python
+ Dataset.count() == [0, 1, 2, ...)
+ Dataset.count(2) == [2, 3, ...)
+ Dataset.count(2, 5) == [2, 7, 12, ...)
+ Dataset.count(0, -1) == [0, -1, -2, ...)
+ Dataset.count(10, -1) == [10, 9, ...)
+ ```
+
+ Args:
+ start: (Optional.) The starting value for the counter. Defaults to 0.
+ step: (Optional.) The step size for the counter. Defaults to 1.
+ dtype: (Optional.) The data type for counter elements. Defaults to
+ `tf.int64`.
+
+ Returns:
+ A `Dataset` of scalar `dtype` elements.
+ """
+ with ops.name_scope("counter"):
+ start = ops.convert_to_tensor(start, dtype=dtype, name="start")
+ step = ops.convert_to_tensor(step, dtype=dtype, name="step")
+ return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
+ scan_ops.scan(start, lambda state, _: (state + step, state)))
diff --git a/tensorflow/python/data/experimental/ops/enumerate_ops.py b/tensorflow/python/data/experimental/ops/enumerate_ops.py
new file mode 100644
index 0000000000..a1af98f552
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/enumerate_ops.py
@@ -0,0 +1,60 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Enumerate dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.enumerate_dataset")
+def enumerate_dataset(start=0):
+ """A transformation that enumerate the elements of a dataset.
+
+ It is Similar to python's `enumerate`.
+ For example:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset.
+ a = { 1, 2, 3 }
+ b = { (7, 8), (9, 10) }
+
+ # The nested structure of the `datasets` argument determines the
+ # structure of elements in the resulting dataset.
+ a.apply(tf.data.experimental.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) }
+ b.apply(tf.data.experimental.enumerate()) == { (0, (7, 8)), (1, (9, 10)) }
+ ```
+
+ Args:
+ start: A `tf.int64` scalar `tf.Tensor`, representing the start
+ value for enumeration.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
+ return dataset_ops.Dataset.zip((dataset_ops.Dataset.range(start, max_value),
+ dataset))
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/error_ops.py b/tensorflow/python/data/experimental/ops/error_ops.py
new file mode 100644
index 0000000000..82e274b70c
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/error_ops.py
@@ -0,0 +1,78 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ignore_errors dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.ignore_errors")
+def ignore_errors():
+ """Creates a `Dataset` from another `Dataset` and silently ignores any errors.
+
+ Use this transformation to produce a dataset that contains the same elements
+ as the input, but silently drops any elements that caused an error. For
+ example:
+
+ ```python
+ dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.])
+
+ # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError.
+ dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error"))
+
+ # Using `ignore_errors()` will drop the element that causes an error.
+ dataset =
+ dataset.apply(tf.data.experimental.ignore_errors()) # ==> {1., 0.5, 0.2}
+ ```
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _IgnoreErrorsDataset(dataset)
+
+ return _apply_fn
+
+
+class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that silently ignores errors when computing its input."""
+
+ def __init__(self, input_dataset):
+ """See `Dataset.ignore_errors()` for details."""
+ super(_IgnoreErrorsDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
diff --git a/tensorflow/python/data/experimental/ops/get_single_element.py b/tensorflow/python/data/experimental/ops/get_single_element.py
new file mode 100644
index 0000000000..132526166c
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/get_single_element.py
@@ -0,0 +1,72 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for Datasets and Iterators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.get_single_element")
+def get_single_element(dataset):
+ """Returns the single element in `dataset` as a nested structure of tensors.
+
+ This function enables you to use a `tf.data.Dataset` in a stateless
+ "tensor-in tensor-out" expression, without creating a `tf.data.Iterator`.
+ This can be useful when your preprocessing transformations are expressed
+ as a `Dataset`, and you want to use the transformation at serving time.
+ For example:
+
+ ```python
+ input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE])
+
+ def preprocessing_fn(input_str):
+ # ...
+ return image, label
+
+ dataset = (tf.data.Dataset.from_tensor_slices(input_batch)
+ .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
+ .batch(BATCH_SIZE))
+
+ image_batch, label_batch = tf.data.experimental.get_single_element(dataset)
+ ```
+
+ Args:
+ dataset: A `tf.data.Dataset` object containing a single element.
+
+ Returns:
+ A nested structure of `tf.Tensor` objects, corresponding to the single
+ element of `dataset`.
+
+ Raises:
+ TypeError: if `dataset` is not a `tf.data.Dataset` object.
+ InvalidArgumentError (at runtime): if `dataset` does not contain exactly
+ one element.
+ """
+ if not isinstance(dataset, dataset_ops.Dataset):
+ raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
+
+ nested_ret = nest.pack_sequence_as(
+ dataset.output_types, gen_dataset_ops.dataset_to_single_element(
+ dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(dataset)))
+ return sparse.deserialize_sparse_tensors(
+ nested_ret, dataset.output_types, dataset.output_shapes,
+ dataset.output_classes)
diff --git a/tensorflow/python/data/experimental/ops/grouping.py b/tensorflow/python/data/experimental/ops/grouping.py
new file mode 100644
index 0000000000..18ba583220
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/grouping.py
@@ -0,0 +1,551 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Grouping dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.group_by_reducer")
+def group_by_reducer(key_func, reducer):
+ """A transformation that groups elements and performs a reduction.
+
+ This transformation maps element of a dataset to a key using `key_func` and
+ groups the elements by key. The `reducer` is used to process each group; its
+ `init_func` is used to initialize state for each group when it is created, the
+ `reduce_func` is used to update the state every time an element is mapped to
+ the matching group, and the `finalize_func` is used to map the final state to
+ an output value.
+
+ Args:
+ key_func: A function mapping a nested structure of tensors
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to a scalar `tf.int64` tensor.
+ reducer: An instance of `Reducer`, which captures the reduction logic using
+ the `init_func`, `reduce_func`, and `finalize_func` functions.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _GroupByReducerDataset(dataset, key_func, reducer)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.group_by_window")
+def group_by_window(key_func,
+ reduce_func,
+ window_size=None,
+ window_size_func=None):
+ """A transformation that groups windows of elements by key and reduces them.
+
+ This transformation maps each consecutive element in a dataset to a key
+ using `key_func` and groups the elements by key. It then applies
+ `reduce_func` to at most `window_size_func(key)` elements matching the same
+ key. All except the final window for each key will contain
+ `window_size_func(key)` elements; the final window may be smaller.
+
+ You may provide either a constant `window_size` or a window size determined by
+ the key through `window_size_func`.
+
+ Args:
+ key_func: A function mapping a nested structure of tensors
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to a scalar `tf.int64` tensor.
+ reduce_func: A function mapping a key and a dataset of up to `window_size`
+ consecutive elements matching that key to another dataset.
+ window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ consecutive elements matching the same key to combine in a single
+ batch, which will be passed to `reduce_func`. Mutually exclusive with
+ `window_size_func`.
+ window_size_func: A function mapping a key to a `tf.int64` scalar
+ `tf.Tensor`, representing the number of consecutive elements matching
+ the same key to combine in a single batch, which will be passed to
+ `reduce_func`. Mutually exclusive with `window_size`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if neither or both of {`window_size`, `window_size_func`} are
+ passed.
+ """
+ if (window_size is not None and window_size_func or
+ not (window_size is not None or window_size_func)):
+ raise ValueError("Must pass either window_size or window_size_func.")
+
+ if window_size is not None:
+
+ def constant_window_func(unused_key):
+ return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
+
+ window_size_func = constant_window_func
+
+ assert window_size_func is not None
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _GroupByWindowDataset(dataset, key_func, reduce_func,
+ window_size_func)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.bucket_by_sequence_length")
+def bucket_by_sequence_length(element_length_func,
+ bucket_boundaries,
+ bucket_batch_sizes,
+ padded_shapes=None,
+ padding_values=None,
+ pad_to_bucket_boundary=False,
+ no_padding=False):
+ """A transformation that buckets elements in a `Dataset` by length.
+
+ Elements of the `Dataset` are grouped together by length and then are padded
+ and batched.
+
+ This is useful for sequence tasks in which the elements have variable length.
+ Grouping together elements that have similar lengths reduces the total
+ fraction of padding in a batch which increases training step efficiency.
+
+ Args:
+ element_length_func: function from element in `Dataset` to `tf.int32`,
+ determines the length of the element, which will determine the bucket it
+ goes into.
+ bucket_boundaries: `list<int>`, upper length boundaries of the buckets.
+ bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be
+ `len(bucket_boundaries) + 1`.
+ padded_shapes: Nested structure of `tf.TensorShape` to pass to
+ `tf.data.Dataset.padded_batch`. If not provided, will use
+ `dataset.output_shapes`, which will result in variable length dimensions
+ being padded out to the maximum length in each batch.
+ padding_values: Values to pad with, passed to
+ `tf.data.Dataset.padded_batch`. Defaults to padding with 0.
+ pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
+ size to maximum length in batch. If `True`, will pad dimensions with
+ unknown size to bucket boundary minus 1 (i.e., the maximum length in each
+ bucket), and caller must ensure that the source `Dataset` does not contain
+ any elements with length longer than `max(bucket_boundaries)`.
+ no_padding: `bool`, indicates whether to pad the batch features (features
+ need to be either of type `tf.SparseTensor` or of same shape).
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
+ """
+ with ops.name_scope("bucket_by_seq_length"):
+ if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1):
+ raise ValueError(
+ "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1")
+
+ batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
+
+ def element_to_bucket_id(*args):
+ """Return int64 id of the length bucket for this element."""
+ seq_length = element_length_func(*args)
+
+ boundaries = list(bucket_boundaries)
+ buckets_min = [np.iinfo(np.int32).min] + boundaries
+ buckets_max = boundaries + [np.iinfo(np.int32).max]
+ conditions_c = math_ops.logical_and(
+ math_ops.less_equal(buckets_min, seq_length),
+ math_ops.less(seq_length, buckets_max))
+ bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
+
+ return bucket_id
+
+ def window_size_fn(bucket_id):
+ # The window size is set to the batch size for this bucket
+ window_size = batch_sizes[bucket_id]
+ return window_size
+
+ def make_padded_shapes(shapes, none_filler=None):
+ padded = []
+ for shape in nest.flatten(shapes):
+ shape = tensor_shape.TensorShape(shape)
+ shape = [
+ none_filler if d.value is None else d
+ for d in shape
+ ]
+ padded.append(shape)
+ return nest.pack_sequence_as(shapes, padded)
+
+ def batching_fn(bucket_id, grouped_dataset):
+ """Batch elements in dataset."""
+ batch_size = window_size_fn(bucket_id)
+ if no_padding:
+ return grouped_dataset.batch(batch_size)
+ none_filler = None
+ if pad_to_bucket_boundary:
+ err_msg = ("When pad_to_bucket_boundary=True, elements must have "
+ "length < max(bucket_boundaries).")
+ check = check_ops.assert_less(
+ bucket_id,
+ constant_op.constant(len(bucket_batch_sizes) - 1,
+ dtype=dtypes.int64),
+ message=err_msg)
+ with ops.control_dependencies([check]):
+ boundaries = constant_op.constant(bucket_boundaries,
+ dtype=dtypes.int64)
+ bucket_boundary = boundaries[bucket_id]
+ none_filler = bucket_boundary - 1
+ shapes = make_padded_shapes(
+ padded_shapes or grouped_dataset.output_shapes,
+ none_filler=none_filler)
+ return grouped_dataset.padded_batch(batch_size, shapes, padding_values)
+
+ def _apply_fn(dataset):
+ return dataset.apply(
+ group_by_window(element_to_bucket_id, batching_fn,
+ window_size_func=window_size_fn))
+
+ return _apply_fn
+
+
+def _map_x_dataset(map_func):
+ """A transformation that maps `map_func` across its input.
+
+ This transformation is similar to `tf.data.Dataset.map`, but in addition to
+ supporting dense and sparse tensor inputs, it also supports dataset inputs.
+
+ Args:
+ map_func: A function mapping a nested structure of tensors and/or datasets
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to another nested structure of tensors and/or
+ datasets.
+
+ Returns:
+ Dataset: A `Dataset`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _MapXDataset(dataset, map_func)
+
+ return _apply_fn
+
+
+class _GroupByReducerDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that groups its input and performs a reduction."""
+
+ def __init__(self, input_dataset, key_func, reducer):
+ """See `group_by_reducer()` for details."""
+ super(_GroupByReducerDataset, self).__init__(input_dataset)
+
+ self._input_dataset = input_dataset
+
+ self._make_key_func(key_func, input_dataset)
+ self._make_init_func(reducer.init_func)
+ self._make_reduce_func(reducer.reduce_func, input_dataset)
+ self._make_finalize_func(reducer.finalize_func)
+
+ def _make_key_func(self, key_func, input_dataset):
+ """Make wrapping Defun for key_func."""
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ key_func, "tf.data.experimental.group_by_reducer()", input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`key_func` must return a single tf.int64 tensor. "
+ "Got type=%s and shape=%s"
+ % (wrapped_func.output_types, wrapped_func.output_shapes))
+ self._key_func = wrapped_func.function
+
+ def _make_init_func(self, init_func):
+ """Make wrapping Defun for init_func."""
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ init_func,
+ "tf.data.experimental.group_by_reducer()",
+ input_classes=ops.Tensor,
+ input_shapes=tensor_shape.scalar(),
+ input_types=dtypes.int64)
+ self._init_func = wrapped_func.function
+ self._state_classes = wrapped_func.output_classes
+ self._state_shapes = wrapped_func.output_shapes
+ self._state_types = wrapped_func.output_types
+
+ def _make_reduce_func(self, reduce_func, input_dataset):
+ """Make wrapping Defun for reduce_func."""
+
+ # Iteratively rerun the reduce function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ reduce_func,
+ "tf.data.experimental.group_by_reducer()",
+ input_classes=(self._state_classes, input_dataset.output_classes),
+ input_shapes=(self._state_shapes, input_dataset.output_shapes),
+ input_types=(self._state_types, input_dataset.output_types),
+ add_to_graph=False)
+
+ # Extract and validate class information from the returned values.
+ for new_state_class, state_class in zip(
+ nest.flatten(wrapped_func.output_classes),
+ nest.flatten(self._state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes, wrapped_func.output_classes))
+
+ # Extract and validate type information from the returned values.
+ for new_state_type, state_type in zip(
+ nest.flatten(wrapped_func.output_types),
+ nest.flatten(self._state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types, wrapped_func.output_types))
+
+ # Extract shape information from the returned values.
+ flat_state_shapes = nest.flatten(self._state_shapes)
+ flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
+ weakened_state_shapes = [
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ original_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ self._state_shapes = nest.pack_sequence_as(self._state_shapes,
+ weakened_state_shapes)
+
+ self._reduce_func = wrapped_func.function
+ self._reduce_func.add_to_graph(ops.get_default_graph())
+
+ def _make_finalize_func(self, finalize_func):
+ """Make wrapping Defun for finalize_func."""
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ finalize_func,
+ "tf.data.experimental.group_by_reducer()",
+ input_classes=self._state_classes,
+ input_shapes=self._state_shapes,
+ input_types=self._state_types)
+ self._finalize_func = wrapped_func.function
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.group_by_reducer_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._key_func.captured_inputs,
+ self._init_func.captured_inputs,
+ self._reduce_func.captured_inputs,
+ self._finalize_func.captured_inputs,
+ key_func=self._key_func,
+ init_func=self._init_func,
+ reduce_func=self._reduce_func,
+ finalize_func=self._finalize_func,
+ **dataset_ops.flat_structure(self))
+
+
+class _GroupByWindowDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that groups its input and performs a windowed reduction."""
+
+ def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
+ """See `group_by_window()` for details."""
+ super(_GroupByWindowDataset, self).__init__(input_dataset)
+
+ self._input_dataset = input_dataset
+
+ self._make_key_func(key_func, input_dataset)
+ self._make_reduce_func(reduce_func, input_dataset)
+ self._make_window_size_func(window_size_func)
+
+ def _make_window_size_func(self, window_size_func):
+ """Make wrapping Defun for window_size_func."""
+ def window_size_func_wrapper(key):
+ return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ window_size_func_wrapper,
+ "tf.data.experimental.group_by_window()",
+ input_classes=ops.Tensor,
+ input_shapes=tensor_shape.scalar(),
+ input_types=dtypes.int64)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`window_size_func` must return a single tf.int64 scalar tensor.")
+ self._window_size_func = wrapped_func.function
+
+ def _make_key_func(self, key_func, input_dataset):
+ """Make wrapping Defun for key_func."""
+ def key_func_wrapper(*args):
+ return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ key_func_wrapper, "tf.data.experimental.group_by_window()",
+ input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`key_func` must return a single tf.int64 scalar tensor.")
+ self._key_func = wrapped_func.function
+
+ def _make_reduce_func(self, reduce_func, input_dataset):
+ """Make wrapping Defun for reduce_func."""
+ nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ reduce_func,
+ "tf.data.experimental.reduce_by_window()",
+ input_classes=(ops.Tensor, nested_dataset),
+ input_shapes=(tensor_shape.scalar(), nested_dataset),
+ input_types=(dtypes.int64, nested_dataset),
+ experimental_nested_dataset_support=True)
+ if not isinstance(
+ wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access
+ raise TypeError("`reduce_func` must return a `Dataset` object.")
+ self._output_classes = wrapped_func.output_classes.output_classes
+ self._output_types = wrapped_func.output_types.output_types
+ self._output_shapes = wrapped_func.output_shapes.output_shapes
+ self._reduce_func = wrapped_func.function
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.group_by_window_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._key_func.captured_inputs,
+ self._reduce_func.captured_inputs,
+ self._window_size_func.captured_inputs,
+ key_func=self._key_func,
+ reduce_func=self._reduce_func,
+ window_size_func=self._window_size_func,
+ **dataset_ops.flat_structure(self))
+
+
+@tf_export("data.experimental.Reducer")
+class Reducer(object):
+ """A reducer is used for reducing a set of elements.
+
+ A reducer is represented as a tuple of the three functions:
+ 1) initialization function: key => initial state
+ 2) reduce function: (old state, input) => new state
+ 3) finalization function: state => result
+ """
+
+ def __init__(self, init_func, reduce_func, finalize_func):
+ self._init_func = init_func
+ self._reduce_func = reduce_func
+ self._finalize_func = finalize_func
+
+ @property
+ def init_func(self):
+ return self._init_func
+
+ @property
+ def reduce_func(self):
+ return self._reduce_func
+
+ @property
+ def finalize_func(self):
+ return self._finalize_func
+
+
+class _MapXDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that maps a function over elements in its input."""
+
+ def __init__(self, input_dataset, map_func):
+ """See `map_x_dataset()` for details."""
+ super(_MapXDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ map_func,
+ "tf.data.experimental.map_x_dataset()",
+ input_dataset,
+ experimental_nested_dataset_support=True)
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
+ self._map_func = wrapped_func.function
+
+ def _as_variant_tensor(self):
+ input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+ return gen_dataset_ops.map_dataset(
+ input_t,
+ self._map_func.captured_inputs,
+ f=self._map_func,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
diff --git a/tensorflow/python/data/experimental/ops/indexed_dataset_ops.py b/tensorflow/python/data/experimental/ops/indexed_dataset_ops.py
new file mode 100644
index 0000000000..9c06474a2f
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/indexed_dataset_ops.py
@@ -0,0 +1,177 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for indexed datasets."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+
+
+class MaterializedIndexedDataset(object):
+ """MaterializedIndexedDataset is highly experimental!
+ """
+
+ def __init__(self, materialized_resource, materializer, output_classes,
+ output_types, output_shapes):
+ self._materialized_resource = materialized_resource
+ self._materializer = materializer
+ self._output_classes = output_classes
+ self._output_types = output_types
+ self._output_shapes = output_shapes
+
+ @property
+ def initializer(self):
+ if self._materializer is not None:
+ return self._materializer
+ raise ValueError("MaterializedDataset does not have a materializer")
+
+ def get(self, index):
+ """Get retrieves a value (or set of values) from the IndexedDataset.
+
+ Args:
+ index: A uint64 scalar or vector tensor with the indices to retrieve.
+
+ Returns:
+ A tensor containing the values corresponding to `index`.
+ """
+ # TODO(saeta): nest.pack_sequence_as(...)
+ return ged_ops.experimental_indexed_dataset_get(
+ self._materialized_resource,
+ index,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self._output_types, self._output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_types(self._output_shapes, self._output_classes)))
+
+
+class IndexedDataset(dataset_ops.Dataset):
+ """IndexedDataset is highly experimental!
+ """
+
+ def __init__(self):
+ pass
+
+ def materialize(self, shared_name=None, container=None):
+ """Materialize creates a MaterializedIndexedDataset.
+
+ IndexedDatasets can be combined through operations such as TBD. Therefore,
+ they are only materialized when absolutely required.
+
+ Args:
+ shared_name: a string for the shared name to use for the resource.
+ container: a string for the container to store the resource.
+
+ Returns:
+ A MaterializedIndexedDataset.
+ """
+ if container is None:
+ container = ""
+ if shared_name is None:
+ shared_name = ""
+ materialized_resource = (
+ ged_ops.experimental_materialized_index_dataset_handle(
+ container=container,
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_types(self.output_shapes,
+ self.output_classes))))
+
+ with ops.colocate_with(materialized_resource):
+ materializer = ged_ops.experimental_indexed_dataset_materialize(
+ self._as_variant_tensor(), materialized_resource)
+ return MaterializedIndexedDataset(materialized_resource, materializer,
+ self.output_classes, self.output_types,
+ self.output_shapes)
+
+ @abc.abstractproperty
+ def output_types(self):
+ """Returns the type of each component of an element of this IndexedDataset.
+
+ Returns:
+ A nested structure of `tf.DType` objects corresponding to each component
+ of an element of this IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset.output_types")
+
+ @abc.abstractproperty
+ def output_classes(self):
+ """Returns the class of each component of an element of this IndexedDataset.
+
+ The expected values are `tf.Tensor` and `tf.SparseTensor`.
+
+ Returns:
+ A nested structure of Python `type` objects corresponding to each
+ component of an element of this IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset.output_classes")
+
+ @abc.abstractproperty
+ def output_shapes(self):
+ """Returns the shape of each component of an element of this IndexedDataset.
+
+ Returns:
+ A nested structure of `tf.TensorShape` objects corresponding to each
+ component of an element of this IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset.output_shapes")
+
+ @abc.abstractmethod
+ def _as_variant_tensor(self):
+ """Creates a `tf.variant` `tf.Tensor` representing this IndexedDataset.
+
+ Returns:
+ A scalar `tf.Tensor` of `tf.variant` type, which represents this
+ IndexedDataset.
+ """
+ raise NotImplementedError("IndexedDataset._as_variant_tensor")
+
+
+class IdentityIndexedDataset(IndexedDataset):
+ """IdentityIndexedDataset is a trivial indexed dataset used for testing.
+ """
+
+ def __init__(self, size):
+ super(IdentityIndexedDataset, self).__init__()
+ # TODO(saeta): Verify _size is a scalar!
+ self._size = ops.convert_to_tensor(size, dtype=dtypes.uint64, name="size")
+
+ @property
+ def output_types(self):
+ return dtypes.uint64
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ def _as_variant_tensor(self):
+ return ged_ops.experimental_identity_indexed_dataset(self._size)
+
+ def _inputs(self):
+ return []
diff --git a/tensorflow/python/data/experimental/ops/interleave_ops.py b/tensorflow/python/data/experimental/ops/interleave_ops.py
new file mode 100644
index 0000000000..a3c094859e
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/interleave_ops.py
@@ -0,0 +1,262 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Non-deterministic dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import random_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.ops import gen_stateless_random_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.parallel_interleave")
+def parallel_interleave(map_func,
+ cycle_length,
+ block_length=1,
+ sloppy=False,
+ buffer_output_elements=None,
+ prefetch_input_elements=None):
+ """A parallel version of the `Dataset.interleave()` transformation.
+
+ `parallel_interleave()` maps `map_func` across its input to produce nested
+ datasets, and outputs their elements interleaved. Unlike
+ `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested
+ datasets in parallel, which increases the throughput, especially in the
+ presence of stragglers. Furthermore, the `sloppy` argument can be used to
+ improve performance, by relaxing the requirement that the outputs are produced
+ in a deterministic order, and allowing the implementation to skip over nested
+ datasets whose elements are not readily available when requested.
+
+ Example usage:
+
+ ```python
+ # Preprocess 4 files concurrently.
+ filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
+ dataset = filenames.apply(
+ tf.data.experimental.parallel_interleave(
+ lambda filename: tf.data.TFRecordDataset(filename),
+ cycle_length=4))
+ ```
+
+ WARNING: If `sloppy` is `True`, the order of produced elements is not
+ deterministic.
+
+ Args:
+ map_func: A function mapping a nested structure of tensors to a `Dataset`.
+ cycle_length: The number of input `Dataset`s to interleave from in parallel.
+ block_length: The number of consecutive elements to pull from an input
+ `Dataset` before advancing to the next input `Dataset`.
+ sloppy: If false, elements are produced in deterministic order. Otherwise,
+ the implementation is allowed, for the sake of expediency, to produce
+ elements in a non-deterministic order.
+ buffer_output_elements: The number of elements each iterator being
+ interleaved should buffer (similar to the `.prefetch()` transformation for
+ each interleaved iterator).
+ prefetch_input_elements: The number of input elements to transform to
+ iterators before they are needed for interleaving.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ return readers.ParallelInterleaveDataset(
+ dataset, map_func, cycle_length, block_length, sloppy,
+ buffer_output_elements, prefetch_input_elements)
+
+ return _apply_fn
+
+
+class _DirectedInterleaveDataset(dataset_ops.Dataset):
+ """A substitute for `Dataset.interleave()` on a fixed list of datasets."""
+
+ def __init__(self, selector_input, data_inputs):
+ self._selector_input = selector_input
+ self._data_inputs = list(data_inputs)
+
+ for data_input in data_inputs[1:]:
+ if (data_input.output_types != data_inputs[0].output_types or
+ data_input.output_classes != data_inputs[0].output_classes):
+ raise TypeError("All datasets must have the same type and class.")
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ return (
+ gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
+ self._selector_input._as_variant_tensor(), [
+ data_input._as_variant_tensor()
+ for data_input in self._data_inputs
+ ], **dataset_ops.flat_structure(self)))
+ # pylint: enable=protected-access
+
+ def _inputs(self):
+ return [self._selector_input] + self._data_inputs
+
+ @property
+ def output_classes(self):
+ return self._data_inputs[0].output_classes
+
+ @property
+ def output_shapes(self):
+ ret = self._data_inputs[0].output_shapes
+ for data_input in self._data_inputs[1:]:
+ ret = nest.pack_sequence_as(ret, [
+ ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
+ nest.flatten(ret), nest.flatten(data_input.output_shapes))
+ ])
+ return ret
+
+ @property
+ def output_types(self):
+ return self._data_inputs[0].output_types
+
+
+@tf_export("data.experimental.sample_from_datasets")
+def sample_from_datasets(datasets, weights=None, seed=None):
+ """Samples elements at random from the datasets in `datasets`.
+
+ Args:
+ datasets: A list of `tf.data.Dataset` objects with compatible structure.
+ weights: (Optional.) A list of `len(datasets)` floating-point values where
+ `weights[i]` represents the probability with which an element should be
+ sampled from `datasets[i]`, or a `tf.data.Dataset` object where each
+ element is such a list. Defaults to a uniform distribution across
+ `datasets`.
+ seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ random seed that will be used to create the distribution. See
+ `tf.set_random_seed` for behavior.
+
+ Returns:
+ A dataset that interleaves elements from `datasets` at random, according to
+ `weights` if provided, otherwise with uniform probability.
+
+ Raises:
+ TypeError: If the `datasets` or `weights` arguments have the wrong type.
+ ValueError: If the `weights` argument is specified and does not match the
+ length of the `datasets` element.
+ """
+ num_datasets = len(datasets)
+ if not isinstance(weights, dataset_ops.Dataset):
+ if weights is None:
+ # Select inputs with uniform probability.
+ logits = [[1.0] * num_datasets]
+
+ else:
+ # Use the given `weights` as the probability of choosing the respective
+ # input.
+ weights = ops.convert_to_tensor(weights, name="weights")
+ if weights.dtype not in (dtypes.float32, dtypes.float64):
+ raise TypeError("`weights` must be convertible to a tensor of "
+ "`tf.float32` or `tf.float64` elements.")
+ if not weights.shape.is_compatible_with([num_datasets]):
+ raise ValueError(
+ "`weights` must be a vector of length `len(datasets)`.")
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed
+ # to weights.
+ logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
+
+ # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
+ # is a `Dataset`, it is possible that evaluating it has a side effect the
+ # user depends on.
+ if len(datasets) == 1:
+ return datasets[0]
+
+ def select_dataset_constant_logits(seed):
+ return array_ops.squeeze(
+ gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
+ axis=[0, 1])
+
+ selector_input = dataset_ops.MapDataset(
+ random_ops.RandomDataset(seed).batch(2),
+ select_dataset_constant_logits,
+ use_inter_op_parallelism=False)
+
+ else:
+ # Use each element of the given `weights` dataset as the probability of
+ # choosing the respective input.
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed to
+ # weights.
+ logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
+
+ def select_dataset_varying_logits(logits, seed):
+ return array_ops.squeeze(
+ gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
+ axis=[0, 1])
+
+ logits_and_seeds = dataset_ops.Dataset.zip(
+ (logits_ds, random_ops.RandomDataset(seed).batch(2)))
+ selector_input = dataset_ops.MapDataset(
+ logits_and_seeds,
+ select_dataset_varying_logits,
+ use_inter_op_parallelism=False)
+
+ return _DirectedInterleaveDataset(selector_input, datasets)
+
+
+@tf_export("data.experimental.choose_from_datasets")
+def choose_from_datasets(datasets, choice_dataset):
+ """Creates a dataset that deterministically chooses elements from `datasets`.
+
+ For example, given the following datasets:
+
+ ```python
+ datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
+ tf.data.Dataset.from_tensors("bar").repeat(),
+ tf.data.Dataset.from_tensors("baz").repeat()]
+
+ # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
+ choice_dataset = tf.data.Dataset.range(3).repeat(3)
+
+ result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
+ ```
+
+ The elements of `result` will be:
+
+ ```
+ "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
+ ```
+
+ Args:
+ datasets: A list of `tf.data.Dataset` objects with compatible structure.
+ choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between
+ `0` and `len(datasets) - 1`.
+
+ Returns:
+ A dataset that interleaves elements from `datasets` according to the values
+ of `choice_dataset`.
+
+ Raises:
+ TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
+ type.
+ """
+ if not (choice_dataset.output_types == dtypes.int64
+ and choice_dataset.output_shapes.is_compatible_with(
+ tensor_shape.scalar())
+ and choice_dataset.output_classes == ops.Tensor):
+ raise TypeError("`choice_dataset` must be a dataset of scalar "
+ "`tf.int64` tensors.")
+ return _DirectedInterleaveDataset(choice_dataset, datasets)
diff --git a/tensorflow/python/data/experimental/ops/iterator_ops.py b/tensorflow/python/data/experimental/ops/iterator_ops.py
new file mode 100644
index 0000000000..72d7d58f06
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/iterator_ops.py
@@ -0,0 +1,268 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Iterator ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops import optional_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.make_saveable_from_iterator")
+def make_saveable_from_iterator(iterator):
+ """Returns a SaveableObject for saving/restore iterator state using Saver.
+
+ Args:
+ iterator: Iterator.
+
+ For example:
+
+ ```python
+ with tf.Graph().as_default():
+ ds = tf.data.Dataset.range(10)
+ iterator = ds.make_initializable_iterator()
+ # Build the iterator SaveableObject.
+ saveable_obj = tf.data.experimental.make_saveable_from_iterator(iterator)
+ # Add the SaveableObject to the SAVEABLE_OBJECTS collection so
+ # it can be automatically saved using Saver.
+ tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
+ saver = tf.train.Saver()
+
+ while continue_training:
+ ... Perform training ...
+ if should_save_checkpoint:
+ saver.save()
+ ```
+
+ Note: When restoring the iterator, the existing iterator state is completely
+ discarded. This means that any changes you may have made to the Dataset
+ graph will be discarded as well! This includes the new Dataset graph
+ that you may have built during validation. So, while running validation,
+ make sure to run the initializer for the validation input pipeline after
+ restoring the checkpoint.
+
+ Note: Not all iterators support checkpointing yet. Attempting to save the
+ state of an unsupported iterator will throw an error.
+ """
+ return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access
+
+
+class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
+ """SaveableObject for saving/restoring iterator state."""
+
+ def __init__(self, iterator_resource):
+ serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
+ specs = [
+ saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
+ iterator_resource.name + "-state")
+ ]
+ super(_Saveable, self).__init__(iterator_resource, specs,
+ iterator_resource.name)
+
+ def restore(self, restored_tensors, unused_restored_shapes):
+ with ops.colocate_with(self.op):
+ return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
+
+
+@tf_export("data.experimental.CheckpointInputPipelineHook")
+class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
+ """Checkpoints input pipeline state every N steps or seconds.
+
+ This hook saves the state of the iterators in the `Graph` so that when
+ training is resumed the input pipeline continues from where it left off.
+ This could potentially avoid overfitting in certain pipelines where the
+ number of training steps per eval are small compared to the dataset
+ size or if the training pipeline is pre-empted.
+
+ Differences from `CheckpointSaverHook`:
+ 1. Saves only the input pipelines in the "iterators" collection and not the
+ global variables or other saveable objects.
+ 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary.
+
+ Example of checkpointing the training pipeline:
+
+ ```python
+ est = tf.estimator.Estimator(model_fn)
+ while True:
+ est.train(
+ train_input_fn,
+ hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)],
+ steps=train_steps_per_eval)
+ # Note: We do not pass the hook here.
+ metrics = est.evaluate(eval_input_fn)
+ if should_stop_the_training(metrics):
+ break
+ ```
+
+ This hook should be used if the input pipeline state needs to be saved
+ separate from the model checkpoint. Doing so may be useful for a few reasons:
+ 1. The input pipeline checkpoint may be large, if there are large shuffle
+ or prefetch buffers for instance, and may bloat the checkpoint size.
+ 2. If the input pipeline is shared between training and validation, restoring
+ the checkpoint during validation may override the validation input
+ pipeline.
+
+ For saving the input pipeline checkpoint alongside the model weights use
+ `tf.data.experimental.make_saveable_from_iterator` directly to create a
+ `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however,
+ that you will need to be careful not to restore the training iterator during
+ eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS
+ collector when building the eval graph.
+ """
+
+ def __init__(self, estimator):
+ """Initializes a `CheckpointInputPipelineHook`.
+
+ Args:
+ estimator: Estimator.
+
+ Raises:
+ ValueError: One of `save_steps` or `save_secs` should be set.
+ ValueError: At most one of saver or scaffold should be set.
+ """
+ # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
+ # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
+ # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
+ # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
+ # to be different to avoid conflicts with the model checkpoint.
+
+ # pylint: disable=protected-access
+ checkpoint_prefix = "input"
+ if estimator._config.num_worker_replicas > 1:
+ # Distributed setting.
+ suffix = "_{}_{}".format(estimator._config.task_type,
+ estimator._config.task_id)
+ checkpoint_prefix += suffix
+ # pylint: enable=protected-access
+
+ # We use a composition paradigm instead of inheriting from
+ # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
+ # to check whether a `CheckpointSaverHook` is already present in the list
+ # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
+ # would thwart this behavior. This hook checkpoints *only the iterators*
+ # and not the graph variables.
+ self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
+ estimator.model_dir,
+ save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access
+ save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access
+ checkpoint_basename=checkpoint_prefix + ".ckpt")
+
+ # Name for the protocol buffer file that will contain the list of most
+ # recent checkpoints stored as a `CheckpointState` protocol buffer.
+ # This file, kept in the same directory as the checkpoint files, is
+ # automatically managed by the `Saver` to keep track of recent checkpoints.
+ # The default name used by the `Saver` for this file is "checkpoint". Here
+ # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
+ # `checkpoint_dir` is the same as the model checkpoint directory, there are
+ # no conflicts during restore.
+ self._latest_filename = "checkpoint_" + checkpoint_prefix
+ self._first_run = True
+
+ def begin(self):
+ # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
+ # collection if no `Saver` or `Scaffold` is provided.
+ # pylint: disable=protected-access
+ if (self._checkpoint_saver_hook._saver is None and
+ self._checkpoint_saver_hook._scaffold is None):
+ iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
+ saveables = [_Saveable(i) for i in iterators]
+ self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
+ self._latest_filename)
+ # pylint: enable=protected-access
+ self._checkpoint_saver_hook.begin()
+
+ def _restore_or_save_initial_ckpt(self, session):
+ # Ideally this should be run in after_create_session but is not for the
+ # following reason:
+ # Currently there is no way of enforcing an order of running the
+ # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
+ # is run *after* this hook. That is troublesome because
+ # 1. If a checkpoint exists and this hook restores it, the initializer hook
+ # will override it.
+ # 2. If no checkpoint exists, this hook will try to save an initialized
+ # iterator which will result in an exception.
+ #
+ # As a temporary fix we enter the following implicit contract between this
+ # hook and the _DatasetInitializerHook.
+ # 1. The _DatasetInitializerHook initializes the iterator in the call to
+ # after_create_session.
+ # 2. This hook saves the iterator on the first call to `before_run()`, which
+ # is guaranteed to happen after `after_create_session()` of all hooks
+ # have been run.
+
+ # Check if there is an existing checkpoint. If so, restore from it.
+ # pylint: disable=protected-access
+ latest_checkpoint_path = checkpoint_management.latest_checkpoint(
+ self._checkpoint_saver_hook._checkpoint_dir,
+ latest_filename=self._latest_filename)
+ if latest_checkpoint_path:
+ self._checkpoint_saver_hook._get_saver().restore(session,
+ latest_checkpoint_path)
+ else:
+ # The checkpoint saved here is the state at step "global_step".
+ # Note: We do not save the GraphDef or MetaGraphDef here.
+ global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
+ self._checkpoint_saver_hook._save(session, global_step)
+ self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
+ # pylint: enable=protected-access
+
+ def before_run(self, run_context):
+ if self._first_run:
+ self._restore_or_save_initial_ckpt(run_context.session)
+ self._first_run = False
+ return self._checkpoint_saver_hook.before_run(run_context)
+
+ def after_run(self, run_context, run_values):
+ self._checkpoint_saver_hook.after_run(run_context, run_values)
+
+ def end(self, session):
+ self._checkpoint_saver_hook.end(session)
+
+
+class _CustomSaver(saver_lib.Saver):
+ """`Saver` with a different default `latest_filename`.
+
+ This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
+ the model ckpt saved by the `CheckpointSaverHook`.
+ """
+
+ def __init__(self, var_list, latest_filename):
+ super(_CustomSaver, self).__init__(var_list)
+ self._latest_filename = latest_filename
+
+ def save(self,
+ sess,
+ save_path,
+ global_step=None,
+ latest_filename=None,
+ meta_graph_suffix="meta",
+ write_meta_graph=True,
+ write_state=True,
+ strip_default_attrs=False):
+ return super(_CustomSaver, self).save(
+ sess, save_path, global_step, latest_filename or self._latest_filename,
+ meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
+
+
+tf_export("data.experimental.Optional")(optional_ops.Optional)
+tf_export("data.experimental.get_next_as_optional")(
+ iterator_ops.get_next_as_optional)
diff --git a/tensorflow/python/data/experimental/ops/map_defun.py b/tensorflow/python/data/experimental/ops/map_defun.py
new file mode 100644
index 0000000000..3ac1158d8b
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/map_defun.py
@@ -0,0 +1,58 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for optimizing `tf.data` pipelines."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
+
+
+def map_defun(fn, elems, output_dtypes, output_shapes):
+ """Map a function on the list of tensors unpacked from `elems` on dimension 0.
+
+ Args:
+ fn: A function (`function.Defun`) that takes a list of tensors and returns
+ another list of tensors. The output list has the same types as
+ output_dtypes. The elements of the output list have the same dimension 0
+ as `elems`, and the remaining dimensions correspond to those of
+ `fn_output_shapes`.
+ elems: A list of tensors.
+ output_dtypes: A list of dtypes corresponding to the output types of the
+ function.
+ output_shapes: A list of `TensorShape`s corresponding to the output
+ shapes from each invocation of the function on slices of inputs.
+
+ Raises:
+ ValueError: if any of the inputs are malformed.
+
+ Returns:
+ A list of `Tensor` objects with the same types as `output_dtypes`.
+ """
+ if not isinstance(elems, list):
+ raise ValueError("`elems` must be a list of tensors.")
+ if not isinstance(output_dtypes, list):
+ raise ValueError("`output_dtypes` must be a list of `tf.DType` objects.")
+ if not isinstance(output_shapes, list):
+ raise ValueError("`output_shapes` must be a list of `tf.TensorShape` "
+ "objects.")
+
+ elems = [ops.convert_to_tensor(e) for e in elems]
+ output_shapes = [tensor_shape.TensorShape(s) for s in output_shapes]
+ return gen_dataset_ops.map_defun(elems, fn.captured_inputs, output_dtypes,
+ output_shapes, fn)
diff --git a/tensorflow/python/data/experimental/ops/optimization.py b/tensorflow/python/data/experimental/ops/optimization.py
new file mode 100644
index 0000000000..276dde8383
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/optimization.py
@@ -0,0 +1,114 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for optimizing `tf.data` pipelines."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+
+# A constant that can be used to enable auto-tuning.
+AUTOTUNE = -1
+
+
+# TODO(jsimsa): Support RE matching for both individual transformation (e.g. to
+# account for indexing) and transformation sequence.
+def assert_next(transformations):
+ """A transformation that asserts which transformations happen next.
+
+ Args:
+ transformations: A `tf.string` vector `tf.Tensor` identifying the
+ transformations that are expected to happen next.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _AssertNextDataset(dataset, transformations)
+
+ return _apply_fn
+
+
+def model():
+ """A transformation that models performance.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return dataset_ops._ModelDataset(dataset) # pylint: disable=protected-access
+
+ return _apply_fn
+
+
+def optimize(optimizations=None):
+ """A transformation that applies optimizations.
+
+ Args:
+ optimizations: (Optional.) A `tf.string` vector `tf.Tensor` identifying
+ optimizations to use. If not specified, the default set of optimizations
+ is applied.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return dataset_ops._OptimizeDataset(dataset, optimizations) # pylint: disable=protected-access
+
+ return _apply_fn
+
+
+class _AssertNextDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that asserts which transformations happen next."""
+
+ def __init__(self, input_dataset, transformations):
+ """See `assert_next()` for details."""
+ super(_AssertNextDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ if transformations is None:
+ raise ValueError("At least one transformation should be specified")
+ self._transformations = ops.convert_to_tensor(
+ transformations, dtype=dtypes.string, name="transformations")
+
+ def _as_variant_tensor(self):
+ return gen_experimental_dataset_ops.experimental_assert_next_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._transformations,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
diff --git a/tensorflow/python/data/experimental/ops/parsing_ops.py b/tensorflow/python/data/experimental/ops/parsing_ops.py
new file mode 100644
index 0000000000..6615b9022a
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/parsing_ops.py
@@ -0,0 +1,152 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental `dataset` API for parsing example."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+class _ParseExampleDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that parses `example` dataset into a `dict` dataset."""
+
+ def __init__(self, input_dataset, features, num_parallel_calls):
+ super(_ParseExampleDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ if not all(types == dtypes.string
+ for types in nest.flatten(input_dataset.output_types)):
+ raise TypeError("Input dataset should be a dataset of vectors of strings")
+ self._num_parallel_calls = num_parallel_calls
+ # pylint: disable=protected-access
+ self._features = parsing_ops._prepend_none_dimension(features)
+ # sparse_keys and dense_keys come back sorted here.
+ (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
+ dense_shapes) = parsing_ops._features_to_raw_params(
+ self._features, [
+ parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
+ parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
+ ])
+ # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
+ (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
+ dense_shape_as_shape) = parsing_ops._process_raw_parameters(
+ None, dense_defaults, sparse_keys, sparse_types, dense_keys,
+ dense_types, dense_shapes)
+ # pylint: enable=protected-access
+ self._sparse_keys = sparse_keys
+ self._sparse_types = sparse_types
+ self._dense_keys = dense_keys
+ self._dense_defaults = dense_defaults_vec
+ self._dense_shapes = dense_shapes
+ self._dense_types = dense_types
+ dense_output_shapes = [
+ self._input_dataset.output_shapes.concatenate(shape)
+ for shape in dense_shape_as_shape
+ ]
+ sparse_output_shapes = [
+ self._input_dataset.output_shapes.concatenate([None])
+ for _ in range(len(sparse_keys))
+ ]
+
+ self._output_shapes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ dense_output_shapes + sparse_output_shapes))
+ self._output_types = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ self._dense_types + self._sparse_types))
+ self._output_classes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ [ops.Tensor for _ in range(len(self._dense_defaults))] +
+ [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
+ ]))
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.parse_example_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._num_parallel_calls,
+ self._dense_defaults,
+ self._sparse_keys,
+ self._dense_keys,
+ self._sparse_types,
+ self._dense_shapes,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+# TODO(b/111553342): add arguments names and example names as well.
+@tf_export("data.experimental.parse_example_dataset")
+def parse_example_dataset(features, num_parallel_calls=1):
+ """A transformation that parses `Example` protos into a `dict` of tensors.
+
+ Parses a number of serialized `Example` protos given in `serialized`. We refer
+ to `serialized` as a batch with `batch_size` many entries of individual
+ `Example` protos.
+
+ This op parses serialized examples into a dictionary mapping keys to `Tensor`
+ and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
+ `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
+ and `SparseFeature` is mapped to a `SparseTensor`, and each
+ `FixedLenFeature` is mapped to a `Tensor`. See `tf.parse_example` for more
+ details about feature dictionaries.
+
+ Args:
+ features: A `dict` mapping feature keys to `FixedLenFeature`,
+ `VarLenFeature`, and `SparseFeature` values.
+ num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
+ representing the number of parsing processes to call in parallel.
+
+ Returns:
+ A dataset transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if features argument is None.
+ """
+ if features is None:
+ raise ValueError("Missing: features was %s." % features)
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
+ if any([
+ isinstance(feature, parsing_ops.SparseFeature)
+ for _, feature in features.items()
+ ]):
+ # pylint: disable=protected-access
+ # pylint: disable=g-long-lambda
+ out_dataset = out_dataset.map(
+ lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features(
+ features, x), num_parallel_calls=num_parallel_calls)
+ return out_dataset
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/prefetching_ops.py b/tensorflow/python/data/experimental/ops/prefetching_ops.py
new file mode 100644
index 0000000000..48d7136f95
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/prefetching_ops.py
@@ -0,0 +1,531 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrapper for prefetching_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
+from tensorflow.python.framework import device as framework_device
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+def function_buffering_resource(string_arg,
+ target_device,
+ f,
+ buffer_size,
+ output_types,
+ container="",
+ shared_name=None,
+ name=None):
+ """Creates a FunctionBufferingResource.
+
+ A FunctionBufferingResource fills up a buffer by calling a function `f` on
+ `target_device`. `f` should take in only a single string argument as input.
+
+ Args:
+ string_arg: The single string argument to the function.
+ target_device: The device to run `f` on.
+ f: The function to be executed.
+ buffer_size: Size of the buffer to be populated.
+ output_types: The output types generated by the function.
+ container: (Optional) string. Defaults to "".
+ shared_name: (Optional) string.
+ name: (Optional) string to name the op.
+
+ Returns:
+ Handle to a FunctionBufferingResource.
+ """
+ if shared_name is None:
+ shared_name = ""
+ return ged_ops.experimental_function_buffering_resource(
+ string_arg=string_arg,
+ target_device=target_device,
+ shared_name=shared_name,
+ f=f,
+ buffer_size=buffer_size,
+ container=container,
+ name=name,
+ output_types=output_types)
+
+
+def function_buffering_resource_get_next(function_buffer_resource,
+ output_types,
+ name=None):
+ return ged_ops.experimental_function_buffering_resource_get_next(
+ function_buffer_resource=function_buffer_resource,
+ output_types=output_types,
+ name=name)
+
+
+def function_buffering_resource_reset(function_buffer_resource, name=None):
+ return ged_ops.experimental_function_buffering_resource_reset(
+ function_buffer_resource=function_buffer_resource, name=name)
+
+
+# pylint: disable=protected-access
+class _PrefetchToDeviceIterator(object):
+ """A replacement for `tf.data.Iterator` that prefetches to another device.
+
+ Args:
+ input_dataset: The input dataset
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ device: A fully specified device string where we want to prefetch to
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be
+ shared under the given name across multiple sessions that share the
+ same devices (e.g. when using a remote server).
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ one_shot,
+ device,
+ buffer_size,
+ shared_name=None):
+ self._input_dataset = input_dataset
+ self._get_next_call_count = 0
+ self._one_shot = one_shot
+ if shared_name is None:
+ shared_name = ""
+
+ if self._one_shot:
+ self._input_iterator = input_dataset.make_one_shot_iterator()
+ else:
+ self._input_iterator = iterator_ops.Iterator.from_structure(
+ self._input_dataset.output_types, self._input_dataset.output_shapes,
+ shared_name, self._input_dataset.output_classes)
+ input_iterator_handle = self._input_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, self._input_iterator.output_types,
+ self._input_iterator.output_shapes,
+ self._input_iterator.output_classes)
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ iterator_device = ged_ops.experimental_iterator_get_device(
+ self._input_iterator._iterator_resource)
+
+ with ops.device(device):
+ self._buffering_resource = function_buffering_resource(
+ f=_prefetch_fn,
+ target_device=iterator_device,
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes)))
+
+ if not self._one_shot:
+ reset_op = function_buffering_resource_reset(self._buffering_resource)
+ with ops.control_dependencies([reset_op]):
+ self._initializer = self._input_iterator.make_initializer(
+ self._input_dataset)
+
+ def get_next(self, name=None):
+ """See `tf.data.Iterator.get_next`."""
+ self._get_next_call_count += 1
+ if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
+ warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
+
+ flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
+ self._buffering_resource,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ name=name)
+
+ ret = sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self.output_types, flat_ret),
+ self.output_types, self.output_shapes, self.output_classes)
+
+ for tensor, shape in zip(
+ nest.flatten(ret), nest.flatten(self.output_shapes)):
+ if isinstance(tensor, ops.Tensor):
+ tensor.set_shape(shape)
+
+ return ret
+
+ @property
+ def initializer(self):
+ if self._one_shot:
+ raise NotImplementedError("Can't initialize a one_shot_iterator")
+ return self._initializer
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
+ """A replacement for `tf.data.Iterator` that prefetches to another device.
+
+ Args:
+ input_dataset: The input dataset
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ device: A fully specified device string where we want to prefetch to
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be
+ shared under the given name across multiple sessions that share the
+ same devices (e.g. when using a remote server).
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ device,
+ buffer_size):
+ with ops.device("/device:CPU:0"):
+ super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
+ input_iterator_handle = gen_dataset_ops.iterator_to_string_handle(
+ self._resource)
+
+ self._device = device
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, self.output_types, self.output_shapes, self.output_classes)
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ _prefetch_fn.add_to_graph(None)
+
+ with ops.device(device):
+ self._buffering_resource = function_buffering_resource(
+ f=_prefetch_fn,
+ output_types=self._flat_output_types,
+ target_device=ged_ops.experimental_iterator_get_device(
+ self._resource),
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ shared_name=iterator_ops._generate_shared_name(
+ "function_buffer_resource"))
+
+ def _next_internal(self):
+ """Returns a nested structure of `tf.Tensor`s containing the next element.
+ """
+ # This runs in sync mode as iterators use an error status to communicate
+ # that there is no more data to iterate over.
+ # TODO(b/77291417): Fix
+ with context.execution_mode(context.SYNC):
+ with ops.device(self._device):
+ ret = ged_ops.experimental_function_buffering_resource_get_next(
+ function_buffer_resource=self._buffering_resource,
+ output_types=self._flat_output_types)
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self._output_types, ret), self._output_types,
+ self._output_shapes, self._output_classes)
+# pylint: enable=protected-access
+
+
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` whose iterator prefetches elements to another device."""
+
+ def __init__(self, input_dataset, device, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._device = device
+ self._buffer_size = buffer_size if buffer_size is not None else 1
+
+ # The static analysis cannot tell that the eager iterator's superclass has
+ # a `next()` method.
+ # pylint: disable=non-iterator-returned
+ def __iter__(self):
+ """Creates an `Iterator` for enumerating the elements of this dataset.
+
+ The returned iterator implements the Python iterator protocol and therefore
+ can only be used in eager mode.
+
+ Returns:
+ An `Iterator` over the elements of this dataset.
+
+ Raises:
+ RuntimeError: If eager execution is enabled.
+ """
+ if context.executing_eagerly():
+ return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
+ self._buffer_size)
+ else:
+ raise RuntimeError("dataset.__iter__() is only supported when eager "
+ "execution is enabled.")
+ # pylint: enable=non-iterator-returned
+
+ def make_one_shot_iterator(self):
+ if context.executing_eagerly():
+ return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
+ self._buffer_size)
+ else:
+ return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True,
+ device=self._device,
+ buffer_size=self._buffer_size)
+
+ def make_initializable_iterator(self, shared_name=None):
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=False,
+ device=self._device,
+ buffer_size=self._buffer_size,
+ shared_name=shared_name)
+
+ def _as_variant_tensor(self):
+ # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
+ # transformation methods is called.
+ # TODO(mrry): Investigate support for chaining further transformations after
+ # the prefetch, including GPU support.
+ raise NotImplementedError("`prefetch_to_device()` must be the last "
+ "transformation in a dataset pipeline.")
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+@tf_export("data.experimental.prefetch_to_device")
+def prefetch_to_device(device, buffer_size=None):
+ """A transformation that prefetches dataset values to the given `device`.
+
+ NOTE: Although the transformation creates a `tf.data.Dataset`, the
+ transformation must be the final `Dataset` in the input pipeline.
+
+ Args:
+ device: A string. The name of a device to which elements will be prefetched.
+ buffer_size: (Optional.) The number of elements to buffer on `device`.
+ Defaults to an automatically chosen value.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ return _PrefetchToDeviceDataset(dataset, device, buffer_size)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.copy_to_device")
+def copy_to_device(target_device, source_device="/cpu:0"):
+ """A transformation that copies dataset elements to the given `target_device`.
+
+ Args:
+ target_device: The name of a device to which elements will be copied.
+ source_device: The original device on which `input_dataset` will be placed.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _CopyToDeviceDataset(
+ dataset, target_device=target_device, source_device=source_device)
+
+ return _apply_fn
+
+
+# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
+# all inputs to the Op are in host memory, thereby avoiding some unnecessary
+# Sends and Recvs.
+class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that copies elements to another device."""
+
+ def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
+ """Constructs a _CopyToDeviceDataset.
+
+ Args:
+ input_dataset: `Dataset` to be copied
+ target_device: The name of the device to which elements would be copied.
+ source_device: Device where input_dataset would be placed.
+ """
+ super(_CopyToDeviceDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._target_device = target_device
+ spec = framework_device.DeviceSpec().from_string(self._target_device)
+ self._is_gpu_target = (spec.device_type == "GPU")
+ self._source_device_string = source_device
+ self._source_device = ops.convert_to_tensor(source_device)
+
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._input_dataset.output_shapes,
+ self._input_dataset.output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes))
+
+ @function.Defun()
+ def _init_func():
+ """Creates an iterator for the input dataset.
+
+ Returns:
+ A `string` tensor that encapsulates the iterator created.
+ """
+ # pylint: disable=protected-access
+ ds_variant = self._input_dataset._as_variant_tensor()
+ resource = gen_dataset_ops.anonymous_iterator(
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies(
+ [gen_dataset_ops.make_iterator(ds_variant, resource)]):
+ return gen_dataset_ops.iterator_to_string_handle(resource)
+
+ @function.Defun()
+ def _remote_init_func():
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=_init_func.captured_inputs,
+ Tout=[dtypes.string],
+ f=_init_func)
+
+ self._init_func = _remote_init_func
+ self._init_captured_args = _remote_init_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _next_func(string_handle):
+ """Calls get_next for created iterator.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ The elements generated from `input_dataset`
+ """
+ with ops.device(self._source_device_string):
+ iterator = iterator_ops.Iterator.from_string_handle(
+ string_handle, self.output_types, self.output_shapes,
+ self.output_classes)
+ ret = iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ @function.Defun(dtypes.string)
+ def _remote_next_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _next_func.captured_inputs,
+ Tout=self._flat_output_types,
+ f=_next_func)
+
+ self._next_func = _remote_next_func
+ self._next_captured_args = _remote_next_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _finalize_func(string_handle):
+ """Destroys the iterator resource created.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ Tensor constant 0
+ """
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
+ string_handle,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies([
+ resource_variable_ops.destroy_resource_op(
+ iterator_resource, ignore_lookup_error=True)]):
+ return array_ops.constant(0, dtypes.int64)
+
+ @function.Defun(dtypes.string)
+ def _remote_finalize_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _finalize_func.captured_inputs,
+ Tout=[dtypes.int64],
+ f=_finalize_func)
+
+ self._finalize_func = _remote_finalize_func
+ self._finalize_captured_args = _remote_finalize_func.captured_inputs
+
+ g = ops.get_default_graph()
+ _remote_init_func.add_to_graph(g)
+ _remote_next_func.add_to_graph(g)
+ _remote_finalize_func.add_to_graph(g)
+ # pylint: enable=protected-scope
+
+ # The one_shot_iterator implementation needs a 0 arg _make_dataset function
+ # that thereby captures all the inputs required to create the dataset. Since
+ # there are strings that are inputs to the GeneratorDataset which can't be
+ # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
+ # GPU
+ def make_one_shot_iterator(self):
+ if self._is_gpu_target:
+ raise ValueError("Cannot create a one shot iterator when using "
+ "`tf.data.experimental.copy_to_device()` on GPU. Please "
+ "use `Dataset.make_initializable_iterator()` instead.")
+ else:
+ return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
+
+ def _as_variant_tensor(self):
+ with ops.device(self._target_device):
+ return gen_dataset_ops.generator_dataset(
+ self._init_captured_args,
+ self._next_captured_args,
+ self._finalize_captured_args,
+ init_func=self._init_func,
+ next_func=self._next_func,
+ finalize_func=self._finalize_func,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
diff --git a/tensorflow/python/data/experimental/ops/random_ops.py b/tensorflow/python/data/experimental/ops/random_ops.py
new file mode 100644
index 0000000000..e3a2aeab31
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/random_ops.py
@@ -0,0 +1,54 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Datasets for random number generators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import random_seed
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.RandomDataset")
+class RandomDataset(dataset_ops.DatasetSource):
+ """A `Dataset` of pseudorandom values."""
+
+ def __init__(self, seed=None):
+ """A `Dataset` of pseudorandom values."""
+ super(RandomDataset, self).__init__()
+ self._seed, self._seed2 = random_seed.get_seed(seed)
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.random_dataset(
+ seed=self._seed,
+ seed2=self._seed2,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ @property
+ def output_types(self):
+ return dtypes.int64
diff --git a/tensorflow/python/data/experimental/ops/readers.py b/tensorflow/python/data/experimental/ops/readers.py
new file mode 100644
index 0000000000..3b2d094514
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/readers.py
@@ -0,0 +1,904 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for reader Datasets."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import csv
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.experimental.ops import parsing_ops
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.data.util import convert
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.util.tf_export import tf_export
+
+_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32,
+ dtypes.int64, dtypes.string)
+
+
+def _is_valid_int32(str_val):
+ try:
+ # Checks equality to prevent int32 overflow
+ return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype(
+ str_val)
+ except (ValueError, OverflowError):
+ return False
+
+
+def _is_valid_int64(str_val):
+ try:
+ dtypes.int64.as_numpy_dtype(str_val)
+ return True
+ except (ValueError, OverflowError):
+ return False
+
+
+def _is_valid_float(str_val, float_dtype):
+ try:
+ return float_dtype.as_numpy_dtype(str_val) < np.inf
+ except ValueError:
+ return False
+
+
+def _infer_type(str_val, na_value, prev_type):
+ """Given a string, infers its tensor type.
+
+ Infers the type of a value by picking the least 'permissive' type possible,
+ while still allowing the previous type inference for this column to be valid.
+
+ Args:
+ str_val: String value to infer the type of.
+ na_value: Additional string to recognize as a NA/NaN CSV value.
+ prev_type: Type previously inferred based on values of this column that
+ we've seen up till now.
+ Returns:
+ Inferred dtype.
+ """
+ if str_val in ("", na_value):
+ # If the field is null, it gives no extra information about its type
+ return prev_type
+
+ type_list = [
+ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string
+ ] # list of types to try, ordered from least permissive to most
+
+ type_functions = [
+ _is_valid_int32,
+ _is_valid_int64,
+ lambda str_val: _is_valid_float(str_val, dtypes.float32),
+ lambda str_val: _is_valid_float(str_val, dtypes.float64),
+ lambda str_val: True,
+ ] # Corresponding list of validation functions
+
+ for i in range(len(type_list)):
+ validation_fn = type_functions[i]
+ if validation_fn(str_val) and (prev_type is None or
+ prev_type in type_list[:i + 1]):
+ return type_list[i]
+
+
+def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header):
+ """Generator that yields rows of CSV file(s) in order."""
+ for fn in filenames:
+ with file_io.FileIO(fn, "r") as f:
+ rdr = csv.reader(
+ f,
+ delimiter=field_delim,
+ quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE)
+ if header:
+ next(rdr) # Skip header lines
+
+ for csv_row in rdr:
+ if len(csv_row) != num_cols:
+ raise ValueError(
+ "Problem inferring types: CSV row has different number of fields "
+ "than expected.")
+ yield csv_row
+
+
+def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim,
+ na_value, header, num_rows_for_inference,
+ select_columns):
+ """Infers column types from the first N valid CSV records of files."""
+ if select_columns is None:
+ select_columns = range(num_cols)
+ inferred_types = [None] * len(select_columns)
+
+ for i, csv_row in enumerate(
+ _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)):
+ if num_rows_for_inference is not None and i >= num_rows_for_inference:
+ break
+
+ for j, col_index in enumerate(select_columns):
+ inferred_types[j] = _infer_type(csv_row[col_index], na_value,
+ inferred_types[j])
+
+ # Replace None's with a default type
+ inferred_types = [t or dtypes.string for t in inferred_types]
+ # Default to 0 or '' for null values
+ return [
+ constant_op.constant([0 if t is not dtypes.string else ""], dtype=t)
+ for t in inferred_types
+ ]
+
+
+def _infer_column_names(filenames, field_delim, use_quote_delim):
+ """Infers column names from first rows of files."""
+ csv_kwargs = {
+ "delimiter": field_delim,
+ "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE
+ }
+ with file_io.FileIO(filenames[0], "r") as f:
+ try:
+ column_names = next(csv.reader(f, **csv_kwargs))
+ except StopIteration:
+ raise ValueError(("Received StopIteration when reading the header line "
+ "of %s. Empty file?") % filenames[0])
+
+ for name in filenames[1:]:
+ with file_io.FileIO(name, "r") as f:
+ try:
+ if next(csv.reader(f, **csv_kwargs)) != column_names:
+ raise ValueError(
+ "Files have different column names in the header row.")
+ except StopIteration:
+ raise ValueError(("Received StopIteration when reading the header line "
+ "of %s. Empty file?") % filenames[0])
+ return column_names
+
+
+def _get_sorted_col_indices(select_columns, column_names):
+ """Transforms select_columns argument into sorted column indices."""
+ names_to_indices = {n: i for i, n in enumerate(column_names)}
+ num_cols = len(column_names)
+ for i, v in enumerate(select_columns):
+ if isinstance(v, int):
+ if v < 0 or v >= num_cols:
+ raise ValueError(
+ "Column index %d specified in select_columns out of valid range." %
+ v)
+ continue
+ if v not in names_to_indices:
+ raise ValueError(
+ "Value '%s' specified in select_columns not a valid column index or "
+ "name." % v)
+ select_columns[i] = names_to_indices[v]
+
+ # Sort and ensure there are no duplicates
+ result = sorted(set(select_columns))
+ if len(result) != len(select_columns):
+ raise ValueError("select_columns contains duplicate columns")
+ return result
+
+
+def _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed):
+ """Optionally shuffle and repeat dataset, as requested."""
+ if num_epochs != 1 and shuffle:
+ # Use shuffle_and_repeat for perf
+ return dataset.apply(
+ shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
+ shuffle_seed))
+ elif shuffle:
+ return dataset.shuffle(shuffle_buffer_size, shuffle_seed)
+ elif num_epochs != 1:
+ return dataset.repeat(num_epochs)
+ return dataset
+
+
+def make_tf_record_dataset(file_pattern,
+ batch_size,
+ parser_fn=None,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=None,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ num_parallel_reads=None,
+ num_parallel_parser_calls=None,
+ drop_final_batch=False):
+ """Reads and optionally parses TFRecord files into a dataset.
+
+ Provides common functionality such as batching, optional parsing, shuffling,
+ and performant defaults.
+
+ Args:
+ file_pattern: List of files or patterns of TFRecord file paths.
+ See `tf.gfile.Glob` for pattern rules.
+ batch_size: An int representing the number of records to combine
+ in a single batch.
+ parser_fn: (Optional.) A function accepting string input to parse
+ and process the record contents. This function must map records
+ to components of a fixed shape, so they may be batched. By
+ default, uses the record contents unmodified.
+ num_epochs: (Optional.) An int specifying the number of times this
+ dataset is repeated. If None (the default), cycles through the
+ dataset forever.
+ shuffle: (Optional.) A bool that indicates whether the input
+ should be shuffled. Defaults to `True`.
+ shuffle_buffer_size: (Optional.) Buffer size to use for
+ shuffling. A large buffer size ensures better shuffling, but
+ increases memory usage and startup time.
+ shuffle_seed: (Optional.) Randomization seed to use for shuffling.
+ prefetch_buffer_size: (Optional.) An int specifying the number of
+ feature batches to prefetch for performance improvement.
+ Defaults to auto-tune. Set to 0 to disable prefetching.
+ num_parallel_reads: (Optional.) Number of threads used to read
+ records from files. By default or if set to a value >1, the
+ results will be interleaved.
+ num_parallel_parser_calls: (Optional.) Number of parallel
+ records to parse in parallel. Defaults to an automatic selection.
+ drop_final_batch: (Optional.) Whether the last batch should be
+ dropped in case its size is smaller than `batch_size`; the
+ default behavior is not to drop the smaller batch.
+
+ Returns:
+ A dataset, where each element matches the output of `parser_fn`
+ except it will have an additional leading `batch-size` dimension,
+ or a `batch_size`-length 1-D tensor of strings if `parser_fn` is
+ unspecified.
+ """
+ files = dataset_ops.Dataset.list_files(
+ file_pattern, shuffle=shuffle, seed=shuffle_seed)
+
+ if num_parallel_reads is None:
+ # Note: We considered auto-tuning this value, but there is a concern
+ # that this affects the mixing of records from different files, which
+ # could affect training convergence/accuracy, so we are defaulting to
+ # a constant for now.
+ num_parallel_reads = 24
+ dataset = core_readers.TFRecordDataset(
+ files, num_parallel_reads=num_parallel_reads)
+
+ if shuffle_buffer_size is None:
+ # TODO(josh11b): Auto-tune this value when not specified
+ shuffle_buffer_size = 10000
+ dataset = _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+
+ # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ drop_final_batch = drop_final_batch or num_epochs is None
+
+ if parser_fn is None:
+ dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
+ else:
+ # TODO(josh11b): if num_parallel_parser_calls is None, use some function
+ # of num cores instead of map_and_batch's default behavior of one batch.
+ dataset = dataset.apply(batching.map_and_batch(
+ parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
+ drop_remainder=drop_final_batch))
+
+ if prefetch_buffer_size == 0:
+ return dataset
+ else:
+ return dataset.prefetch(buffer_size=prefetch_buffer_size)
+
+
+@tf_export("data.experimental.make_csv_dataset")
+def make_csv_dataset(
+ file_pattern,
+ batch_size,
+ column_names=None,
+ column_defaults=None,
+ label_name=None,
+ select_columns=None,
+ field_delim=",",
+ use_quote_delim=True,
+ na_value="",
+ header=True,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=10000,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ num_parallel_reads=1,
+ sloppy=False,
+ num_rows_for_inference=100,
+ compression_type=None,
+):
+ """Reads CSV files into a dataset.
+
+ Reads CSV files into a dataset, where each element is a (features, labels)
+ tuple that corresponds to a batch of CSV rows. The features dictionary
+ maps feature column names to `Tensor`s containing the corresponding
+ feature data, and labels is a `Tensor` containing the batch's label data.
+
+ Args:
+ file_pattern: List of files or patterns of file paths containing CSV
+ records. See `tf.gfile.Glob` for pattern rules.
+ batch_size: An int representing the number of records to combine
+ in a single batch.
+ column_names: An optional list of strings that corresponds to the CSV
+ columns, in order. One per column of the input record. If this is not
+ provided, infers the column names from the first row of the records.
+ These names will be the keys of the features dict of each dataset element.
+ column_defaults: A optional list of default values for the CSV fields. One
+ item per selected column of the input record. Each item in the list is
+ either a valid CSV dtype (float32, float64, int32, int64, or string), or a
+ `Tensor` with one of the aforementioned types. The tensor can either be
+ a scalar default value (if the column is optional), or an empty tensor (if
+ the column is required). If a dtype is provided instead of a tensor, the
+ column is also treated as required. If this list is not provided, tries
+ to infer types based on reading the first num_rows_for_inference rows of
+ files specified, and assumes all columns are optional, defaulting to `0`
+ for numeric values and `""` for string values. If both this and
+ `select_columns` are specified, these must have the same lengths, and
+ `column_defaults` is assumed to be sorted in order of increasing column
+ index.
+ label_name: A optional string corresponding to the label column. If
+ provided, the data for this column is returned as a separate `Tensor` from
+ the features dictionary, so that the dataset complies with the format
+ expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input
+ function.
+ select_columns: An optional list of integer indices or string column
+ names, that specifies a subset of columns of CSV data to select. If
+ column names are provided, these must correspond to names provided in
+ `column_names` or inferred from the file header lines. When this argument
+ is specified, only a subset of CSV columns will be parsed and returned,
+ corresponding to the columns specified. Using this results in faster
+ parsing and lower memory usage. If both this and `column_defaults` are
+ specified, these must have the same lengths, and `column_defaults` is
+ assumed to be sorted in order of increasing column index.
+ field_delim: An optional `string`. Defaults to `","`. Char delimiter to
+ separate fields in a record.
+ use_quote_delim: An optional bool. Defaults to `True`. If false, treats
+ double quotation marks as regular characters inside of the string fields.
+ na_value: Additional string to recognize as NA/NaN.
+ header: A bool that indicates whether the first rows of provided CSV files
+ correspond to header lines with column names, and should not be included
+ in the data.
+ num_epochs: An int specifying the number of times this dataset is repeated.
+ If None, cycles through the dataset forever.
+ shuffle: A bool that indicates whether the input should be shuffled.
+ shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
+ ensures better shuffling, but increases memory usage and startup time.
+ shuffle_seed: Randomization seed to use for shuffling.
+ prefetch_buffer_size: An int specifying the number of feature
+ batches to prefetch for performance improvement. Recommended value is the
+ number of batches consumed per training step. Defaults to auto-tune.
+
+ num_parallel_reads: Number of threads used to read CSV records from files.
+ If >1, the results will be interleaved.
+ sloppy: If `True`, reading performance will be improved at
+ the cost of non-deterministic ordering. If `False`, the order of elements
+ produced is deterministic prior to shuffling (elements are still
+ randomized if `shuffle=True`. Note that if the seed is set, then order
+ of elements after shuffling is deterministic). Defaults to `False`.
+ num_rows_for_inference: Number of rows of a file to use for type inference
+ if record_defaults is not provided. If None, reads all the rows of all
+ the files. Defaults to 100.
+ compression_type: (Optional.) A `tf.string` scalar evaluating to one of
+ `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression.
+
+ Returns:
+ A dataset, where each element is a (features, labels) tuple that corresponds
+ to a batch of `batch_size` CSV rows. The features dictionary maps feature
+ column names to `Tensor`s containing the corresponding column data, and
+ labels is a `Tensor` containing the column data for the label column
+ specified by `label_name`.
+
+ Raises:
+ ValueError: If any of the arguments is malformed.
+ """
+ # Create dataset of all matching filenames
+ filenames = _get_file_names(file_pattern, False)
+ dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
+ if shuffle:
+ dataset = dataset.shuffle(len(filenames), shuffle_seed)
+
+ # Clean arguments; figure out column names and defaults
+
+ if column_names is None:
+ if not header:
+ raise ValueError("Cannot infer column names without a header line.")
+ # If column names are not provided, infer from the header lines
+ column_names = _infer_column_names(filenames, field_delim, use_quote_delim)
+ if len(column_names) != len(set(column_names)):
+ raise ValueError("Cannot have duplicate column names.")
+
+ if select_columns is not None:
+ select_columns = _get_sorted_col_indices(select_columns, column_names)
+
+ if column_defaults is not None:
+ column_defaults = [
+ constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
+ for x in column_defaults
+ ]
+ else:
+ # If column defaults are not provided, infer from records at graph
+ # construction time
+ column_defaults = _infer_column_defaults(
+ filenames, len(column_names), field_delim, use_quote_delim, na_value,
+ header, num_rows_for_inference, select_columns)
+
+ if select_columns is not None and len(column_defaults) != len(select_columns):
+ raise ValueError(
+ "If specified, column_defaults and select_columns must have same "
+ "length."
+ )
+ if select_columns is not None and len(column_names) > len(select_columns):
+ # Pick the relevant subset of column names
+ column_names = [column_names[i] for i in select_columns]
+
+ if label_name is not None and label_name not in column_names:
+ raise ValueError("`label_name` provided must be one of the columns.")
+
+ def filename_to_dataset(filename):
+ return CsvDataset(
+ filename,
+ record_defaults=column_defaults,
+ field_delim=field_delim,
+ use_quote_delim=use_quote_delim,
+ na_value=na_value,
+ select_cols=select_columns,
+ header=header,
+ compression_type=compression_type,
+ )
+
+ def map_fn(*columns):
+ """Organizes columns into a features dictionary.
+
+ Args:
+ *columns: list of `Tensor`s corresponding to one csv record.
+ Returns:
+ An OrderedDict of feature names to values for that particular record. If
+ label_name is provided, extracts the label feature to be returned as the
+ second element of the tuple.
+ """
+ features = collections.OrderedDict(zip(column_names, columns))
+ if label_name is not None:
+ label = features.pop(label_name)
+ return features, label
+ return features
+
+ # Read files sequentially (if num_parallel_reads=1) or in parallel
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
+
+ dataset = _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+
+ # Apply batch before map for perf, because map has high overhead relative
+ # to the size of the computation in each map.
+ # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ dataset = dataset.batch(batch_size=batch_size,
+ drop_remainder=num_epochs is None)
+ dataset = dataset_ops.MapDataset(
+ dataset, map_fn, use_inter_op_parallelism=False)
+ dataset = dataset.prefetch(prefetch_buffer_size)
+
+ return dataset
+
+
+_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB
+
+
+@tf_export("data.experimental.CsvDataset")
+class CsvDataset(dataset_ops.DatasetSource):
+ """A Dataset comprising lines from one or more CSV files."""
+
+ def __init__(self,
+ filenames,
+ record_defaults,
+ compression_type=None,
+ buffer_size=None,
+ header=False,
+ field_delim=",",
+ use_quote_delim=True,
+ na_value="",
+ select_cols=None):
+ """Creates a `CsvDataset` by reading and decoding CSV files.
+
+ The elements of this dataset correspond to records from the file(s).
+ RFC 4180 format is expected for CSV files
+ (https://tools.ietf.org/html/rfc4180)
+ Note that we allow leading and trailing spaces with int or float field.
+
+
+ For example, suppose we have a file 'my_file0.csv' with four CSV columns of
+ different data types:
+ ```
+ abcdefg,4.28E10,5.55E6,12
+ hijklmn,-5.3E14,,2
+ ```
+
+ We can construct a CsvDataset from it as follows:
+ ```python
+ dataset = tf.data.experimental.CsvDataset(
+ "my_file*.csv",
+ [tf.float32, # Required field, use dtype or empty tensor
+ tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0
+ tf.int32, # Required field, use dtype or empty tensor
+ ],
+ select_cols=[1,2,3] # Only parse last three columns
+ )
+ ```
+
+ The expected output of its iterations is:
+ ```python
+ next_element = dataset.make_one_shot_iterator().get_next()
+ with tf.Session() as sess:
+ while True:
+ try:
+ print(sess.run(next_element))
+ except tf.errors.OutOfRangeError:
+ break
+
+ >> (4.28e10, 5.55e6, 12)
+ >> (-5.3e14, 0.0, 2)
+ ```
+
+ Args:
+ filenames: A `tf.string` tensor containing one or more filenames.
+ record_defaults: A list of default values for the CSV fields. Each item in
+ the list is either a valid CSV `DType` (float32, float64, int32, int64,
+ string), or a `Tensor` object with one of the above types. One per
+ column of CSV data, with either a scalar `Tensor` default value for the
+ column if it is optional, or `DType` or empty `Tensor` if required. If
+ both this and `select_columns` are specified, these must have the same
+ lengths, and `column_defaults` is assumed to be sorted in order of
+ increasing column index.
+ compression_type: (Optional.) A `tf.string` scalar evaluating to one of
+ `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
+ compression.
+ buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
+ to buffer while reading files. Defaults to 4MB.
+ header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
+ have header line(s) that should be skipped when parsing. Defaults to
+ `False`.
+ field_delim: (Optional.) A `tf.string` scalar containing the delimiter
+ character that separates fields in a record. Defaults to `","`.
+ use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
+ double quotation marks as regular characters inside of string fields
+ (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
+ na_value: (Optional.) A `tf.string` scalar indicating a value that will
+ be treated as NA/NaN.
+ select_cols: (Optional.) A sorted list of column indices to select from
+ the input data. If specified, only this subset of columns will be
+ parsed. Defaults to parsing all columns.
+ """
+ super(CsvDataset, self).__init__()
+ self._filenames = ops.convert_to_tensor(
+ filenames, dtype=dtypes.string, name="filenames")
+ self._compression_type = convert.optional_param_to_tensor(
+ "compression_type",
+ compression_type,
+ argument_default="",
+ argument_dtype=dtypes.string)
+ record_defaults = [
+ constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
+ for x in record_defaults
+ ]
+ self._record_defaults = ops.convert_n_to_tensor(
+ record_defaults, name="record_defaults")
+ self._buffer_size = convert.optional_param_to_tensor(
+ "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
+ self._header = ops.convert_to_tensor(
+ header, dtype=dtypes.bool, name="header")
+ self._field_delim = ops.convert_to_tensor(
+ field_delim, dtype=dtypes.string, name="field_delim")
+ self._use_quote_delim = ops.convert_to_tensor(
+ use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
+ self._na_value = ops.convert_to_tensor(
+ na_value, dtype=dtypes.string, name="na_value")
+ self._select_cols = convert.optional_param_to_tensor(
+ "select_cols",
+ select_cols,
+ argument_default=[],
+ argument_dtype=dtypes.int64,
+ )
+ self._output_shapes = tuple(
+ tensor_shape.scalar() for _ in range(len(record_defaults)))
+ self._output_types = tuple(d.dtype for d in self._record_defaults)
+ self._output_classes = tuple(
+ ops.Tensor for _ in range(len(record_defaults)))
+
+ def _as_variant_tensor(self):
+ # Constructs graph node for the dataset op.
+ return gen_experimental_dataset_ops.experimental_csv_dataset(
+ filenames=self._filenames,
+ record_defaults=self._record_defaults,
+ buffer_size=self._buffer_size,
+ header=self._header,
+ output_shapes=self._output_shapes,
+ field_delim=self._field_delim,
+ use_quote_delim=self._use_quote_delim,
+ na_value=self._na_value,
+ select_cols=self._select_cols,
+ compression_type=self._compression_type,
+ )
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+@tf_export("data.experimental.make_batched_features_dataset")
+def make_batched_features_dataset(file_pattern,
+ batch_size,
+ features,
+ reader=core_readers.TFRecordDataset,
+ label_key=None,
+ reader_args=None,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=10000,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ reader_num_threads=1,
+ parser_num_threads=2,
+ sloppy_ordering=False,
+ drop_final_batch=False):
+ """Returns a `Dataset` of feature dictionaries from `Example` protos.
+
+ If label_key argument is provided, returns a `Dataset` of tuple
+ comprising of feature dictionaries and label.
+
+ Example:
+
+ ```
+ serialized_examples = [
+ features {
+ feature { key: "age" value { int64_list { value: [ 0 ] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
+ feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } }
+ },
+ features {
+ feature { key: "age" value { int64_list { value: [] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
+ feature { key: "kws" value { bytes_list { value: [ "sports" ] } } }
+ }
+ ]
+ ```
+
+ We can use arguments:
+
+ ```
+ features: {
+ "age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
+ "gender": FixedLenFeature([], dtype=tf.string),
+ "kws": VarLenFeature(dtype=tf.string),
+ }
+ ```
+
+ And the expected output is:
+
+ ```python
+ {
+ "age": [[0], [-1]],
+ "gender": [["f"], ["f"]],
+ "kws": SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0]],
+ values=["code", "art", "sports"]
+ dense_shape=[2, 2]),
+ }
+ ```
+
+ Args:
+ file_pattern: List of files or patterns of file paths containing
+ `Example` records. See `tf.gfile.Glob` for pattern rules.
+ batch_size: An int representing the number of records to combine
+ in a single batch.
+ features: A `dict` mapping feature keys to `FixedLenFeature` or
+ `VarLenFeature` values. See `tf.parse_example`.
+ reader: A function or class that can be
+ called with a `filenames` tensor and (optional) `reader_args` and returns
+ a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
+ label_key: (Optional) A string corresponding to the key labels are stored in
+ `tf.Examples`. If provided, it must be one of the `features` key,
+ otherwise results in `ValueError`.
+ reader_args: Additional arguments to pass to the reader class.
+ num_epochs: Integer specifying the number of times to read through the
+ dataset. If None, cycles through the dataset forever. Defaults to `None`.
+ shuffle: A boolean, indicates whether the input should be shuffled. Defaults
+ to `True`.
+ shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity
+ ensures better shuffling but would increase memory usage and startup time.
+ shuffle_seed: Randomization seed to use for shuffling.
+ prefetch_buffer_size: Number of feature batches to prefetch in order to
+ improve performance. Recommended value is the number of batches consumed
+ per training step. Defaults to auto-tune.
+ reader_num_threads: Number of threads used to read `Example` records. If >1,
+ the results will be interleaved.
+ parser_num_threads: Number of threads to use for parsing `Example` tensors
+ into a dictionary of `Feature` tensors.
+ sloppy_ordering: If `True`, reading performance will be improved at
+ the cost of non-deterministic ordering. If `False`, the order of elements
+ produced is deterministic prior to shuffling (elements are still
+ randomized if `shuffle=True`. Note that if the seed is set, then order
+ of elements after shuffling is deterministic). Defaults to `False`.
+ drop_final_batch: If `True`, and the batch size does not evenly divide the
+ input dataset size, the final smaller batch will be dropped. Defaults to
+ `False`.
+
+ Returns:
+ A dataset of `dict` elements, (or a tuple of `dict` elements and label).
+ Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
+
+ Raises:
+ ValueError: If `label_key` is not one of the `features` keys.
+ """
+ # Create dataset of all matching filenames
+ filenames = _get_file_names(file_pattern, False)
+ dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
+ if shuffle:
+ dataset = dataset.shuffle(len(filenames), shuffle_seed)
+
+ # Read `Example` records from files as tensor objects.
+ if reader_args is None:
+ reader_args = []
+
+ # Read files sequentially (if reader_num_threads=1) or in parallel
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ lambda filename: reader(filename, *reader_args),
+ cycle_length=reader_num_threads,
+ sloppy=sloppy_ordering))
+
+ # Extract values if the `Example` tensors are stored as key-value tuples.
+ if dataset.output_types == (dtypes.string, dtypes.string):
+ dataset = dataset_ops.MapDataset(
+ dataset, lambda _, v: v, use_inter_op_parallelism=False)
+
+ # Apply dataset repeat and shuffle transformations.
+ dataset = _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+
+ # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ dataset = dataset.batch(
+ batch_size, drop_remainder=drop_final_batch or num_epochs is None)
+
+ # Parse `Example` tensors to a dictionary of `Feature` tensors.
+ dataset = dataset.apply(
+ parsing_ops.parse_example_dataset(
+ features, num_parallel_calls=parser_num_threads))
+
+ if label_key:
+ if label_key not in features:
+ raise ValueError(
+ "The `label_key` provided (%r) must be one of the `features` keys." %
+ label_key)
+ dataset = dataset.map(lambda x: (x, x.pop(label_key)))
+
+ dataset = dataset.prefetch(prefetch_buffer_size)
+ return dataset
+
+
+def _get_file_names(file_pattern, shuffle):
+ """Parse list of file names from pattern, optionally shuffled.
+
+ Args:
+ file_pattern: File glob pattern, or list of glob patterns.
+ shuffle: Whether to shuffle the order of file names.
+
+ Returns:
+ List of file names matching `file_pattern`.
+
+ Raises:
+ ValueError: If `file_pattern` is empty, or pattern matches no files.
+ """
+ if isinstance(file_pattern, list):
+ if not file_pattern:
+ raise ValueError("File pattern is empty.")
+ file_names = []
+ for entry in file_pattern:
+ file_names.extend(gfile.Glob(entry))
+ else:
+ file_names = list(gfile.Glob(file_pattern))
+
+ if not file_names:
+ raise ValueError("No files match %s." % file_pattern)
+
+ # Sort files so it will be deterministic for unit tests.
+ if not shuffle:
+ file_names = sorted(file_names)
+ return file_names
+
+
+@tf_export("data.experimental.SqlDataset")
+class SqlDataset(dataset_ops.DatasetSource):
+ """A `Dataset` consisting of the results from a SQL query."""
+
+ def __init__(self, driver_name, data_source_name, query, output_types):
+ """Creates a `SqlDataset`.
+
+ `SqlDataset` allows a user to read data from the result set of a SQL query.
+ For example:
+
+ ```python
+ dataset = tf.data.experimental.SqlDataset("sqlite", "/foo/bar.sqlite3",
+ "SELECT name, age FROM people",
+ (tf.string, tf.int32))
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+ # Prints the rows of the result set of the above query.
+ while True:
+ try:
+ print(sess.run(next_element))
+ except tf.errors.OutOfRangeError:
+ break
+ ```
+
+ Args:
+ driver_name: A 0-D `tf.string` tensor containing the database type.
+ Currently, the only supported value is 'sqlite'.
+ data_source_name: A 0-D `tf.string` tensor containing a connection string
+ to connect to the database.
+ query: A 0-D `tf.string` tensor containing the SQL query to execute.
+ output_types: A tuple of `tf.DType` objects representing the types of the
+ columns returned by `query`.
+ """
+ super(SqlDataset, self).__init__()
+ self._driver_name = ops.convert_to_tensor(
+ driver_name, dtype=dtypes.string, name="driver_name")
+ self._data_source_name = ops.convert_to_tensor(
+ data_source_name, dtype=dtypes.string, name="data_source_name")
+ self._query = ops.convert_to_tensor(
+ query, dtype=dtypes.string, name="query")
+ self._output_types = output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.sql_dataset(self._driver_name,
+ self._data_source_name, self._query,
+ nest.flatten(self.output_types),
+ nest.flatten(self.output_shapes))
+
+ @property
+ def output_classes(self):
+ return nest.map_structure(lambda _: ops.Tensor, self._output_types)
+
+ @property
+ def output_shapes(self):
+ return nest.map_structure(lambda _: tensor_shape.TensorShape([]),
+ self._output_types)
+
+ @property
+ def output_types(self):
+ return self._output_types
diff --git a/tensorflow/python/data/experimental/ops/resampling.py b/tensorflow/python/data/experimental/ops/resampling.py
new file mode 100644
index 0000000000..3a3040ae9a
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/resampling.py
@@ -0,0 +1,296 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Resampling dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import scan_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.rejection_resample")
+def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
+ """A transformation that resamples a dataset to achieve a target distribution.
+
+ **NOTE** Resampling is performed via rejection sampling; some fraction
+ of the input values will be dropped.
+
+ Args:
+ class_func: A function mapping an element of the input dataset to a scalar
+ `tf.int32` tensor. Values should be in `[0, num_classes)`.
+ target_dist: A floating point type tensor, shaped `[num_classes]`.
+ initial_dist: (Optional.) A floating point type tensor, shaped
+ `[num_classes]`. If not provided, the true class distribution is
+ estimated live in a streaming fashion.
+ seed: (Optional.) Python integer seed for the resampler.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
+ class_values_ds = dataset.map(class_func)
+
+ # Get initial distribution.
+ if initial_dist is not None:
+ initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
+ acceptance_dist, prob_of_original = (
+ _calculate_acceptance_probs_with_mixing(initial_dist_t,
+ target_dist_t))
+ initial_dist_ds = dataset_ops.Dataset.from_tensors(
+ initial_dist_t).repeat()
+ acceptance_dist_ds = dataset_ops.Dataset.from_tensors(
+ acceptance_dist).repeat()
+ prob_of_original_ds = dataset_ops.Dataset.from_tensors(
+ prob_of_original).repeat()
+ else:
+ initial_dist_ds = _estimate_initial_dist_ds(
+ target_dist_t, class_values_ds)
+ acceptance_and_original_prob_ds = initial_dist_ds.map(
+ lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda
+ initial, target_dist_t))
+ acceptance_dist_ds = acceptance_and_original_prob_ds.map(
+ lambda accept_prob, _: accept_prob)
+ prob_of_original_ds = acceptance_and_original_prob_ds.map(
+ lambda _, prob_original: prob_original)
+ filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
+ class_values_ds, seed)
+ # Prefetch filtered dataset for speed.
+ filtered_ds = filtered_ds.prefetch(3)
+
+ prob_original_static = _get_prob_original_static(
+ initial_dist_t, target_dist_t) if initial_dist is not None else None
+ if prob_original_static == 1:
+ return dataset_ops.Dataset.zip((class_values_ds, dataset))
+ elif prob_original_static == 0:
+ return filtered_ds
+ else:
+ return interleave_ops.sample_from_datasets(
+ [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds],
+ weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
+ seed=seed)
+
+ return _apply_fn
+
+
+def _get_prob_original_static(initial_dist_t, target_dist_t):
+ """Returns the static probability of sampling from the original.
+
+ `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters
+ an Op that it isn't defined for. We have some custom logic to avoid this.
+
+ Args:
+ initial_dist_t: A tensor of the initial distribution.
+ target_dist_t: A tensor of the target distribution.
+
+ Returns:
+ The probability of sampling from the original distribution as a constant,
+ if it is a constant, or `None`.
+ """
+ init_static = tensor_util.constant_value(initial_dist_t)
+ target_static = tensor_util.constant_value(target_dist_t)
+
+ if init_static is None or target_static is None:
+ return None
+ else:
+ return np.min(target_static / init_static)
+
+
+def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
+ seed):
+ """Filters a dataset based on per-class acceptance probabilities.
+
+ Args:
+ dataset: The dataset to be filtered.
+ acceptance_dist_ds: A dataset of acceptance probabilities.
+ initial_dist_ds: A dataset of the initial probability distribution, given or
+ estimated.
+ class_values_ds: A dataset of the corresponding classes.
+ seed: (Optional.) Python integer seed for the resampler.
+
+ Returns:
+ A dataset of (class value, data) after filtering.
+ """
+ def maybe_warn_on_large_rejection(accept_dist, initial_dist):
+ proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
+ return control_flow_ops.cond(
+ math_ops.less(proportion_rejected, .5),
+ lambda: accept_dist,
+ lambda: logging_ops.Print( # pylint: disable=g-long-lambda
+ accept_dist, [proportion_rejected, initial_dist, accept_dist],
+ message="Proportion of examples rejected by sampler is high: ",
+ summarize=100,
+ first_n=10))
+
+ acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds,
+ initial_dist_ds))
+ .map(maybe_warn_on_large_rejection))
+
+ def _gather_and_copy(class_val, acceptance_prob, data):
+ return class_val, array_ops.gather(acceptance_prob, class_val), data
+
+ current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
+ (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy)
+ filtered_ds = (
+ current_probabilities_and_class_and_data_ds
+ .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
+ return filtered_ds.map(lambda class_value, _, data: (class_value, data))
+
+
+def _estimate_initial_dist_ds(
+ target_dist_t, class_values_ds, dist_estimation_batch_size=32,
+ smoothing_constant=10):
+ num_classes = (target_dist_t.shape[0].value or
+ array_ops.shape(target_dist_t)[0])
+ initial_examples_per_class_seen = array_ops.fill(
+ [num_classes], np.int64(smoothing_constant))
+
+ def update_estimate_and_tile(num_examples_per_class_seen, c):
+ updated_examples_per_class_seen, dist = _estimate_data_distribution(
+ c, num_examples_per_class_seen)
+ tiled_dist = array_ops.tile(
+ array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
+ return updated_examples_per_class_seen, tiled_dist
+
+ initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
+ .apply(scan_ops.scan(initial_examples_per_class_seen,
+ update_estimate_and_tile))
+ .apply(batching.unbatch()))
+
+ return initial_dist_ds
+
+
+def _get_target_to_initial_ratio(initial_probs, target_probs):
+ # Add tiny to initial_probs to avoid divide by zero.
+ denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
+ return target_probs / denom
+
+
+def _estimate_data_distribution(c, num_examples_per_class_seen):
+ """Estimate data distribution as labels are seen.
+
+ Args:
+ c: The class labels. Type `int32`, shape `[batch_size]`.
+ num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
+ containing counts.
+
+ Returns:
+ num_examples_per_lass_seen: Updated counts. Type `int64`, shape
+ `[num_classes]`.
+ dist: The updated distribution. Type `float32`, shape `[num_classes]`.
+ """
+ num_classes = num_examples_per_class_seen.get_shape()[0].value
+ # Update the class-count based on what labels are seen in batch.
+ num_examples_per_class_seen = math_ops.add(
+ num_examples_per_class_seen, math_ops.reduce_sum(
+ array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
+ init_prob_estimate = math_ops.truediv(
+ num_examples_per_class_seen,
+ math_ops.reduce_sum(num_examples_per_class_seen))
+ dist = math_ops.cast(init_prob_estimate, dtypes.float32)
+ return num_examples_per_class_seen, dist
+
+
+def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
+ """Calculates the acceptance probabilities and mixing ratio.
+
+ In this case, we assume that we can *either* sample from the original data
+ distribution with probability `m`, or sample from a reshaped distribution
+ that comes from rejection sampling on the original distribution. This
+ rejection sampling is done on a per-class basis, with `a_i` representing the
+ probability of accepting data from class `i`.
+
+ This method is based on solving the following analysis for the reshaped
+ distribution:
+
+ Let F be the probability of a rejection (on any example).
+ Let p_i be the proportion of examples in the data in class i (init_probs)
+ Let a_i is the rate the rejection sampler should *accept* class i
+ Let t_i is the target proportion in the minibatches for class i (target_probs)
+
+ ```
+ F = sum_i(p_i * (1-a_i))
+ = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1
+ ```
+
+ An example with class `i` will be accepted if `k` rejections occur, then an
+ example with class `i` is seen by the rejector, and it is accepted. This can
+ be written as follows:
+
+ ```
+ t_i = sum_k=0^inf(F^k * p_i * a_i)
+ = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1
+ = p_i * a_i / sum_j(p_j * a_j) using F from above
+ ```
+
+ Note that the following constraints hold:
+ ```
+ 0 <= p_i <= 1, sum_i(p_i) = 1
+ 0 <= a_i <= 1
+ 0 <= t_i <= 1, sum_i(t_i) = 1
+ ```
+
+ A solution for a_i in terms of the other variables is the following:
+ ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
+
+ If we try to minimize the amount of data rejected, we get the following:
+
+ M_max = max_i [ t_i / p_i ]
+ M_min = min_i [ t_i / p_i ]
+
+ The desired probability of accepting data if it comes from class `i`:
+
+ a_i = (t_i/p_i - m) / (M_max - m)
+
+ The desired probability of pulling a data element from the original dataset,
+ rather than the filtered one:
+
+ m = M_min
+
+ Args:
+ initial_probs: A Tensor of the initial probability distribution, given or
+ estimated.
+ target_probs: A Tensor of the corresponding classes.
+
+ Returns:
+ (A 1D Tensor with the per-class acceptance probabilities, the desired
+ probability of pull from the original distribution.)
+ """
+ ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
+ max_ratio = math_ops.reduce_max(ratio_l)
+ min_ratio = math_ops.reduce_min(ratio_l)
+
+ # Target prob to sample from original distribution.
+ m = min_ratio
+
+ # TODO(joelshor): Simplify fraction, if possible.
+ a_i = (ratio_l - m) / (max_ratio - m)
+ return a_i, m
diff --git a/tensorflow/python/data/experimental/ops/scan_ops.py b/tensorflow/python/data/experimental/ops/scan_ops.py
new file mode 100644
index 0000000000..e05e7c5a18
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/scan_ops.py
@@ -0,0 +1,177 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Scan dataset transformation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+class _ScanDataset(dataset_ops.UnaryDataset):
+ """A dataset that scans a function across its input."""
+
+ def __init__(self, input_dataset, initial_state, scan_func):
+ """See `scan()` for details."""
+ super(_ScanDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ with ops.name_scope("initial_state"):
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ self._initial_state = nest.pack_sequence_as(initial_state, [
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
+ t, name="component_%d" % i)
+ for i, t in enumerate(nest.flatten(initial_state))
+ ])
+
+ # Compute initial values for the state classes, shapes and types based on
+ # the initial state. The shapes may be refined by running `tf_scan_func` one
+ # or more times below.
+ self._state_classes = sparse.get_classes(self._initial_state)
+ self._state_shapes = nest.pack_sequence_as(
+ self._initial_state,
+ [t.get_shape() for t in nest.flatten(self._initial_state)])
+ self._state_types = nest.pack_sequence_as(
+ self._initial_state,
+ [t.dtype for t in nest.flatten(self._initial_state)])
+
+ # Will be populated by calling `tf_scan_func`.
+ self._output_classes = None
+ self._output_shapes = None
+ self._output_types = None
+
+ # Iteratively rerun the scan function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ scan_func,
+ "tf.data.experimental.scan()",
+ input_classes=(self._state_classes, input_dataset.output_classes),
+ input_shapes=(self._state_shapes, input_dataset.output_shapes),
+ input_types=(self._state_types, input_dataset.output_types),
+ add_to_graph=False)
+ if not (
+ isinstance(wrapped_func.output_types, collections.Sequence) and
+ len(wrapped_func.output_types) == 2):
+ raise TypeError("The scan function must return a pair comprising the "
+ "new state and the output value.")
+
+ new_state_classes, self._output_classes = wrapped_func.output_classes
+
+ # Extract and validate class information from the returned values.
+ for new_state_class, state_class in zip(
+ nest.flatten(new_state_classes),
+ nest.flatten(self._state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes, new_state_classes))
+
+ # Extract and validate type information from the returned values.
+ new_state_types, self._output_types = wrapped_func.output_types
+ for new_state_type, state_type in zip(
+ nest.flatten(new_state_types), nest.flatten(self._state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types, new_state_types))
+
+ # Extract shape information from the returned values.
+ new_state_shapes, self._output_shapes = wrapped_func.output_shapes
+
+ flat_state_shapes = nest.flatten(self._state_shapes)
+ flat_new_state_shapes = nest.flatten(new_state_shapes)
+ weakened_state_shapes = [
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ original_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ self._state_shapes = nest.pack_sequence_as(self._state_shapes,
+ weakened_state_shapes)
+
+ self._scan_func = wrapped_func.function
+ self._scan_func.add_to_graph(ops.get_default_graph())
+
+ def _as_variant_tensor(self):
+ input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+ return gen_dataset_ops.scan_dataset(
+ input_t,
+ nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
+ self._scan_func.captured_inputs,
+ f=self._scan_func,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+
+@tf_export("data.experimental.scan")
+def scan(initial_state, scan_func):
+ """A transformation that scans a function across an input dataset.
+
+ This transformation is a stateful relative of `tf.data.Dataset.map`.
+ In addition to mapping `scan_func` across the elements of the input dataset,
+ `scan()` accumulates one or more state tensors, whose initial values are
+ `initial_state`.
+
+ Args:
+ initial_state: A nested structure of tensors, representing the initial state
+ of the accumulator.
+ scan_func: A function that maps `(old_state, input_element)` to
+ `(new_state, output_element). It must take two arguments and return a
+ pair of nested structures of tensors. The `new_state` must match the
+ structure of `initial_state`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ return _ScanDataset(dataset, initial_state, scan_func)
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/shuffle_ops.py b/tensorflow/python/data/experimental/ops/shuffle_ops.py
new file mode 100644
index 0000000000..a4307212da
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/shuffle_ops.py
@@ -0,0 +1,102 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental shuffle ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import random_seed
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that fuses `shuffle` and `repeat`."""
+
+ def __init__(self, input_dataset, buffer_size, count=None, seed=None):
+ super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._buffer_size = ops.convert_to_tensor(
+ buffer_size, dtype=dtypes.int64, name="buffer_size")
+ if count is None:
+ self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
+ else:
+ self._count = ops.convert_to_tensor(
+ count, dtype=dtypes.int64, name="count")
+ self._seed, self._seed2 = random_seed.get_seed(seed)
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ input_resource = self._input_dataset._as_variant_tensor()
+ return gen_dataset_ops.shuffle_and_repeat_dataset(
+ input_resource,
+ buffer_size=self._buffer_size,
+ count=self._count,
+ seed=self._seed,
+ seed2=self._seed2,
+ **dataset_ops.flat_structure(self))
+ # pylint: enable=protected-access
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+@tf_export("data.experimental.shuffle_and_repeat")
+def shuffle_and_repeat(buffer_size, count=None, seed=None):
+ """Shuffles and repeats a Dataset returning a new permutation for each epoch.
+
+ `dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size, count))`
+
+ is equivalent to
+
+ `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)`
+
+ The difference is that the latter dataset is not serializable. So,
+ if you need to checkpoint an input pipeline with reshuffling you must use
+ this implementation.
+
+ Args:
+ buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
+ maximum number elements that will be buffered when prefetching.
+ count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ number of times the dataset should be repeated. The default behavior
+ (if `count` is `None` or `-1`) is for the dataset be repeated
+ indefinitely.
+ seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ random seed that will be used to create the distribution. See
+ `tf.set_random_seed` for behavior.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset): # pylint: disable=missing-docstring
+ return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed)
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/stats_ops.py b/tensorflow/python/data/experimental/ops/stats_ops.py
new file mode 100644
index 0000000000..54ef6fc3e8
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/stats_ops.py
@@ -0,0 +1,214 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for gathering statistics from `tf.data` pipelines."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.StatsAggregator")
+class StatsAggregator(object):
+ """A stateful resource that aggregates statistics from one or more iterators.
+
+ To record statistics, use one of the custom transformation functions defined
+ in this module when defining your `tf.data.Dataset`. All statistics will be
+ aggregated by the `StatsAggregator` that is associated with a particular
+ iterator (see below). For example, to record the latency of producing each
+ element by iterating over a dataset:
+
+ ```python
+ dataset = ...
+ dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes"))
+ ```
+
+ To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
+ the following pattern:
+
+ ```python
+ stats_aggregator = stats_ops.StatsAggregator()
+ dataset = ...
+
+ # Apply `set_stats_aggregator` to associate `dataset` with `stats_aggregator`.
+ dataset = dataset.apply(
+ tf.data.experimental.set_stats_aggregator(stats_aggregator))
+ iterator = dataset.make_one_shot_iterator()
+ ```
+
+ To get a protocol buffer summary of the currently aggregated statistics,
+ use the `StatsAggregator.get_summary()` tensor. The easiest way to do this
+ is to add the returned tensor to the `tf.GraphKeys.SUMMARIES` collection,
+ so that the summaries will be included with any existing summaries.
+
+ ```python
+ stats_aggregator = stats_ops.StatsAggregator()
+ # ...
+ stats_summary = stats_aggregator.get_summary()
+ tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary)
+ ```
+
+ Note: This interface is experimental and expected to change. In particular,
+ we expect to add other implementations of `StatsAggregator` that provide
+ different ways of exporting statistics, and add more types of statistics.
+ """
+
+ def __init__(self):
+ """Creates a `StatsAggregator`."""
+ self._resource = gen_dataset_ops.stats_aggregator_handle()
+
+ # TODO(b/116314787): Update this/add support for V2 summary API.
+ def get_summary(self):
+ """Returns a string `tf.Tensor` that summarizes the aggregated statistics.
+
+ The returned tensor will contain a serialized `tf.summary.Summary` protocol
+ buffer, which can be used with the standard TensorBoard logging facilities.
+
+ Returns:
+ A scalar string `tf.Tensor` that summarizes the aggregated statistics.
+ """
+ return gen_dataset_ops.stats_aggregator_summary(self._resource)
+
+
+class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that acts as an identity, and sets given stats_aggregator."""
+
+ def __init__(self, input_dataset, stats_aggregator, tag, prefix):
+ super(_SetStatsAggregatorDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._stats_aggregator = stats_aggregator
+ self._tag = tag
+ self._prefix = prefix
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.set_stats_aggregator_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._stats_aggregator._resource, # pylint: disable=protected-access
+ self._tag,
+ self._prefix,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+@tf_export("data.experimental.set_stats_aggregator")
+def set_stats_aggregator(stats_aggregator, tag="", counter_prefix=""):
+ """Set the given `stats_aggregator` for aggregating the input dataset stats.
+
+ Args:
+ stats_aggregator: A `tf.contrib.data.StatsAggregator` object.
+ tag: (Optional) String, all statistics recorded for the input `dataset`
+ will have given `tag` prepend with the name.
+ counter_prefix: (Optional) String, all statistics recorded as `counters`
+ will have the given `prefix` for the counter. Defaults to "/tesorflow".
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _SetStatsAggregatorDataset(dataset, stats_aggregator, tag,
+ counter_prefix)
+
+ return _apply_fn
+
+
+# TODO(b/38416882): Properly export in the `tf.data.experimental` API when
+# stable or make private / remove.
+def bytes_produced_stats(tag):
+ """Records the number of bytes produced by each element of the input dataset.
+
+ To consume the statistics, associate a `StatsAggregator` with the output
+ dataset.
+
+ Args:
+ tag: String. All statistics recorded by the returned transformation will
+ be associated with the given `tag`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _StatsDataset(dataset, gen_dataset_ops.bytes_produced_stats_dataset,
+ tag)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.latency_stats")
+def latency_stats(tag):
+ """Records the latency of producing each element of the input dataset.
+
+ To consume the statistics, associate a `StatsAggregator` with the output
+ dataset.
+
+ Args:
+ tag: String. All statistics recorded by the returned transformation will
+ be associated with the given `tag`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _StatsDataset(dataset, gen_dataset_ops.latency_stats_dataset, tag)
+
+ return _apply_fn
+
+
+class _StatsDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that acts as an identity, and also records statistics."""
+
+ def __init__(self, input_dataset, op_function, tag):
+ super(_StatsDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._op_function = op_function
+ self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string)
+
+ def _as_variant_tensor(self):
+ return self._op_function(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._tag,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
diff --git a/tensorflow/python/data/experimental/ops/threadpool.py b/tensorflow/python/data/experimental/ops/threadpool.py
new file mode 100644
index 0000000000..3ea017c6e8
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/threadpool.py
@@ -0,0 +1,104 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for controlling threading in `tf.data` pipelines."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.ops import resource_variable_ops
+
+_uid_counter = 0
+_uid_lock = threading.Lock()
+
+
+def _generate_shared_name(prefix):
+ with _uid_lock:
+ global _uid_counter
+ uid = _uid_counter
+ _uid_counter += 1
+ return "{}{}".format(prefix, uid)
+
+
+# TODO(b/73383364): Properly export in the `tf.data.experimental` API when
+# stable or make private / remove.
+class PrivateThreadPool(object):
+ """A stateful resource that represents a private thread pool."""
+
+ def __init__(self, num_threads, display_name=None,
+ max_intra_op_parallelism=1):
+ """Creates a `PrivateThreadPool` with the given number of threads."""
+ if context.executing_eagerly():
+ shared_name = _generate_shared_name("privatethreadpool")
+ self._resource = ged_ops.experimental_thread_pool_handle(
+ num_threads=num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name=display_name,
+ shared_name=shared_name)
+ self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._resource, handle_device=context.context().device_name)
+ else:
+ self._resource = ged_ops.experimental_thread_pool_handle(
+ num_threads=num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name=display_name)
+
+
+class _ThreadPoolDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that acts as an identity, and sets a custom threadpool."""
+
+ def __init__(self, input_dataset, thread_pool):
+ super(_ThreadPoolDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._thread_pool = thread_pool
+
+ def _as_variant_tensor(self):
+ return ged_ops.experimental_thread_pool_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._thread_pool._resource, # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+# TODO(b/73383364): Properly export in the `tf.data.experimental` API when
+# stable or make private / remove.
+def override_threadpool(dataset, thread_pool):
+ """Returns a new dataset that uses the given thread pool for its operations.
+
+ Args:
+ dataset: A `tf.data.Dataset` object.
+ thread_pool: A `PrivateThreadPool` object.
+
+ Returns:
+ A dataset containing the same values as `dataset`, but which uses
+ `thread_pool` to compute any of its parallel operations (such as
+ `tf.data.Dataset.map`).
+ """
+ return _ThreadPoolDataset(dataset, thread_pool)
diff --git a/tensorflow/python/data/experimental/ops/unique.py b/tensorflow/python/data/experimental/ops/unique.py
new file mode 100644
index 0000000000..2a7775c456
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/unique.py
@@ -0,0 +1,79 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unique element dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.unique")
+def unique():
+ """Creates a `Dataset` from another `Dataset`, discarding duplicates.
+
+ Use this transformation to produce a dataset that contains one instance of
+ each unique element in the input. For example:
+
+ ```python
+ dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])
+
+ # Using `unique()` will drop the duplicate elements.
+ dataset = dataset.apply(tf.data.experimental.unique()) # ==> { 1, 37, 2 }
+ ```
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _UniqueDataset(dataset)
+
+ return _apply_fn
+
+
+class _UniqueDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` contains the unique elements from its input."""
+
+ def __init__(self, input_dataset):
+ """See `unique()` for details."""
+ super(_UniqueDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
+ dtypes.string):
+ raise TypeError(
+ "`tf.data.experimental.unique()` only supports inputs with a single "
+ "`tf.int32`, `tf.int64`, or `tf.string` component.")
+
+ def _as_variant_tensor(self):
+ return gen_experimental_dataset_ops.experimental_unique_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
diff --git a/tensorflow/python/data/experimental/ops/writers.py b/tensorflow/python/data/experimental/ops/writers.py
new file mode 100644
index 0000000000..994447cb4d
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/writers.py
@@ -0,0 +1,60 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for tf.data writers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.TFRecordWriter")
+class TFRecordWriter(object):
+ """Writes data to a TFRecord file."""
+
+ def __init__(self, filename, compression_type=None):
+ self._filename = ops.convert_to_tensor(
+ filename, dtypes.string, name="filename")
+ self._compression_type = convert.optional_param_to_tensor(
+ "compression_type",
+ compression_type,
+ argument_default="",
+ argument_dtype=dtypes.string)
+
+ def write(self, dataset):
+ """Returns a `tf.Operation` to write a dataset to a file.
+
+ Args:
+ dataset: a `tf.data.Dataset` whose elements are to be written to a file
+
+ Returns:
+ A `tf.Operation` that, when run, writes contents of `dataset` to a file.
+ """
+ if not isinstance(dataset, dataset_ops.Dataset):
+ raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
+ if (dataset.output_types != dtypes.string or
+ dataset.output_shapes != tensor_shape.scalar()):
+ raise TypeError(
+ "`dataset` must produce scalar `DT_STRING` tensors whereas it "
+ "produces shape {0} and types {1}".format(dataset.output_shapes,
+ dataset.output_types))
+ return gen_dataset_ops.dataset_to_tf_record(
+ dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 7a6f03d4d3..c7295d6e69 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -15,6 +15,7 @@ tf_py_test(
size = "small",
srcs = ["batch_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -31,10 +32,44 @@ tf_py_test(
)
tf_py_test(
+ name = "cache_dataset_op_test",
+ size = "small",
+ srcs = ["cache_dataset_op_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+)
+
+tf_py_test(
+ name = "concatenate_dataset_op_test",
+ size = "small",
+ srcs = ["concatenate_dataset_op_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//third_party/py/numpy",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_py_test(
name = "dataset_constructor_op_test",
size = "small",
srcs = ["dataset_constructor_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@@ -63,6 +98,7 @@ tf_py_test(
size = "medium",
srcs = ["dataset_from_generator_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -78,8 +114,11 @@ tf_py_test(
size = "small",
srcs = ["dataset_ops_test.py"],
additional_deps = [
- "//tensorflow/core:protos_all_py",
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -89,6 +128,7 @@ tf_py_test(
size = "small",
srcs = ["filter_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -106,6 +146,7 @@ tf_py_test(
size = "small",
srcs = ["flat_map_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
@@ -123,6 +164,7 @@ tf_py_test(
size = "small",
srcs = ["list_files_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
@@ -137,6 +179,7 @@ tf_py_test(
size = "small",
srcs = ["interleave_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -151,11 +194,80 @@ tf_py_test(
],
)
+cuda_py_test(
+ name = "iterator_ops_test",
+ size = "small",
+ srcs = ["iterator_ops_test.py"],
+ additional_deps = [
+ "//third_party/py/numpy",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/training/checkpointable:util",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:training",
+ "//tensorflow/python/compat:compat",
+ "//tensorflow/python:util",
+ "//tensorflow/python:variables",
+ ],
+ grpc_enabled = True,
+)
+
+tf_py_test(
+ name = "iterator_ops_cluster_test",
+ size = "small",
+ srcs = ["iterator_ops_cluster_test.py"],
+ additional_deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:lookup_ops",
+ ],
+ grpc_enabled = True,
+ tags = [
+ "no_oss", # Test flaky due to port collisions.
+ "no_windows",
+ ],
+)
+
tf_py_test(
name = "map_dataset_op_test",
size = "small",
srcs = ["map_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -177,11 +289,54 @@ tf_py_test(
],
)
+cuda_py_test(
+ name = "multi_device_iterator_test",
+ size = "medium",
+ srcs = ["multi_device_iterator_test.py"],
+ additional_deps = [
+ ":test_base",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:multi_device_iterator_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ ],
+ tags = [
+ "no_windows_gpu",
+ ],
+)
+
+cuda_py_test(
+ name = "optional_ops_test",
+ size = "small",
+ srcs = ["optional_ops_test.py"],
+ additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:optional_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:tensor_shape",
+ ],
+)
+
tf_py_test(
name = "prefetch_dataset_op_test",
size = "small",
srcs = ["prefetch_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -197,6 +352,7 @@ tf_py_test(
size = "small",
srcs = ["range_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dataset_ops_gen",
@@ -218,6 +374,7 @@ tf_py_test(
size = "small",
srcs = ["reader_dataset_ops_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -236,32 +393,35 @@ tf_py_test(
)
tf_py_test(
- name = "sequence_dataset_op_test",
+ name = "reduce_dataset_op_test",
size = "small",
- srcs = ["sequence_dataset_op_test.py"],
+ srcs = ["reduce_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
+ "@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
],
)
tf_py_test(
- name = "shuffle_dataset_op_test",
+ name = "sequence_dataset_op_test",
size = "small",
- srcs = ["shuffle_dataset_op_test.py"],
+ srcs = ["sequence_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
],
)
@@ -270,6 +430,7 @@ tf_py_test(
size = "small",
srcs = ["shard_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
@@ -277,155 +438,30 @@ tf_py_test(
)
tf_py_test(
- name = "cache_dataset_op_test",
+ name = "shuffle_dataset_op_test",
size = "small",
- srcs = ["cache_dataset_op_test.py"],
+ srcs = ["shuffle_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:variables",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
],
)
-tf_py_test(
- name = "zip_dataset_op_test",
- size = "small",
- srcs = ["zip_dataset_op_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-tf_py_test(
- name = "concatenate_dataset_op_test",
- size = "small",
- srcs = ["concatenate_dataset_op_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
-cuda_py_test(
- name = "iterator_ops_test",
- size = "small",
- srcs = ["iterator_ops_test.py"],
- additional_deps = [
- "//third_party/py/numpy",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/util:sparse",
- "//tensorflow/python/eager:context",
- "//tensorflow/python/training/checkpointable:util",
- "//tensorflow/python:array_ops",
+py_library(
+ name = "test_base",
+ srcs = ["test_base.py"],
+ deps = [
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
"//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:gradients",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:session",
"//tensorflow/python:sparse_tensor",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:training",
- "//tensorflow/python/compat:compat",
- "//tensorflow/python:util",
- "//tensorflow/python:variables",
- ],
- grpc_enabled = True,
-)
-
-tf_py_test(
- name = "iterator_ops_cluster_test",
- size = "small",
- srcs = ["iterator_ops_cluster_test.py"],
- additional_deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:session",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:lookup_ops",
- ],
- grpc_enabled = True,
- tags = [
- "no_oss", # Test flaky due to port collisions.
- "no_windows",
- ],
-)
-
-cuda_py_test(
- name = "optional_ops_test",
- size = "small",
- srcs = ["optional_ops_test.py"],
- additional_deps = [
- "@absl_py//absl/testing:parameterized",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/ops:optional_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:tensor_shape",
- ],
-)
-
-cuda_py_test(
- name = "multi_device_iterator_test",
- size = "small",
- srcs = ["multi_device_iterator_test.py"],
- additional_deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:multi_device_iterator_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_test_lib",
- ],
- tags = [
- "no_windows_gpu",
+ "//tensorflow/python/data/util:nest",
],
)
@@ -434,6 +470,7 @@ tf_py_test(
size = "small",
srcs = ["window_dataset_op_test.py"],
additional_deps = [
+ ":test_base",
"@absl_py//absl/testing:parameterized",
"//third_party/py/numpy",
"//tensorflow/python:array_ops",
@@ -447,14 +484,16 @@ tf_py_test(
)
tf_py_test(
- name = "inputs_test",
+ name = "zip_dataset_op_test",
size = "small",
- srcs = ["inputs_test.py"],
+ srcs = ["zip_dataset_op_test.py"],
additional_deps = [
- "@absl_py//absl/testing:parameterized",
+ ":test_base",
"//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
"//tensorflow/python/data/ops:dataset_ops",
],
)
diff --git a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
index c48708a2b9..9cb4daf284 100644
--- a/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/batch_dataset_op_test.py
@@ -24,6 +24,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -37,7 +38,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase, parameterized.TestCase):
+class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
('even', 28, 14, False),
@@ -115,11 +116,6 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testBatchSparse(self):
def _sparse(i):
@@ -227,7 +223,7 @@ def _random_seq_lens(count):
return np.random.randint(20, size=(count,)).astype(np.int32)
-class PaddedBatchDatasetTest(test.TestCase, parameterized.TestCase):
+class PaddedBatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
('default_padding', _random_seq_lens(32), 4, [-1], False),
diff --git a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
index d5f5b2fe05..63625fac03 100644
--- a/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/cache_dataset_op_test.py
@@ -23,6 +23,7 @@ import tempfile
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -34,7 +35,7 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import test
-class FileCacheDatasetTest(test.TestCase):
+class FileCacheDatasetTest(test_base.DatasetTestBase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
@@ -200,7 +201,7 @@ class FileCacheDatasetTest(test.TestCase):
self.assertAllEqual(elements, elements_itr2)
-class MemoryCacheDatasetTest(test.TestCase):
+class MemoryCacheDatasetTest(test_base.DatasetTestBase):
def testCacheDatasetPassthrough(self):
with ops.device("cpu:0"):
diff --git a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
index 5dfb84f28e..83af31f380 100644
--- a/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/concatenate_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
-class ConcatenateDatasetTest(test.TestCase):
+class ConcatenateDatasetTest(test_base.DatasetTestBase):
def testConcatenateDataset(self):
input_components = (
diff --git a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
index e43564a2eb..bc6b36285a 100644
--- a/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_constructor_op_test.py
@@ -23,6 +23,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -36,7 +37,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def testFromTensors(self):
"""Test a dataset that represents a single tuple of tensors."""
@@ -58,11 +59,6 @@ class DatasetConstructorTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testFromTensorsSparse(self):
"""Test a dataset that represents a single tuple of tensors."""
components = (sparse_tensor.SparseTensorValue(
diff --git a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
index cd0c1ddf1e..cb8cb9a77d 100644
--- a/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_from_generator_op_test.py
@@ -22,6 +22,7 @@ import threading
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -30,7 +31,7 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def _testFromGenerator(self, generator, elem_sequence, num_repeats,
output_types=None):
diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
index 239aa85175..b9f8875b9f 100644
--- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
@@ -18,12 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+import numpy as np
+
from tensorflow.core.framework import graph_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
-class DatasetOpsTest(test.TestCase):
+class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testAsSerializedGraph(self):
dataset = dataset_ops.Dataset.range(10)
@@ -32,6 +40,155 @@ class DatasetOpsTest(test.TestCase):
sess.run(dataset._as_serialized_graph()))
self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))
+ @staticmethod
+ def make_apply_fn(dataset):
+
+ def apply_fn(dataset):
+
+ def _apply_fn(dataset):
+ return dataset.cache()
+
+ return dataset.apply(_apply_fn)
+
+ return apply_fn
+
+ @staticmethod
+ def make_gen():
+
+ def gen():
+ yield 42
+
+ return gen
+
+ @staticmethod
+ def make_interleave_fn(dataset, num_parallel_calls=None):
+
+ def interleave_fn(dataset):
+ return dataset.interleave(
+ lambda x: dataset_ops.Dataset.range(0),
+ cycle_length=2,
+ num_parallel_calls=num_parallel_calls)
+
+ return interleave_fn
+
+ @parameterized.named_parameters(
+ ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)),
+ ("FromGenerator",
+ dataset_ops.Dataset.from_generator(make_gen.__func__(), dtypes.int32),
+ 1),
+ ("FromSparseTensorSlices",
+ dataset_ops.Dataset.from_sparse_tensor_slices(
+ sparse_tensor.SparseTensor(
+ indices=np.array([[0, 0], [1, 0], [2, 0]]),
+ values=np.array([0, 0, 0]),
+ dense_shape=np.array([3, 1])))),
+ ("FromTensors", dataset_ops.Dataset.from_tensors([42])),
+ ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])),
+ ("Range", dataset_ops.Dataset.range(10)),
+ ("TextLine", readers.TextLineDataset("")),
+ ("TFRecord", readers.TFRecordDataset(""), 1),
+ )
+ def testDatasetSourceInputs(self, dataset, num_inputs=0):
+ self.assertEqual(num_inputs, len(dataset._inputs()))
+
+ @parameterized.named_parameters(
+ ("Apply", make_apply_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)),
+ ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)),
+ ("Filter", lambda x: x.filter(lambda x: True),
+ dataset_ops.Dataset.range(0)),
+ ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Interleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)),
+ ("PaddedBatch", lambda x: x.padded_batch(10, []),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelInterleave",
+ make_interleave_fn.__func__(dataset_ops.Dataset.range(0), 2),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
+ dataset_ops.Dataset.range(0)),
+ ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)),
+ ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)),
+ ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)),
+ ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)),
+ ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)),
+ )
+ def testUnaryTransformationInputs(self, dataset_fn, input_dataset):
+ self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
+
+ @parameterized.named_parameters(
+ ("Concatenate", lambda x, y: x.concatenate(y),
+ dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))
+ def testBinaryTransformationInputs(self, dataset_fn, input1, input2):
+ self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
+
+ @parameterized.named_parameters(
+ ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))),
+ ("ZipNest", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0),
+ (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
+ ("ZipTuple", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))))
+ def testVariadicTransformationInputs(self, dataset_fn, input_datasets):
+ self.assertEqual(
+ nest.flatten(input_datasets),
+ dataset_fn(input_datasets)._inputs())
+
+ def testCollectInputs(self):
+ ds1 = dataset_ops.Dataset.range(0)
+ ds2 = ds1.concatenate(ds1)
+ ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))
+
+ inputs = []
+ queue = [ds3]
+ while queue:
+ ds = queue[0]
+ queue = queue[1:]
+ queue.extend(ds._inputs())
+ inputs.append(ds)
+
+ self.assertEqual(5, inputs.count(ds1))
+ self.assertEqual(2, inputs.count(ds2))
+ self.assertEqual(1, inputs.count(ds3))
+
+ def testOptionsDefault(self):
+ ds = dataset_ops.Dataset.range(0)
+ self.assertEqual(dataset_ops.Options(), ds.options())
+
+ def testOptionsOnce(self):
+ options = dataset_ops.Options()
+ ds = dataset_ops.Dataset.range(0).with_options(options).cache()
+ self.assertEqual(options, ds.options())
+
+ def testOptionsTwiceSame(self):
+ options = dataset_ops.Options()
+ options.experimental_autotune = True
+ ds = dataset_ops.Dataset.range(0).with_options(options).with_options(
+ options)
+ self.assertEqual(options, ds.options())
+
+ def testOptionsTwiceDifferent(self):
+ options1 = dataset_ops.Options()
+ options1.experimental_autotune = True
+ options2 = dataset_ops.Options()
+ options2.experimental_filter_fusion = False
+ ds = dataset_ops.Dataset.range(0).with_options(options1).with_options(
+ options2)
+ self.assertTrue(ds.options().experimental_autotune)
+ self.assertFalse(ds.options().experimental_filter_fusion)
+
+ def testOptionsTwiceDifferentError(self):
+ options1 = dataset_ops.Options()
+ options1.experimental_autotune = True
+ options2 = dataset_ops.Options()
+ options2.experimental_autotune = False
+ with self.assertRaisesRegexp(ValueError,
+ "Cannot merge incompatible values of option"):
+ dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
index 19944d389f..a0c6b37a6d 100644
--- a/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/filter_dataset_op_test.py
@@ -22,6 +22,7 @@ import time
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,7 +34,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class FilterDatasetTest(test.TestCase):
+class FilterDatasetTest(test_base.DatasetTestBase):
def testFilterDataset(self):
components = (
@@ -129,11 +130,6 @@ class FilterDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSparse(self):
def _map_fn(i):
@@ -160,7 +156,7 @@ class FilterDatasetTest(test.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def testReturnComponent(self):
+ def testShortCircuit(self):
iterator = (
dataset_ops.Dataset.zip(
(dataset_ops.Dataset.range(10),
diff --git a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
index 1123cbff62..68038f9cfc 100644
--- a/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/flat_map_dataset_op_test.py
@@ -22,6 +22,7 @@ import random
import numpy as np
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.training import server_lib
-class FlatMapDatasetTest(test.TestCase):
+class FlatMapDatasetTest(test_base.DatasetTestBase):
# pylint: disable=g-long-lambda
def testFlatMapDataset(self):
diff --git a/tensorflow/python/data/kernel_tests/inputs_test.py b/tensorflow/python/data/kernel_tests/inputs_test.py
index 4c9279dd95..d089b49bcc 100644
--- a/tensorflow/python/data/kernel_tests/inputs_test.py
+++ b/tensorflow/python/data/kernel_tests/inputs_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.data.util import nest
@@ -27,7 +28,7 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
-class InputsTest(test.TestCase, parameterized.TestCase):
+class InputsTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def make_apply_fn(dataset):
diff --git a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
index e7e51df65e..92bb67b6ff 100644
--- a/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/interleave_dataset_op_test.py
@@ -22,6 +22,7 @@ import itertools
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import sparse_tensor
@@ -30,7 +31,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class InterleaveDatasetTest(test.TestCase, parameterized.TestCase):
+class InterleaveDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _interleave(self, lists, cycle_length, block_length):
num_open = 0
diff --git a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
index c4b338a58f..8eb13815d4 100644
--- a/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/list_files_dataset_op_test.py
@@ -22,6 +22,7 @@ from os import path
import shutil
import tempfile
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class ListFilesDatasetOpTest(test.TestCase):
+class ListFilesDatasetOpTest(test_base.DatasetTestBase):
def setUp(self):
self.tmp_dir = tempfile.mkdtemp()
diff --git a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
index ae04995436..4683b1db91 100644
--- a/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/map_dataset_op_test.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
+"""Tests for `tf.data.Dataset.map()`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
@@ -27,6 +27,7 @@ import numpy as np
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -47,7 +48,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
-class MapDatasetTest(test.TestCase, parameterized.TestCase):
+class MapDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def _buildMapDataset(self, components, count):
def _map_fn(x, y, z):
@@ -266,6 +267,35 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
+ def testCaptureIterator(self):
+
+ def _build_ds(iterator):
+
+ def _map_fn(x):
+ get_next = iterator.get_next()
+ return x * get_next
+
+ return dataset_ops.Dataset.range(10).map(_map_fn)
+
+ def _build_graph():
+ captured_iterator = dataset_ops.Dataset.range(
+ 10).make_initializable_iterator()
+ ds = _build_ds(captured_iterator)
+ iterator = ds.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ return captured_iterator.initializer, init_op, get_next
+
+ with ops.Graph().as_default() as g:
+ captured_init_op, init_op, get_next = _build_graph()
+ with self.session(graph=g) as sess:
+ sess.run(captured_init_op)
+ sess.run(init_op)
+ for i in range(10):
+ self.assertEqual(i * i, sess.run(get_next))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
def testCaptureHashTable(self):
# NOTE(mrry): We must use the V2 variants of `HashTable`
# etc. because these produce a `tf.resource`-typed output that is
@@ -574,11 +604,6 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSparse(self):
def _sparse(i):
@@ -597,7 +622,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
- self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
+ self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _sparse(i))
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -624,7 +649,7 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
sess.run(init_op)
for i in range(10):
actual = sess.run(get_next)
- self.assertTrue(isinstance(actual, sparse_tensor.SparseTensorValue))
+ self.assertIsInstance(actual, sparse_tensor.SparseTensorValue)
self.assertSparseValuesEqual(actual, _check(_sparse(i)).eval())
with self.assertRaises(errors.OutOfRangeError):
sess.run(get_next)
@@ -758,19 +783,72 @@ class MapDatasetTest(test.TestCase, parameterized.TestCase):
self.assertTrue(all(tids[0] == tid for tid in tids))
# pylint: enable=g-long-lambda
+ @parameterized.named_parameters(
+ ("SequentialIdentity", None, lambda x: x, None),
+ ("SequentialReplicate", None, lambda x: (x, x), None),
+ ("SequentialSwap", (None, None), lambda x, y: (y, x), None),
+ ("SequentialProject", (None, None), lambda x, y: x, None),
+ ("ParallelIdentity", None, lambda x: x, 10),
+ ("ParallelReplicate", None, lambda x: (x, x), 10),
+ ("ParallelSwap", (None, None), lambda x, y: (y, x), 10),
+ ("ParallelProject", (None, None), lambda x, y: x, 10),
+ )
+ def testShortCircuit(self, structure, map_fn, num_parallel_calls):
+ dataset = self.structuredDataset(structure).repeat().map(
+ map_fn, num_parallel_calls=num_parallel_calls)
+ get_next = dataset.make_one_shot_iterator().get_next()
+
+ with self.cached_session() as sess:
+ if isinstance(structure, tuple):
+ expected = map_fn(*sess.run(self.structuredElement(structure)))
+ else:
+ expected = map_fn(sess.run(self.structuredElement(structure)))
+ self.assertEqual(expected, sess.run(get_next))
+
+ @parameterized.named_parameters(
+ ("Sequential", None),
+ ("Parallel", 10),
+ )
+ def testShortCircuitCapturedInput(self, num_parallel_calls):
+ captured_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = self.structuredDataset(None).repeat().map(
+ lambda x: captured_t, num_parallel_calls=num_parallel_calls)
+ iterator = dataset.make_initializable_iterator()
+ get_next = iterator.get_next()
+
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer, feed_dict={captured_t: 42})
+ self.assertEqual(42, sess.run(get_next))
+
class MapDatasetBenchmark(test.Benchmark):
def benchmarkChainOfMaps(self):
chain_lengths = [0, 1, 2, 5, 10, 20, 50]
for chain_length in chain_lengths:
- for use_inter_op_parallelism in [False, True]:
+ for mode in ["general", "single-threaded", "short-circuit"]:
+ if mode == "general":
+ map_fn = lambda x: x + 1
+ use_inter_op_parallelism = True
+ print_label = ""
+ benchmark_label = ""
+ if mode == "single-threaded":
+ map_fn = lambda x: x + 1
+ use_inter_op_parallelism = False
+ print_label = " (single threaded mode)"
+ benchmark_label = "_single_threaded"
+ if mode == "short-circuit":
+ map_fn = lambda x: x
+ use_inter_op_parallelism = True # should not have any significance
+ print_label = " (short circuit mode)"
+ benchmark_label = "_short_circuit"
+
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors(0).repeat(None)
for _ in range(chain_length):
dataset = dataset_ops.MapDataset(
dataset,
- lambda x: x,
+ map_fn,
use_inter_op_parallelism=use_inter_op_parallelism)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
@@ -788,25 +866,39 @@ class MapDatasetBenchmark(test.Benchmark):
median_wall_time = np.median(deltas) / 100
print("Map dataset chain length%s: %d Median wall time: %f" %
- (" (single threaded mode)" if not use_inter_op_parallelism
- else "", chain_length, median_wall_time))
+ (print_label, chain_length, median_wall_time))
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
name="benchmark_map_dataset_chain_latency_%d%s" %
- (chain_length, "_single_threaded"
- if not use_inter_op_parallelism else ""))
+ (chain_length, benchmark_label))
def benchmarkMapFanOut(self):
fan_outs = [1, 2, 5, 10, 20, 50, 100]
for fan_out in fan_outs:
- for use_inter_op_parallelism in [False, True]:
+ for mode in ["general", "single-threaded", "short-circuit"]:
+ if mode == "general":
+ map_fn = lambda *xs: [x + 1 for x in xs]
+ use_inter_op_parallelism = True
+ print_label = ""
+ benchmark_label = ""
+ if mode == "single-threaded":
+ map_fn = lambda *xs: [x + 1 for x in xs]
+ use_inter_op_parallelism = False
+ print_label = " (single threaded mode)"
+ benchmark_label = "_single_threaded"
+ if mode == "short-circuit":
+ map_fn = lambda *xs: xs
+ use_inter_op_parallelism = True # should not have any significance
+ print_label = " (short circuit mode)"
+ benchmark_label = "_short_circuit"
+
with ops.Graph().as_default():
dataset = dataset_ops.Dataset.from_tensors(
tuple(0 for _ in range(fan_out))).repeat(None)
dataset = dataset_ops.MapDataset(
dataset,
- lambda *xs: xs,
+ map_fn,
use_inter_op_parallelism=use_inter_op_parallelism)
iterator = dataset.make_one_shot_iterator()
next_element = iterator.get_next()
@@ -824,14 +916,12 @@ class MapDatasetBenchmark(test.Benchmark):
median_wall_time = np.median(deltas) / 100
print("Map dataset fan out%s: %d Median wall time: %f" %
- (" (single threaded mode)" if not use_inter_op_parallelism
- else "", fan_out, median_wall_time))
+ (print_label, fan_out, median_wall_time))
self.report_benchmark(
iters=1000,
wall_time=median_wall_time,
- name="benchmark_map_dataset_fan_out_%d%s" %
- (fan_out, "_single_threaded"
- if not use_inter_op_parallelism else ""))
+ name="benchmark_map_dataset_fan_out_%d%s" % (fan_out,
+ benchmark_label))
if __name__ == "__main__":
diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
index 056664b83b..1cf6dd1bea 100644
--- a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
+++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import multi_device_iterator_ops
from tensorflow.python.framework import dtypes
@@ -29,7 +30,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class MultiDeviceIteratorTest(test.TestCase):
+class MultiDeviceIteratorTest(test_base.DatasetTestBase):
def testNoGetNext(self):
dataset = dataset_ops.Dataset.range(10)
diff --git a/tensorflow/python/data/kernel_tests/optional_ops_test.py b/tensorflow/python/data/kernel_tests/optional_ops_test.py
index 706a65fe55..604e3ad88e 100644
--- a/tensorflow/python/data/kernel_tests/optional_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/optional_ops_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import optional_ops
@@ -35,7 +36,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class OptionalTest(test.TestCase, parameterized.TestCase):
+class OptionalTest(test_base.DatasetTestBase, parameterized.TestCase):
@test_util.run_in_graph_and_eager_modes
def testFromValue(self):
diff --git a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
index cc97bac609..76e2697b29 100644
--- a/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/prefetch_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
from absl.testing import parameterized
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class PrefetchDatasetTest(test.TestCase, parameterized.TestCase):
+class PrefetchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.parameters((-1), (0), (5))
def testBufferSize(self, buffer_size):
diff --git a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
index 51e90785e7..b7e2a5f615 100644
--- a/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/range_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import os
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
@@ -34,7 +35,7 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
-class RangeDatasetTest(test.TestCase):
+class RangeDatasetTest(test_base.DatasetTestBase):
def tearDown(self):
# Remove all checkpoint files.
diff --git a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
index aa3636364d..aef2dd1d9c 100644
--- a/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/reader_dataset_ops_test.py
@@ -21,6 +21,7 @@ import gzip
import os
import zlib
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
@@ -46,7 +47,7 @@ except ImportError:
psutil_import_succeeded = False
-class TextLineDatasetTest(test.TestCase):
+class TextLineDatasetTest(test_base.DatasetTestBase):
def _lineText(self, f, l):
return compat.as_bytes("%d: %d" % (f, l))
@@ -199,7 +200,7 @@ class TextLineDatasetTest(test.TestCase):
self.assertNotIn(filename, [open_file.path for open_file in open_files])
-class FixedLengthRecordReaderTest(test.TestCase):
+class FixedLengthRecordReaderTest(test_base.DatasetTestBase):
def setUp(self):
super(FixedLengthRecordReaderTest, self).setUp()
@@ -621,7 +622,7 @@ class FixedLengthRecordReaderTest(test.TestCase):
sess.run(get_next_op)
-class TFRecordDatasetTest(test.TestCase):
+class TFRecordDatasetTest(test_base.DatasetTestBase):
def setUp(self):
super(TFRecordDatasetTest, self).setUp()
diff --git a/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py
new file mode 100644
index 0000000000..11e07300b9
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/reduce_dataset_op_test.py
@@ -0,0 +1,124 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ def testSum(self):
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce(np.int64(0), lambda x, y: x + y)
+ with self.cached_session() as sess:
+ self.assertEqual(((i + 1) * i) // 2, sess.run(result))
+
+ def testSumTuple(self):
+
+ def reduce_fn(state, value):
+ v1, v2 = value
+ return state + v1 + v2
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ ds = dataset_ops.Dataset.zip((ds, ds))
+ result = ds.reduce(np.int64(0), reduce_fn)
+ with self.cached_session() as sess:
+ self.assertEqual(((i + 1) * i), sess.run(result))
+
+ def testSumAndCount(self):
+
+ def reduce_fn(state, value):
+ s, c = state
+ return s + value, c + 1
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce((np.int64(0), np.int64(0)), reduce_fn)
+ with self.cached_session() as sess:
+ s, c = sess.run(result)
+ self.assertEqual(((i + 1) * i) // 2, s)
+ self.assertEqual(i, c)
+
+ def testSquareUsingPlaceholder(self):
+ delta = array_ops.placeholder(dtype=dtypes.int64)
+
+ def reduce_fn(state, _):
+ return state + delta
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1)
+ result = ds.reduce(np.int64(0), reduce_fn)
+ with self.cached_session() as sess:
+ square = sess.run(result, feed_dict={delta: i})
+ self.assertEqual(i * i, square)
+
+ def testSparse(self):
+
+ def reduce_fn(_, value):
+ return value
+
+ def make_sparse_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.from_tensors(make_sparse_fn(i+1))
+ result = ds.reduce(make_sparse_fn(0), reduce_fn)
+ with self.cached_session() as sess:
+ self.assertSparseValuesEqual(make_sparse_fn(i+1), sess.run(result))
+
+ def testNested(self):
+
+ def reduce_fn(state, value):
+ state["dense"] += value["dense"]
+ state["sparse"] = value["sparse"]
+ return state
+
+ def make_sparse_fn(i):
+ return sparse_tensor.SparseTensorValue(
+ indices=np.array([[0, 0]]),
+ values=(i * np.array([1])),
+ dense_shape=np.array([1, 1]))
+
+ def map_fn(i):
+ return {"dense": math_ops.cast(i, dtype=dtypes.int64),
+ "sparse": make_sparse_fn(math_ops.cast(i, dtype=dtypes.int64))}
+
+ for i in range(10):
+ ds = dataset_ops.Dataset.range(1, i + 1).map(map_fn)
+ result = ds.reduce(map_fn(0), reduce_fn)
+ with self.cached_session() as sess:
+ result = sess.run(result)
+ self.assertEqual(((i + 1) * i) // 2, result["dense"])
+ self.assertSparseValuesEqual(make_sparse_fn(i), result["sparse"])
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
index 37e2333560..e86356dee7 100644
--- a/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/sequence_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class SequenceDatasetTest(test.TestCase):
+class SequenceDatasetTest(test_base.DatasetTestBase):
def testRepeatTensorDataset(self):
"""Test a dataset that repeats its input multiple times."""
diff --git a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
index 137f6341ce..b9f3c79da5 100644
--- a/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shard_dataset_op_test.py
@@ -17,12 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class ShardDatasetOpTest(test.TestCase):
+class ShardDatasetOpTest(test_base.DatasetTestBase):
def testSimpleCase(self):
dataset = dataset_ops.Dataset.range(10).shard(5, 2)
diff --git a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
index f294840706..347af18576 100644
--- a/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/shuffle_dataset_op_test.py
@@ -21,6 +21,7 @@ import collections
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -30,7 +31,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ShuffleDatasetTest(test.TestCase):
+class ShuffleDatasetTest(test_base.DatasetTestBase):
def testShuffleDataset(self):
components = (
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
new file mode 100644
index 0000000000..b73a94e683
--- /dev/null
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -0,0 +1,138 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test utilities for tf.data functionality."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import re
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class DatasetTestBase(test.TestCase):
+ """Base class for dataset tests."""
+
+ def assertSparseValuesEqual(self, a, b):
+ """Asserts that two SparseTensors/SparseTensorValues are equal."""
+ self.assertAllEqual(a.indices, b.indices)
+ self.assertAllEqual(a.values, b.values)
+ self.assertAllEqual(a.dense_shape, b.dense_shape)
+
+ def getNext(self, dataset):
+ """Returns a callable that returns the next element of the dataset.
+
+ Example use:
+ ```python
+ # In both graph and eager modes
+ dataset = ...
+ nxt = self.getNext(dataset)
+ result = self.evaluate(nxt())
+ ```
+
+ Args:
+ dataset: A dataset whose next element is returned
+
+ Returns:
+ A callable that returns the next element of `dataset`
+ """
+ it = dataset.make_one_shot_iterator()
+ if context.executing_eagerly():
+ return it.get_next
+ else:
+ nxt = it.get_next()
+ return lambda: nxt
+
+ def assertDatasetsEqual(self, dataset1, dataset2):
+ """Checks that datasets are equal. Supports both graph and eager mode."""
+ self.assertEqual(dataset1.output_types, dataset2.output_types)
+ self.assertEqual(dataset1.output_classes, dataset2.output_classes)
+
+ next1 = self.getNext(dataset1)
+ next2 = self.getNext(dataset2)
+ while True:
+ try:
+ op1 = self.evaluate(next1())
+ except errors.OutOfRangeError:
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(next2())
+ break
+ op2 = self.evaluate(next2())
+
+ op1 = nest.flatten(op1)
+ op2 = nest.flatten(op2)
+ assert len(op1) == len(op2)
+ for i in range(len(op1)):
+ if isinstance(
+ op1[i],
+ (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
+ self.assertSparseValuesEqual(op1[i], op2[i])
+ else:
+ self.assertAllEqual(op1[i], op2[i])
+
+ def assertDatasetsRaiseSameError(self,
+ dataset1,
+ dataset2,
+ exception_class,
+ replacements=None):
+ """Checks that datasets raise the same error on the first get_next call."""
+ next1 = self.getNext(dataset1)
+ next2 = self.getNext(dataset2)
+ try:
+ self.evaluate(next1())
+ raise ValueError(
+ 'Expected dataset to raise an error of type %s, but it did not.' %
+ repr(exception_class))
+ except exception_class as e:
+ expected_message = e.message
+ for old, new, count in replacements:
+ expected_message = expected_message.replace(old, new, count)
+ # Check that the first segment of the error messages are the same.
+ with self.assertRaisesRegexp(exception_class,
+ re.escape(expected_message)):
+ self.evaluate(next2())
+
+ def structuredDataset(self, structure, shape=None, dtype=dtypes.int64):
+ """Returns a singleton dataset with the given structure."""
+ if shape is None:
+ shape = []
+ if structure is None:
+ return dataset_ops.Dataset.from_tensors(
+ array_ops.zeros(shape, dtype=dtype))
+ else:
+ return dataset_ops.Dataset.zip(
+ tuple([
+ self.structuredDataset(substructure, shape, dtype)
+ for substructure in structure
+ ]))
+
+ def structuredElement(self, structure, shape=None, dtype=dtypes.int64):
+ """Returns an element with the given structure."""
+ if shape is None:
+ shape = []
+ if structure is None:
+ return array_ops.zeros(shape, dtype=dtype)
+ else:
+ return tuple([
+ self.structuredElement(substructure, shape, dtype)
+ for substructure in structure
+ ])
diff --git a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
index fd4348426d..9d06781094 100644
--- a/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/window_dataset_op_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -29,7 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class WindowDatasetTest(test.TestCase, parameterized.TestCase):
+class WindowDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("1", 20, 14, 7, 1),
@@ -150,11 +151,6 @@ class WindowDatasetTest(test.TestCase, parameterized.TestCase):
stride_t: stride
})
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testWindowSparse(self):
def _sparse(i):
diff --git a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
index 3106effbd3..9d76387a34 100644
--- a/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
+++ b/tensorflow/python/data/kernel_tests/zip_dataset_op_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -26,7 +27,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ZipDatasetTest(test.TestCase):
+class ZipDatasetTest(test_base.DatasetTestBase):
def testZipDataset(self):
component_placeholders = [
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index ac87a451b1..b7e19055f2 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -86,6 +86,18 @@ class Dataset(object):
raise NotImplementedError("Dataset._inputs")
+ def options(self):
+ """Returns the options for this dataset.
+
+ Returns:
+ A `tf.data.Options` object representing the dataset options.
+ """
+ for input_dataset in self._inputs():
+ options = input_dataset.options()
+ if options is not None:
+ return options
+ return Options()
+
def make_initializable_iterator(self, shared_name=None):
"""Creates an `Iterator` for enumerating the elements of this dataset.
@@ -114,6 +126,13 @@ class Dataset(object):
raise RuntimeError(
"dataset.make_initializable_iterator is not supported when eager "
"execution is enabled.")
+ dataset = self
+ options = self.options()
+ static_optimizations = options._static_optimizations() # pylint: disable=protected-access
+ if static_optimizations:
+ dataset = _OptimizeDataset(dataset, static_optimizations)
+ if options.experimental_autotune:
+ dataset = _ModelDataset(dataset)
if shared_name is None:
shared_name = ""
if compat.forward_compatible(2018, 8, 3):
@@ -123,11 +142,12 @@ class Dataset(object):
iterator_resource = gen_dataset_ops.iterator(
container="", shared_name=shared_name, **flat_structure(self))
with ops.colocate_with(iterator_resource):
- initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(),
- iterator_resource)
+ initializer = gen_dataset_ops.make_iterator(
+ dataset._as_variant_tensor(), # pylint: disable=protected-access
+ iterator_resource)
return iterator_ops.Iterator(iterator_resource, initializer,
- self.output_types, self.output_shapes,
- self.output_classes)
+ dataset.output_types, dataset.output_shapes,
+ dataset.output_classes)
def __iter__(self):
"""Creates an `Iterator` for enumerating the elements of this dataset.
@@ -162,7 +182,14 @@ class Dataset(object):
# a 0-argument function.
@function.Defun(capture_by_value=True)
def _make_dataset():
- return self._as_variant_tensor() # pylint: disable=protected-access
+ dataset = self
+ options = self.options()
+ static_optimizations = options._static_optimizations() # pylint: disable=protected-access
+ if static_optimizations:
+ dataset = _OptimizeDataset(dataset, static_optimizations)
+ if options.experimental_autotune:
+ dataset = _ModelDataset(dataset)
+ return dataset._as_variant_tensor() # pylint: disable=protected-access
try:
_make_dataset.add_to_graph(ops.get_default_graph())
@@ -889,8 +916,8 @@ class Dataset(object):
will be padded out to the maximum length of all elements in that
dimension.
- See also `tf.contrib.data.dense_to_sparse_batch`, which combines elements
- that may have different shapes into a `tf.SparseTensor`.
+ See also `tf.data.experimental.dense_to_sparse_batch`, which combines
+ elements that may have different shapes into a `tf.SparseTensor`.
Args:
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
@@ -1205,6 +1232,266 @@ class Dataset(object):
shift = size
return WindowDataset(self, size, shift, stride, drop_remainder)
+ def reduce(self, initial_state, reduce_func):
+ """Reduces the input dataset to a single element.
+
+ The transformation calls `reduce_func` successively on every element of
+ the input dataset until the dataset is exhausted, aggregating information in
+ its internal state. The `initial_state` argument is used for the initial
+ state and the final state is returned as the result.
+
+ For example:
+ - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, _: x + 1)`
+ produces `5`
+ - `tf.data.Dataset.range(5).reduce(np.int64(0), lambda x, y: x + y)`
+ produces `10`
+
+ Args:
+ initial_state: A nested structure of tensors, representing the initial
+ state of the transformation.
+ reduce_func: A function that maps `(old_state, input_element)` to
+ `new_state`. It must take two arguments and return a nested structure
+ of tensors. The structure of `new_state` must match the structure of
+ `initial_state`.
+
+ Returns:
+ A nested structure of `tf.Tensor` objects, corresponding to the final
+ state of the transformation.
+
+ """
+
+ with ops.name_scope("initial_state"):
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ initial_state = nest.pack_sequence_as(initial_state, [
+ sparse_tensor_lib.SparseTensor.from_value(t)
+ if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
+ t, name="component_%d" % i)
+ for i, t in enumerate(nest.flatten(initial_state))
+ ])
+
+ # Compute initial values for the state classes, shapes and types based on
+ # the initial state.
+ state_classes = sparse.get_classes(initial_state)
+ state_shapes = nest.pack_sequence_as(
+ initial_state, [t.get_shape() for t in nest.flatten(initial_state)])
+ state_types = nest.pack_sequence_as(
+ initial_state, [t.dtype for t in nest.flatten(initial_state)])
+
+ # Iteratively rerun the reduce function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ wrapped_func = StructuredFunctionWrapper(
+ reduce_func,
+ "reduce()",
+ input_classes=(state_classes, self.output_classes),
+ input_shapes=(state_shapes, self.output_shapes),
+ input_types=(state_types, self.output_types),
+ add_to_graph=False)
+
+ # Extract and validate class information from the returned values.
+ output_classes = wrapped_func.output_classes
+ for new_state_class, state_class in zip(
+ nest.flatten(output_classes), nest.flatten(state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." % (state_classes,
+ wrapped_func.output_classes))
+
+ # Extract and validate type information from the returned values.
+ output_types = wrapped_func.output_types
+ for new_state_type, state_type in zip(
+ nest.flatten(output_types), nest.flatten(state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." % (state_types,
+ wrapped_func.output_types))
+
+ # Extract shape information from the returned values.
+ output_shapes = wrapped_func.output_shapes
+ flat_state_shapes = nest.flatten(state_shapes)
+ flat_new_state_shapes = nest.flatten(output_shapes)
+ weakened_state_shapes = [
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ original_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ state_shapes = nest.pack_sequence_as(state_shapes,
+ weakened_state_shapes)
+
+ reduce_func = wrapped_func.function
+ reduce_func.add_to_graph(ops.get_default_graph())
+
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(
+ output_types,
+ gen_dataset_ops.reduce_dataset(
+ self._as_variant_tensor(), # pylint: disable=protected-access
+ nest.flatten(sparse.serialize_sparse_tensors(initial_state)),
+ reduce_func.captured_inputs,
+ f=reduce_func,
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(output_shapes, output_classes)),
+ output_types=nest.flatten(
+ sparse.as_dense_types(output_types, output_classes)))),
+ output_types,
+ output_shapes,
+ output_classes)
+
+ def with_options(self, options):
+ """Returns a new `tf.data.Dataset` with the given options set.
+
+ The options are "global" in the sense they apply to the entire input
+ pipeline in which the `with_options` transformation is used. If options are
+ set multiple times, they are merged if possible (see
+ `tf.data.Options.merge()` for details).
+
+ Args:
+ options: A `tf.data.Options` that identifies the options the use.
+
+ Returns:
+ Dataset: A `Dataset` with the given options.
+
+ Raises:
+ ValueError: if options are set more than once
+ """
+ return _OptionsDataset(self, options)
+
+
+@tf_export("data.Options")
+class Options(object):
+ """Represents options for tf.data.Dataset.
+
+ An `Options` object can be for instance used to control which static
+ optimizations to apply or whether to use performance modeling to dynamically
+ tune the parallelism of operations such as `tf.data.Dataset.map` or
+ `tf.data.Dataset.interleave`.
+ """
+ for _name, _ty, _docstring in [
+ ("experimental_autotune", bool,
+ "Whether to dynamically adjust the values of tunable parameters (e.g. "
+ "degrees of parallelism)."),
+ ("experimental_filter_fusion", bool,
+ "Whether to fuse filter transformations."),
+ ("experimental_hoist_random_uniform", bool,
+ "Whether to hoist `tf.random_uniform()` ops out of map transformations."
+ ),
+ ("experimental_latency_all_edges", bool,
+ "Whether to add latency measurements on all edges."),
+ ("experimental_map_and_batch_fusion", bool,
+ "Whether to fuse map and batch transformations."),
+ ("experimental_map_and_filter_fusion", bool,
+ "Whether to fuse map and filter transformations."),
+ ("experimental_map_fusion", bool, "Whether to fuse map transformations."),
+ ("experimental_map_parallelization", bool,
+ "Whether to parallelize stateless map transformations."),
+ ("experimental_map_vectorization", bool,
+ "Whether to vectorize map transformations."),
+ ("experimental_noop_elimination", bool,
+ "Whether to eliminate no-op transformations."),
+ ("experimental_shuffle_and_repeat_fusion", bool,
+ "Whether to fuse shuffle and repeat transformations."),
+ ]:
+
+ def _make_getter(name): # pylint: disable=no-self-argument
+
+ def getter(self):
+ return getattr(self, "_" + name)
+
+ return getter
+
+ def _make_setter(name, ty): # pylint: disable=no-self-argument
+
+ def setter(self, value):
+ if not isinstance(value, ty):
+ raise TypeError(
+ "Attempting to set the option %s to incompatible value: %r" %
+ (name, value))
+ setattr(self, "_" + name, value)
+
+ return setter
+
+ vars()["_" + _name] = None
+ vars()[_name] = property(
+ _make_getter(_name), _make_setter(_name, _ty), None, _docstring)
+
+ def __init__(self):
+ pass
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.__dict__ == other.__dict__
+ else:
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def _static_optimizations(self):
+ """Produces the list of enabled static optimizations."""
+ experimental_optimizations = [
+ "filter_fusion", "hoist_random_uniform", "latency_all_edges",
+ "map_and_batch_fusion", "map_and_filter_fusion", "map_fusion",
+ "map_parallelization", "map_vectorization", "noop_elimination",
+ "shuffle_and_repeat_fusion"
+ ]
+ result = []
+ for exp_opt in experimental_optimizations:
+ if getattr(self, "experimental_" + exp_opt):
+ result.append(exp_opt)
+ return result
+
+ def merge(self, options):
+ """Merges itself with the given `tf.data.Options`.
+
+ The given `tf.data.Options` can be merged as long as there does not exist an
+ attribute that is set to different values in `self` and `options`.
+
+ Args:
+ options: a `tf.data.Options` to merge with
+
+ Raises:
+ ValueError: if the given `tf.data.Options` cannot be merged
+
+ Returns:
+ New `tf.data.Options()` object which is the result of merging self with
+ the input `tf.data.Options`.
+ """
+ result = Options()
+ for other in [self, options]:
+ for name in [
+ "experimental_autotune", "experimental_filter_fusion",
+ "experimental_hoist_random_uniform", "experimental_latency_all_edges",
+ "experimental_map_and_batch_fusion",
+ "experimental_map_and_filter_fusion", "experimental_map_fusion",
+ "experimental_map_parallelization", "experimental_map_vectorization",
+ "experimental_noop_elimination",
+ "experimental_shuffle_and_repeat_fusion"
+ ]:
+ this = getattr(result, name)
+ that = getattr(other, name)
+ if that is not None:
+ if this is None:
+ setattr(result, name, that)
+ elif this != that:
+ raise ValueError(
+ "Cannot merge incompatible values of option: %s" % (name))
+ return result
+
class DatasetSource(Dataset):
"""Abstract class representing a dataset with no inputs."""
@@ -1544,6 +1831,10 @@ class StructuredFunctionWrapper(object):
flat_classes.append(component)
flat_shapes.append(component)
flat_types.append(component)
+ if t.options() != Options():
+ warnings.warn("Encountered a nested dataset with non-default "
+ "options. These options will not be propagated to "
+ "the outer dataset.")
else:
try:
t = ops.convert_to_tensor(t)
@@ -2583,3 +2874,91 @@ class WindowDataset(UnaryDataset):
@property
def output_types(self):
return self._output_types
+
+
+class _OptionsDataset(UnaryDataset):
+ """An identity `Dataset` that stores options."""
+
+ def __init__(self, input_dataset, options):
+ super(_OptionsDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._options = input_dataset.options()
+ if self._options:
+ self._options = self._options.merge(options)
+ else:
+ self._options = options
+
+ def _as_variant_tensor(self):
+ return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+
+ def options(self):
+ return self._options
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _ModelDataset(UnaryDataset):
+ """A `Dataset` that acts as an identity, and models performance."""
+
+ def __init__(self, input_dataset):
+ """See `optimize()` for details."""
+ super(_ModelDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.model_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _OptimizeDataset(UnaryDataset):
+ """A `Dataset` that acts as an identity, and applies optimizations."""
+
+ def __init__(self, input_dataset, optimizations):
+ """See `optimize()` for details."""
+ super(_OptimizeDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ if optimizations is None:
+ optimizations = []
+ self._optimizations = ops.convert_to_tensor(
+ optimizations, dtype=dtypes.string, name="optimizations")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.optimize_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._optimizations,
+ **flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py
index 3bbebd7878..aca989e03a 100644
--- a/tensorflow/python/data/ops/optional_ops.py
+++ b/tensorflow/python/data/ops/optional_ops.py
@@ -31,7 +31,7 @@ class Optional(object):
An `Optional` can represent the result of an operation that may fail as a
value, rather than raising an exception and halting execution. For example,
- `tf.contrib.data.get_next_as_optional` returns an `Optional` that either
+ `tf.data.experimental.get_next_as_optional` returns an `Optional` that either
contains the next value from a `tf.data.Iterator` if one exists, or a "none"
value that indicates the end of the sequence has been reached.
"""
@@ -111,7 +111,7 @@ class Optional(object):
class _OptionalImpl(Optional):
- """Concrete implementation of `tf.contrib.data.Optional`.
+ """Concrete implementation of `tf.data.experimental.Optional`.
NOTE(mrry): This implementation is kept private, to avoid defining
`Optional.__init__()` in the public API.
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index b0f26631f9..d08da6704c 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -129,7 +129,7 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
def __init__(self, input_dataset, map_func, cycle_length, block_length,
sloppy, buffer_output_elements, prefetch_input_elements):
- """See `tf.contrib.data.parallel_interleave()` for details."""
+ """See `tf.data.experimental.parallel_interleave()` for details."""
super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func,
cycle_length, block_length)
self._sloppy = ops.convert_to_tensor(
@@ -158,7 +158,7 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
# pylint: enable=protected-access
def _transformation_name(self):
- return "tf.contrib.data.parallel_interleave()"
+ return "tf.data.experimental.parallel_interleave()"
@tf_export("data.TFRecordDataset")
diff --git a/tensorflow/python/debug/BUILD b/tensorflow/python/debug/BUILD
index 849d165bfa..e84482d2b2 100644
--- a/tensorflow/python/debug/BUILD
+++ b/tensorflow/python/debug/BUILD
@@ -18,6 +18,7 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
load("//tensorflow:tensorflow.bzl", "if_not_windows")
py_library(
diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py
index 4630bda590..f197a9e4dc 100644
--- a/tensorflow/python/debug/cli/analyzer_cli_test.py
+++ b/tensorflow/python/debug/cli/analyzer_cli_test.py
@@ -599,11 +599,11 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
v_name = "simple_mul_add/v"
u_init = constant_op.constant(u_init_val, shape=[2, 2], name="u_init")
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
cls._u_line_number = line_number_above()
v_init = constant_op.constant(v_init_val, shape=[2, 1], name="v_init")
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
cls._v_line_number = line_number_above()
w = math_ops.matmul(u, v, name="simple_mul_add/matmul")
@@ -612,7 +612,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
x = math_ops.add(w, w, name="simple_mul_add/add")
cls._x_line_number = line_number_above()
- a = variables.Variable([1, 3, 3, 7], name="a")
+ a = variables.VariableV1([1, 3, 3, 7], name="a")
u.initializer.run()
v.initializer.run()
@@ -1371,7 +1371,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
# Verify the annotation of the line that creates u.
index = self._findSourceLine(out, self._u_line_number)
self.assertEqual(
- ["L%d u = variables.Variable(u_init, name=u_name)" %
+ ["L%d u = variables.VariableV1(u_init, name=u_name)" %
self._u_line_number,
" simple_mul_add/u",
" simple_mul_add/u/Assign",
@@ -1388,7 +1388,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
# Verify the annotation of the line that creates v.
index = self._findSourceLine(out, self._v_line_number)
self.assertEqual(
- ["L%d v = variables.Variable(v_init, name=v_name)" %
+ ["L%d v = variables.VariableV1(v_init, name=v_name)" %
self._v_line_number,
" simple_mul_add/v"],
out.lines[index : index + 2])
@@ -1425,7 +1425,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
# Verify the annotation of the line that creates u.
index = self._findSourceLine(out, self._u_line_number)
self.assertEqual(
- ["L%d u = variables.Variable(u_init, name=u_name)" %
+ ["L%d u = variables.VariableV1(u_init, name=u_name)" %
self._u_line_number,
" simple_mul_add/u/read:0",
" simple_mul_add/u:0"],
@@ -1447,7 +1447,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
index = self._findSourceLine(out, self._u_line_number)
self.assertEqual(
- ["L%d u = variables.Variable(u_init, name=u_name)" %
+ ["L%d u = variables.VariableV1(u_init, name=u_name)" %
self._u_line_number,
" simple_mul_add/u",
" simple_mul_add/u/Assign",
@@ -1470,7 +1470,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
index = self._findSourceLine(out, self._u_line_number)
self.assertEqual(
- ["L%d u = variables.Variable(u_init, name=u_name)" %
+ ["L%d u = variables.VariableV1(u_init, name=u_name)" %
self._u_line_number,
" simple_mul_add/u",
" (... Omitted 2 of 3 op(s) ...) +5"],
@@ -1580,7 +1580,7 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase):
"""List an input tree containing tensors from non-:0 output slot."""
with session.Session(config=no_rewrite_session_config()) as sess:
- x = variables.Variable([1, 3, 3, 7], name="x")
+ x = variables.VariableV1([1, 3, 3, 7], name="x")
_, idx = array_ops.unique(x, name="x_unique")
idx_times_two = math_ops.multiply(idx, 2, name="idx_times_two")
sess.run(x.initializer)
@@ -1684,7 +1684,7 @@ class AnalyzerCLIControlDepTest(test_util.TensorFlowTestCase):
with session.Session(config=no_rewrite_session_config()) as sess:
x_init_val = np.array([5.0, 3.0])
x_init = constant_op.constant(x_init_val, shape=[2])
- x = variables.Variable(x_init, name="control_deps/x")
+ x = variables.VariableV1(x_init, name="control_deps/x")
y = math_ops.add(x, x, name="control_deps/y")
y = control_flow_ops.with_dependencies(
diff --git a/tensorflow/python/debug/cli/stepper_cli_test.py b/tensorflow/python/debug/cli/stepper_cli_test.py
index ee8cabca0d..7b8a42c253 100644
--- a/tensorflow/python/debug/cli/stepper_cli_test.py
+++ b/tensorflow/python/debug/cli/stepper_cli_test.py
@@ -132,8 +132,8 @@ def _parse_updated(lines):
class NodeStepperSimpleGraphTest(test_util.TensorFlowTestCase):
def setUp(self):
- self.a = variables.Variable(10.0, name="a")
- self.b = variables.Variable(20.0, name="b")
+ self.a = variables.VariableV1(10.0, name="a")
+ self.b = variables.VariableV1(20.0, name="b")
self.c = math_ops.add(self.a, self.b, name="c") # Should be 30.0.
self.d = math_ops.subtract(self.a, self.c, name="d") # Should be -20.0.
diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py
index 019f13c450..f9bb3148fb 100644
--- a/tensorflow/python/debug/examples/debug_tflearn_iris.py
+++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py
@@ -94,13 +94,15 @@ def main(_):
"sepal_length", "sepal_width", "petal_length", "petal_width", "label"]
batch_size = 32
def training_input_fn():
- return tf.contrib.data.make_csv_dataset(
- [training_data_path], batch_size,
- column_names=column_names, label_name="label")
+ return tf.data.experimental.make_csv_dataset([training_data_path],
+ batch_size,
+ column_names=column_names,
+ label_name="label")
def test_input_fn():
- return tf.contrib.data.make_csv_dataset(
- [test_data_path], batch_size,
- column_names=column_names, label_name="label")
+ return tf.data.experimental.make_csv_dataset([test_data_path],
+ batch_size,
+ column_names=column_names,
+ label_name="label")
feature_columns = [tf.feature_column.numeric_column(feature)
for feature in column_names[:-1]]
diff --git a/tensorflow/python/debug/examples/examples_test.sh b/tensorflow/python/debug/examples/examples_test.sh
index f7d597c8c0..89dc918616 100755
--- a/tensorflow/python/debug/examples/examples_test.sh
+++ b/tensorflow/python/debug/examples/examples_test.sh
@@ -115,7 +115,7 @@ OUTPUT=$(${OFFLINE_ANALYZER_BIN} 2>&1)
set -e
EXPECTED_OUTPUT="ERROR: dump_dir flag is empty."
-if [[ "${OUTPUT}" != "${EXPECTED_OUTPUT}" ]]; then
+if ! echo "${OUTPUT}" | grep -q "${EXPECTED_OUTPUT}"; then
echo "ERROR: offline_analyzer output didn't match expectation: ${OUTPUT}" 1>&2
echo "Expected output: ${EXPECTED_OUTPUT}"
exit 1
diff --git a/tensorflow/python/debug/lib/debug_utils_test.py b/tensorflow/python/debug/lib/debug_utils_test.py
index 5b1875e092..23ab98444c 100644
--- a/tensorflow/python/debug/lib/debug_utils_test.py
+++ b/tensorflow/python/debug/lib/debug_utils_test.py
@@ -46,8 +46,8 @@ class DebugUtilsTest(test_util.TensorFlowTestCase):
cls._b_init = constant_op.constant(
cls._b_init_val, shape=[2, 1], name="b_init")
- cls._a = variables.Variable(cls._a_init, name="a1")
- cls._b = variables.Variable(cls._b_init, name="b")
+ cls._a = variables.VariableV1(cls._a_init, name="a1")
+ cls._b = variables.VariableV1(cls._b_init, name="b")
cls._c = constant_op.constant(cls._c_val, shape=[2, 1], name="c")
# Matrix product of a and b.
diff --git a/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py b/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
index 46a7be5808..74498c8ea3 100644
--- a/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py
@@ -118,8 +118,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase):
"""
with ops.Graph().as_default() as graph:
with ops.device("/job:worker/task:0/cpu:0"):
- self.a = variables.Variable(10.0, name="a")
- self.b = variables.Variable(100.0, name="b")
+ self.a = variables.VariableV1(10.0, name="a")
+ self.b = variables.VariableV1(100.0, name="b")
self.inc_a = state_ops.assign_add(self.a, 2.0, name="inc_a")
self.dec_b = state_ops.assign_add(self.b, -5.0, name="dec_b")
self.p = math_ops.multiply(self.inc_a, self.dec_b, name="p")
diff --git a/tensorflow/python/debug/lib/grpc_large_data_test.py b/tensorflow/python/debug/lib/grpc_large_data_test.py
index 5bc477a9ba..ccc21bcf94 100644
--- a/tensorflow/python/debug/lib/grpc_large_data_test.py
+++ b/tensorflow/python/debug/lib/grpc_large_data_test.py
@@ -61,7 +61,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
with self.test_session(
use_gpu=True,
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- u = variables.Variable(42.0, name="original_u")
+ u = variables.VariableV1(42.0, name="original_u")
for _ in xrange(50 * 1000):
u = array_ops.identity(u)
sess.run(variables.global_variables_initializer())
@@ -94,7 +94,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
u_init = constant_op.constant(
u_init_val_array, dtype=dtypes.float32, name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds # Unused by this watch_fn.
@@ -117,7 +117,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
b"", b"spam", b"A" * 2500 * 1024, b"B" * 2500 * 1024, b"egg", b""]
u_init = constant_op.constant(
u_init_val, dtype=dtypes.string, name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds
@@ -146,7 +146,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
u_init = constant_op.constant(
u_init_val_array, dtype=dtypes.string, name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds
@@ -167,7 +167,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
config=session_debug_testlib.no_rewrite_session_config()) as sess:
u_init = constant_op.constant(
[], dtype=dtypes.float32, shape=[0], name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds
@@ -189,7 +189,7 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase):
config=session_debug_testlib.no_rewrite_session_config()) as sess:
u_init = constant_op.constant(
[], dtype=dtypes.string, shape=[0], name="u_init")
- u = variables.Variable(u_init, name="u")
+ u = variables.VariableV1(u_init, name="u")
def watch_fn(fetches, feeds):
del fetches, feeds
diff --git a/tensorflow/python/debug/lib/session_debug_file_test.py b/tensorflow/python/debug/lib/session_debug_file_test.py
index ba0f15b4e2..1874160dd6 100644
--- a/tensorflow/python/debug/lib/session_debug_file_test.py
+++ b/tensorflow/python/debug/lib/session_debug_file_test.py
@@ -58,9 +58,9 @@ class SessionDebugFileTest(session_debug_testlib.SessionDebugTestBase):
v_name = "diff_Watch/v"
u_init = constant_op.constant(u_init_val, shape=[2, 2])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v_init = constant_op.constant(v_init_val, shape=[2, 1])
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
w = math_ops.matmul(u, v, name="diff_Watch/matmul")
diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py
index 91f21cb1f3..bfc9a3a382 100644
--- a/tensorflow/python/debug/lib/session_debug_grpc_test.py
+++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py
@@ -148,8 +148,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
sess, "localhost:%d" % self._server_port, watch_fn="foo")
def testGrpcDebugWrapperSessionWithoutWatchFnWorks(self):
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -175,8 +175,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
del feeds, fetch_keys
return ["DebugIdentity", "DebugNumericSummary"], r".*/read", None
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -209,8 +209,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
op_type_regex_whitelist=None,
tolerate_debug_op_creation_failures=True)
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -241,8 +241,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
14, len(dump.get_tensors("v/read", 0, "DebugNumericSummary")[0]))
def testTensorBoardDebugHookWorks(self):
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -286,8 +286,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase):
self._server.query_source_file_line(__file__, 1)
def testTensorBoardDebugHookDisablingTracebackSourceCodeSendingWorks(self):
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess = session.Session(
@@ -381,8 +381,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenDebugNodes(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_1")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_1")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
@@ -451,8 +451,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_1")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_1")
# These two nodes have names that match those in the
# toggle_watch_on_core_metadata argument used when calling
# start_server_on_separate_thread().
@@ -491,7 +491,7 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testToggleEnableTwoDebugWatchesNoCrosstalkBetweenServers(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v = variables.Variable(50.0, name="v")
+ v = variables.VariableV1(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
@@ -534,8 +534,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testToggleBreakpointsWorks(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_2")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
@@ -592,8 +592,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testTensorBoardDebuggerWrapperToggleBreakpointsWorks(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_2")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
@@ -665,8 +665,8 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testTensorBoardDebuggerWrapperDisablingTracebackSourceSendingWorks(self):
with session.Session(
config=session_debug_testlib.no_rewrite_session_config()) as sess:
- v_1 = variables.Variable(50.0, name="v_1")
- v_2 = variables.Variable(-50.0, name="v_2")
+ v_1 = variables.VariableV1(50.0, name="v_1")
+ v_2 = variables.VariableV1(-50.0, name="v_2")
delta_1 = constant_op.constant(5.0, name="delta_1")
delta_2 = constant_op.constant(-5.0, name="delta_2")
inc_v_1 = state_ops.assign_add(v_1, delta_1, name="inc_v_1")
@@ -699,7 +699,7 @@ class SessionDebugGrpcGatingTest(test_util.TensorFlowTestCase):
def testGetGrpcDebugWatchesReturnsCorrectAnswer(self):
with session.Session() as sess:
- v = variables.Variable(50.0, name="v")
+ v = variables.VariableV1(50.0, name="v")
delta = constant_op.constant(5.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
@@ -743,7 +743,7 @@ class DelayedDebugServerTest(test_util.TensorFlowTestCase):
with self.cached_session() as sess:
a_init = constant_op.constant(42.0, name="a_init")
- a = variables.Variable(a_init, name="a")
+ a = variables.VariableV1(a_init, name="a")
def watch_fn(fetches, feeds):
del fetches, feeds
diff --git a/tensorflow/python/debug/lib/session_debug_testlib.py b/tensorflow/python/debug/lib/session_debug_testlib.py
index 070d9c4cd7..25ef91b575 100644
--- a/tensorflow/python/debug/lib/session_debug_testlib.py
+++ b/tensorflow/python/debug/lib/session_debug_testlib.py
@@ -70,7 +70,7 @@ class _RNNCellForTest(rnn_cell_impl.RNNCell):
def __init__(self, input_output_size, state_size):
self._input_output_size = input_output_size
self._state_size = state_size
- self._w = variables.Variable(1.0, dtype=dtypes.float32, name="w")
+ self._w = variables.VariableV1(1.0, dtype=dtypes.float32, name="w")
@property
def output_size(self):
@@ -182,9 +182,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
w_name = "w"
u_init = constant_op.constant(u_init_val, shape=[2, 2])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v_init = constant_op.constant(v_init_val, shape=[2, 1])
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
w = math_ops.matmul(u, v, name=w_name)
@@ -221,8 +221,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testCopyNodesHaveCorrectDebugOpsAndURLsAttributeValues(self):
with session.Session() as sess:
- u = variables.Variable(2.1, name="u")
- v = variables.Variable(20.0, name="v")
+ u = variables.VariableV1(2.1, name="u")
+ v = variables.VariableV1(20.0, name="v")
w = math_ops.multiply(u, v, name="w")
sess.run(variables.global_variables_initializer())
@@ -324,8 +324,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
str1_name = "str1"
str2_name = "str2"
- str1 = variables.Variable(str1_init, name=str1_name)
- str2 = variables.Variable(str2_init, name=str2_name)
+ str1 = variables.VariableV1(str1_init, name=str1_name)
+ str2 = variables.VariableV1(str2_init, name=str2_name)
# Concatenate str1 and str2
str_concat = math_ops.add(str1, str2, name="str_concat")
@@ -387,9 +387,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
s_name = "%s/s" % op_namespace
u_init = constant_op.constant(u_init_val, shape=[2, 2])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
s_init = constant_op.constant(s_init_val)
- s = variables.Variable(s_init, name=s_name)
+ s = variables.VariableV1(s_init, name=s_name)
run_options = config_pb2.RunOptions(output_partition_graphs=True)
debug_urls = self._debug_urls()
@@ -439,7 +439,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
u_init_val = np.array(11.0)
u_init = constant_op.constant(u_init_val)
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
# "v" is the increment.
v_name = "testDumpToFileWhileLoop/v"
@@ -447,7 +447,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
v_init_val = np.array(2.0)
v_init = constant_op.constant(v_init_val)
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
u.initializer.run()
v.initializer.run()
@@ -605,8 +605,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugCondWatchingWholeGraphWorks(self):
with session.Session() as sess:
- x = variables.Variable(10.0, name="x")
- y = variables.Variable(20.0, name="y")
+ x = variables.VariableV1(10.0, name="x")
+ y = variables.VariableV1(20.0, name="y")
cond = control_flow_ops.cond(
x > y, lambda: math_ops.add(x, 1), lambda: math_ops.add(y, 1))
@@ -628,9 +628,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
z_name = "testFindNodesWithBadTensorValues/z"
u_init = constant_op.constant([2.0, 4.0])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v_init = constant_op.constant([2.0, 1.0])
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
# Expected output: [0.0, 3.0]
w = math_ops.subtract(u, v, name=w_name)
@@ -679,9 +679,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
z_name = "testFindInfOrNanWithOpNameExclusion/z"
u_init = constant_op.constant([2.0, 4.0])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v_init = constant_op.constant([2.0, 1.0])
- v = variables.Variable(v_init, name=v_name)
+ v = variables.VariableV1(v_init, name=v_name)
# Expected output: [0.0, 3.0]
w = math_ops.subtract(u, v, name=w_name)
@@ -725,7 +725,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
w_name = "testDumpGraphStructureLookup/w"
u_init = constant_op.constant([2.0, 4.0])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v = math_ops.add(u, u, name=v_name)
w = math_ops.add(v, v, name=w_name)
@@ -859,9 +859,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testGraphPathFindingOnControlEdgesWorks(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- v1 = variables.Variable(1.0, name="v1")
- v2 = variables.Variable(2.0, name="v2")
- v3 = variables.Variable(3.0, name="v3")
+ v1 = variables.VariableV1(1.0, name="v1")
+ v2 = variables.VariableV1(2.0, name="v2")
+ v3 = variables.VariableV1(3.0, name="v3")
a = math_ops.add(v1, v2, name="a")
with ops.control_dependencies([a]):
c = math_ops.subtract(v3, v3, name="c")
@@ -875,8 +875,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testGraphPathFindingReverseRefEdgeWorks(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- v = variables.Variable(10.0, name="v")
- delta = variables.Variable(1.0, name="delta")
+ v = variables.VariableV1(10.0, name="v")
+ delta = variables.VariableV1(1.0, name="delta")
inc_v = state_ops.assign_add(v, delta, name="inc_v")
sess.run(variables.global_variables_initializer())
@@ -894,7 +894,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
w_name = "testDumpCausalityCheck/w"
u_init = constant_op.constant([2.0, 4.0])
- u = variables.Variable(u_init, name=u_name)
+ u = variables.VariableV1(u_init, name=u_name)
v = math_ops.add(u, u, name=v_name)
w = math_ops.add(v, v, name=w_name)
@@ -980,7 +980,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
w_name = "oneOfTwoSlots/w"
y_name = "oneOfTwoSlots/y"
- x = variables.Variable([1, 3, 3, 7], dtype=dtypes.int32, name=x_name)
+ x = variables.VariableV1([1, 3, 3, 7], dtype=dtypes.int32, name=x_name)
sess.run(x.initializer)
unique_x, indices, _ = array_ops.unique_with_counts(x, name=u_name)
@@ -1039,9 +1039,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
with session.Session(config=no_rewrite_session_config()) as sess:
u_init = constant_op.constant(10.0)
- u = variables.Variable(u_init, name="gdo/u")
+ u = variables.VariableV1(u_init, name="gdo/u")
v_init = constant_op.constant(20.0)
- v = variables.Variable(v_init, name="gdo/v")
+ v = variables.VariableV1(v_init, name="gdo/v")
w = math_ops.multiply(u, v, name="gdo/w")
# gdo stands for GradientDescentOptimizer.
@@ -1085,7 +1085,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
with session.Session() as sess:
x_init = constant_op.constant([2, 2, 3, 5, 5])
- x = variables.Variable(x_init, name="unconnected/x")
+ x = variables.VariableV1(x_init, name="unconnected/x")
# The UniqueOp (tf.unique) has two output slots. Use only slot 0 in the
# graph. Let the debugger watch the unused slot 1.
@@ -1225,14 +1225,14 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryOnInitializedTensorGivesCorrectResult(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- a = variables.Variable(
+ a = variables.VariableV1(
[
np.nan, np.nan, 0.0, 0.0, 0.0, -1.0, -3.0, 3.0, 7.0, -np.inf,
-np.inf, np.inf, np.inf, np.inf, np.inf, np.inf, np.nan, np.nan
],
dtype=np.float32,
name="numeric_summary/a")
- b = variables.Variable(
+ b = variables.VariableV1(
[0.0] * 18, dtype=np.float32, name="numeric_summary/b")
c = math_ops.add(a, b, name="numeric_summary/c")
@@ -1249,7 +1249,7 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryOnUninitializedTensorGivesCorrectResult(self):
with session.Session() as sess:
- a = variables.Variable(
+ a = variables.VariableV1(
[42], dtype=np.float32, name="numeric_summary_uninit/a")
_, dump = self._debug_run_and_get_dump(
@@ -1275,9 +1275,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryFailureIsToleratedWhenOrdered(self):
with session.Session() as sess:
- a = variables.Variable("1", name="a")
- b = variables.Variable("3", name="b")
- c = variables.Variable("2", name="c")
+ a = variables.VariableV1("1", name="a")
+ b = variables.VariableV1("3", name="b")
+ c = variables.VariableV1("2", name="c")
d = math_ops.add(a, b, name="d")
e = math_ops.add(d, c, name="e")
@@ -1313,9 +1313,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryInvalidAttributesStringAreCaught(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- a = variables.Variable(10.0, name="a")
- b = variables.Variable(0.0, name="b")
- c = variables.Variable(0.0, name="c")
+ a = variables.VariableV1(10.0, name="a")
+ b = variables.VariableV1(0.0, name="b")
+ c = variables.VariableV1(0.0, name="c")
x = math_ops.divide(a, b, name="x")
y = math_ops.multiply(x, c, name="y")
@@ -1361,9 +1361,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryMuteOnHealthyMutesOnlyHealthyTensorDumps(self):
with session.Session(config=no_rewrite_session_config()) as sess:
- a = variables.Variable(10.0, name="a")
- b = variables.Variable(0.0, name="b")
- c = variables.Variable(0.0, name="c")
+ a = variables.VariableV1(10.0, name="a")
+ b = variables.VariableV1(0.0, name="b")
+ c = variables.VariableV1(0.0, name="c")
x = math_ops.divide(a, b, name="x")
y = math_ops.multiply(x, c, name="y")
@@ -1396,8 +1396,8 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testDebugNumericSummaryMuteOnHealthyAndCustomBoundsWork(self):
with session.Session() as sess:
- a = variables.Variable([10.0, 10.0], name="a")
- b = variables.Variable([10.0, 2.0], name="b")
+ a = variables.VariableV1([10.0, 10.0], name="a")
+ b = variables.VariableV1([10.0, 2.0], name="b")
x = math_ops.add(a, b, name="x") # [20.0, 12.0]
y = math_ops.divide(x, b, name="y") # [2.0, 6.0]
@@ -1436,9 +1436,9 @@ class SessionDebugTestBase(test_util.TensorFlowTestCase):
def testLookUpNodePythonTracebackWorks(self):
with session.Session() as sess:
u_init = constant_op.constant(10.0)
- u = variables.Variable(u_init, name="traceback/u")
+ u = variables.VariableV1(u_init, name="traceback/u")
v_init = constant_op.constant(20.0)
- v = variables.Variable(v_init, name="traceback/v")
+ v = variables.VariableV1(v_init, name="traceback/v")
w = math_ops.multiply(u, v, name="traceback/w")
@@ -1487,7 +1487,7 @@ class DebugConcurrentRunCallsTest(test_util.TensorFlowTestCase):
self.skipTest("No testing concurrent runs on a single GPU.")
with session.Session() as sess:
- v = variables.Variable(30.0, name="v")
+ v = variables.VariableV1(30.0, name="v")
constants = []
for i in xrange(self._num_concurrent_runs):
constants.append(constant_op.constant(1.0, name="c%d" % i))
diff --git a/tensorflow/python/debug/lib/stepper_test.py b/tensorflow/python/debug/lib/stepper_test.py
index 9a3d0efabf..3839c67198 100644
--- a/tensorflow/python/debug/lib/stepper_test.py
+++ b/tensorflow/python/debug/lib/stepper_test.py
@@ -36,8 +36,8 @@ from tensorflow.python.training import gradient_descent
class StepperTest(test_util.TensorFlowTestCase):
def setUp(self):
- self.a = variables.Variable(2.0, name="a")
- self.b = variables.Variable(3.0, name="b")
+ self.a = variables.VariableV1(2.0, name="a")
+ self.b = variables.VariableV1(3.0, name="b")
self.c = math_ops.multiply(self.a, self.b, name="c") # Should be 6.0.
self.d = math_ops.multiply(self.a, self.a, name="d") # Should be 4.0.
@@ -49,7 +49,7 @@ class StepperTest(test_util.TensorFlowTestCase):
# The there nodes x, y and z form a graph with "cross-links" in. I.e., x
# and y are both direct inputs to z, but x is also a direct input to y.
- self.x = variables.Variable(2.0, name="x") # Should be 2.0
+ self.x = variables.VariableV1(2.0, name="x") # Should be 2.0
self.y = math_ops.negative(self.x, name="y") # Should be -2.0.
self.z = math_ops.multiply(self.x, self.y, name="z") # Should be -4.0.
@@ -580,7 +580,7 @@ class StepperTestWithPlaceHolders(test_util.TensorFlowTestCase):
class StepperAssignAddTest(test_util.TensorFlowTestCase):
def setUp(self):
- self.v = variables.Variable(10.0, name="v")
+ self.v = variables.VariableV1(10.0, name="v")
self.p = math_ops.add(self.v, self.v, name="p")
self.q = math_ops.multiply(self.p, self.p, name="q")
self.delta = constant_op.constant(2.0, name="delta")
@@ -711,9 +711,9 @@ class StepperBackwardRunTest(test_util.TensorFlowTestCase):
Construct a backward graph using the GradientDescentOptimizer.
"""
- self.a = variables.Variable(1.0, name="a")
- self.b = variables.Variable(2.0, name="b")
- self.c = variables.Variable(4.0, name="c")
+ self.a = variables.VariableV1(1.0, name="a")
+ self.b = variables.VariableV1(2.0, name="b")
+ self.c = variables.VariableV1(4.0, name="c")
self.d = math_ops.multiply(self.a, self.b, name="d")
self.e = math_ops.multiply(self.b, self.c, name="e")
self.f = math_ops.multiply(self.d, self.e, name="f")
diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
index 254201c393..11011a5c13 100644
--- a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py
@@ -46,7 +46,7 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def setUp(self):
self.session_root = tempfile.mkdtemp()
- self.v = variables.Variable(10.0, dtype=dtypes.float32, name="v")
+ self.v = variables.VariableV1(10.0, dtype=dtypes.float32, name="v")
self.delta = constant_op.constant(1.0, dtype=dtypes.float32, name="delta")
self.eta = constant_op.constant(-1.4, dtype=dtypes.float32, name="eta")
self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")
diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
index 05c9eaa4d2..149a7497df 100644
--- a/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
+++ b/tensorflow/python/debug/wrappers/local_cli_wrapper_test.py
@@ -132,8 +132,8 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def setUp(self):
self._tmp_dir = tempfile.mktemp()
- self.v = variables.Variable(10.0, name="v")
- self.w = variables.Variable(21.0, name="w")
+ self.v = variables.VariableV1(10.0, name="v")
+ self.w = variables.VariableV1(21.0, name="w")
self.delta = constant_op.constant(1.0, name="delta")
self.inc_v = state_ops.assign_add(self.v, self.delta, name="inc_v")
@@ -358,7 +358,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def testDebuggingMakeCallableTensorRunnerWorks(self):
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
- v = variables.Variable(42)
+ v = variables.VariableV1(42)
tensor_runner = wrapped_sess.make_callable(v)
self.sess.run(v.initializer)
@@ -382,7 +382,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def testDebuggingMakeCallableOperationRunnerWorks(self):
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"], ["run"]], self.sess, dump_root=self._tmp_dir)
- v = variables.Variable(10.0)
+ v = variables.VariableV1(10.0)
inc_v = state_ops.assign_add(v, 1.0)
op_runner = wrapped_sess.make_callable(inc_v.op)
self.sess.run(v.initializer)
@@ -403,7 +403,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertEqual(1, len(wrapped_sess.observers["debug_dumps"]))
def testDebuggingMakeCallableFromOptionsWithZeroFeedWorks(self):
- variable_1 = variables.Variable(
+ variable_1 = variables.VariableV1(
10.5, dtype=dtypes.float32, name="variable_1")
a = math_ops.add(variable_1, variable_1, "callable_a")
math_ops.add(a, a, "callable_b")
@@ -480,7 +480,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
self.assertItemsEqual(["callable_a", "callable_b"], node_names)
def testDebugMakeCallableFromOptionsWithCustomOptionsAndMetadataWorks(self):
- variable_1 = variables.Variable(
+ variable_1 = variables.VariableV1(
10.5, dtype=dtypes.float32, name="variable_1")
a = math_ops.add(variable_1, variable_1, "callable_a")
math_ops.add(a, a, "callable_b")
@@ -528,7 +528,7 @@ class LocalCLIDebugWrapperSessionTest(test_util.TensorFlowTestCase):
def testRuntimeErrorBeforeGraphExecutionIsRaised(self):
# Use an impossible device name to cause an error before graph execution.
with ops.device("/device:GPU:1337"):
- w = variables.Variable([1.0] * 10, name="w")
+ w = variables.VariableV1([1.0] * 10, name="w")
wrapped_sess = LocalCLIDebuggerWrapperSessionForTest(
[["run"]], self.sess, dump_root=self._tmp_dir)
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index bd3562f1ff..b9b77d4a5b 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -126,7 +126,7 @@ class _WorkerContext(object):
replicated training.
task_id: an integer indicating id of the corresponding task. It can be
None if it is local training or in-graph replicated training.
- session_config: an optional @{tf.ConfigProto} object.
+ session_config: an optional `tf.ConfigProto` object.
rpc_layer: optional string specifying the RPC protocol for communication
with worker masters. If None or empty, hosts in the `cluster_spec` will
be used directly.
@@ -685,7 +685,7 @@ def run_distribute_coordinator(worker_fn,
in a cluster. If not set or empty, fall back to local training.
task_type: the current task type, optional if this is a client.
task_id: the current task id, optional if this is a client.
- session_config: an optional @{tf.ConfigProto} object which will be passed
+ session_config: an optional `tf.ConfigProto` object which will be passed
to `strategy`'s `configure` method and used to create a session.
rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py
index 8daa34c885..0289689134 100644
--- a/tensorflow/python/distribute/estimator_training.py
+++ b/tensorflow/python/distribute/estimator_training.py
@@ -62,7 +62,7 @@ def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
# Sort task names in cluster by "chief"/"master", "evaluator", "worker"
# and "ps". More details can be found at the documentation of
- # @{tf.estimator.RunConfig.global_id_in_cluster}.
+ # `tf.estimator.RunConfig.global_id_in_cluster`.
task_type_ordered_list = []
if chief_task_type in cluster_spec.jobs:
task_type_ordered_list = [chief_task_type]
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index d3d997e6df..cae809a7c3 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -37,6 +37,7 @@ cc_library(
"//tensorflow/python:safe_ptr",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
],
)
@@ -250,6 +251,7 @@ py_library(
"//tensorflow/python:gradients_impl",
"//tensorflow/python:graph_to_function_def",
"//tensorflow/python:util",
+ "//tensorflow/python/autograph",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:core",
"//tensorflow/python/eager:execute",
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 78f3198011..deac29111f 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -619,7 +619,7 @@ pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
def _handle_or_self(x):
"""If x is ResourceVariable, return its handle, else x."""
- if isinstance(x, resource_variable_ops.ResourceVariable):
+ if resource_variable_ops.is_resource_variable(x):
x = x.handle
return x
diff --git a/tensorflow/python/eager/benchmarks_test.py b/tensorflow/python/eager/benchmarks_test.py
index 3fe79ef244..2b0118c07f 100644
--- a/tensorflow/python/eager/benchmarks_test.py
+++ b/tensorflow/python/eager/benchmarks_test.py
@@ -353,7 +353,7 @@ class MicroBenchmarks(test.Benchmark):
num_iters,
execution_mode=None):
f = function.defun(math_ops.matmul)
- func = lambda: f(m, m, transpose_b)
+ func = lambda: f(m, m, transpose_b=transpose_b)
self._run(func, num_iters, execution_mode=execution_mode)
def _benchmark_defun_matmul_forward_backward(self,
@@ -366,7 +366,7 @@ class MicroBenchmarks(test.Benchmark):
def func():
with backprop.GradientTape() as gt:
gt.watch(m)
- y = f(m, m, transpose_b)
+ y = f(m, m, transpose_b=transpose_b)
_ = gt.gradient(y, m)
self._run(func, num_iters, execution_mode=execution_mode)
diff --git a/tensorflow/python/eager/core_test.py b/tensorflow/python/eager/core_test.py
index fb5442b646..e601aa376f 100644
--- a/tensorflow/python/eager/core_test.py
+++ b/tensorflow/python/eager/core_test.py
@@ -631,6 +631,34 @@ class TFETest(test_util.TensorFlowTestCase):
for t in tensors:
self.assertIsInstance(t, ops.EagerTensor)
+ def testSmallIntegerOpsForcedToCPU(self):
+ if not context.context().num_gpus():
+ self.skipTest('No GPUs found')
+
+ a = constant_op.constant((1, 2, 3, 4, 5), dtype=dtypes.int64)
+ b = constant_op.constant((2, 3, 4, 5, 6), dtype=dtypes.int64)
+ with context.device('gpu:0'):
+ c = a + b
+
+ # Op forced to CPU since all constants are integers and small.
+ self.assertEqual(c.device, '/job:localhost/replica:0/task:0/device:CPU:0')
+
+ a = array_ops.zeros((8, 10), dtype=dtypes.int64)
+ b = array_ops.ones((8, 10), dtype=dtypes.int64)
+
+ with context.device('gpu:0'):
+ c = a + b
+
+ # Op not forced to CPU since the tensors are larger than 64 elements.
+ self.assertEqual(c.device, '/job:localhost/replica:0/task:0/device:GPU:0')
+
+ a = constant_op.constant((1, 2, 3, 4, 5), dtype=dtypes.float32)
+ b = constant_op.constant((2, 3, 4, 5, 6), dtype=dtypes.float32)
+ with context.device('gpu:0'):
+ c = a + b
+
+ # Op not forced to CPU since the constants are not integers.
+ self.assertEqual(c.device, '/job:localhost/replica:0/task:0/device:GPU:0')
class SendRecvTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index b28befeb62..93168826b1 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import functools
+import re
import sys
import threading
import weakref
@@ -30,6 +31,7 @@ import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import function_pb2
+from tensorflow.python import autograph
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
@@ -61,9 +63,15 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce
# This is to avoid a circular dependency with gradients_impl
gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
+FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
+BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
# TODO(scottzhu): Update this to allow arbitrary attribute names in future.
-WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_"
+WHITELIST_FUNCTION_ATTRIBUTE_REGEX = [
+ "experimental_.*",
+ FORWARD_FUNCTION_ATTRIBUTE_NAME,
+ BACKWARD_FUNCTION_ATTRIBUTE_NAME
+]
def _create_substitute_placeholder(value, name=None, dtype=None):
@@ -140,10 +148,11 @@ def _parse_func_attrs(attributes):
"""
attrs = {}
for key, value in attributes.items():
- if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX):
+ if not any([re.match(reg, key)
+ for reg in WHITELIST_FUNCTION_ATTRIBUTE_REGEX]):
raise ValueError("Attribute name is not whitelisted. "
"Whitelisted: prefix %s, got: %s" %
- (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key))
+ (WHITELIST_FUNCTION_ATTRIBUTE_REGEX, key))
if isinstance(value, attr_value_pb2.AttrValue):
attrs[key] = value
@@ -154,7 +163,7 @@ def _parse_func_attrs(attributes):
attrs[key] = attr_value_pb2.AttrValue(i=value)
elif isinstance(value, float):
attrs[key] = attr_value_pb2.AttrValue(f=value)
- elif isinstance(value, str):
+ elif isinstance(value, (str, bytes)):
attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
else:
raise ValueError("Unsupported attribute type for %s with type %s" %
@@ -486,6 +495,9 @@ class _EagerDefinedFunction(object):
Returns:
The outputs of the function call.
+
+ Raises:
+ ValueError: if the number of arguments is incorrect.
"""
executing_eagerly = ctx.executing_eagerly()
@@ -519,6 +531,10 @@ class _EagerDefinedFunction(object):
# TODO(akshayka): Either remove this if the FunctionLibraryRuntime
# creates `PartitionedCallOp` kernels by default, or remove the previous
# branch if a TPU kernel is registered for `PartitionedCall`.
+ if len(args) != len(self.signature.input_arg):
+ raise ValueError(
+ "Arguments and signature arguments do not match: %s %s " %
+ (len(args), len(list(self.signature.input_arg))))
outputs = functional_ops.partitioned_call(
args=args,
f=self,
@@ -705,6 +721,7 @@ class Function(object):
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
backwards_graph = FuncGraph(_backward_name(self._func_graph.name))
+ forward_function_name = _forward_name(self._func_graph.name)
with backwards_graph.as_default():
gradients_wrt_outputs = [
graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs
@@ -715,11 +732,11 @@ class Function(object):
grad_ys=gradients_wrt_outputs,
src_graph=self._func_graph)
- self._forward_function = _EagerDefinedFunction(
- _forward_name(
- self._func_graph.name), self._func_graph, self._func_graph.inputs,
- self._func_graph.outputs + list(backwards_graph.captures.keys()),
- self._attrs)
+ backwards_graph_captures = list(backwards_graph.captures.keys())
+
+ backward_function_attr = _parse_func_attrs(
+ {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
+ backward_function_attr.update(self._attrs)
# The ordering of `backwards_graph.inputs` is important: inputs of
# `self._backward_graph_function` correspond to outputs of
@@ -732,7 +749,16 @@ class Function(object):
grad for grad in _flatten(gradients_wrt_inputs) if grad is not None)
backwards_graph.structured_outputs = gradients_wrt_inputs
self._backward_graph_function = Function(
- backwards_graph, attrs=self._attrs)
+ backwards_graph, attrs=backward_function_attr)
+
+ forward_function_attr = _parse_func_attrs({
+ BACKWARD_FUNCTION_ATTRIBUTE_NAME:
+ self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access
+ forward_function_attr.update(self._attrs)
+ self._forward_function = _EagerDefinedFunction(
+ forward_function_name, self._func_graph, self._func_graph.inputs,
+ self._func_graph.outputs + backwards_graph_captures,
+ forward_function_attr)
def _backprop_call(self, args):
"""Calls the forward function and records the result on a tape.
@@ -829,20 +855,12 @@ class Function(object):
return ret
-def _get_defun_inputs_from_signature(signature):
- """Maps a signature to graph-construction inputs."""
- function_inputs = [
- graph_placeholder(spec.dtype, spec.shape)
- for spec in nest.flatten(signature)
- ]
- return nest.pack_sequence_as(signature, function_inputs)
-
-
def _get_defun_inputs_from_args(args):
"""Maps python function args to graph-construction inputs."""
function_inputs = [
graph_placeholder(arg.dtype, arg.shape)
- if isinstance(arg, ops.Tensor) else arg for arg in nest.flatten(args)
+ if isinstance(arg, (ops.Tensor, tensor_spec.TensorSpec))
+ else arg for arg in nest.flatten(args)
]
return nest.pack_sequence_as(args, function_inputs)
@@ -852,7 +870,8 @@ def func_graph_from_py_func(name,
args,
kwargs,
signature=None,
- func_graph=None):
+ func_graph=None,
+ experimental_autograph=False):
"""Returns a `FuncGraph` generated from `python_func`.
Args:
@@ -869,6 +888,8 @@ def func_graph_from_py_func(name,
inputs.
func_graph: Optional. An instance of FuncGraph. If provided, we will use
this graph else a new one is built and returned.
+ experimental_autograph: whether to use autograph to compile `python_func`.
+ See https://www.tensorflow.org/guide/autograph for more information.
Returns:
A FuncGraph.
@@ -883,12 +904,12 @@ def func_graph_from_py_func(name,
with func_graph.as_default(), AutomaticControlDependencies() as a:
variable_scope.get_variable_scope().set_use_resource(True)
- if signature is None:
- func_args = _get_defun_inputs_from_args(args)
- func_kwargs = _get_defun_inputs_from_args(kwargs)
- else:
- func_args = _get_defun_inputs_from_signature(signature)
- func_kwargs = {}
+ if signature is not None:
+ args = signature
+ kwargs = {}
+
+ func_args = _get_defun_inputs_from_args(args)
+ func_kwargs = _get_defun_inputs_from_args(kwargs)
# Note: `nest.flatten` sorts by keys, as does `_deterministic_dict_values`.
# Variables to help check whether mutation happens in calling the function
@@ -914,7 +935,17 @@ def func_graph_from_py_func(name,
this_tape = tape.push_new_tape()
try:
- func_outputs = python_func(*func_args, **func_kwargs)
+ if experimental_autograph:
+ func_outputs = autograph.converted_call(
+ python_func,
+ autograph.ConversionOptions(
+ verbose=True,
+ recursive=True,
+ force_conversion=False,
+ strip_decorators=(defun,),
+ arg_types={}), *func_args, **func_kwargs)
+ else:
+ func_outputs = python_func(*func_args, **func_kwargs)
# invariant: `func_outputs` contains only Tensors and `None`s.
func_outputs = nest.map_structure(convert, func_outputs)
@@ -986,52 +1017,8 @@ def func_graph_from_py_func(name,
return func_graph
-_TensorType = collections.namedtuple("_TensorType", ["dtype", "shape"])
-
-
-def _encode_arg(arg):
- """A canonical representation for this argument, for use in a cache key."""
-
- # `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
- # are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
- # are used for both performance reasons, as much TensorFlow code specializes
- # on known shapes to produce slimmer graphs, and correctness, as some
- # high-level APIs require shapes to be fully-known.
- #
- # TODO(akshayka): Add support for sparse tensors.
- #
- # pylint: disable=protected-access
- if isinstance(arg, ops.Tensor):
- return _TensorType(arg.dtype, arg._shape_tuple())
- elif isinstance(arg, ops.IndexedSlices):
- if arg.dense_shape is not None:
- return tuple([
- _TensorType(arg.values.dtype, arg.values._shape_tuple()),
- _TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
- _TensorType(arg.dense_shape.dtype, arg.dense_shape._shape_tuple()),
- ])
- else:
- return tuple([
- _TensorType(arg.values.dtype, arg.values._shape_tuple()),
- _TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
- ])
- # pylint: enable=protected-access
- elif isinstance(arg, (list, tuple)):
- return tuple([_encode_arg(elem) for elem in arg])
- elif isinstance(arg, dict):
- return tuple(
- (_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg))
- else:
- try:
- # If possible, keep only a weak reference to Python objects. Weak
- # references hash to the same value as the original object.
- # TODO(allenl): Clean up dead functions and their cache keys if the cache
- # gets large. Right now creating objects with a defunned method, calling
- # the method, and losing a reference to the object in a loop will leak
- # memory here.
- return weakref.ref(arg)
- except TypeError:
- return arg
+pywrap_tensorflow.RegisterType("Tensor", ops.Tensor)
+pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices)
def _deterministic_dict_values(dictionary):
@@ -1054,7 +1041,8 @@ class PolymorphicFunction(object):
python_function,
name,
input_signature=None,
- attributes=None):
+ attributes=None,
+ experimental_autograph=False):
"""Initializes a polymorphic function.
Args:
@@ -1064,7 +1052,10 @@ class PolymorphicFunction(object):
specifying the input signature of this function. If `None`, a separate
function is instantiated for each inferred input signature.
attributes: dict, extra keyword arguments that will be added as attribute
- of the function.
+ of the function.
+ experimental_autograph: whether to use autograph to compile
+ `python_function`. See https://www.tensorflow.org/guide/autograph for
+ more information.
Raises:
ValueError: if `input_signature` is not None and the `python_function`'s
@@ -1080,6 +1071,7 @@ class PolymorphicFunction(object):
self._args_to_prepend = tuple()
self._kwargs_to_include = {}
self._name = name
+ self._experimental_autograph = experimental_autograph
self._function_cache = collections.OrderedDict()
self._function_attributes = attributes or {}
@@ -1101,6 +1093,8 @@ class PolymorphicFunction(object):
offset + index: default
for index, default in enumerate(fullargspec.defaults or [])
}
+ self._default_values = fullargspec.defaults
+ self._default_values_start_index = offset
if input_signature is None:
self._input_signature = None
else:
@@ -1161,30 +1155,29 @@ class PolymorphicFunction(object):
"""Computes the cache key given inputs and execution context."""
if self._input_signature is None:
inputs = (args, kwargs) if kwargs else args
- cache_key = tuple(_encode_arg(arg) for arg in inputs)
+ cache_key = pywrap_tensorflow.TFE_Py_EncodeArg(inputs)
else:
del args, kwargs
cache_key = self._flat_input_signature
+ ctx = context.context()
with ops.init_scope():
- init_graph = ops.get_default_graph()
-
# The graph, or whether we're executing eagerly, should be a part of the
# cache key so we don't improperly capture tensors such as variables.
- executing_eagerly = context.executing_eagerly()
- execution_context = executing_eagerly or init_graph
+ executing_eagerly = ctx.executing_eagerly()
+ execution_context = executing_eagerly or ops.get_default_graph()
- default_graph = ops.get_default_graph()
- # Putting the device in the cache key ensures that call-site device
- # annotations are respected.
- device_functions = _get_device_functions(context.context(), default_graph)
-
- # `ops.colocate_with` directives translate into `ops.device` directives when
- # eager execution is enabled.
- colocation_stack = (() if executing_eagerly else
- tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
+ if executing_eagerly:
+ device_functions = (pydev.merge_device(ctx.device_name),)
+ colocation_stack = ()
+ else:
+ default_graph = ops.get_default_graph()
+ # Putting the device in the cache key ensures that call-site device
+ # annotations are respected.
+ device_functions = tuple(default_graph._device_functions_outer_to_inner) # pylint: disable=protected-access
+ colocation_stack = tuple(default_graph._colocation_stack.peek_objs()) # pylint: disable=protected-access
- return cache_key + (execution_context, device_functions, colocation_stack)
+ return (cache_key, execution_context, device_functions, colocation_stack)
def _canonicalize_function_inputs(self, *args, **kwargs):
"""Canonicalizes `args` and `kwargs`.
@@ -1209,35 +1202,44 @@ class PolymorphicFunction(object):
"""
args = self._args_to_prepend + args
kwargs = dict(kwargs, **self._kwargs_to_include)
- # Maps from index of arg to its corresponding value, according to `args`
- # and `kwargs`; seeded with the default values for the named args that
- # aren't in `args`.
- arg_indices_to_values = {
- index: default
- for index, default in six.iteritems(self._arg_indices_to_default_values)
- if index >= len(args)
- }
- consumed_args = []
- for arg, value in six.iteritems(kwargs):
- index = self._args_to_indices.get(arg, None)
- if index is not None:
- arg_indices_to_values[index] = value
- consumed_args.append(arg)
- elif self._input_signature is not None:
- raise ValueError("Cannot define a TensorFlow function from a Python "
- "function with keyword arguments when "
- "input_signature is provided.")
- for arg in consumed_args:
- # After this loop, `kwargs` will only contain true keyword arguments, as
- # opposed to named arguments called in a keyword-like fashion.
- kwargs.pop(arg)
- inputs = args + _deterministic_dict_values(arg_indices_to_values)
+ if not kwargs:
+ if self._default_values:
+ inputs = args + self._default_values[len(args) -
+ self._default_values_start_index:]
+ else:
+ inputs = args
+ else:
+ # Maps from index of arg to its corresponding value, according to `args`
+ # and `kwargs`; seeded with the default values for the named args that
+ # aren't in `args`.
+ arg_indices_to_values = {
+ index: default for index, default in six.iteritems(
+ self._arg_indices_to_default_values) if index >= len(args)
+ }
+ consumed_args = []
+ for arg, value in six.iteritems(kwargs):
+ index = self._args_to_indices.get(arg, None)
+ if index is not None:
+ arg_indices_to_values[index] = value
+ consumed_args.append(arg)
+ elif self._input_signature is not None:
+ raise ValueError("Cannot define a TensorFlow function from a Python "
+ "function with keyword arguments when "
+ "input_signature is provided.")
+ for arg in consumed_args:
+ # After this loop, `kwargs` will only contain true keyword arguments, as
+ # opposed to named arguments called in a keyword-like fashion.
+ kwargs.pop(arg)
+ inputs = args + _deterministic_dict_values(arg_indices_to_values)
flat_inputs = nest.flatten(inputs)
# Check for NumPy arrays in arguments and convert them to Tensors.
+ # TODO(nareshmodi): Skip ndarray conversion to tensor altogether, perhaps
+ # finding a way to store them directly in the cache key (currently not
+ # possible since ndarrays are not hashable).
need_packing = False
for index, value in enumerate(flat_inputs):
- if isinstance(value, np.ndarray):
+ if type(value) == np.ndarray:
flat_inputs[index] = constant_op.constant(value)
need_packing = True
if need_packing:
@@ -1295,8 +1297,13 @@ class PolymorphicFunction(object):
if graph_function is None:
graph_function = Function(
- func_graph_from_py_func(self._name, self._python_function, args,
- kwargs, self._input_signature),
+ func_graph_from_py_func(
+ self._name,
+ self._python_function,
+ args,
+ kwargs,
+ self._input_signature,
+ experimental_autograph=self._experimental_autograph),
self._function_attributes)
self._function_cache[cache_key] = graph_function
return graph_function, [
@@ -1328,8 +1335,25 @@ def register(func, *args, **kwargs):
"Got type: %s" % type(func))
concrete_func = func.get_concrete_function(*args, **kwargs)
graph = ops.get_default_graph()
- concrete_func._inference_function.add_to_graph(graph) # pylint: disable=protected-access
- # TODO(scottzhu): support concrete_func._backward_graph_function in future.
+
+ # There are two situations for the actual call of a defun:
+ # 1. If none of the input args are resource variables or watch by any tape,
+ # it will run the _inference_function of concrete_func for forward pass, and
+ # the gradient will be generated by standard mechanism.
+ # 2. Otherwise, defun will create two functions, one for forward pass, and the
+ # backward pass will be created via tape.
+ # When registering the function, we put both cases into graph.
+ # pylint: disable=protected-access
+ concrete_func._inference_function.add_to_graph(graph)
+
+ if concrete_func._backward_graph_function is None:
+ concrete_func._construct_backprop_function()
+ forward_function = concrete_func._forward_function
+ backward_function = concrete_func._backward_graph_function._inference_function
+ forward_function.add_to_graph(graph)
+ backward_function.add_to_graph(graph)
+ # pylint: enable=protected-access
+
return concrete_func
@@ -1340,7 +1364,7 @@ def _validate_signature(signature):
"a possibly nested sequence of TensorSpec objects.")
-def defun(func=None, input_signature=None):
+def defun(func=None, input_signature=None, experimental_autograph=False):
"""Compiles a Python function into a callable TensorFlow graph.
`defun` (short for "define function") trace-compiles a Python function
@@ -1649,6 +1673,10 @@ def defun(func=None, input_signature=None):
function is instantiated for each inferred input signature. If a
signature is specified, every input to `func` must be a `Tensor`, and
`func` cannot accept `**kwargs`.
+ experimental_autograph: Whether `func` should be compiled before
+ constructing the graph. See https://www.tensorflow.org/guide/autograph
+ for more information.
+
Returns:
If `func` is not None, returns a callable that will execute the compiled
@@ -1660,10 +1688,16 @@ def defun(func=None, input_signature=None):
TypeError: If `input_signature` is neither `None` nor a sequence of
`tf.contrib.eager.TensorSpec` objects.
"""
- return defun_with_attributes(func=func, input_signature=input_signature)
+ return defun_with_attributes(
+ func=func,
+ input_signature=input_signature,
+ experimental_autograph=experimental_autograph)
-def defun_with_attributes(func=None, input_signature=None, attributes=None):
+def defun_with_attributes(func=None,
+ input_signature=None,
+ attributes=None,
+ experimental_autograph=False):
"""Compiles a Python function into a callable TensorFlow graph.
This function supports adding extra function attributes. See detailed
@@ -1678,6 +1712,7 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None):
attributes. Currently only support primitive types as value, and only
whitelisted attribute name is allowed. Unwhitelisted attribute name or
unsupported value will result into ValueError.
+ experimental_autograph: same as defun()'s experimental_autograph.
Returns:
Same as the return value of defun, with attributes added to the function in
@@ -1694,8 +1729,12 @@ def defun_with_attributes(func=None, input_signature=None, attributes=None):
name = "function"
return tf_decorator.make_decorator(
function,
- PolymorphicFunction(function, name, input_signature=input_signature,
- attributes=attributes))
+ PolymorphicFunction(
+ function,
+ name,
+ input_signature=input_signature,
+ attributes=attributes,
+ experimental_autograph=experimental_autograph))
# This code path is for the `foo = tfe.defun(foo, ...)` use case
if func is not None:
@@ -1898,8 +1937,10 @@ class AutomaticControlDependencies(object):
last_op_using_resource_tensor[inp] = op
ops_which_must_run = set([op])
continue
+ found_resource = False
for inp in op.inputs:
if inp.dtype == dtypes_module.resource:
+ found_resource = True
# Deal with switches, finally.
if inp.op.type == "Switch":
self._process_switch(inp.op, ops_which_must_run,
@@ -1914,6 +1955,11 @@ class AutomaticControlDependencies(object):
if inp in merge_for_resource:
merge_for_resource[inp]._add_control_input(op) # pylint: disable=protected-access
last_op_using_resource_tensor[inp] = op
+ if (op.op_def.is_stateful and not found_resource
+ and op._control_flow_context is None): # pylint: disable=protected-access
+ if None in last_op_using_resource_tensor:
+ op._add_control_input(last_op_using_resource_tensor[None]) # pylint: disable=protected-access
+ last_op_using_resource_tensor[None] = op
control_inputs = [c for c in control_inputs
if c._control_flow_context is op._control_flow_context] # pylint: disable=protected-access
op._add_control_inputs(control_inputs) # pylint: disable=protected-access
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 59faf967c5..57e545be69 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -172,6 +172,43 @@ class FunctionTest(test.TestCase):
out = sq_op(t)
self.assertAllEqual(out, math_ops.matmul(t, t).numpy())
+ def testInputSpecGraphFunction(self):
+ matmul = function.defun(math_ops.matmul)
+
+ @function.defun
+ def sq(a):
+ return matmul(a, a)
+
+ sq_op = sq.get_concrete_function(
+ tensor_spec.TensorSpec((None, None), dtypes.float32))
+ self.assertEqual([None, None], sq_op.output_shapes.as_list())
+
+ t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ out1 = sq_op(t1)
+ self.assertAllEqual(out1, math_ops.matmul(t1, t1).numpy())
+
+ t2 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ out2 = sq_op(t2)
+ self.assertAllEqual(out2, math_ops.matmul(t2, t2).numpy())
+
+ def testNestedInputSpecGraphFunction(self):
+ matmul = function.defun(math_ops.matmul)
+
+ @function.defun
+ def sq(mats):
+ ((a, b),) = mats
+ return matmul(a, b)
+
+ sq_op = sq.get_concrete_function(
+ [(tensor_spec.TensorSpec((None, None), dtypes.float32),
+ tensor_spec.TensorSpec((None, None), dtypes.float32))])
+ self.assertEqual([None, None], sq_op.output_shapes.as_list())
+
+ t1 = constant_op.constant([[1.0, 2.0], [3.0, 4.0]])
+ t2 = constant_op.constant([[1.4, 2.4], [3.4, 4.4]])
+ out = sq_op(t1, t2) # Flattened structure for inputs to the graph function
+ self.assertAllEqual(out, math_ops.matmul(t1, t2).numpy())
+
def testExecutingStatelessDefunConcurrently(self):
@function.defun
@@ -1237,6 +1274,62 @@ class FunctionTest(test.TestCase):
x = constant_op.constant([1.0, 2.0])
self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
+ def testCacheObjectHashCollisions(self):
+
+ class Foo(object):
+
+ def __hash__(self):
+ return 42
+
+ def func(foo):
+ del foo
+ return
+
+ defined = function.defun(func)
+ defined(Foo())
+ self.assertEqual(len(defined._function_cache), 1)
+
+ defined(Foo())
+ self.assertEqual(len(defined._function_cache), 2)
+
+ def testCacheTensorShapeDtypeCollision(self):
+
+ def func(t):
+ return t + t
+
+ defined = function.defun(func)
+ t = constant_op.constant([[1.0]], dtype=dtypes.complex64)
+ defined(t)
+ self.assertEqual(len(defined._function_cache), 1)
+
+ t = constant_op.constant([1.0], dtype=dtypes.complex128)
+ defined(t)
+ self.assertEqual(len(defined._function_cache), 2)
+
+ def testCacheTensorUnknownShapesCollision(self):
+
+ def func(t):
+ return t + t
+
+ with context.graph_mode(), self.cached_session():
+ defined = function.defun(func)
+
+ p = array_ops.placeholder(dtype=dtypes.float32, shape=None)
+ defined(p)
+ self.assertEqual(len(defined._function_cache), 1)
+
+ p = array_ops.placeholder(dtype=dtypes.float32, shape=[None])
+ defined(p)
+ self.assertEqual(len(defined._function_cache), 2)
+
+ p = array_ops.placeholder(dtype=dtypes.float32, shape=[None, None])
+ defined(p)
+ self.assertEqual(len(defined._function_cache), 3)
+
+ t = constant_op.constant(1.0, dtype=dtypes.float32)
+ defined(t)
+ self.assertEqual(len(defined._function_cache), 4)
+
def testPythonFunctionWithDefaultArgs(self):
def func(foo, bar=1, baz=2):
@@ -1250,20 +1343,20 @@ class FunctionTest(test.TestCase):
def cache_keys():
"""Sanitizes cache keys of non-input metadata."""
- return tuple(key[:3] for key in defined._function_cache)
+ return tuple(key[0] for key in defined._function_cache)
# `True` corresponds to the fact that we're executing eagerly
- self.assertIn((0, 1, 20), cache_keys())
+ self.assertIn(('URRR', (0, 1, 20)), cache_keys())
defined(1) # bar=1, baz=2
- self.assertIn((1, 1, 2), cache_keys())
+ self.assertIn(('URRR', (1, 1, 2)), cache_keys())
# This matches the previous call.
defined(foo=1)
self.assertEqual(len(defined._function_cache), 2)
defined(1, 2, 3)
- self.assertIn((1, 2, 3), cache_keys())
+ self.assertIn(('URRR', (1, 2, 3)), cache_keys())
# This matches the previous call.
defined(1, bar=2, baz=3)
@@ -1669,12 +1762,38 @@ class FunctionTest(test.TestCase):
graph = ops.get_default_graph()
# pylint: disable=protected-access
- self.assertEqual(len(graph._functions), 2)
+ self.assertEqual(len(graph._functions), 6)
+ # two sets of functions, each of them are (inference, forward, backward)
functions = list(graph._functions.values())
- pre_register_matmul_func_name = functions[0].definition.signature.name
- self.assertRegexpMatches(pre_register_matmul_func_name, '.*matmul.*')
- pre_register_add_func_name = functions[1].definition.signature.name
- self.assertRegexpMatches(pre_register_add_func_name, '.*add.*')
+ captured_function_names = [
+ f.definition.signature.name for f in functions
+ ]
+ expected_func_name_regex = [
+ '.*inference.*matmul.*',
+ '.*forward.*matmul.*',
+ '.*inference.*backward.*matmul.*',
+ '.*inference.*add.*',
+ '.*forward.*add.*',
+ '.*inference.*backward.*add.*',
+ ]
+ for i in range(len(functions)):
+ self.assertRegexpMatches(captured_function_names[i],
+ expected_func_name_regex[i])
+
+ # Check the forward and backward function has the correct attributes.
+ self.assertEquals(
+ functions[1].definition.attr['backward_function_name'].s,
+ functions[2].name)
+ self.assertEquals(
+ functions[2].definition.attr['forward_function_name'].s,
+ functions[1].name)
+
+ self.assertEquals(
+ functions[4].definition.attr['backward_function_name'].s,
+ functions[5].name)
+ self.assertEquals(
+ functions[5].definition.attr['forward_function_name'].s,
+ functions[4].name)
sq = defun_matmul(t, t)
double = add(t, t)
@@ -1682,12 +1801,11 @@ class FunctionTest(test.TestCase):
self.assertAllEqual(double.eval().reshape(-1), [2, 4, 6, 8])
# Make sure the pre registered function is used, and no other function
# is added.
- self.assertEqual(len(graph._functions), 2)
+ self.assertEqual(len(graph._functions), 6)
functions = list(graph._functions.values())
- called_func_name = functions[0].definition.signature.name
- self.assertEqual(pre_register_matmul_func_name, called_func_name)
- called_func_name = functions[1].definition.signature.name
- self.assertEqual(pre_register_add_func_name, called_func_name)
+ for i in range(len(functions)):
+ self.assertEquals(captured_function_names[i],
+ functions[i].definition.signature.name)
def testRegisterFunctionWithInputSignature(self):
def matmul(x, y):
@@ -1705,7 +1823,7 @@ class FunctionTest(test.TestCase):
graph = ops.get_default_graph()
# pylint: disable=protected-access
- self.assertEqual(len(graph._functions), 1)
+ self.assertEqual(len(graph._functions), 3)
# Test input param shape mismatch
t2 = constant_op.constant([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
@@ -1728,7 +1846,7 @@ class FunctionTest(test.TestCase):
graph = ops.get_default_graph()
# Only one function is registered since the input param are in same type
# pylint: disable=protected-access
- self.assertEqual(len(graph._functions), 1)
+ self.assertEqual(len(graph._functions), 3)
def testCallingFunctionWithDifferentVariables(self):
@@ -1767,7 +1885,8 @@ class FunctionTest(test.TestCase):
'be Tensors;.*'):
graph_function('Not a Tensor.')
- def testSwapImplementationWithGrapplerPlugin(self):
+ # TODO(scottzhu): Revive the test once the grappler plugin is updated.
+ def disabled_testSwapImplementationWithGrapplerPlugin(self):
rewrites = rewriter_config_pb2.RewriterConfig()
# function_optimizer has to be turn off, otherwise it will delete the
# registered function if it does not get called.
diff --git a/tensorflow/python/eager/imperative_grad.py b/tensorflow/python/eager/imperative_grad.py
index 5f5af4ab6c..5c35860e9d 100644
--- a/tensorflow/python/eager/imperative_grad.py
+++ b/tensorflow/python/eager/imperative_grad.py
@@ -51,11 +51,6 @@ def imperative_grad(
Raises:
RuntimeError: if something goes wrong.
- ValueError: if there is no sequence of differentiable operations connecting
- a source and any target Tensor. This can happen either if the target is
- not computed based on the source, if the tracing was set up incorrectly,
- or if only non-differentiable functions of the source were used in the
- computation of target.
"""
return pywrap_tensorflow.TFE_Py_TapeGradient(
tape._tape, # pylint: disable=protected-access
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index f1b4042ec9..decd635b58 100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -224,4 +224,8 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim);
// The shape is represented as a Python tuple of integers.
PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor);
+// Encodes the object as a tuple that is meant to be used as part of the key
+// for the defun function cache.
+PyObject* TFE_Py_EncodeArg(PyObject*);
+
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 159b1c1218..9789dbadee 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/python/eager/pywrap_tfe.h"
+#include "absl/strings/str_cat.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
@@ -567,11 +568,8 @@ bool SetOpAttrScalar(
return false;
}
}
- TFE_Op* func = TFE_NewOp(
- ctx, string(func_name.data(), func_name.size()).c_str(), status);
- if (TF_GetCode(status) != TF_OK) return false;
- TFE_OpSetAttrFunction(op, key, func);
- TFE_DeleteOp(func);
+ TF_SetStatus(status, TF_OK, "");
+ TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
} else {
TF_SetStatus(
status, TF_UNIMPLEMENTED,
@@ -1230,8 +1228,9 @@ static PyTypeObject TFE_Py_Tape_Type = {
// GIL, which is always held when any TFE_Py_* methods are called. We should
// revisit this if/when decide to not hold the GIL while manipulating the tape
// stack.
-static tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set = nullptr;
tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* GetTapeSet() {
+ thread_local tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>* tape_set{
+ nullptr};
if (tape_set == nullptr) {
tape_set = new tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*>;
}
@@ -1266,27 +1265,10 @@ class SafeTapeSet {
tensorflow::gtl::CompactPointerSet<TFE_Py_Tape*> tape_set_;
};
-// xcode 7 doesn't define thread_local, so for compatibility we implement our
-// own. TODO(apassos) remove once we can deprecate xcode 7.
-#ifndef __APPLE__
bool* ThreadTapeIsStopped() {
thread_local bool thread_tape_is_stopped{false};
return &thread_tape_is_stopped;
}
-#else
-static std::unordered_map<std::thread::id, bool>* tape_is_stopped = nullptr;
-bool* ThreadTapeIsStopped() {
- if (tape_is_stopped == nullptr) {
- tape_is_stopped = new std::unordered_map<std::thread::id, bool>;
- }
- auto it = tape_is_stopped->find(std::this_thread::get_id());
- if (it != tape_is_stopped->end()) {
- return &(it->second);
- }
- return &(tape_is_stopped->emplace(std::this_thread::get_id(), false)
- .first->second);
-}
-#endif
void TFE_Py_TapeSetStopOnThread() { *ThreadTapeIsStopped() = true; }
@@ -1569,9 +1551,8 @@ void TapeSetRecordOperation(
}
for (TFE_Py_Tape* tape : SafeTapeSet()) {
- auto* function = backward_function_getter();
tape->tape->RecordOperation(op_type_str, output_info, input_ids,
- input_dtypes, function,
+ input_dtypes, backward_function_getter,
backward_function_killer);
}
}
@@ -2750,3 +2731,218 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
return RecordGradient(op_name, inputs, attrs, results, name);
}
+
+namespace {
+const char kTensor[] = "T";
+const char kIndexedSlices[] = "I";
+const char kList[] = "L";
+const char kTuple[] = "U";
+const char kDict[] = "D";
+const char kRaw[] = "R";
+const char kShape[] = "s";
+const char kDType[] = "d";
+const char kNone[] = "n";
+
+struct EncodeResult {
+ string str;
+ std::vector<PyObject*> objects;
+
+ PyObject* ToPyTuple() {
+ PyObject* result = PyTuple_New(2);
+
+ PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str.c_str()));
+
+ if (objects.empty()) {
+ Py_INCREF(Py_None);
+ PyTuple_SET_ITEM(result, 1, Py_None);
+ } else {
+ PyObject* objects_tuple = PyTuple_New(objects.size());
+
+ for (int i = 0; i < objects.size(); i++) {
+ PyTuple_SET_ITEM(objects_tuple, i, objects[i]);
+ }
+
+ PyTuple_SET_ITEM(result, 1, objects_tuple);
+ }
+
+ return result;
+ }
+};
+
+tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
+ if (EagerTensor_CheckExact(arg)) {
+ TFE_TensorHandle* t = EagerTensor_Handle(arg);
+ tensorflow::TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(t->handle->Shape(&tensor_shape));
+
+ absl::StrAppend(&result->str, kDType, t->handle->dtype);
+
+ absl::StrAppend(&result->str, kShape);
+ for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
+ absl::StrAppend(&result->str, dim_size);
+ }
+
+ return tensorflow::Status::OK();
+ }
+
+ tensorflow::Safe_PyObjectPtr dtype_object(
+ PyObject_GetAttrString(arg, "dtype"));
+
+ if (dtype_object == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "ops.Tensor object doesn't have dtype() attr.");
+ }
+
+ tensorflow::Safe_PyObjectPtr dtype_enum(
+ PyObject_GetAttrString(dtype_object.get(), "_type_enum"));
+
+ if (dtype_enum == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "ops.Tensor's dtype object doesn't have _type_enum() attr.");
+ }
+
+ tensorflow::DataType dtype =
+ static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
+
+ absl::StrAppend(&result->str, kDType, dtype);
+ static char _shape_tuple[] = "_shape_tuple";
+ tensorflow::Safe_PyObjectPtr shape_tuple(
+ PyObject_CallMethod(arg, _shape_tuple, nullptr));
+
+ if (shape_tuple == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "ops.Tensor object doesn't have _shape_tuple() method.");
+ }
+
+ if (shape_tuple.get() == Py_None) {
+ // Unknown shape, encode that directly.
+ absl::StrAppend(&result->str, kNone);
+ return tensorflow::Status::OK();
+ }
+
+ absl::StrAppend(&result->str, kShape);
+ tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
+ shape_tuple.get(), "shape_tuple didn't return a sequence"));
+
+ int len = PySequence_Fast_GET_SIZE(shape_seq.get());
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i);
+ if (item == Py_None) {
+ absl::StrAppend(&result->str, kNone);
+ } else {
+ absl::StrAppend(&result->str, MakeInt(item));
+ }
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result);
+
+// This function doesn't set the type of sequence before
+tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
+ EncodeResult* result) {
+ tensorflow::Safe_PyObjectPtr arg_seq(
+ PySequence_Fast(arg, "unable to create seq from list/tuple"));
+
+ absl::StrAppend(&result->str, type);
+ int len = PySequence_Fast_GET_SIZE(arg_seq.get());
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(arg_seq.get(), i);
+ if (item == Py_None) {
+ absl::StrAppend(&result->str, kNone);
+ } else {
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(item, result));
+ }
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
+ if (tensorflow::swig::IsTensor(arg)) {
+ absl::StrAppend(&result->str, kTensor);
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(arg, result));
+ } else if (tensorflow::swig::IsIndexedSlices(arg)) {
+ absl::StrAppend(&result->str, kIndexedSlices);
+ tensorflow::Safe_PyObjectPtr values(PyObject_GetAttrString(arg, "values"));
+ if (values == nullptr) {
+ PyErr_Clear();
+ return tensorflow::errors::InvalidArgument(
+ "IndexedSlices does not have a values attr");
+ }
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(values.get(), result));
+
+ tensorflow::Safe_PyObjectPtr indices(
+ PyObject_GetAttrString(arg, "indices"));
+ if (indices == nullptr) {
+ PyErr_Clear();
+ return tensorflow::errors::InvalidArgument(
+ "IndexedSlices does not have a indices attr");
+ }
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(indices.get(), result));
+
+ tensorflow::Safe_PyObjectPtr dense_shape(
+ PyObject_GetAttrString(arg, "dense_shape"));
+ if (dense_shape == nullptr) {
+ PyErr_Clear();
+ return tensorflow::errors::InvalidArgument(
+ "IndexedSlices does not have a dense_shape attr");
+ }
+ if (dense_shape.get() != Py_None) {
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(dense_shape.get(), result));
+ }
+ } else if (PyList_Check(arg)) {
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kList, result));
+ } else if (PyTuple_Check(arg)) {
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kTuple, result));
+ } else if (PyDict_Check(arg)) {
+ tensorflow::Safe_PyObjectPtr keys(PyDict_Keys(arg));
+ if (PyList_Sort(keys.get()) == -1) {
+ return tensorflow::errors::Internal("Unable to sort keys");
+ }
+
+ absl::StrAppend(&result->str, kDict);
+ int len = PyList_Size(keys.get());
+
+ for (int i = 0; i < len; i++) {
+ PyObject* key = PyList_GetItem(keys.get(), i);
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(key, result));
+ PyObject* value = PyDict_GetItem(arg, key);
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(value, result));
+ }
+ } else {
+ PyObject* object = PyWeakref_NewRef(arg, nullptr);
+
+ if (object == nullptr) {
+ PyErr_Clear();
+
+ object = arg;
+ Py_INCREF(object);
+ }
+
+ absl::StrAppend(&result->str, kRaw);
+ result->objects.push_back(object);
+ }
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace
+
+// `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
+// are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
+// are used for both performance reasons, as much TensorFlow code specializes
+// on known shapes to produce slimmer graphs, and correctness, as some
+// high-level APIs require shapes to be fully-known.
+//
+// TODO(nareshmodi): Add support for sparse tensors.
+PyObject* TFE_Py_EncodeArg(PyObject* arg) {
+ EncodeResult result;
+ const auto status = TFE_Py_EncodeArgHelper(arg, &result);
+ if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+ return nullptr;
+ }
+
+ return result.ToPyTuple();
+}
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 7f2349954d..1c4c5951df 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -281,6 +281,7 @@ py_library(
":prediction_keys",
"//tensorflow:tensorflow_py_no_contrib",
"//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
@@ -303,6 +304,7 @@ py_test(
":pandas_io",
":prediction_keys",
"//tensorflow:tensorflow_py_no_contrib",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
@@ -342,6 +344,7 @@ py_test(
":pandas_io",
":prediction_keys",
"//tensorflow:tensorflow_py_no_contrib",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index 1c0c4581c0..a6c2aaa7d9 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -24,7 +24,10 @@ from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import optimizers
-from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.engine import training
from tensorflow.python.layers import core as core_layers
from tensorflow.python.layers import normalization
from tensorflow.python.ops import init_ops
@@ -45,8 +48,14 @@ def _add_hidden_layer_summary(value, tag):
summary.histogram('%s/activation' % tag, value)
-def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,
- dropout, input_layer_partitioner, batch_norm):
+def _dnn_logit_fn_builder(units,
+ hidden_units,
+ feature_columns,
+ activation_fn,
+ dropout,
+ input_layer_partitioner,
+ batch_norm,
+ shared_state_manager=None):
"""Function builder for a dnn logit_fn.
Args:
@@ -60,6 +69,8 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,
coordinate.
input_layer_partitioner: Partitioner for input layer.
batch_norm: Whether to use batch normalization after each hidden layer.
+ shared_state_manager: A SharedEmbeddingStateManager object to hold the
+ shared state for SharedEmbeddingColumn's.
Returns:
A logit_fn (see below).
@@ -85,50 +96,129 @@ def _dnn_logit_fn_builder(units, hidden_units, feature_columns, activation_fn,
A `Tensor` representing the logits, or a list of `Tensor`'s representing
multiple logits in the MultiHead case.
"""
- is_training = mode == model_fn.ModeKeys.TRAIN
- with variable_scope.variable_scope(
- 'input_from_feature_columns',
- values=tuple(six.itervalues(features)),
- partitioner=input_layer_partitioner):
- net = feature_column_lib.input_layer(
- features=features, feature_columns=feature_columns)
+ dnn_model = _DNNModel(
+ units,
+ hidden_units,
+ feature_columns,
+ activation_fn,
+ dropout,
+ input_layer_partitioner,
+ batch_norm,
+ shared_state_manager,
+ name='dnn')
+ return dnn_model(features, mode)
+
+ return dnn_logit_fn
+
+
+def _get_previous_name_scope():
+ current_name_scope = ops.get_name_scope()
+ return current_name_scope.rsplit('/', 1)[0] + '/'
+
+
+class _DNNModel(training.Model):
+ """A DNN Model."""
+
+ def __init__(self,
+ units,
+ hidden_units,
+ feature_columns,
+ activation_fn,
+ dropout,
+ input_layer_partitioner,
+ batch_norm,
+ shared_state_manager,
+ name=None,
+ **kwargs):
+ super(_DNNModel, self).__init__(name=name, **kwargs)
+ if feature_column_v2.is_feature_column_v2(feature_columns):
+ self._input_layer = feature_column_v2.FeatureLayer(
+ feature_columns=feature_columns,
+ name='input_layer',
+ shared_state_manager=shared_state_manager)
+ else:
+ self._input_layer = feature_column.InputLayer(
+ feature_columns=feature_columns,
+ name='input_layer',
+ create_scope_now=False)
+
+ self._add_layer(self._input_layer, 'input_layer')
+
+ self._dropout = dropout
+ self._batch_norm = batch_norm
+
+ self._hidden_layers = []
+ self._dropout_layers = []
+ self._batch_norm_layers = []
+ self._hidden_layer_scope_names = []
for layer_id, num_hidden_units in enumerate(hidden_units):
with variable_scope.variable_scope(
- 'hiddenlayer_%d' % layer_id, values=(net,)) as hidden_layer_scope:
- net = core_layers.dense(
- net,
+ 'hiddenlayer_%d' % layer_id) as hidden_layer_scope:
+ hidden_layer = core_layers.Dense(
units=num_hidden_units,
activation=activation_fn,
kernel_initializer=init_ops.glorot_uniform_initializer(),
- name=hidden_layer_scope)
- if dropout is not None and is_training:
- net = core_layers.dropout(net, rate=dropout, training=True)
- if batch_norm:
- # TODO(hjm): In future, if this becomes popular, we can enable
- # customization of the batch normalization params by accepting a
- # list of `BatchNormalization` instances as `batch_norm`.
- net = normalization.batch_normalization(
- net,
+ name=hidden_layer_scope,
+ _scope=hidden_layer_scope)
+ self._add_layer(hidden_layer, hidden_layer_scope.name)
+ self._hidden_layer_scope_names.append(hidden_layer_scope.name)
+ self._hidden_layers.append(hidden_layer)
+ if self._dropout is not None:
+ dropout_layer = core_layers.Dropout(rate=self._dropout)
+ self._add_layer(dropout_layer, dropout_layer.name)
+ self._dropout_layers.append(dropout_layer)
+ if self._batch_norm:
+ batch_norm_layer = normalization.BatchNormalization(
# The default momentum 0.99 actually crashes on certain
# problem, so here we use 0.999, which is the default of
# tf.contrib.layers.batch_norm.
momentum=0.999,
- training=is_training,
- name='batchnorm_%d' % layer_id)
- _add_hidden_layer_summary(net, hidden_layer_scope.name)
-
- with variable_scope.variable_scope('logits', values=(net,)) as logits_scope:
- logits = core_layers.dense(
- net,
+ trainable=True,
+ name='batchnorm_%d' % layer_id,
+ _scope='batchnorm_%d' % layer_id)
+ self._add_layer(batch_norm_layer, batch_norm_layer.name)
+ self._batch_norm_layers.append(batch_norm_layer)
+
+ with variable_scope.variable_scope('logits') as logits_scope:
+ self._logits_layer = core_layers.Dense(
units=units,
activation=None,
kernel_initializer=init_ops.glorot_uniform_initializer(),
- name=logits_scope)
- _add_hidden_layer_summary(logits, logits_scope.name)
-
- return logits
+ name=logits_scope,
+ _scope=logits_scope)
+ self._add_layer(self._logits_layer, logits_scope.name)
+ self._logits_scope_name = logits_scope.name
+ self._input_layer_partitioner = input_layer_partitioner
- return dnn_logit_fn
+ def call(self, features, mode):
+ is_training = mode == model_fn.ModeKeys.TRAIN
+ # The Keras training.Model adds a name_scope with the name of the model
+ # which modifies the constructed graph. Hence we add another name_scope
+ # here which is the one before the training.Model one was applied.
+ # TODO(rohanj): Remove this in TF 2.0 (b/116728605)
+ with ops.name_scope(name=_get_previous_name_scope()):
+ # TODO(rohanj): Remove dependence on variable scope for partitioning.
+ with variable_scope.variable_scope(
+ 'input_from_feature_columns',
+ partitioner=self._input_layer_partitioner):
+ net = self._input_layer(features)
+ for i in range(len(self._hidden_layers)):
+ net = self._hidden_layers[i](net)
+ if self._dropout is not None and is_training:
+ net = self._dropout_layers[i](net, training=True)
+ if self._batch_norm:
+ net = self._batch_norm_layers[i](net, training=is_training)
+ _add_hidden_layer_summary(net, self._hidden_layer_scope_names[i])
+
+ logits = self._logits_layer(net)
+ _add_hidden_layer_summary(logits, self._logits_scope_name)
+ return logits
+
+ def _add_layer(self, layer, layer_name):
+ # "Magic" required for keras.Model classes to track all the variables in
+ # a list of layers.Layer objects.
+ # TODO(ashankar): Figure out API so user code doesn't have to do this.
+ setattr(self, layer_name, layer)
def _dnn_model_fn(features,
@@ -143,7 +233,8 @@ def _dnn_model_fn(features,
input_layer_partitioner=None,
config=None,
use_tpu=False,
- batch_norm=False):
+ batch_norm=False,
+ shared_state_manager=None):
"""Deep Neural Net model_fn.
Args:
@@ -167,6 +258,8 @@ def _dnn_model_fn(features,
use_tpu: Whether to make a DNN model able to run on TPU. Will make function
return a `_TPUEstimatorSpec` instance and disable variable partitioning.
batch_norm: Whether to use batch normalization after each hidden layer.
+ shared_state_manager: A SharedEmbeddingStateManager object to hold the
+ shared state for SharedEmbeddingColumn's.
Returns:
An `EstimatorSpec` instance.
@@ -202,7 +295,8 @@ def _dnn_model_fn(features,
activation_fn=activation_fn,
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
- batch_norm=batch_norm)
+ batch_norm=batch_norm,
+ shared_state_manager=shared_state_manager)
logits = logit_fn(features=features, mode=mode)
if use_tpu:
@@ -370,6 +464,10 @@ class DNNClassifier(estimator.Estimator):
"""
head = head_lib._binary_logistic_or_multi_class_head( # pylint: disable=protected-access
n_classes, weight_column, label_vocabulary, loss_reduction)
+
+ shared_state_manager = feature_column_v2.maybe_create_shared_state_manager(
+ feature_columns)
+
def _model_fn(features, labels, mode, config):
"""Call the defined shared _dnn_model_fn."""
return _dnn_model_fn(
@@ -384,7 +482,8 @@ class DNNClassifier(estimator.Estimator):
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
config=config,
- batch_norm=batch_norm)
+ batch_norm=batch_norm,
+ shared_state_manager=shared_state_manager)
super(DNNClassifier, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config,
@@ -532,6 +631,10 @@ class DNNRegressor(estimator.Estimator):
batch_norm: Whether to use batch normalization after each hidden layer.
"""
+ shared_state_manager = None
+ if feature_column_v2.is_feature_column_v2(feature_columns):
+ shared_state_manager = feature_column_v2.SharedEmbeddingStateManager()
+
def _model_fn(features, labels, mode, config):
"""Call the defined shared _dnn_model_fn."""
return _dnn_model_fn(
@@ -539,7 +642,8 @@ class DNNRegressor(estimator.Estimator):
labels=labels,
mode=mode,
head=head_lib._regression_head( # pylint: disable=protected-access
- label_dimension=label_dimension, weight_column=weight_column,
+ label_dimension=label_dimension,
+ weight_column=weight_column,
loss_reduction=loss_reduction),
hidden_units=hidden_units,
feature_columns=tuple(feature_columns or []),
@@ -548,7 +652,8 @@ class DNNRegressor(estimator.Estimator):
dropout=dropout,
input_layer_partitioner=input_layer_partitioner,
config=config,
- batch_norm=batch_norm)
+ batch_norm=batch_norm,
+ shared_state_manager=shared_state_manager)
super(DNNRegressor, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config,
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined.py b/tensorflow/python/estimator/canned/dnn_linear_combined.py
index 9799cf9e98..f712244c8d 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined.py
@@ -27,6 +27,7 @@ from tensorflow.python.estimator.canned import dnn
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import linear
from tensorflow.python.estimator.canned import optimizers
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import nn
@@ -142,6 +143,9 @@ def _dnn_linear_combined_model_fn(features,
max_partitions=num_ps_replicas,
min_slice_size=64 << 20))
+ shared_state_manager = feature_column_v2.maybe_create_shared_state_manager(
+ list(linear_feature_columns) + list(dnn_feature_columns))
+
# Build DNN Logits.
dnn_parent_scope = 'dnn'
@@ -169,8 +173,9 @@ def _dnn_linear_combined_model_fn(features,
feature_columns=dnn_feature_columns,
activation_fn=dnn_activation_fn,
dropout=dnn_dropout,
+ batch_norm=batch_norm,
input_layer_partitioner=input_layer_partitioner,
- batch_norm=batch_norm)
+ shared_state_manager=shared_state_manager)
dnn_logits = dnn_logit_fn(features=features, mode=mode)
linear_parent_scope = 'linear'
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
index d16318659b..ae968e717a 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import shutil
import tempfile
+from absl.testing import parameterized
import numpy as np
import six
@@ -35,6 +36,7 @@ from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.estimator.inputs import pandas_io
from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import nn
@@ -119,7 +121,16 @@ class LinearOnlyRegressorPartitionerTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorPartitionerV2Test(
+ linear_testing_utils.BaseLinearRegressorPartitionerTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorEvaluationTest(
@@ -128,7 +139,16 @@ class LinearOnlyRegressorEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorEvaluationV2Test(
+ linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorPredictTest(
@@ -137,7 +157,16 @@ class LinearOnlyRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorPredictV2Test(
+ linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorIntegrationTest(
@@ -146,7 +175,16 @@ class LinearOnlyRegressorIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorIntegrationV2Test(
+ linear_testing_utils.BaseLinearRegressorIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorTrainingTest(
@@ -155,7 +193,16 @@ class LinearOnlyRegressorTrainingTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorTrainingV2Test(
+ linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
def _linear_classifier_fn(feature_columns,
@@ -185,7 +232,18 @@ class LinearOnlyClassifierTrainingTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierTrainingV2Test(
+ linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearOnlyClassifierClassesEvaluationTest(
@@ -194,7 +252,18 @@ class LinearOnlyClassifierClassesEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierClassesEvaluationV2Test(
+ linear_testing_utils.BaseLinearClassifierEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearOnlyClassifierPredictTest(
@@ -203,7 +272,18 @@ class LinearOnlyClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierPredictV2Test(
+ linear_testing_utils.BaseLinearClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearOnlyClassifierIntegrationTest(
@@ -212,9 +292,21 @@ class LinearOnlyClassifierIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierIntegrationV2Test(
+ linear_testing_utils.BaseLinearClassifierIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
def setUp(self):
@@ -225,13 +317,15 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._model_dir)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- label_dimension, batch_size):
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, label_dimension, batch_size,
+ fc_impl):
linear_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
dnn_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
feature_columns = linear_feature_columns + dnn_feature_columns
est = dnn_linear_combined.DNNLinearCombinedRegressor(
linear_feature_columns=linear_feature_columns,
@@ -257,14 +351,14 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, label_dimension), predictions.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
label_dimension = 2
batch_size = 10
@@ -293,9 +387,10 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -326,9 +421,10 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
label_dimension = 2
batch_size = 10
@@ -376,7 +472,8 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
# A function to mimic dnn-classifier init reuse same tests.
@@ -407,7 +504,16 @@ class DNNOnlyClassifierEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNOnlyClassifierEvaluateV2Test(
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNOnlyClassifierPredictTest(
@@ -416,7 +522,16 @@ class DNNOnlyClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNOnlyClassifierPredictV2Test(
+ dnn_testing_utils.BaseDNNClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNOnlyClassifierTrainTest(
@@ -425,7 +540,16 @@ class DNNOnlyClassifierTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNOnlyClassifierTrainV2Test(dnn_testing_utils.BaseDNNClassifierTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
# A function to mimic dnn-regressor init reuse same tests.
@@ -454,7 +578,16 @@ class DNNOnlyRegressorEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNOnlyRegressorEvaluateV2Test(
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNOnlyRegressorPredictTest(
@@ -463,7 +596,16 @@ class DNNOnlyRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNOnlyRegressorPredictV2Test(
+ dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNOnlyRegressorTrainTest(
@@ -472,9 +614,19 @@ class DNNOnlyRegressorTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+class DNNOnlyRegressorTrainV2Test(dnn_testing_utils.BaseDNNRegressorTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
+
+
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def setUp(self):
@@ -488,13 +640,14 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def _as_label(self, data_in_float):
return np.rint(data_in_float).astype(np.int64)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- n_classes, batch_size):
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, n_classes, batch_size, fc_impl):
linear_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
dnn_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
feature_columns = linear_feature_columns + dnn_feature_columns
est = dnn_linear_combined.DNNLinearCombinedClassifier(
linear_feature_columns=linear_feature_columns,
@@ -520,14 +673,14 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
n_classes = 3
input_dimension = 2
@@ -559,9 +712,10 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -593,9 +747,10 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2
n_classes = 3
@@ -647,9 +802,11 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedTests(test.TestCase):
def setUp(self):
@@ -681,9 +838,9 @@ class DNNLinearCombinedTests(test.TestCase):
return optimizer_mock
- def test_train_op_calls_both_dnn_and_linear(self):
+ def test_train_op_calls_both_dnn_and_linear(self, fc_impl):
opt = gradient_descent.GradientDescentOptimizer(1.)
- x_column = feature_column.numeric_column('x')
+ x_column = fc_impl.numeric_column('x')
input_fn = numpy_io.numpy_input_fn(
x={'x': np.array([[0.], [1.]])},
y=np.array([[0.], [1.]]),
@@ -708,7 +865,7 @@ class DNNLinearCombinedTests(test.TestCase):
checkpoint_utils.load_variable(
self._model_dir, 'dnn_called'))
- def test_dnn_and_linear_logits_are_added(self):
+ def test_dnn_and_linear_logits_are_added(self, fc_impl):
with ops.Graph().as_default():
variables_lib.Variable([[1.0]], name='linear/linear_model/x/weights')
variables_lib.Variable([2.0], name='linear/linear_model/bias_weights')
@@ -719,7 +876,7 @@ class DNNLinearCombinedTests(test.TestCase):
variables_lib.Variable(1, name='global_step', dtype=dtypes.int64)
linear_testing_utils.save_variables_to_ckpt(self._model_dir)
- x_column = feature_column.numeric_column('x')
+ x_column = fc_impl.numeric_column('x')
est = dnn_linear_combined.DNNLinearCombinedRegressor(
linear_feature_columns=[x_column],
dnn_hidden_units=[1],
@@ -737,6 +894,7 @@ class DNNLinearCombinedTests(test.TestCase):
next(est.predict(input_fn=input_fn)))
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedWarmStartingTest(test.TestCase):
def setUp(self):
@@ -758,11 +916,11 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._ckpt_and_vocab_dir)
- def test_classifier_basic_warm_starting(self):
+ def test_classifier_basic_warm_starting(self, fc_impl):
"""Tests correctness of DNNLinearCombinedClassifier default warm-start."""
- age = feature_column.numeric_column('age')
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ age = fc_impl.numeric_column('age')
+ city = fc_impl.embedding_column(
+ fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -798,11 +956,11 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase):
dnn_lc_classifier.get_variable_value(variable_name),
warm_started_dnn_lc_classifier.get_variable_value(variable_name))
- def test_regressor_basic_warm_starting(self):
+ def test_regressor_basic_warm_starting(self, fc_impl):
"""Tests correctness of DNNLinearCombinedRegressor default warm-start."""
- age = feature_column.numeric_column('age')
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ age = fc_impl.numeric_column('age')
+ city = fc_impl.embedding_column(
+ fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -836,11 +994,11 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase):
dnn_lc_regressor.get_variable_value(variable_name),
warm_started_dnn_lc_regressor.get_variable_value(variable_name))
- def test_warm_starting_selective_variables(self):
+ def test_warm_starting_selective_variables(self, fc_impl):
"""Tests selecting variables to warm-start."""
- age = feature_column.numeric_column('age')
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ age = fc_impl.numeric_column('age')
+ city = fc_impl.embedding_column(
+ fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
diff --git a/tensorflow/python/estimator/canned/dnn_test.py b/tensorflow/python/estimator/canned/dnn_test.py
index fc90b7c35e..756696cea0 100644
--- a/tensorflow/python/estimator/canned/dnn_test.py
+++ b/tensorflow/python/estimator/canned/dnn_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import shutil
import tempfile
+from absl.testing import parameterized
import numpy as np
import six
@@ -33,6 +34,7 @@ from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.estimator.inputs import pandas_io
from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import data_flow_ops
@@ -62,15 +64,32 @@ class DNNModelFnTest(dnn_testing_utils.BaseDNNModelFnTest, test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- dnn_testing_utils.BaseDNNModelFnTest.__init__(self, dnn._dnn_model_fn)
+ dnn_testing_utils.BaseDNNModelFnTest.__init__(
+ self, dnn._dnn_model_fn, fc_impl=feature_column)
+
+
+class DNNModelFnV2Test(dnn_testing_utils.BaseDNNModelFnTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNModelFnTest.__init__(
+ self, dnn._dnn_model_fn, fc_impl=feature_column_v2)
class DNNLogitFnTest(dnn_testing_utils.BaseDNNLogitFnTest, test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- dnn_testing_utils.BaseDNNLogitFnTest.__init__(self,
- dnn._dnn_logit_fn_builder)
+ dnn_testing_utils.BaseDNNLogitFnTest.__init__(
+ self, dnn._dnn_logit_fn_builder, fc_impl=feature_column)
+
+
+class DNNLogitFnV2Test(dnn_testing_utils.BaseDNNLogitFnTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNLogitFnTest.__init__(
+ self, dnn._dnn_logit_fn_builder, fc_impl=feature_column_v2)
class DNNWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,
@@ -78,8 +97,17 @@ class DNNWarmStartingTest(dnn_testing_utils.BaseDNNWarmStartingTest,
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- dnn_testing_utils.BaseDNNWarmStartingTest.__init__(self, _dnn_classifier_fn,
- _dnn_regressor_fn)
+ dnn_testing_utils.BaseDNNWarmStartingTest.__init__(
+ self, _dnn_classifier_fn, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNWarmStartingV2Test(dnn_testing_utils.BaseDNNWarmStartingTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNWarmStartingTest.__init__(
+ self, _dnn_classifier_fn, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNClassifierEvaluateTest(
@@ -88,7 +116,16 @@ class DNNClassifierEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNClassifierEvaluateV2Test(
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNClassifierPredictTest(
@@ -97,7 +134,16 @@ class DNNClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNClassifierPredictV2Test(dnn_testing_utils.BaseDNNClassifierPredictTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNClassifierTrainTest(
@@ -106,7 +152,16 @@ class DNNClassifierTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNClassifierTrainV2Test(dnn_testing_utils.BaseDNNClassifierTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
def _dnn_regressor_fn(*args, **kwargs):
@@ -119,7 +174,16 @@ class DNNRegressorEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNRegressorEvaluateV2Test(dnn_testing_utils.BaseDNNRegressorEvaluateTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNRegressorPredictTest(
@@ -128,7 +192,16 @@ class DNNRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNRegressorPredictV2Test(dnn_testing_utils.BaseDNNRegressorPredictTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNRegressorTrainTest(
@@ -137,7 +210,16 @@ class DNNRegressorTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNRegressorTrainV2Test(dnn_testing_utils.BaseDNNRegressorTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
def _queue_parsed_features(feature_map):
@@ -156,7 +238,8 @@ def _queue_parsed_features(feature_map):
return {keys[i]: dequeued_tensors[i] for i in range(len(dequeued_tensors))}
-class DNNRegressorIntegrationTest(test.TestCase):
+@parameterized.parameters((feature_column,), (feature_column_v2,))
+class DNNRegressorIntegrationTest(test.TestCase, parameterized.TestCase):
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -166,11 +249,11 @@ class DNNRegressorIntegrationTest(test.TestCase):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._model_dir)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- label_dimension, batch_size):
- feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, label_dimension, batch_size,
+ fc_impl):
+ feature_columns = [fc_impl.numeric_column('x', shape=(input_dimension,))]
+
est = dnn.DNNRegressor(
hidden_units=(2, 2),
feature_columns=feature_columns,
@@ -194,14 +277,14 @@ class DNNRegressorIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, label_dimension), predictions.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
label_dimension = 2
batch_size = 10
@@ -230,9 +313,10 @@ class DNNRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -263,9 +347,10 @@ class DNNRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
label_dimension = 2
batch_size = 10
@@ -313,9 +398,11 @@ class DNNRegressorIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNClassifierIntegrationTest(test.TestCase):
def setUp(self):
@@ -329,11 +416,10 @@ class DNNClassifierIntegrationTest(test.TestCase):
def _as_label(self, data_in_float):
return np.rint(data_in_float).astype(np.int64)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- n_classes, batch_size):
- feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, n_classes, batch_size, fc_impl):
+ feature_columns = [fc_impl.numeric_column('x', shape=(input_dimension,))]
+
est = dnn.DNNClassifier(
hidden_units=(2, 2),
feature_columns=feature_columns,
@@ -357,14 +443,14 @@ class DNNClassifierIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
n_classes = 3
input_dimension = 2
@@ -396,9 +482,10 @@ class DNNClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -430,9 +517,10 @@ class DNNClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2
n_classes = 3
@@ -484,7 +572,8 @@ class DNNClassifierIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
if __name__ == '__main__':
diff --git a/tensorflow/python/estimator/canned/dnn_testing_utils.py b/tensorflow/python/estimator/canned/dnn_testing_utils.py
index 11f1e93630..cd66d0a3bd 100644
--- a/tensorflow/python/estimator/canned/dnn_testing_utils.py
+++ b/tensorflow/python/estimator/canned/dnn_testing_utils.py
@@ -104,6 +104,7 @@ def create_checkpoint(weights_and_biases,
weights_and_biases: Iterable of tuples of weight and bias values.
global_step: Initial global step to save in checkpoint.
model_dir: Directory into which checkpoint is saved.
+ batch_norm_vars: Variables used for batch normalization.
"""
weights, biases = zip(*weights_and_biases)
if batch_norm_vars:
@@ -244,8 +245,9 @@ def mock_optimizer(testcase, hidden_units, expected_loss=None):
class BaseDNNModelFnTest(object):
"""Tests that _dnn_model_fn passes expected logits to mock head."""
- def __init__(self, dnn_model_fn):
+ def __init__(self, dnn_model_fn, fc_impl=feature_column):
self._dnn_model_fn = dnn_model_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -272,7 +274,7 @@ class BaseDNNModelFnTest(object):
head=head,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column(
+ self._fc_impl.numeric_column(
'age', shape=np.array(inputs).shape[1:])
],
optimizer=mock_optimizer(self, hidden_units))
@@ -462,8 +464,8 @@ class BaseDNNModelFnTest(object):
head=head,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column('age'),
- feature_column.numeric_column('height')
+ self._fc_impl.numeric_column('age'),
+ self._fc_impl.numeric_column('height')
],
optimizer=mock_optimizer(self, hidden_units))
with monitored_session.MonitoredTrainingSession(
@@ -499,7 +501,7 @@ class BaseDNNModelFnTest(object):
head=head,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column(
+ self._fc_impl.numeric_column(
'age', shape=np.array(inputs).shape[1:])
],
optimizer=mock_optimizer(self, hidden_units))
@@ -508,8 +510,9 @@ class BaseDNNModelFnTest(object):
class BaseDNNLogitFnTest(object):
"""Tests correctness of logits calculated from _dnn_logit_fn_builder."""
- def __init__(self, dnn_logit_fn_builder):
+ def __init__(self, dnn_logit_fn_builder, fc_impl=feature_column):
self._dnn_logit_fn_builder = dnn_logit_fn_builder
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -541,7 +544,7 @@ class BaseDNNLogitFnTest(object):
units=logits_dimension,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column(
+ self._fc_impl.numeric_column(
'age', shape=np.array(inputs).shape[1:])
],
activation_fn=nn.relu,
@@ -786,8 +789,8 @@ class BaseDNNLogitFnTest(object):
units=logits_dimension,
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column('age'),
- feature_column.numeric_column('height')
+ self._fc_impl.numeric_column('age'),
+ self._fc_impl.numeric_column('height')
],
activation_fn=nn.relu,
dropout=None,
@@ -806,9 +809,13 @@ class BaseDNNLogitFnTest(object):
class BaseDNNWarmStartingTest(object):
- def __init__(self, _dnn_classifier_fn, _dnn_regressor_fn):
+ def __init__(self,
+ _dnn_classifier_fn,
+ _dnn_regressor_fn,
+ fc_impl=feature_column):
self._dnn_classifier_fn = _dnn_classifier_fn
self._dnn_regressor_fn = _dnn_regressor_fn
+ self._fc_impl = fc_impl
def setUp(self):
# Create a directory to save our old checkpoint and vocabularies to.
@@ -843,8 +850,8 @@ class BaseDNNWarmStartingTest(object):
def test_classifier_basic_warm_starting(self):
"""Tests correctness of DNNClassifier default warm-start."""
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ city = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -875,8 +882,8 @@ class BaseDNNWarmStartingTest(object):
def test_regressor_basic_warm_starting(self):
"""Tests correctness of DNNRegressor default warm-start."""
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ city = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -905,8 +912,8 @@ class BaseDNNWarmStartingTest(object):
def test_warm_starting_selective_variables(self):
"""Tests selecting variables to warm-start."""
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ city = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -958,8 +965,8 @@ class BaseDNNWarmStartingTest(object):
vocab_file = os.path.join(self._ckpt_and_vocab_dir, 'occupation_vocab')
with open(vocab_file, 'w') as f:
f.write('\n'.join(vocab_list))
- occupation = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_file(
+ occupation = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=vocab_file,
vocabulary_size=len(vocab_list)),
@@ -985,8 +992,8 @@ class BaseDNNWarmStartingTest(object):
'new_occupation_vocab')
with open(new_vocab_file, 'w') as f:
f.write('\n'.join(new_vocab_list))
- new_occupation = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_file(
+ new_occupation = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=new_vocab_file,
vocabulary_size=len(new_vocab_list)),
@@ -1051,8 +1058,8 @@ class BaseDNNWarmStartingTest(object):
def test_warm_starting_with_naming_change(self):
"""Tests warm-starting with a Tensor name remapping."""
- locality = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ locality = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'locality', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -1068,8 +1075,8 @@ class BaseDNNWarmStartingTest(object):
# Create a second DNNClassifier, warm-started from the first. Use a
# learning_rate = 0.0 optimizer to check values (use SGD so we don't have
# accumulator values that change).
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ city = self._fc_impl.embedding_column(
+ self._fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
warm_started_dnn_classifier = self._dnn_classifier_fn(
@@ -1101,8 +1108,9 @@ class BaseDNNWarmStartingTest(object):
class BaseDNNClassifierEvaluateTest(object):
- def __init__(self, dnn_classifier_fn):
+ def __init__(self, dnn_classifier_fn, fc_impl=feature_column):
self._dnn_classifier_fn = dnn_classifier_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1121,7 +1129,7 @@ class BaseDNNClassifierEvaluateTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age')],
+ feature_columns=[self._fc_impl.numeric_column('age')],
model_dir=self._model_dir)
def _input_fn():
# batch_size = 2, one false label, and one true.
@@ -1161,7 +1169,7 @@ class BaseDNNClassifierEvaluateTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age', shape=[2])],
+ feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],
n_classes=n_classes,
model_dir=self._model_dir)
def _input_fn():
@@ -1192,7 +1200,7 @@ class BaseDNNClassifierEvaluateTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age')],
+ feature_columns=[self._fc_impl.numeric_column('age')],
model_dir=self._model_dir)
def _input_fn():
# batch_size = 2, one false label, and one true.
@@ -1218,7 +1226,7 @@ class BaseDNNClassifierEvaluateTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age', shape=[2])],
+ feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],
n_classes=n_classes,
weight_column='w',
model_dir=self._model_dir)
@@ -1238,8 +1246,9 @@ class BaseDNNClassifierEvaluateTest(object):
class BaseDNNRegressorEvaluateTest(object):
- def __init__(self, dnn_regressor_fn):
+ def __init__(self, dnn_regressor_fn, fc_impl=feature_column):
self._dnn_regressor_fn = dnn_regressor_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1259,7 +1268,7 @@ class BaseDNNRegressorEvaluateTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age')],
+ feature_columns=[self._fc_impl.numeric_column('age')],
model_dir=self._model_dir)
def _input_fn():
return {'age': [[10.]]}, [[1.]]
@@ -1289,7 +1298,7 @@ class BaseDNNRegressorEvaluateTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age', shape=[2])],
+ feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],
label_dimension=label_dimension,
model_dir=self._model_dir)
def _input_fn():
@@ -1320,7 +1329,7 @@ class BaseDNNRegressorEvaluateTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=[feature_column.numeric_column('age', shape=[2])],
+ feature_columns=[self._fc_impl.numeric_column('age', shape=[2])],
label_dimension=label_dimension,
weight_column='w',
model_dir=self._model_dir)
@@ -1339,8 +1348,9 @@ class BaseDNNRegressorEvaluateTest(object):
class BaseDNNClassifierPredictTest(object):
- def __init__(self, dnn_classifier_fn):
+ def __init__(self, dnn_classifier_fn, fc_impl=feature_column):
self._dnn_classifier_fn = dnn_classifier_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1361,7 +1371,7 @@ class BaseDNNClassifierPredictTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
label_vocabulary=label_vocabulary,
- feature_columns=(feature_column.numeric_column('x'),),
+ feature_columns=(self._fc_impl.numeric_column('x'),),
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)
@@ -1405,7 +1415,7 @@ class BaseDNNClassifierPredictTest(object):
dnn_classifier = self._dnn_classifier_fn(
hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('x', shape=(2,)),),
+ feature_columns=(self._fc_impl.numeric_column('x', shape=(2,)),),
label_vocabulary=label_vocabulary,
n_classes=3,
model_dir=self._model_dir)
@@ -1453,8 +1463,9 @@ class BaseDNNClassifierPredictTest(object):
class BaseDNNRegressorPredictTest(object):
- def __init__(self, dnn_regressor_fn):
+ def __init__(self, dnn_regressor_fn, fc_impl=feature_column):
self._dnn_regressor_fn = dnn_regressor_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1475,7 +1486,7 @@ class BaseDNNRegressorPredictTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('x'),),
+ feature_columns=(self._fc_impl.numeric_column('x'),),
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
x={'x': np.array([[10.]])}, batch_size=1, shuffle=False)
@@ -1497,7 +1508,7 @@ class BaseDNNRegressorPredictTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=(2, 2),
- feature_columns=(feature_column.numeric_column('x', shape=(2,)),),
+ feature_columns=(self._fc_impl.numeric_column('x', shape=(2,)),),
label_dimension=3,
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
@@ -1594,8 +1605,9 @@ def _assert_simple_summary(testcase, expected_values, actual_summary):
class BaseDNNClassifierTrainTest(object):
- def __init__(self, dnn_classifier_fn):
+ def __init__(self, dnn_classifier_fn, fc_impl=feature_column):
self._dnn_classifier_fn = dnn_classifier_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1609,7 +1621,7 @@ class BaseDNNClassifierTrainTest(object):
hidden_units = (2, 2)
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
model_dir=self._model_dir)
# Train for a few steps, then validate final checkpoint.
@@ -1625,7 +1637,7 @@ class BaseDNNClassifierTrainTest(object):
n_classes = 3
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1643,7 +1655,7 @@ class BaseDNNClassifierTrainTest(object):
self, hidden_units=hidden_units)
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1682,7 +1694,7 @@ class BaseDNNClassifierTrainTest(object):
self, hidden_units=hidden_units, expected_loss=expected_loss)
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1728,7 +1740,7 @@ class BaseDNNClassifierTrainTest(object):
self, hidden_units=hidden_units, expected_loss=expected_loss)
dnn_classifier = self._dnn_classifier_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1759,7 +1771,7 @@ class BaseDNNClassifierTrainTest(object):
dnn_classifier = self._dnn_classifier_fn(
n_classes=n_classes,
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1793,8 +1805,9 @@ class BaseDNNClassifierTrainTest(object):
class BaseDNNRegressorTrainTest(object):
- def __init__(self, dnn_regressor_fn):
+ def __init__(self, dnn_regressor_fn, fc_impl=feature_column):
self._dnn_regressor_fn = dnn_regressor_fn
+ self._fc_impl = fc_impl
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1808,7 +1821,7 @@ class BaseDNNRegressorTrainTest(object):
hidden_units = (2, 2)
dnn_regressor = self._dnn_regressor_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
model_dir=self._model_dir)
# Train for a few steps, then validate final checkpoint.
@@ -1824,7 +1837,7 @@ class BaseDNNRegressorTrainTest(object):
opt = mock_optimizer(self, hidden_units=hidden_units)
dnn_regressor = self._dnn_regressor_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1864,7 +1877,7 @@ class BaseDNNRegressorTrainTest(object):
self, hidden_units=hidden_units, expected_loss=expected_loss)
dnn_regressor = self._dnn_regressor_fn(
hidden_units=hidden_units,
- feature_columns=(feature_column.numeric_column('age'),),
+ feature_columns=(self._fc_impl.numeric_column('age'),),
optimizer=opt,
model_dir=self._model_dir)
self.assertEqual(0, opt.minimize.call_count)
@@ -1917,7 +1930,8 @@ class BaseDNNRegressorTrainTest(object):
dnn_regressor = self._dnn_regressor_fn(
hidden_units=hidden_units,
feature_columns=[
- feature_column.numeric_column('age', shape=[input_dimension])],
+ self._fc_impl.numeric_column('age', shape=[input_dimension])
+ ],
label_dimension=label_dimension,
optimizer=opt,
model_dir=self._model_dir)
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index 115dd18518..8b96284bd3 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -25,14 +25,18 @@ import six
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import optimizers
-from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variable_ops
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import ftrl
+from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import estimator_export
@@ -46,23 +50,42 @@ def _get_default_optimizer(feature_columns):
return ftrl.FtrlOptimizer(learning_rate=learning_rate)
-def _compute_fraction_of_zero(cols_to_vars):
- """Given a linear cols_to_vars dict, compute the fraction of zero weights.
+def _get_expanded_variable_list(var_list):
+ """Given a list of variables, expands them if they are partitioned.
Args:
- cols_to_vars: A dictionary mapping FeatureColumns to lists of tf.Variables
- like one returned from feature_column_lib.linear_model.
+ var_list: A list of variables.
+
+ Returns:
+ A list of variables where each partitioned variable is expanded to its
+ components.
+ """
+ returned_list = []
+ for variable in var_list:
+ if (isinstance(variable, variable_ops.Variable) or
+ resource_variable_ops.is_resource_variable(variable)):
+ returned_list.append(variable) # Single variable case.
+ else: # Must be a PartitionedVariable, so convert into a list.
+ returned_list.extend(list(variable))
+ return returned_list
+
+
+# TODO(rohanj): Consider making this a public utility method.
+def _compute_fraction_of_zero(variables):
+ """Given a linear variables list, compute the fraction of zero weights.
+
+ Args:
+ variables: A list or list of list of variables
Returns:
The fraction of zeros (sparsity) in the linear model.
"""
all_weight_vars = []
- for var_or_var_list in cols_to_vars.values():
+ for var_or_var_list in variables:
+ var_list = nest.flatten(var_or_var_list)
# Skip empty-lists associated with columns that created no Variables.
- if var_or_var_list:
- all_weight_vars += [
- array_ops.reshape(var, [-1]) for var in var_or_var_list
- ]
+ if var_list:
+ all_weight_vars += [array_ops.reshape(var, [-1]) for var in var_list]
return nn.zero_fraction(array_ops.concat(all_weight_vars, axis=0))
@@ -92,14 +115,36 @@ def _linear_logit_fn_builder(units, feature_columns, sparse_combiner='sum'):
Returns:
A `Tensor` representing the logits.
"""
- cols_to_vars = {}
- logits = feature_column_lib.linear_model(
- features=features,
- feature_columns=feature_columns,
- units=units,
- sparse_combiner=sparse_combiner,
- cols_to_vars=cols_to_vars)
- bias = cols_to_vars.pop('bias')
+ if feature_column_v2.is_feature_column_v2(feature_columns):
+ shared_state_manager = feature_column_v2.SharedEmbeddingStateManager()
+ linear_model = feature_column_v2.LinearModel(
+ feature_columns=feature_columns,
+ units=units,
+ sparse_combiner=sparse_combiner,
+ shared_state_manager=shared_state_manager)
+ logits = linear_model(features)
+ bias = linear_model.bias_variable
+
+ # We'd like to get all the non-bias variables associated with this
+ # LinearModel. This includes the shared embedding variables as well.
+ variables = linear_model.variables
+ variables.remove(bias)
+ variables.extend(shared_state_manager.variables)
+
+ # Expand (potential) Partitioned variables
+ bias = _get_expanded_variable_list([bias])
+ variables = _get_expanded_variable_list(variables)
+ else:
+ linear_model = feature_column._LinearModel( # pylint: disable=protected-access
+ feature_columns=feature_columns,
+ units=units,
+ sparse_combiner=sparse_combiner,
+ name='linear_model')
+ logits = linear_model(features)
+ cols_to_vars = linear_model.cols_to_vars()
+ bias = cols_to_vars.pop('bias')
+ variables = cols_to_vars.values()
+
if units > 1:
summary.histogram('bias', bias)
else:
@@ -107,7 +152,7 @@ def _linear_logit_fn_builder(units, feature_columns, sparse_combiner='sum'):
# so we should provide a scalar summary.
summary.scalar('bias', bias[0][0])
summary.scalar('fraction_of_zero_weights',
- _compute_fraction_of_zero(cols_to_vars))
+ _compute_fraction_of_zero(variables))
return logits
return linear_logit_fn
diff --git a/tensorflow/python/estimator/canned/linear_test.py b/tensorflow/python/estimator/canned/linear_test.py
index 59a230417d..3e6da5de22 100644
--- a/tensorflow/python/estimator/canned/linear_test.py
+++ b/tensorflow/python/estimator/canned/linear_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
from tensorflow.python.estimator.canned import linear
from tensorflow.python.estimator.canned import linear_testing_utils
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.platform import test
@@ -40,7 +42,16 @@ class LinearRegressorPartitionerTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorPartitionerV2Test(
+ linear_testing_utils.BaseLinearRegressorPartitionerTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorEvaluationTest(
@@ -49,7 +60,16 @@ class LinearRegressorEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorEvaluationV2Test(
+ linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorPredictTest(
@@ -58,7 +78,16 @@ class LinearRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorPredictV2Test(
+ linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorIntegrationTest(
@@ -67,7 +96,16 @@ class LinearRegressorIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorIntegrationV2Test(
+ linear_testing_utils.BaseLinearRegressorIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorTrainingTest(
@@ -76,19 +114,37 @@ class LinearRegressorTrainingTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
-# Tests for Linear Classifier.
+class LinearRegressorTrainingV2Test(
+ linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase):
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
+
+# Tests for Linear Classifier.
class LinearClassifierTrainingTest(
linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierTrainingV2Test(
+ linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearClassifierEvaluationTest(
@@ -97,7 +153,18 @@ class LinearClassifierEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierEvaluationV2Test(
+ linear_testing_utils.BaseLinearClassifierEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearClassifierPredictTest(
@@ -106,7 +173,18 @@ class LinearClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierPredictV2Test(
+ linear_testing_utils.BaseLinearClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearClassifierIntegrationTest(
@@ -115,7 +193,18 @@ class LinearClassifierIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierIntegrationV2Test(
+ linear_testing_utils.BaseLinearClassifierIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
# Tests for Linear logit_fn.
@@ -124,7 +213,17 @@ class LinearLogitFnTest(linear_testing_utils.BaseLinearLogitFnTest,
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- linear_testing_utils.BaseLinearLogitFnTest.__init__(self)
+ linear_testing_utils.BaseLinearLogitFnTest.__init__(
+ self, fc_lib=feature_column)
+
+
+class LinearLogitFnV2Test(linear_testing_utils.BaseLinearLogitFnTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearLogitFnTest.__init__(
+ self, fc_lib=feature_column_v2)
# Tests for warm-starting with Linear logit_fn.
@@ -134,7 +233,22 @@ class LinearWarmStartingTest(linear_testing_utils.BaseLinearWarmStartingTest,
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearWarmStartingTest.__init__(
- self, _linear_classifier_fn, _linear_regressor_fn)
+ self,
+ _linear_classifier_fn,
+ _linear_regressor_fn,
+ fc_lib=feature_column)
+
+
+class LinearWarmStartingV2Test(linear_testing_utils.BaseLinearWarmStartingTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearWarmStartingTest.__init__(
+ self,
+ _linear_classifier_fn,
+ _linear_regressor_fn,
+ fc_lib=feature_column_v2)
if __name__ == '__main__':
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index 65cdd50061..827352a70b 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -37,7 +37,8 @@ from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.estimator.inputs import pandas_io
-from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -152,8 +153,9 @@ class CheckPartitionerVarHook(session_run_hook.SessionRunHook):
class BaseLinearRegressorPartitionerTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -173,7 +175,7 @@ class BaseLinearRegressorPartitionerTest(object):
return [partitions, 1] if shape[0] == x_dim else [1]
regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.categorical_column_with_hash_bucket(
+ feature_columns=(self._fc_lib.categorical_column_with_hash_bucket(
'language', hash_bucket_size=x_dim),),
partitioner=_partitioner,
model_dir=self._model_dir)
@@ -209,9 +211,8 @@ class BaseLinearRegressorPartitionerTest(object):
'_get_replica_device_setter',
return_value=lambda _: '/cpu:0'):
linear_regressor = self._linear_regressor_fn(
- feature_columns=(
- feature_column_lib.categorical_column_with_hash_bucket(
- 'language', hash_bucket_size=x_dim),),
+ feature_columns=(self._fc_lib.categorical_column_with_hash_bucket(
+ 'language', hash_bucket_size=x_dim),),
config=FakeRunConfig(),
model_dir=self._model_dir)
@@ -232,8 +233,9 @@ class BaseLinearRegressorPartitionerTest(object):
# TODO(b/36813849): Add tests with dynamic shape inputs using placeholders.
class BaseLinearRegressorEvaluationTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -252,7 +254,7 @@ class BaseLinearRegressorEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir)
eval_metrics = linear_regressor.evaluate(
input_fn=lambda: ({'age': ((1,),)}, ((10.,),)), steps=1)
@@ -276,7 +278,7 @@ class BaseLinearRegressorEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir)
eval_metrics = linear_regressor.evaluate(
input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1)
@@ -308,7 +310,7 @@ class BaseLinearRegressorEvaluationTest(object):
return features, labels
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
weight_column='weights',
model_dir=self._model_dir)
eval_metrics = linear_regressor.evaluate(input_fn=_input_fn, steps=1)
@@ -336,8 +338,7 @@ class BaseLinearRegressorEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column(
- 'age', shape=(x_dim,)),),
+ feature_columns=(self._fc_lib.numeric_column('age', shape=(x_dim,)),),
label_dimension=label_dim,
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
@@ -374,8 +375,8 @@ class BaseLinearRegressorEvaluationTest(object):
batch_size = 2
feature_columns = [
- feature_column_lib.numeric_column('age'),
- feature_column_lib.numeric_column('height')
+ self._fc_lib.numeric_column('age'),
+ self._fc_lib.numeric_column('height')
]
input_fn = numpy_io.numpy_input_fn(
x={'age': np.array([20, 40]),
@@ -402,8 +403,9 @@ class BaseLinearRegressorEvaluationTest(object):
class BaseLinearRegressorPredictTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -422,7 +424,7 @@ class BaseLinearRegressorPredictTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('x'),),
+ feature_columns=(self._fc_lib.numeric_column('x'),),
model_dir=self._model_dir)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -441,7 +443,7 @@ class BaseLinearRegressorPredictTest(object):
batch_size = 2
label_dimension = 3
x_dim = 4
- feature_columns = (feature_column_lib.numeric_column('x', shape=(x_dim,)),)
+ feature_columns = (self._fc_lib.numeric_column('x', shape=(x_dim,)),)
with ops.Graph().as_default():
variables_lib.Variable( # shape=[x_dim, label_dimension]
[[1., 2., 3.], [2., 3., 4.], [3., 4., 5.], [4., 5., 6.]],
@@ -479,8 +481,8 @@ class BaseLinearRegressorPredictTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('x0'),
- feature_column_lib.numeric_column('x1')),
+ feature_columns=(self._fc_lib.numeric_column('x0'),
+ self._fc_lib.numeric_column('x1')),
model_dir=self._model_dir)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -515,9 +517,8 @@ class BaseLinearRegressorPredictTest(object):
dense_shape=[2, 2]),
})
- feature_columns = (
- feature_column_lib.categorical_column_with_vocabulary_list(
- 'language', vocabulary_list=['a', 'b', 'c']),)
+ feature_columns = (self._fc_lib.categorical_column_with_vocabulary_list(
+ 'language', vocabulary_list=['a', 'b', 'c']),)
# Check prediction for each sparse_combiner.
# With sparse_combiner = 'sum', we have
@@ -561,8 +562,9 @@ class BaseLinearRegressorPredictTest(object):
class BaseLinearRegressorIntegrationTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -575,7 +577,7 @@ class BaseLinearRegressorIntegrationTest(object):
def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
input_dimension, label_dimension, prediction_length):
feature_columns = [
- feature_column_lib.numeric_column('x', shape=(input_dimension,))
+ self._fc_lib.numeric_column('x', shape=(input_dimension,))
]
est = self._linear_regressor_fn(
feature_columns=feature_columns,
@@ -597,7 +599,7 @@ class BaseLinearRegressorIntegrationTest(object):
self.assertAllEqual((prediction_length, label_dimension), predictions.shape)
# EXPORT
- feature_spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ feature_spec = self._fc_lib.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
@@ -729,8 +731,9 @@ class BaseLinearRegressorIntegrationTest(object):
class BaseLinearRegressorTrainingTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -808,7 +811,7 @@ class BaseLinearRegressorTrainingTest(object):
label = 5.
age = 17
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir)
# Train for a few steps, and validate final checkpoint.
@@ -820,7 +823,7 @@ class BaseLinearRegressorTrainingTest(object):
def testTrainWithOneDimLabel(self):
label_dimension = 1
batch_size = 20
- feature_columns = [feature_column_lib.numeric_column('age', shape=(1,))]
+ feature_columns = [self._fc_lib.numeric_column('age', shape=(1,))]
est = self._linear_regressor_fn(
feature_columns=feature_columns,
label_dimension=label_dimension,
@@ -840,7 +843,7 @@ class BaseLinearRegressorTrainingTest(object):
def testTrainWithOneDimWeight(self):
label_dimension = 1
batch_size = 20
- feature_columns = [feature_column_lib.numeric_column('age', shape=(1,))]
+ feature_columns = [self._fc_lib.numeric_column('age', shape=(1,))]
est = self._linear_regressor_fn(
feature_columns=feature_columns,
label_dimension=label_dimension,
@@ -867,7 +870,7 @@ class BaseLinearRegressorTrainingTest(object):
# loss = (logits - label)^2 = (0 - 5.)^2 = 25.
mock_optimizer = self._mock_optimizer(expected_loss=25.)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir,
optimizer=mock_optimizer)
self.assertEqual(0, mock_optimizer.minimize.call_count)
@@ -900,7 +903,7 @@ class BaseLinearRegressorTrainingTest(object):
# loss = (logits - label)^2 = (175 - 5)^2 = 28900
mock_optimizer = self._mock_optimizer(expected_loss=28900.)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir,
optimizer=mock_optimizer)
self.assertEqual(0, mock_optimizer.minimize.call_count)
@@ -935,7 +938,7 @@ class BaseLinearRegressorTrainingTest(object):
# loss = sum(logits - label)^2 = (175 - 5)^2 + (155 - 3)^2 = 52004
mock_optimizer = self._mock_optimizer(expected_loss=52004.)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir,
optimizer=mock_optimizer)
self.assertEqual(0, mock_optimizer.minimize.call_count)
@@ -954,8 +957,9 @@ class BaseLinearRegressorTrainingTest(object):
class BaseLinearClassifierTrainingTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1031,7 +1035,7 @@ class BaseLinearClassifierTrainingTest(object):
label = 0
age = 17
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1051,7 +1055,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
data_rank_1 = np.array([0, 1])
@@ -1078,7 +1082,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
data_rank_1 = np.array([0, 1])
@@ -1103,7 +1107,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
weight_column='w',
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1129,7 +1133,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
weight_column='w',
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1166,7 +1170,7 @@ class BaseLinearClassifierTrainingTest(object):
expected_loss=-1 * math.log(1.0/n_classes))
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1229,7 +1233,7 @@ class BaseLinearClassifierTrainingTest(object):
mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1277,7 +1281,7 @@ class BaseLinearClassifierTrainingTest(object):
mock_optimizer = self._mock_optimizer(expected_loss=1.1132617)
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1341,7 +1345,7 @@ class BaseLinearClassifierTrainingTest(object):
mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1368,8 +1372,9 @@ class BaseLinearClassifierTrainingTest(object):
class BaseLinearClassifierEvaluationTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1398,7 +1403,7 @@ class BaseLinearClassifierEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
eval_metrics = est.evaluate(
@@ -1464,7 +1469,7 @@ class BaseLinearClassifierEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
eval_metrics = est.evaluate(
@@ -1540,7 +1545,7 @@ class BaseLinearClassifierEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
weight_column='w',
model_dir=self._model_dir)
@@ -1605,8 +1610,9 @@ class BaseLinearClassifierEvaluationTest(object):
class BaseLinearClassifierPredictTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1634,7 +1640,7 @@ class BaseLinearClassifierPredictTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
label_vocabulary=label_vocabulary,
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1730,9 +1736,8 @@ class BaseLinearClassifierPredictTest(object):
dense_shape=[2, 2]),
})
- feature_columns = (
- feature_column_lib.categorical_column_with_vocabulary_list(
- 'language', vocabulary_list=['a', 'b', 'c']),)
+ feature_columns = (self._fc_lib.categorical_column_with_vocabulary_list(
+ 'language', vocabulary_list=['a', 'b', 'c']),)
# Check prediction for each sparse_combiner.
# With sparse_combiner = 'sum', we have
@@ -1776,8 +1781,9 @@ class BaseLinearClassifierPredictTest(object):
class BaseLinearClassifierIntegrationTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1789,7 +1795,7 @@ class BaseLinearClassifierIntegrationTest(object):
def _test_complete_flow(self, n_classes, train_input_fn, eval_input_fn,
predict_input_fn, input_dimension, prediction_length):
feature_columns = [
- feature_column_lib.numeric_column('x', shape=(input_dimension,))
+ self._fc_lib.numeric_column('x', shape=(input_dimension,))
]
est = self._linear_classifier_fn(
feature_columns=feature_columns,
@@ -1811,7 +1817,7 @@ class BaseLinearClassifierIntegrationTest(object):
self.assertAllEqual((prediction_length, 1), predictions.shape)
# EXPORT
- feature_spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ feature_spec = self._fc_lib.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
@@ -1961,9 +1967,12 @@ class BaseLinearClassifierIntegrationTest(object):
class BaseLinearLogitFnTest(object):
+ def __init__(self, fc_lib=feature_column):
+ self._fc_lib = fc_lib
+
def test_basic_logit_correctness(self):
"""linear_logit_fn simply wraps feature_column_lib.linear_model."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
with ops.Graph().as_default():
logit_fn = linear._linear_logit_fn_builder(units=2, feature_columns=[age])
logits = logit_fn(features={'age': [[23.], [31.]]})
@@ -1983,12 +1992,14 @@ class BaseLinearLogitFnTest(object):
def test_compute_fraction_of_zero(self):
"""Tests the calculation of sparsity."""
- age = feature_column_lib.numeric_column('age')
- occupation = feature_column_lib.categorical_column_with_hash_bucket(
+ if self._fc_lib != feature_column:
+ return
+ age = feature_column.numeric_column('age')
+ occupation = feature_column.categorical_column_with_hash_bucket(
'occupation', hash_bucket_size=5)
with ops.Graph().as_default():
cols_to_vars = {}
- feature_column_lib.linear_model(
+ feature_column.linear_model(
features={
'age': [[23.], [31.]],
'occupation': [['doctor'], ['engineer']]
@@ -1997,7 +2008,42 @@ class BaseLinearLogitFnTest(object):
units=3,
cols_to_vars=cols_to_vars)
cols_to_vars.pop('bias')
- fraction_zero = linear._compute_fraction_of_zero(cols_to_vars)
+ fraction_zero = linear._compute_fraction_of_zero(cols_to_vars.values())
+ age_var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
+ 'linear_model/age')[0]
+ with tf_session.Session() as sess:
+ sess.run([variables_lib.global_variables_initializer()])
+ # Upon initialization, all variables will be zero.
+ self.assertAllClose(1, fraction_zero.eval())
+
+ sess.run(age_var.assign([[2.0, 0.0, -1.0]]))
+ # 1 of the 3 age weights are zero, and all of the 15 (5 hash buckets
+ # x 3-dim output) are zero.
+ self.assertAllClose(16. / 18., fraction_zero.eval())
+
+ def test_compute_fraction_of_zero_v2(self):
+ """Tests the calculation of sparsity."""
+ if self._fc_lib != feature_column_v2:
+ return
+
+ age = feature_column_v2.numeric_column('age')
+ occupation = feature_column_v2.categorical_column_with_hash_bucket(
+ 'occupation', hash_bucket_size=5)
+ shared_state_manager = feature_column_v2.SharedEmbeddingStateManager()
+ with ops.Graph().as_default():
+ model = feature_column_v2.LinearModel(
+ feature_columns=[age, occupation],
+ units=3,
+ shared_state_manager=shared_state_manager)
+ features = {
+ 'age': [[23.], [31.]],
+ 'occupation': [['doctor'], ['engineer']]
+ }
+ model(features)
+ variables = model.variables
+ variables.remove(model.bias_variable)
+ variables.extend(shared_state_manager.variables)
+ fraction_zero = linear._compute_fraction_of_zero(variables)
age_var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
'linear_model/age')[0]
with tf_session.Session() as sess:
@@ -2013,9 +2059,13 @@ class BaseLinearLogitFnTest(object):
class BaseLinearWarmStartingTest(object):
- def __init__(self, _linear_classifier_fn, _linear_regressor_fn):
+ def __init__(self,
+ _linear_classifier_fn,
+ _linear_regressor_fn,
+ fc_lib=feature_column):
self._linear_classifier_fn = _linear_classifier_fn
self._linear_regressor_fn = _linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
# Create a directory to save our old checkpoint and vocabularies to.
@@ -2039,7 +2089,7 @@ class BaseLinearWarmStartingTest(object):
def test_classifier_basic_warm_starting(self):
"""Tests correctness of LinearClassifier default warm-start."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
# Create a LinearClassifier and train to save a checkpoint.
linear_classifier = self._linear_classifier_fn(
@@ -2066,7 +2116,7 @@ class BaseLinearWarmStartingTest(object):
def test_regressor_basic_warm_starting(self):
"""Tests correctness of LinearRegressor default warm-start."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
# Create a LinearRegressor and train to save a checkpoint.
linear_regressor = self._linear_regressor_fn(
@@ -2091,7 +2141,7 @@ class BaseLinearWarmStartingTest(object):
def test_warm_starting_selective_variables(self):
"""Tests selecting variables to warm-start."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
# Create a LinearClassifier and train to save a checkpoint.
linear_classifier = self._linear_classifier_fn(
@@ -2128,7 +2178,7 @@ class BaseLinearWarmStartingTest(object):
vocab_file = os.path.join(self._ckpt_and_vocab_dir, 'occupation_vocab')
with open(vocab_file, 'w') as f:
f.write('\n'.join(vocab_list))
- occupation = feature_column_lib.categorical_column_with_vocabulary_file(
+ occupation = self._fc_lib.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=vocab_file,
vocabulary_size=len(vocab_list))
@@ -2152,7 +2202,7 @@ class BaseLinearWarmStartingTest(object):
'new_occupation_vocab')
with open(new_vocab_file, 'w') as f:
f.write('\n'.join(new_vocab_list))
- new_occupation = feature_column_lib.categorical_column_with_vocabulary_file(
+ new_occupation = self._fc_lib.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=new_vocab_file,
vocabulary_size=len(new_vocab_list))
@@ -2205,7 +2255,7 @@ class BaseLinearWarmStartingTest(object):
def test_warm_starting_with_naming_change(self):
"""Tests warm-starting with a Tensor name remapping."""
- age_in_years = feature_column_lib.numeric_column('age_in_years')
+ age_in_years = self._fc_lib.numeric_column('age_in_years')
# Create a LinearClassifier and train to save a checkpoint.
linear_classifier = self._linear_classifier_fn(
@@ -2219,7 +2269,7 @@ class BaseLinearWarmStartingTest(object):
# learning_rate = 0.0 optimizer to check values (use SGD so we don't have
# accumulator values that change).
warm_started_linear_classifier = self._linear_classifier_fn(
- feature_columns=[feature_column_lib.numeric_column('age')],
+ feature_columns=[self._fc_lib.numeric_column('age')],
n_classes=4,
optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
# The 'age' variable correspond to the 'age_in_years' variable in the
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index eec64ad452..e6d82f0db7 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -144,7 +144,7 @@ class Estimator(object):
* `labels`: This is the second item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
single `tf.Tensor` or `dict` of same (for multi-head models).
- If mode is @{tf.estimator.ModeKeys.PREDICT}, `labels=None` will
+ If mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will
be passed. If the `model_fn`'s signature does not accept
`mode`, the `model_fn` must still be able to handle
`labels=None`.
@@ -468,17 +468,41 @@ class Estimator(object):
with ops.Graph().as_default():
if self._eval_distribution:
+ # We want to create the iterations variable outside the distribution
+ # scope as that is just stored on the host and mainly used to drive
+ # the loop and doesn't need to be a Mirrored/Device variable.
+ training.get_or_create_steps_per_run_variable()
with self._eval_distribution.scope():
return _evaluate()
else:
return _evaluate()
def _convert_eval_steps_to_hooks(self, steps):
+ """Create hooks to run correct number of steps in evaluation.
+
+ Args:
+ steps: number of steps to run during evaluation.
+
+ Raises:
+ ValueError: if steps is less than or equal to zero.
+
+ Returns:
+ List of hooks to be passed to the estimator.
+ """
if steps is None:
return []
if steps <= 0:
raise ValueError('Must specify steps > 0, given: {}'.format(steps))
+
+ # The hooks are declared as private in evaluation.py discourage the use
+ # by other libraries or open source users. This should be the only usage
+ # of the estimator evaluation hooks.
+ if self._eval_distribution:
+ steps_per_run = getattr(self._eval_distribution, 'steps_per_run', 1)
+ if steps_per_run > 1:
+ return [evaluation._MultiStepStopAfterNEvalsHook( # pylint: disable=protected-access
+ num_evals=steps, steps_per_run=steps_per_run)]
return [evaluation._StopAfterNEvalsHook(num_evals=steps)] # pylint: disable=protected-access
def predict(self,
@@ -783,9 +807,9 @@ class Estimator(object):
those features and labels, and restores the given checkpoint
(or, lacking that, the most recent checkpoint) into the graph.
Only one of the modes is used for saving variables to the `SavedModel`
- (order of preference: @{tf.estimator.ModeKeys#TRAIN$TRAIN},
- @{tf.estimator.ModeKeys#EVAL$EVAL}, then
- @{tf.estimator.ModeKeys#PREDICT$PREDICT}), such that up to three
+ (order of preference: `tf.estimator.ModeKeys.TRAIN`,
+ `tf.estimator.ModeKeys.EVAL`, then
+ `tf.estimator.ModeKeys.PREDICT`), such that up to three
`tf.MetaGraphDefs` are saved with a single set of variables in a single
`SavedModel` directory.
@@ -1081,7 +1105,7 @@ class Estimator(object):
"""Creates the global step tensor in graph.
The global step tensor must be an integer type with name 'global_step' and
- be added to the collection @{tf.GraphKeys#GLOBAL_STEP$GLOBAL_STEP}.
+ be added to the collection `tf.GraphKeys.GLOBAL_STEP`.
Args:
graph: The graph in which to create the global step tensor.
@@ -1394,6 +1418,36 @@ class Estimator(object):
# It is expected to have one CheckpointSaverHook. If multiple, we pick
# up the first one to add listener.
saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access
+
+ # Add summary hooks to worker 0 if we are running with a master, to ensure
+ # that summaries are written at correct intervals even with long-running
+ # evaluations.
+ save_summary_steps = self._config.save_summary_steps
+ log_step_count_steps = self._config.log_step_count_steps
+ if (self._config.cluster_spec and self._config.cluster_spec.jobs and
+ (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
+ # Update config values to prevent the default hooks from being created on
+ # the master or other workers.
+ save_summary_steps = 0
+ log_step_count_steps = None
+
+ if (self._config.task_type == run_config.TaskType.WORKER and
+ self._config.task_id == 0):
+ if (self._config.save_summary_steps and
+ self._config.save_summary_steps > 0):
+ worker_hooks.append(
+ training.SummarySaverHook(
+ save_steps=self._config.save_summary_steps,
+ output_dir=self._config.model_dir,
+ scaffold=estimator_spec.scaffold))
+
+ if (self._config.log_step_count_steps and
+ self._config.log_step_count_steps > 0):
+ worker_hooks.append(
+ training.StepCounterHook(
+ every_n_steps=self._config.log_step_count_steps,
+ output_dir=self._config.model_dir))
+
with training.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
@@ -1403,9 +1457,9 @@ class Estimator(object):
chief_only_hooks=(
tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
save_checkpoint_secs=0, # Saving is handled by a hook.
- save_summaries_steps=self._config.save_summary_steps,
+ save_summaries_steps=save_summary_steps,
config=self._session_config,
- log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
+ log_step_count_steps=log_step_count_steps) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
@@ -1474,6 +1528,7 @@ class Estimator(object):
self._eval_distribution.__class__.__name__ == 'TPUStrategy')
if is_tpu_strategy:
+ steps_per_run_variable = training.get_or_create_steps_per_run_variable()
def step_fn(ctx, features, labels=None):
"""Runs one step of the eval computation and captures outputs."""
estimator_spec = self._eval_distribution.call_for_each_tower(
@@ -1490,7 +1545,7 @@ class Estimator(object):
# TODO(priyag): Fix eval step hook to account for steps_per_run.
ctx = self._eval_distribution.run_steps_on_dataset(
- step_fn, iterator, iterations=self._eval_distribution.steps_per_run)
+ step_fn, iterator, iterations=steps_per_run_variable)
update_op = ctx.run_op
eval_dict = ctx.non_tensor_outputs['eval_dict']
grouped_estimator_spec = ctx.non_tensor_outputs['estimator_spec']
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index 1ed5e30b0e..246dfb1a4b 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import functools
import glob
+import json
import os
import tempfile
@@ -969,6 +970,99 @@ class EstimatorTrainTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'):
est.train(dummy_input_fn, steps=1)
+ def test_master_distributed_hooks(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.MASTER,
+ 'index': 0
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
+ def test_master_distributed_hooks_for_worker_0(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.WORKER,
+ 'index': 0
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertTrue(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertTrue(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
+ def test_master_distributed_hooks_for_worker_nonzero(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235', 'localhost:1237'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.WORKER,
+ 'index': 1
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
_, _ = features, labels
@@ -1017,7 +1111,7 @@ class EstimatorGetVariablesTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='one')
+ variables.VariableV1(1., name='one')
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=constant_op.constant(0.),
@@ -1033,8 +1127,8 @@ class EstimatorGetVariablesTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='one')
- variables.Variable(3., name='three')
+ variables.VariableV1(1., name='one')
+ variables.VariableV1(3., name='three')
return model_fn_lib.EstimatorSpec(
mode=mode,
loss=constant_op.constant(0.),
@@ -1178,13 +1272,13 @@ class EstimatorEvaluateTest(test.TestCase):
def _model_fn(features, labels, mode, params):
del features, labels, params
mean = metrics_module.Mean()
- mean.update_state(variables.Variable(2.) + 1)
+ mean.update_state(variables.VariableV1(2.) + 1)
return model_fn_lib.EstimatorSpec(
mode,
loss=constant_op.constant(1.),
eval_metric_ops={
'mean1': mean,
- 'mean2': metrics_lib.mean(variables.Variable(2.) + 1)
+ 'mean2': metrics_lib.mean(variables.VariableV1(2.) + 1)
})
est = estimator.Estimator(model_fn=_model_fn)
@@ -1332,7 +1426,7 @@ class EstimatorEvaluateTest(test.TestCase):
def _model_fn_with_incremental_loss(features, labels, mode):
_, _ = features, labels
- local_weight = variables.Variable(
+ local_weight = variables.VariableV1(
0., name='local_weight', collections=[ops.GraphKeys.LOCAL_VARIABLES])
# Loss will be 2, 4, 6, ...
loss = 2 * state_ops.assign_add(local_weight, 1.)
@@ -1385,7 +1479,7 @@ class EstimatorEvaluateTest(test.TestCase):
def _get_model_fn(val=1):
def _model_fn(features, labels, mode):
del features, labels # unused
- variables.Variable(val, name='weight')
+ variables.VariableV1(val, name='weight')
return model_fn_lib.EstimatorSpec(
mode=mode,
predictions=constant_op.constant([[1.]]),
@@ -1409,7 +1503,7 @@ class EstimatorEvaluateTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
self.mock_saver = get_mock_saver()
return model_fn_lib.EstimatorSpec(
mode=mode,
@@ -1603,7 +1697,7 @@ class EstimatorPredictTest(test.TestCase):
def test_no_checkpoint_uses_init(self):
def _model_fn(features, labels, mode, params, config):
del features, labels, params, config
- x = variables.Variable([[3.]], name='x')
+ x = variables.VariableV1([[3.]], name='x')
return model_fn_lib.EstimatorSpec(mode, predictions=math_ops.add(x, 1.))
est = estimator.Estimator(model_fn=_model_fn)
# Expected prediction value is 1 + the value of the Variable that is newly
@@ -1614,7 +1708,7 @@ class EstimatorPredictTest(test.TestCase):
def _make_model_fn(x):
def _variable_creating_and_export_model_fn(features, labels, mode):
_, _ = features, labels
- x_var = variables.Variable([[x]], name='x')
+ x_var = variables.VariableV1([[x]], name='x')
return model_fn_lib.EstimatorSpec(
mode,
predictions=math_ops.add(x_var, 1.),
@@ -1936,7 +2030,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- v = variables.Variable([[16.]], name='weight')
+ v = variables.VariableV1([[16.]], name='weight')
prediction = v * 2
return model_fn_lib.EstimatorSpec(
mode,
@@ -1953,7 +2047,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- v = variables.Variable([[16.]], name='weight')
+ v = variables.VariableV1([[16.]], name='weight')
prediction = v * 2
return model_fn_lib.EstimatorSpec(
mode,
@@ -1974,7 +2068,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
self.mock_saver = get_mock_saver()
return model_fn_lib.EstimatorSpec(
mode=mode,
@@ -2029,7 +2123,7 @@ class EstimatorPredictTest(test.TestCase):
def _model_fn_for_export_tests(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
scores = constant_op.constant([3.])
classes = constant_op.constant(['wumpus'])
update_global_step = state_ops.assign_add(training.get_global_step(), 1)
@@ -2052,11 +2146,11 @@ def _x_y_input_fn():
def _model_fn_with_x_y(features, labels, mode):
_ = labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
scores = constant_op.constant([3.])
classes = constant_op.constant(['wumpus'])
if mode == model_fn_lib.ModeKeys.PREDICT:
- variables.Variable(36., name='name_collision')
+ variables.VariableV1(36., name='name_collision')
return model_fn_lib.EstimatorSpec(
mode,
predictions=constant_op.constant(10.),
@@ -2076,8 +2170,8 @@ def _model_fn_with_x_y(features, labels, mode):
metrics_lib.mean(
features['x'] - features['y'], name='{}mean'.format(prefix))
}
- variables.Variable(1., name='later_var')
- variables.Variable(3., name='name_collision')
+ variables.VariableV1(1., name='later_var')
+ variables.VariableV1(3., name='name_collision')
return model_fn_lib.EstimatorSpec(
mode,
predictions=multiplied,
@@ -2411,9 +2505,9 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_with_predict_only_vars(features, labels, mode):
_, _ = features, labels
if mode == model_fn_lib.ModeKeys.PREDICT:
- variables.Variable(1., name='only_in_predict')
+ variables.VariableV1(1., name='only_in_predict')
else:
- variables.Variable(1., name='otherwise')
+ variables.VariableV1(1., name='otherwise')
prediction = constant_op.constant(1.)
return model_fn_lib.EstimatorSpec(
@@ -2684,7 +2778,7 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
self.mock_saver = get_mock_saver()
scores = constant_op.constant([3.])
return model_fn_lib.EstimatorSpec(
@@ -2717,7 +2811,7 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
scores = constant_op.constant([3.])
if mode == model_fn_lib.ModeKeys.PREDICT:
@@ -2762,8 +2856,8 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- my_int = variables.Variable(1, name='my_int',
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ my_int = variables.VariableV1(1, name='my_int',
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
_ = training.get_or_create_steps_per_run_variable()
scores = constant_op.constant([3.])
with ops.control_dependencies([
@@ -2808,8 +2902,8 @@ class EstimatorExportTest(test.TestCase):
def _model_fn_scaffold(features, labels, mode):
_, _ = features, labels
- my_int = variables.Variable(1, name='my_int',
- collections=[ops.GraphKeys.LOCAL_VARIABLES])
+ my_int = variables.VariableV1(1, name='my_int',
+ collections=[ops.GraphKeys.LOCAL_VARIABLES])
scores = constant_op.constant([3.])
with ops.control_dependencies([
variables.local_variables_initializer(),
@@ -3038,7 +3132,7 @@ class EstimatorExportTest(test.TestCase):
def _model_fn(features, labels, mode):
_, _ = features, labels
- variables.Variable(1., name='weight')
+ variables.VariableV1(1., name='weight')
return model_fn_lib.EstimatorSpec(
mode,
predictions=constant_op.constant(10.),
@@ -3081,7 +3175,7 @@ class EstimatorHookOrderingTest(test.TestCase):
"""A graph that generates NaN's for testing."""
del features, labels
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, name='global_step')
inc_global_step = state_ops.assign_add(global_step, 1)
nan_const = constant_op.constant(np.nan, dtype=dtypes.float32)
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 6b2765be82..5d5ed81fbb 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import os
import re
+import six
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator as estimator_lib
@@ -31,6 +32,7 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import metrics
from tensorflow.python.keras import models
from tensorflow.python.keras import optimizers
from tensorflow.python.ops import check_ops
@@ -214,25 +216,40 @@ def _convert_keras_metrics_to_estimator(model):
if not getattr(model, 'metrics', None):
return None
- # TODO(psv/fchollet): support stateful metrics
eval_metric_ops = {}
+
+ def get_metric_name(metric):
+ if isinstance(metric, metrics.Metric):
+ return metric.name
+ if callable(metric):
+ return metric.__name__
+ assert isinstance(metric, six.string_types)
+ return metric
+
# When each metric maps to an output
if isinstance(model.metrics, dict):
for i, output_name in enumerate(model.metrics.keys()):
- metric_name = model.metrics[output_name]
- if callable(metric_name):
- metric_name = metric_name.__name__
+ # `metric` is the user given metric value in `compile`. This can be
+ # metric name (`acc`), metric function (binary_accuracy) or a metric
+ # object (BinaryAccuracy()).
+ metric = model.metrics[output_name]
+ metric_name = get_metric_name(metric)
# When some outputs use the same metric
if list(model.metrics.values()).count(metric_name) > 1:
metric_name += '_' + output_name
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i - len(model.metrics)])
+ if isinstance(metric, metrics.Metric):
+ eval_metric_ops[metric_name] = metric
+ else:
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i - len(model.metrics)])
else:
- for i, metric_name in enumerate(model.metrics):
- if callable(metric_name):
- metric_name = metric_name.__name__
- eval_metric_ops[metric_name] = metrics_module.mean(
- model.metrics_tensors[i])
+ for i, metric in enumerate(model.metrics):
+ metric_name = get_metric_name(metric)
+ if isinstance(metric, metrics.Metric):
+ eval_metric_ops[metric_name] = metric
+ else:
+ eval_metric_ops[metric_name] = metrics_module.mean(
+ model.metrics_tensors[i])
return eval_metric_ops
@@ -351,6 +368,44 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
return latest_path
+def _get_file_from_google_storage(keras_model_path, model_dir):
+ """Get file from google storage and download to local file.
+
+ Args:
+ keras_model_path: a google storage path for compiled keras model.
+ model_dir: the directory from estimator config.
+
+ Returns:
+ The path where keras model is saved.
+
+ Raises:
+ ValueError: if storage object name does not end with .h5.
+ """
+ try:
+ from google.cloud import storage # pylint:disable=g-import-not-at-top
+ except ImportError:
+ raise TypeError('Could not save model to Google cloud storage; please '
+ 'install `google-cloud-storage` via '
+ '`pip install google-cloud-storage`.')
+ storage_client = storage.Client()
+ path, blob_name = os.path.split(keras_model_path)
+ _, bucket_name = os.path.split(path)
+ keras_model_dir = os.path.join(model_dir, 'keras')
+ if not gfile.Exists(keras_model_dir):
+ gfile.MakeDirs(keras_model_dir)
+ file_name = os.path.join(keras_model_dir, 'keras_model.h5')
+ try:
+ blob = storage_client.get_bucket(bucket_name).blob(blob_name)
+ blob.download_to_filename(file_name)
+ except:
+ raise ValueError('Failed to download keras model, please check '
+ 'environment variable GOOGLE_APPLICATION_CREDENTIALS '
+ 'and model path storage.googleapis.com/{bucket}/{object}.')
+ logging.info('Saving model to {}'.format(file_name))
+ del storage_client
+ return file_name
+
+
def model_to_estimator(keras_model=None,
keras_model_path=None,
custom_objects=None,
@@ -390,12 +445,13 @@ def model_to_estimator(keras_model=None,
'Please specity either `keras_model` or `keras_model_path`, '
'but not both.')
+ config = estimator_lib.maybe_overwrite_model_dir_and_session_config(
+ config, model_dir)
if not keras_model:
if keras_model_path.startswith(
'gs://') or 'storage.googleapis.com' in keras_model_path:
- raise ValueError(
- '%s is not a local path. Please copy the model locally first.' %
- keras_model_path)
+ keras_model_path = _get_file_from_google_storage(keras_model_path,
+ config.model_dir)
logging.info('Loading models from %s', keras_model_path)
keras_model = models.load_model(keras_model_path)
else:
@@ -408,9 +464,6 @@ def model_to_estimator(keras_model=None,
'Please compile the model with `model.compile()` '
'before calling `model_to_estimator()`.')
- config = estimator_lib.maybe_overwrite_model_dir_and_session_config(config,
- model_dir)
-
keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
if _any_weight_initialized(keras_model):
# Warn if config passed to estimator tries to update GPUOptions. If a
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 3758243d7b..4e285fa25a 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -257,7 +257,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -281,7 +281,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
my_hook = MyHook()
with self.cached_session():
@@ -306,7 +306,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
my_hook = MyHook()
with self.cached_session():
keras_model.fit(x_train, y_train, epochs=1)
@@ -328,7 +328,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -351,7 +351,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -370,7 +370,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
# Create state
@@ -581,12 +581,6 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, 'compiled'):
keras_lib.model_to_estimator(keras_model=keras_model)
- with self.cached_session():
- keras_model = simple_sequential_model()
- with self.assertRaisesRegexp(ValueError, 'not a local path'):
- keras_lib.model_to_estimator(
- keras_model_path='gs://bucket/object')
-
def test_invalid_ionames_error(self):
(x_train, y_train), (_, _) = testing_utils.get_test_data(
train_samples=_TRAIN_SIZE,
@@ -662,7 +656,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
tf_config = json.dumps({
'cluster': {
@@ -687,7 +681,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
gpu_options = config_pb2.GPUOptions(per_process_gpu_memory_fraction=0.3)
sess_config = config_pb2.ConfigProto(gpu_options=gpu_options)
@@ -706,7 +700,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
est_keras = keras_lib.model_to_estimator(
@@ -736,7 +730,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
with test.mock.patch.object(tempfile, 'mkdtemp', return_value=_TMP_DIR):
@@ -751,7 +745,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer='rmsprop',
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
with self.assertRaisesRegexp(ValueError, '`model_dir` are set both in '
@@ -765,7 +759,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=rmsprop.RMSPropOptimizer(1e-3),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session():
keras_model.train_on_batch(
np.random.random((10,) + _INPUT_SIZE),
@@ -776,7 +770,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=SGD(lr=0.0001, momentum=0.9),
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
keras_lib.model_to_estimator(
keras_model=keras_model, config=self._config)
@@ -786,7 +780,7 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
keras_model.compile(
loss='categorical_crossentropy',
optimizer=optimizer,
- metrics=['mse', keras.metrics.categorical_accuracy])
+ metrics=['mse', keras.metrics.CategoricalAccuracy()])
with self.cached_session() as sess:
keras_model_fn = keras_lib._create_keras_model_fn(keras_model)
global_step = training_util.create_global_step()
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index 31e4778e72..fb110c4b7b 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import os
import time
-from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training
@@ -144,14 +143,11 @@ class StrategyInitFinalizeHook(training.SessionRunHook):
self._finalize_fn = finalize_fn
def begin(self):
+ # We only create the init ops, but don't run it. We rely on SessionManager
+ # to run it for us.
self._init_ops = self._initialization_fn()
self._finalize_ops = self._finalize_fn()
- def after_create_session(self, session, coord):
- logging.info('Initialize system')
- session.run(self._init_ops,
- options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
-
def end(self, session):
logging.info('Finalize system.')
session.run(self._finalize_ops)
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 5800b693b4..ac53a84eef 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -156,7 +156,7 @@ py_test(
"//tensorflow/python:variables",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
- "//tensorflow/python/estimator:estimator_py",
+ "//tensorflow/python/estimator:numpy_io",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/feature_column/feature_column.py b/tensorflow/python/feature_column/feature_column.py
index 9984379e9d..28a8286544 100644
--- a/tensorflow/python/feature_column/feature_column.py
+++ b/tensorflow/python/feature_column/feature_column.py
@@ -170,7 +170,8 @@ def _internal_input_layer(features,
trainable=True,
cols_to_vars=None,
scope=None,
- cols_to_output_tensors=None):
+ cols_to_output_tensors=None,
+ from_template=False):
"""See input_layer. `scope` is a name or variable scope to use."""
feature_columns = _normalize_feature_columns(feature_columns)
@@ -186,10 +187,7 @@ def _internal_input_layer(features,
if ops.GraphKeys.MODEL_VARIABLES not in weight_collections:
weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
- # a non-None `scope` can allow for variable reuse, when, e.g., this function
- # is wrapped by a `make_template`.
- with variable_scope.variable_scope(
- scope, default_name='input_layer', values=features.values()):
+ def _get_logits(): # pylint: disable=missing-docstring
builder = _LazyBuilder(features)
output_tensors = []
ordered_columns = []
@@ -217,6 +215,16 @@ def _internal_input_layer(features,
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
+ # If we're constructing from the `make_template`, that by default adds a
+ # variable scope with the name of the layer. In that case, we dont want to
+ # add another `variable_scope` as that would break checkpoints.
+ if from_template:
+ return _get_logits()
+ else:
+ with variable_scope.variable_scope(
+ scope, default_name='input_layer', values=features.values()):
+ return _get_logits()
+
@tf_export('feature_column.input_layer')
def input_layer(features,
@@ -301,17 +309,18 @@ class InputLayer(object):
feature_columns,
weight_collections=None,
trainable=True,
- cols_to_vars=None):
+ cols_to_vars=None,
+ name='feature_column_input_layer',
+ create_scope_now=True):
"""See `input_layer`."""
self._feature_columns = feature_columns
self._weight_collections = weight_collections
self._trainable = trainable
self._cols_to_vars = cols_to_vars
+ self._name = name
self._input_layer_template = template.make_template(
- 'feature_column_input_layer',
- _internal_input_layer,
- create_scope_now_=True)
+ self._name, _internal_input_layer, create_scope_now_=create_scope_now)
self._scope = self._input_layer_template.variable_scope
def __call__(self, features):
@@ -321,7 +330,11 @@ class InputLayer(object):
weight_collections=self._weight_collections,
trainable=self._trainable,
cols_to_vars=None,
- scope=self._scope)
+ from_template=True)
+
+ @property
+ def name(self):
+ return self._name
@property
def non_trainable_variables(self):
@@ -2305,7 +2318,7 @@ class _LazyBuilder(object):
# Input_tensor must have rank 1.
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
return sparse_ops.sparse_reshape(
- input_tensor, [array_ops.shape(input_tensor)[0], -1])
+ input_tensor, [array_ops.shape(input_tensor)[0], 1])
else:
return array_ops.expand_dims(input_tensor, -1)
@@ -2647,6 +2660,7 @@ class _EmbeddingColumn(
inputs=inputs,
weight_collections=weight_collections,
trainable=trainable)
+
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
sequence_length = _sequence_length_from_sparse_tensor(
sparse_tensors.id_tensor)
@@ -2816,7 +2830,7 @@ def _check_shape(shape, key):
shape = [shape]
shape = tuple(shape)
for dimension in shape:
- if not isinstance(dimension, int):
+ if not isinstance(dimension, six.integer_types):
raise TypeError('shape dimensions must be integer. '
'shape: {}, key: {}'.format(shape, key))
if dimension < 1:
@@ -3370,6 +3384,16 @@ class _IndicatorColumn(_DenseColumn, _SequenceDenseColumn,
def _verify_static_batch_size_equality(tensors, columns):
+ """Validates that the first dim (batch size) of all tensors are equal or None.
+
+ Args:
+ tensors: list of tensors to check.
+ columns: list of feature columns matching tensors. Will be used for error
+ messaging.
+
+ Raises:
+ ValueError: if one of the tensors has a variant batch size
+ """
# bath_size is a tf.Dimension object.
expected_batch_size = None
for i in range(0, len(tensors)):
@@ -3390,9 +3414,18 @@ def _sequence_length_from_sparse_tensor(sp_tensor, num_elements=1):
with ops.name_scope(None, 'sequence_length') as name_scope:
row_ids = sp_tensor.indices[:, 0]
column_ids = sp_tensor.indices[:, 1]
+ # Add one to convert column indices to element length
column_ids += array_ops.ones_like(column_ids)
- seq_length = math_ops.to_int64(
- math_ops.segment_max(column_ids, segment_ids=row_ids) / num_elements)
+ # Get the number of elements we will have per example/row
+ seq_length = math_ops.segment_max(column_ids, segment_ids=row_ids)
+
+ # The raw values are grouped according to num_elements;
+ # how many entities will we have after grouping?
+ # Example: orig tensor [[1, 2], [3]], col_ids = (0, 1, 1),
+ # row_ids = (0, 0, 1), seq_length = [2, 1]. If num_elements = 2,
+ # these will get grouped, and the final seq_length is [1, 1]
+ seq_length = math_ops.to_int64(math_ops.ceil(seq_length / num_elements))
+
# If the last n rows do not have ids, seq_length will have shape
# [batch_size - n]. Pad the remaining values with zeros.
n_pad = array_ops.shape(sp_tensor)[:1] - array_ops.shape(seq_length)[:1]
@@ -3426,25 +3459,14 @@ class _SequenceCategoricalColumn(
sparse_tensors = self.categorical_column._get_sparse_tensors(inputs) # pylint: disable=protected-access
id_tensor = sparse_tensors.id_tensor
weight_tensor = sparse_tensors.weight_tensor
- # Expands final dimension, so that embeddings are not combined during
- # embedding lookup.
- check_id_rank = check_ops.assert_equal(
- array_ops.rank(id_tensor), 2,
- data=[
- 'Column {} expected ID tensor of rank 2. '.format(self.name),
- 'id_tensor shape: ', array_ops.shape(id_tensor)])
- with ops.control_dependencies([check_id_rank]):
- id_tensor = sparse_ops.sparse_reshape(
- id_tensor,
- shape=array_ops.concat([id_tensor.dense_shape, [1]], axis=0))
+
+ # Expands third dimension, if necessary so that embeddings are not
+ # combined during embedding lookup. If the tensor is already 3D, leave
+ # as-is.
+ shape = array_ops.shape(id_tensor)
+ target_shape = [shape[0], shape[1], -1]
+ id_tensor = sparse_ops.sparse_reshape(id_tensor, target_shape)
if weight_tensor is not None:
- check_weight_rank = check_ops.assert_equal(
- array_ops.rank(weight_tensor), 2,
- data=[
- 'Column {} expected weight tensor of rank 2.'.format(self.name),
- 'weight_tensor shape:', array_ops.shape(weight_tensor)])
- with ops.control_dependencies([check_weight_rank]):
- weight_tensor = sparse_ops.sparse_reshape(
- weight_tensor,
- shape=array_ops.concat([weight_tensor.dense_shape, [1]], axis=0))
+ weight_tensor = sparse_ops.sparse_reshape(weight_tensor, target_shape)
+
return _CategoricalColumn.IdWeightPair(id_tensor, weight_tensor)
diff --git a/tensorflow/python/feature_column/feature_column_test.py b/tensorflow/python/feature_column/feature_column_test.py
index abb79efa68..1ae510250c 100644
--- a/tensorflow/python/feature_column/feature_column_test.py
+++ b/tensorflow/python/feature_column/feature_column_test.py
@@ -169,6 +169,18 @@ class LazyColumnTest(test.TestCase):
TypeError, '"key" must be either a "str" or "_FeatureColumn".'):
builder.get(NotAFeatureColumn())
+ def test_expand_dim_rank_1_sparse_tensor_empty_batch(self):
+ # empty 1-D sparse tensor:
+ builder = _LazyBuilder(features={'a': sparse_tensor.SparseTensor(
+ indices=np.reshape(np.array([], dtype=np.int64), (0, 1)),
+ dense_shape=[0],
+ values=np.array([]))})
+ with self.cached_session():
+ spv = builder.get('a').eval()
+ self.assertAllEqual(np.array([0, 1], dtype=np.int64), spv.dense_shape)
+ self.assertAllEqual(
+ np.reshape(np.array([], dtype=np.int64), (0, 2)), spv.indices)
+
class NumericColumnTest(test.TestCase):
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index 57f7af7635..b79373c475 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -136,14 +136,11 @@ import six
from tensorflow.python.eager import context
-from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.keras.engine import training
from tensorflow.python.keras.engine.base_layer import Layer
-from tensorflow.python.layers import base
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
@@ -153,7 +150,6 @@ from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import string_ops
from tensorflow.python.ops import variable_scope
@@ -245,28 +241,19 @@ class StateManager(object):
raise NotImplementedError('StateManager.get_resource')
-class _InputLayerStateManager(StateManager):
- """Manages the state of InputLayer."""
+class _StateManagerImpl(StateManager):
+ """Manages the state of FeatureLayer and LinearModel."""
- def __init__(self, layer, feature_columns, trainable):
- """Creates an _InputLayerStateManager object.
+ def __init__(self, layer, trainable):
+ """Creates an _StateManagerImpl object.
Args:
layer: The input layer this state manager is associated with.
- feature_columns: List of feature columns for the input layer
trainable: Whether by default, variables created are trainable or not.
"""
self._trainable = trainable
self._layer = layer
- self._cols_to_vars_map = {}
- self._cols_to_names_map = {}
- for column in sorted(feature_columns, key=lambda x: x.name):
- self._cols_to_vars_map[column] = {}
- base_name = column.name
- if isinstance(column, SharedEmbeddingColumn):
- base_name = column.shared_collection_name
- with variable_scope.variable_scope(base_name) as vs:
- self._cols_to_names_map[column] = _strip_leading_slashes(vs.name)
+ self._cols_to_vars_map = collections.defaultdict(lambda: {})
def create_variable(self,
feature_column,
@@ -277,19 +264,20 @@ class _InputLayerStateManager(StateManager):
initializer=None):
if name in self._cols_to_vars_map[feature_column]:
raise ValueError('Variable already exists.')
- with variable_scope.variable_scope(self._cols_to_names_map[feature_column]):
- var = self._layer.add_variable(
- name=name,
- shape=shape,
- dtype=dtype,
- initializer=initializer,
- trainable=self._trainable and trainable,
- # TODO(rohanj): Get rid of this hack once we have a mechanism for
- # specifying a default partitioner for an entire layer. In that case,
- # the default getter for Layers should work.
- getter=variable_scope.get_variable)
- self._cols_to_vars_map[feature_column][name] = var
- return var
+
+ var = self._layer.add_variable(
+ name=name,
+ shape=shape,
+ dtype=dtype,
+ initializer=initializer,
+ trainable=self._trainable and trainable,
+ use_resource=True,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
+ self._cols_to_vars_map[feature_column][name] = var
+ return var
def get_variable(self, feature_column, name):
if name in self._cols_to_vars_map[feature_column]:
@@ -313,12 +301,15 @@ class FeatureLayer(Layer):
keywords_embedded = embedding_column(
categorical_column_with_hash_bucket("keywords", 10K), dimensions=16)
columns = [price, keywords_embedded, ...]
- features = tf.parse_example(..., features=make_parse_example_spec(columns))
feature_layer = FeatureLayer(columns)
+
+ features = tf.parse_example(..., features=make_parse_example_spec(columns))
dense_tensor = feature_layer(features)
for units in [128, 64, 32]:
dense_tensor = tf.layers.dense(dense_tensor, units, tf.nn.relu)
- prediction = tf.layers.dense(dense_tensor, 1)."""
+ prediction = tf.layers.dense(dense_tensor, 1).
+ ```
+ """
def __init__(self,
feature_columns,
@@ -375,8 +366,7 @@ class FeatureLayer(Layer):
super(FeatureLayer, self).__init__(name=name, trainable=trainable, **kwargs)
self._feature_columns = _normalize_feature_columns(feature_columns)
- self._state_manager = _InputLayerStateManager(self, self._feature_columns,
- self.trainable)
+ self._state_manager = _StateManagerImpl(self, self.trainable)
self._shared_state_manager = shared_state_manager
for column in sorted(self._feature_columns, key=lambda x: x.name):
if not isinstance(column, DenseColumn):
@@ -394,8 +384,9 @@ class FeatureLayer(Layer):
if isinstance(column, SharedEmbeddingColumn):
column.create_state(self._shared_state_manager)
else:
- with variable_scope.variable_scope(None, default_name=self.name):
- column.create_state(self._state_manager)
+ with variable_scope._pure_variable_scope(self.name): # pylint: disable=protected-access
+ with variable_scope._pure_variable_scope(column.name): # pylint: disable=protected-access
+ column.create_state(self._state_manager)
super(FeatureLayer, self).build(None)
def call(self, features, cols_to_output_tensors=None):
@@ -424,19 +415,20 @@ class FeatureLayer(Layer):
output_tensors = []
ordered_columns = []
for column in sorted(self._feature_columns, key=lambda x: x.name):
- ordered_columns.append(column)
- if isinstance(column, SharedEmbeddingColumn):
- tensor = column.get_dense_tensor(transformation_cache,
- self._shared_state_manager)
- else:
- tensor = column.get_dense_tensor(transformation_cache,
- self._state_manager)
- num_elements = column.variable_shape.num_elements()
- batch_size = array_ops.shape(tensor)[0]
- tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
- output_tensors.append(tensor)
- if cols_to_output_tensors is not None:
- cols_to_output_tensors[column] = tensor
+ with ops.name_scope(column.name):
+ ordered_columns.append(column)
+ if isinstance(column, SharedEmbeddingColumn):
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._shared_state_manager)
+ else:
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._state_manager)
+ num_elements = column.variable_shape.num_elements()
+ batch_size = array_ops.shape(tensor)[0]
+ tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
+ output_tensors.append(tensor)
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors[column] = tensor
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
@@ -448,20 +440,18 @@ class FeatureLayer(Layer):
return (input_shape[0], total_elements)
-def linear_model(features,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- cols_to_vars=None):
- """Returns a linear prediction `Tensor` based on given `feature_columns`.
+def _strip_leading_slashes(name):
+ return name.rsplit('/', 1)[-1]
+
+
+class LinearModel(Layer):
+ """Produces a linear prediction `Tensor` based on given `feature_columns`.
- This function generates a weighted sum based on output dimension `units`.
+ This layer generates a weighted sum based on output dimension `units`.
Weighted sum refers to logits in classification problems. It refers to the
prediction itself for linear regression problems.
- Note on supported columns: `linear_model` treats categorical columns as
+ Note on supported columns: `LinearModel` treats categorical columns as
`indicator_column`s. To be specific, assume the input as `SparseTensor` looks
like:
@@ -486,308 +476,195 @@ def linear_model(features,
keywords = categorical_column_with_hash_bucket("keywords", 10K)
keywords_price = crossed_column('keywords', price_buckets, ...)
columns = [price_buckets, keywords, keywords_price ...]
+ linear_model = LinearModel(columns)
+
features = tf.parse_example(..., features=make_parse_example_spec(columns))
- prediction = linear_model(features, columns)
+ prediction = linear_model(features)
```
-
- Args:
- features: A mapping from key to tensors. `_FeatureColumn`s look up via these
- keys. For example `numeric_column('price')` will look at 'price' key in
- this dict. Values are `Tensor` or `SparseTensor` depending on
- corresponding `_FeatureColumn`.
- feature_columns: An iterable containing the FeatureColumns to use as inputs
- to your model. All items should be instances of classes derived from
- `_FeatureColumn`s.
- units: An integer, dimensionality of the output space. Default value is 1.
- sparse_combiner: A string specifying how to reduce if a categorical column
- is multivalent. Except `numeric_column`, almost all columns passed to
- `linear_model` are considered as categorical columns. It combines each
- categorical column independently. Currently "mean", "sqrtn" and "sum" are
- supported, with "sum" the default for linear model. "sqrtn" often achieves
- good accuracy, in particular with bag-of-words columns.
- * "sum": do not normalize features in the column
- * "mean": do l1 normalization on features in the column
- * "sqrtn": do l2 normalization on features in the column
- For example, for two features represented as the categorical columns:
-
- ```python
- # Feature 1
-
- shape = [2, 2]
- {
- [0, 0]: "a"
- [0, 1]: "b"
- [1, 0]: "c"
- }
-
- # Feature 2
-
- shape = [2, 3]
- {
- [0, 0]: "d"
- [1, 0]: "e"
- [1, 1]: "f"
- [1, 2]: "g"
- }
- ```
- with `sparse_combiner` as "mean", the linear model outputs conceptly are:
- ```
- y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
- y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
- ```
- where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
- assigned to the presence of `x` in the input features.
- weight_collections: A list of collection names to which the Variable will be
- added. Note that, variables will also be added to collections
- `tf.GraphKeys.GLOBAL_VARIABLES` and `ops.GraphKeys.MODEL_VARIABLES`.
- trainable: If `True` also add the variable to the graph collection
- `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
- cols_to_vars: If not `None`, must be a dictionary that will be filled with a
- mapping from `_FeatureColumn` to associated list of `Variable`s. For
- example, after the call, we might have cols_to_vars = {
- _NumericColumn(
- key='numeric_feature1', shape=(1,):
- [<tf.Variable 'linear_model/price2/weights:0' shape=(1, 1)>],
- 'bias': [<tf.Variable 'linear_model/bias_weights:0' shape=(1,)>],
- _NumericColumn(
- key='numeric_feature2', shape=(2,)):
- [<tf.Variable 'linear_model/price1/weights:0' shape=(2, 1)>]}
- If a column creates no variables, its value will be an empty list. Note
- that cols_to_vars will also contain a string key 'bias' that maps to a
- list of Variables.
-
- Returns:
- A `Tensor` which represents predictions/logits of a linear model. Its shape
- is (batch_size, units) and its dtype is `float32`.
-
- Raises:
- ValueError: if an item in `feature_columns` is neither a `_DenseColumn`
- nor `_CategoricalColumn`.
- """
- with variable_scope.variable_scope(None, 'linear_model') as vs:
- model_name = _strip_leading_slashes(vs.name)
- linear_model_layer = _LinearModel(
- feature_columns=feature_columns,
- units=units,
- sparse_combiner=sparse_combiner,
- weight_collections=weight_collections,
- trainable=trainable,
- name=model_name)
- retval = linear_model_layer(features) # pylint: disable=not-callable
- if cols_to_vars is not None:
- cols_to_vars.update(linear_model_layer.cols_to_vars())
- return retval
-
-
-def _add_to_collections(var, weight_collections):
- """Adds a var to the list of weight_collections provided.
-
- Handles the case for partitioned and non-partitioned variables.
-
- Args:
- var: A variable or Partitioned Variable.
- weight_collections: List of collections to add variable to.
- """
- for weight_collection in weight_collections:
- # The layer self.add_variable call already adds it to GLOBAL_VARIABLES.
- if weight_collection == ops.GraphKeys.GLOBAL_VARIABLES:
- continue
- # TODO(rohanj): Explore adding a _get_variable_list method on `Variable`
- # so that we don't have to do this check.
- if isinstance(var, variables.PartitionedVariable):
- for constituent_var in list(var):
- ops.add_to_collection(weight_collection, constituent_var)
- else:
- ops.add_to_collection(weight_collection, var)
-
-
-class _FCLinearWrapper(base.Layer):
- """Wraps a _FeatureColumn in a layer for use in a linear model.
-
- See `linear_model` above.
"""
def __init__(self,
- feature_column,
+ feature_columns,
units=1,
sparse_combiner='sum',
- weight_collections=None,
trainable=True,
name=None,
+ shared_state_manager=None,
**kwargs):
- super(_FCLinearWrapper, self).__init__(
- trainable=trainable, name=name, **kwargs)
- self._feature_column = feature_column
- self._units = units
- self._sparse_combiner = sparse_combiner
- self._weight_collections = weight_collections
+ """Constructs a LinearModel.
- def build(self, _):
- if isinstance(self._feature_column, fc_old._CategoricalColumn): # pylint: disable=protected-access
- weight = self.add_variable(
- name='weights',
- shape=(self._feature_column._num_buckets, self._units), # pylint: disable=protected-access
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- else:
- num_elements = self._feature_column._variable_shape.num_elements() # pylint: disable=protected-access
- weight = self.add_variable(
- name='weights',
- shape=[num_elements, self._units],
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- _add_to_collections(weight, self._weight_collections)
- self._weight_var = weight
- self.built = True
-
- def call(self, builder):
- weighted_sum = fc_old._create_weighted_sum( # pylint: disable=protected-access
- column=self._feature_column,
- builder=builder,
- units=self._units,
- sparse_combiner=self._sparse_combiner,
- weight_collections=self._weight_collections,
- trainable=self.trainable,
- weight_var=self._weight_var)
- return weighted_sum
+ Args:
+ feature_columns: An iterable containing the FeatureColumns to use as
+ inputs to your model. All items should be instances of classes derived
+ from `_FeatureColumn`s.
+ units: An integer, dimensionality of the output space. Default value is 1.
+ sparse_combiner: A string specifying how to reduce if a categorical column
+ is multivalent. Except `numeric_column`, almost all columns passed to
+ `linear_model` are considered as categorical columns. It combines each
+ categorical column independently. Currently "mean", "sqrtn" and "sum"
+ are supported, with "sum" the default for linear model. "sqrtn" often
+ achieves good accuracy, in particular with bag-of-words columns.
+ * "sum": do not normalize features in the column
+ * "mean": do l1 normalization on features in the column
+ * "sqrtn": do l2 normalization on features in the column
+ For example, for two features represented as the categorical columns:
+
+ ```python
+ # Feature 1
+
+ shape = [2, 2]
+ {
+ [0, 0]: "a"
+ [0, 1]: "b"
+ [1, 0]: "c"
+ }
+
+ # Feature 2
+
+ shape = [2, 3]
+ {
+ [0, 0]: "d"
+ [1, 0]: "e"
+ [1, 1]: "f"
+ [1, 2]: "g"
+ }
+ ```
+
+ with `sparse_combiner` as "mean", the linear model outputs conceptly are
+ ```
+ y_0 = 1.0 / 2.0 * ( w_a + w_ b) + w_c + b_0
+ y_1 = w_d + 1.0 / 3.0 * ( w_e + w_ f + w_g) + b_1
+ ```
+ where `y_i` is the output, `b_i` is the bias, and `w_x` is the weight
+ assigned to the presence of `x` in the input features.
+ trainable: If `True` also add the variable to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES` (see `tf.Variable`).
+ name: Name to give to the Linear Model. All variables and ops created will
+ be scoped by this name.
+ shared_state_manager: SharedEmbeddingStateManager that manages the state
+ of SharedEmbeddingColumns. For more info, look at `FeatureLayer`.
+ **kwargs: Keyword arguments to construct a layer.
+ Raises:
+ ValueError: if an item in `feature_columns` is neither a `DenseColumn`
+ nor `CategoricalColumn`.
+ """
+ super(LinearModel, self).__init__(name=name, trainable=trainable, **kwargs)
-class _BiasLayer(base.Layer):
- """A layer for the bias term.
- """
+ self._feature_columns = _normalize_feature_columns(feature_columns)
+ self._feature_columns = sorted(self._feature_columns, key=lambda x: x.name)
+ for column in self._feature_columns:
+ if not isinstance(column, (DenseColumn, CategoricalColumn)):
+ raise ValueError(
+ 'Items of feature_columns must be either a '
+ 'DenseColumn or CategoricalColumn. Given: {}'.format(column))
- def __init__(self,
- units=1,
- trainable=True,
- weight_collections=None,
- name=None,
- **kwargs):
- super(_BiasLayer, self).__init__(trainable=trainable, name=name, **kwargs)
self._units = units
- self._weight_collections = weight_collections
-
- def build(self, _):
- self._bias_variable = self.add_variable(
- 'bias_weights',
- shape=[self._units],
- initializer=init_ops.zeros_initializer(),
- trainable=self.trainable)
- _add_to_collections(self._bias_variable, self._weight_collections)
- self.built = True
-
- def call(self, _):
- return self._bias_variable
+ self._sparse_combiner = sparse_combiner
+ self._state_manager = _StateManagerImpl(self, self.trainable)
+ self._shared_state_manager = shared_state_manager
+ self._bias_variable = None
-def _get_expanded_variable_list(var_list):
- returned_list = []
- for variable in var_list:
- if (isinstance(variable, variables.Variable) or
- resource_variable_ops.is_resource_variable(variable)):
- returned_list.append(variable) # Single variable case.
- else: # Must be a PartitionedVariable, so convert into a list.
- returned_list.extend(list(variable))
- return returned_list
+ def build(self, _):
+ # Create state for shared embedding columns.
+ for column in self._feature_columns:
+ if isinstance(column, SharedEmbeddingColumn):
+ column.create_state(self._shared_state_manager)
+ # We need variable scopes for now because we want the variable partitioning
+ # information to percolate down. We also use _pure_variable_scope's here
+ # since we want to open up a name_scope in the `call` method while creating
+ # the ops.
+ with variable_scope._pure_variable_scope(self.name): # pylint: disable=protected-access
+ for column in self._feature_columns:
+ with variable_scope._pure_variable_scope(column.name): # pylint: disable=protected-access
+ # Create the state for each feature column
+ if not isinstance(column, SharedEmbeddingColumn):
+ column.create_state(self._state_manager)
+
+ # Create a weight variable for each column.
+ if isinstance(column, CategoricalColumn):
+ first_dim = column.num_buckets
+ else:
+ first_dim = column.variable_shape.num_elements()
+ self._state_manager.create_variable(
+ column,
+ name='weights',
+ dtype=dtypes.float32,
+ shape=(first_dim, self._units),
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable)
+
+ # Create a bias variable.
+ self._bias_variable = self.add_variable(
+ name='bias_weights',
+ dtype=dtypes.float32,
+ shape=[self._units],
+ initializer=init_ops.zeros_initializer(),
+ trainable=self.trainable,
+ use_resource=True,
+ # TODO(rohanj): Get rid of this hack once we have a mechanism for
+ # specifying a default partitioner for an entire layer. In that case,
+ # the default getter for Layers should work.
+ getter=variable_scope.get_variable)
-def _strip_leading_slashes(name):
- return name.rsplit('/', 1)[-1]
+ super(LinearModel, self).build(None)
+ def call(self, features):
+ """Returns a `Tensor` the represents the predictions of a linear model.
-class _LinearModel(training.Model):
- """Creates a linear model using feature columns.
+ Args:
+ features: A mapping from key to tensors. `_FeatureColumn`s look up via
+ these keys. For example `numeric_column('price')` will look at 'price'
+ key in this dict. Values are `Tensor` or `SparseTensor` depending on
+ corresponding `_FeatureColumn`.
- See `linear_model` for details.
- """
+ Returns:
+ A `Tensor` which represents predictions/logits of a linear model. Its
+ shape is (batch_size, units) and its dtype is `float32`.
- def __init__(self,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- name=None,
- **kwargs):
- super(_LinearModel, self).__init__(name=name, **kwargs)
- self._feature_columns = fc_old._normalize_feature_columns( # pylint: disable=protected-access
- feature_columns)
- self._weight_collections = list(weight_collections or [])
- if ops.GraphKeys.GLOBAL_VARIABLES not in self._weight_collections:
- self._weight_collections.append(ops.GraphKeys.GLOBAL_VARIABLES)
- if ops.GraphKeys.MODEL_VARIABLES not in self._weight_collections:
- self._weight_collections.append(ops.GraphKeys.MODEL_VARIABLES)
-
- column_layers = {}
- for column in sorted(self._feature_columns, key=lambda x: x.name):
- with variable_scope.variable_scope(
- None, default_name=column._var_scope_name) as vs: # pylint: disable=protected-access
- # Having the fully expressed variable scope name ends up doubly
- # expressing the outer scope (scope with which this method was called)
- # in the name of the variable that would get created.
- column_name = _strip_leading_slashes(vs.name)
- column_layer = _FCLinearWrapper(column, units, sparse_combiner,
- self._weight_collections, trainable,
- column_name, **kwargs)
- column_layers[column_name] = column_layer
- self._column_layers = self._add_layers(column_layers)
- self._bias_layer = _BiasLayer(
- units=units,
- trainable=trainable,
- weight_collections=self._weight_collections,
- name='bias_layer',
- **kwargs)
- self._cols_to_vars = {}
-
- def cols_to_vars(self):
- """Returns a dict mapping _FeatureColumns to variables.
-
- See `linear_model` for more information.
- This is not populated till `call` is called i.e. layer is built.
+ Raises:
+ ValueError: If features are not a dictionary.
"""
- return self._cols_to_vars
-
- def call(self, features):
- with variable_scope.variable_scope(self.name):
- for column in self._feature_columns:
- if not isinstance(
- column,
- (
- fc_old._DenseColumn, # pylint: disable=protected-access
- fc_old._CategoricalColumn)): # pylint: disable=protected-access
- raise ValueError(
- 'Items of feature_columns must be either a '
- '_DenseColumn or _CategoricalColumn. Given: {}'.format(column))
+ if not isinstance(features, dict):
+ raise ValueError('We expected a dictionary here. Instead we got: ',
+ features)
+ with ops.name_scope(self.name):
+ transformation_cache = FeatureTransformationCache(features)
weighted_sums = []
- ordered_columns = []
- builder = fc_old._LazyBuilder(features) # pylint: disable=protected-access
- for layer in sorted(self._column_layers.values(), key=lambda x: x.name):
- column = layer._feature_column # pylint: disable=protected-access
- ordered_columns.append(column)
- weighted_sum = layer(builder)
- weighted_sums.append(weighted_sum)
- self._cols_to_vars[column] = ops.get_collection(
- ops.GraphKeys.GLOBAL_VARIABLES, scope=layer.scope_name)
-
- _verify_static_batch_size_equality(weighted_sums, ordered_columns)
+ for column in self._feature_columns:
+ with ops.name_scope(column.name):
+ # All the weights used in the linear model are owned by the state
+ # manager associated with this Linear Model.
+ weight_var = self._state_manager.get_variable(column, 'weights')
+
+ # The embedding weights for the SharedEmbeddingColumn are owned by
+ # the shared_state_manager and so we need to pass that in while
+ # creating the weighted sum. For all other columns, the state is owned
+ # by the Linear Model's state manager.
+ if isinstance(column, SharedEmbeddingColumn):
+ state_manager = self._shared_state_manager
+ else:
+ state_manager = self._state_manager
+ weighted_sum = _create_weighted_sum(
+ column=column,
+ transformation_cache=transformation_cache,
+ state_manager=state_manager,
+ sparse_combiner=self._sparse_combiner,
+ weight_var=weight_var)
+ weighted_sums.append(weighted_sum)
+
+ _verify_static_batch_size_equality(weighted_sums, self._feature_columns)
predictions_no_bias = math_ops.add_n(
weighted_sums, name='weighted_sum_no_bias')
predictions = nn_ops.bias_add(
- predictions_no_bias,
- self._bias_layer( # pylint: disable=not-callable
- builder,
- scope=variable_scope.get_variable_scope()), # pylint: disable=not-callable
- name='weighted_sum')
- bias = self._bias_layer.variables[0]
- self._cols_to_vars['bias'] = _get_expanded_variable_list([bias])
- return predictions
-
- def _add_layers(self, layers):
- # "Magic" required for keras.Model classes to track all the variables in
- # a list of layers.Layer objects.
- # TODO(ashankar): Figure out API so user code doesn't have to do this.
- for name, layer in layers.items():
- setattr(self, 'layer-%s' % name, layer)
- return layers
+ predictions_no_bias, self._bias_variable, name='weighted_sum')
+ return predictions
+
+ @property
+ def bias_variable(self):
+ return self._bias_variable
def _transform_features(features, feature_columns, state_manager):
@@ -2045,58 +1922,40 @@ class DenseColumn(FeatureColumn):
pass
-def _create_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- weight_var=None):
+def is_feature_column_v2(feature_columns):
+ """Returns True if all feature columns are V2."""
+ for feature_column in feature_columns:
+ if not isinstance(feature_column, FeatureColumn):
+ return False
+ return True
+
+
+def _create_weighted_sum(column, transformation_cache, state_manager,
+ sparse_combiner, weight_var):
"""Creates a weighted sum for a dense/categorical column for linear_model."""
if isinstance(column, CategoricalColumn):
return _create_categorical_column_weighted_sum(
column=column,
transformation_cache=transformation_cache,
state_manager=state_manager,
- units=units,
sparse_combiner=sparse_combiner,
- weight_collections=weight_collections,
- trainable=trainable,
weight_var=weight_var)
else:
return _create_dense_column_weighted_sum(
column=column,
transformation_cache=transformation_cache,
state_manager=state_manager,
- units=units,
- weight_collections=weight_collections,
- trainable=trainable,
weight_var=weight_var)
-def _create_dense_column_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- weight_collections,
- trainable,
- weight_var=None):
+def _create_dense_column_weighted_sum(column, transformation_cache,
+ state_manager, weight_var):
"""Create a weighted sum of a dense column for linear_model."""
tensor = column.get_dense_tensor(transformation_cache, state_manager)
num_elements = column.variable_shape.num_elements()
batch_size = array_ops.shape(tensor)[0]
tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
- if weight_var is not None:
- weight = weight_var
- else:
- weight = variable_scope.get_variable(
- name='weights',
- shape=[num_elements, units],
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
- return math_ops.matmul(tensor, weight, name='weighted_sum')
+ return math_ops.matmul(tensor, weight_var, name='weighted_sum')
class CategoricalColumn(FeatureColumn):
@@ -2137,14 +1996,8 @@ class CategoricalColumn(FeatureColumn):
pass
-def _create_categorical_column_weighted_sum(column,
- transformation_cache,
- state_manager,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- weight_var=None):
+def _create_categorical_column_weighted_sum(
+ column, transformation_cache, state_manager, sparse_combiner, weight_var):
# pylint: disable=g-doc-return-or-yield,g-doc-args
"""Create a weighted sum of a categorical column for linear_model.
@@ -2183,17 +2036,8 @@ def _create_categorical_column_weighted_sum(column,
weight_tensor = sparse_ops.sparse_reshape(
weight_tensor, [array_ops.shape(weight_tensor)[0], -1])
- if weight_var is not None:
- weight = weight_var
- else:
- weight = variable_scope.get_variable(
- name='weights',
- shape=(column.num_buckets, units),
- initializer=init_ops.zeros_initializer(),
- trainable=trainable,
- collections=weight_collections)
return _safe_embedding_lookup_sparse(
- weight,
+ weight_var,
id_tensor,
sparse_weights=weight_tensor,
combiner=sparse_combiner,
@@ -2333,7 +2177,7 @@ class FeatureTransformationCache(object):
# Input_tensor must have rank 1.
if isinstance(input_tensor, sparse_tensor_lib.SparseTensor):
return sparse_ops.sparse_reshape(
- input_tensor, [array_ops.shape(input_tensor)[0], -1])
+ input_tensor, [array_ops.shape(input_tensor)[0], 1])
else:
return array_ops.expand_dims(input_tensor, -1)
@@ -2769,6 +2613,7 @@ class SharedEmbeddingStateManager(Layer):
dtype=dtype,
trainable=self.trainable and trainable,
initializer=initializer,
+ use_resource=True,
# TODO(rohanj): Get rid of this hack once we have a mechanism for
# specifying a default partitioner for an entire layer. In that case,
# the default getter for Layers should work.
@@ -2782,6 +2627,12 @@ class SharedEmbeddingStateManager(Layer):
return self._var_dict[name]
+def maybe_create_shared_state_manager(feature_columns):
+ if is_feature_column_v2(feature_columns):
+ return SharedEmbeddingStateManager()
+ return None
+
+
class SharedEmbeddingColumn(
DenseColumn, SequenceDenseColumn,
collections.namedtuple(
@@ -2822,6 +2673,10 @@ class SharedEmbeddingColumn(
def create_state(self, state_manager):
"""Creates the shared embedding lookup variable."""
+ if not isinstance(state_manager, SharedEmbeddingStateManager):
+ raise ValueError('Expected state_manager to be of type '
+ 'SharedEmbeddingStateManager. Obtained type: {}'.format(
+ type(state_manager)))
embedding_shape = (self.categorical_column.num_buckets, self.dimension)
state_manager.create_variable(
name=self.shared_collection_name,
@@ -3433,11 +3288,10 @@ def _safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
- if not isinstance(embedding_weights[0],
- resource_variable_ops.ResourceVariable):
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
+ # TODO(rohanj): Look into removing this convert_to_tensor call.
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
with ops.name_scope(name, 'embedding_lookup',
embedding_weights + [sparse_ids,
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index 58168e0f9e..d3787146ed 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -31,9 +31,7 @@ from tensorflow.python.client import session
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.estimator.inputs import numpy_io
-from tensorflow.python.feature_column import feature_column as fc_old
from tensorflow.python.feature_column import feature_column_v2 as fc
-from tensorflow.python.feature_column.feature_column_v2 import _LinearModel
from tensorflow.python.feature_column.feature_column_v2 import _transform_features
from tensorflow.python.feature_column.feature_column_v2 import FeatureColumn
from tensorflow.python.feature_column.feature_column_v2 import FeatureLayer
@@ -48,7 +46,6 @@ from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.ops import partitioned_variables
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import test
@@ -177,6 +174,22 @@ class LazyColumnTest(test.TestCase):
TypeError, '"key" must be either a "str" or "FeatureColumn".'):
transformation_cache.get(NotAFeatureColumn(), None)
+ def test_expand_dim_rank_1_sparse_tensor_empty_batch(self):
+ # empty 1-D sparse tensor:
+ transformation_cache = FeatureTransformationCache(
+ features={
+ 'a':
+ sparse_tensor.SparseTensor(
+ indices=np.reshape(np.array([], dtype=np.int64), (0, 1)),
+ dense_shape=[0],
+ values=np.array([]))
+ })
+ with self.cached_session():
+ spv = transformation_cache.get('a', None).eval()
+ self.assertAllEqual(np.array([0, 1], dtype=np.int64), spv.dense_shape)
+ self.assertAllEqual(
+ np.reshape(np.array([], dtype=np.int64), (0, 2)), spv.indices)
+
class NumericColumnTest(test.TestCase):
@@ -344,26 +357,12 @@ class NumericColumnTest(test.TestCase):
self.assertEqual(a.default_value, ((3., 2.),))
def test_linear_model(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.]], price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price_var.assign([[10.]]))
- self.assertAllClose([[10.], [50.]], predictions.eval())
-
- def test_keras_linear_model(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.]], price_var.eval())
@@ -548,13 +547,13 @@ class BucketizedColumnTest(test.TestCase):
def test_linear_model_one_input_value(self):
"""Tests linear_model() for input with shape=[1]."""
- price = fc_old.numeric_column('price', shape=[1])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ price = fc.numeric_column('price', shape=[1])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': [[-1.], [1.], [5.], [6.]]}
- predictions = fc.linear_model(features, [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ model = fc.LinearModel([bucketized_price])
+ predictions = model(features)
+ bucketized_price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight variable per bucket, all initialized to zero.
@@ -573,13 +572,13 @@ class BucketizedColumnTest(test.TestCase):
def test_linear_model_two_input_values(self):
"""Tests linear_model() for input with shape=[2]."""
- price = fc_old.numeric_column('price', shape=[2])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
+ price = fc.numeric_column('price', shape=[2])
+ bucketized_price = fc.bucketized_column(price, boundaries=[0, 2, 4, 6])
with ops.Graph().as_default():
features = {'price': [[-1., 1.], [5., 6.]]}
- predictions = fc.linear_model(features, [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
+ model = fc.LinearModel([bucketized_price])
+ predictions = model(features)
+ bucketized_price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
# One weight per bucket per input column, all initialized to zero.
@@ -600,62 +599,6 @@ class BucketizedColumnTest(test.TestCase):
sess.run(bias.assign([1.]))
self.assertAllClose([[81.], [141.]], predictions.eval())
- def test_keras_linear_model_one_input_value(self):
- """Tests _LinearModel for input with shape=[1]."""
- price = fc_old.numeric_column('price', shape=[1])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
- with ops.Graph().as_default():
- features = {'price': [[-1.], [1.], [5.], [6.]]}
- predictions = get_keras_linear_model_predictions(features,
- [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- # One weight variable per bucket, all initialized to zero.
- self.assertAllClose([[0.], [0.], [0.], [0.], [0.]],
- bucketized_price_var.eval())
- self.assertAllClose([[0.], [0.], [0.], [0.]], predictions.eval())
- sess.run(
- bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.]]))
- # price -1. is in the 0th bucket, whose weight is 10.
- # price 1. is in the 1st bucket, whose weight is 20.
- # price 5. is in the 3rd bucket, whose weight is 40.
- # price 6. is in the 4th bucket, whose weight is 50.
- self.assertAllClose([[10.], [20.], [40.], [50.]], predictions.eval())
- sess.run(bias.assign([1.]))
- self.assertAllClose([[11.], [21.], [41.], [51.]], predictions.eval())
-
- def test_keras_linear_model_two_input_values(self):
- """Tests _LinearModel for input with shape=[2]."""
- price = fc_old.numeric_column('price', shape=[2])
- bucketized_price = fc_old.bucketized_column(price, boundaries=[0, 2, 4, 6])
- with ops.Graph().as_default():
- features = {'price': [[-1., 1.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(features,
- [bucketized_price])
- bias = get_linear_model_bias()
- bucketized_price_var = get_linear_model_column_var(bucketized_price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- # One weight per bucket per input column, all initialized to zero.
- self.assertAllClose(
- [[0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.], [0.]],
- bucketized_price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(
- bucketized_price_var.assign([[10.], [20.], [30.], [40.], [50.],
- [60.], [70.], [80.], [90.], [100.]]))
- # 1st example:
- # price -1. is in the 0th bucket, whose weight is 10.
- # price 1. is in the 6th bucket, whose weight is 70.
- # 2nd example:
- # price 5. is in the 3rd bucket, whose weight is 40.
- # price 6. is in the 9th bucket, whose weight is 100.
- self.assertAllClose([[80.], [140.]], predictions.eval())
- sess.run(bias.assign([1.]))
- self.assertAllClose([[81.], [141.]], predictions.eval())
-
class HashedCategoricalColumnTest(test.TestCase):
@@ -836,39 +779,18 @@ class HashedCategoricalColumnTest(test.TestCase):
transformation_cache.get(hashed_sparse, None), id_weight_pair.id_tensor)
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 3: wire_var[3] = 4
- # 'skywalker' -> 2, 'omar' -> 2: wire_var[2] + wire_var[2] = 3+3 = 6
- self.assertAllClose(((4.,), (6.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_hash_bucket('wire', 4)
- self.assertEqual(4, wire_column._num_buckets)
+ wire_column = fc.categorical_column_with_hash_bucket('wire', 4)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -1087,93 +1009,12 @@ class CrossedColumnTest(test.TestCase):
Uses data from test_get_sparse_tesnsors_simple.
"""
- a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
- b = fc_old.bucketized_column(a, boundaries=(0, 1))
- crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'a': constant_op.constant(((-1., .5), (.5, 1.))),
- 'c': sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=['cA', 'cB', 'cC'],
- dense_shape=(2, 2)),
- }, (crossed,))
- bias = get_linear_model_bias()
- crossed_var = get_linear_model_column_var(crossed)
- with _initialized_session() as sess:
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(
- ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
- # Expected ids after cross = (1, 0, 1, 3, 4, 2)
- self.assertAllClose(((3.,), (14.,)), predictions.eval())
- sess.run(bias.assign((.1,)))
- self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
-
- def test_linear_model_with_weights(self):
-
- class _TestColumnWithWeights(fc_old._CategoricalColumn):
- """Produces sparse IDs and sparse weights."""
-
- @property
- def name(self):
- return 'test_column'
-
- @property
- def _parse_example_spec(self):
- return {
- self.name: parsing_ops.VarLenFeature(dtypes.int32),
- '{}_weights'.format(self.name): parsing_ops.VarLenFeature(
- dtypes.float32),
- }
-
- @property
- def _num_buckets(self):
- return 5
-
- def _transform_feature(self, inputs):
- return (inputs.get(self.name),
- inputs.get('{}_weights'.format(self.name)))
-
- def _get_sparse_tensors(self, inputs, weight_collections=None,
- trainable=None):
- """Populates both id_tensor and weight_tensor."""
- ids_and_weights = inputs.get(self)
- return fc_old._CategoricalColumn.IdWeightPair(
- id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
-
- t = _TestColumnWithWeights()
- crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
- with ops.Graph().as_default():
- with self.assertRaisesRegexp(
- ValueError,
- 'crossed_column does not support weight_tensor.*{}'.format(t.name)):
- fc.linear_model({
- t.name: sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=[0, 1, 2],
- dense_shape=(2, 2)),
- '{}_weights'.format(t.name): sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=[1., 10., 2.],
- dense_shape=(2, 2)),
- 'c': sparse_tensor.SparseTensor(
- indices=((0, 0), (1, 0), (1, 1)),
- values=['cA', 'cB', 'cC'],
- dense_shape=(2, 2)),
- }, (crossed,))
-
- def test_keras_linear_model(self):
- """Tests _LinearModel.
-
- Uses data from test_get_sparse_tesnsors_simple.
- """
- a = fc_old.numeric_column('a', dtype=dtypes.int32, shape=(2,))
- b = fc_old.bucketized_column(a, boundaries=(0, 1))
- crossed = fc_old.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
+ a = fc.numeric_column('a', dtype=dtypes.int32, shape=(2,))
+ b = fc.bucketized_column(a, boundaries=(0, 1))
+ crossed = fc.crossed_column([b, 'c'], hash_bucket_size=5, hash_key=5)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((crossed,))
+ predictions = model({
'a':
constant_op.constant(((-1., .5), (.5, 1.))),
'c':
@@ -1181,13 +1022,12 @@ class CrossedColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=['cA', 'cB', 'cC'],
dense_shape=(2, 2)),
- }, (crossed,))
- bias = get_linear_model_bias()
- crossed_var = get_linear_model_column_var(crossed)
+ })
+ crossed_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,), (0.,)),
- crossed_var.eval())
+ self.assertAllClose(
+ ((0.,), (0.,), (0.,), (0.,), (0.,)), crossed_var.eval())
self.assertAllClose(((0.,), (0.,)), predictions.eval())
sess.run(crossed_var.assign(((1.,), (2.,), (3.,), (4.,), (5.,))))
# Expected ids after cross = (1, 0, 1, 3, 4, 2)
@@ -1195,9 +1035,9 @@ class CrossedColumnTest(test.TestCase):
sess.run(bias.assign((.1,)))
self.assertAllClose(((3.1,), (14.1,)), predictions.eval())
- def test_keras_linear_model_with_weights(self):
+ def test_linear_model_with_weights(self):
- class _TestColumnWithWeights(fc_old._CategoricalColumn):
+ class _TestColumnWithWeights(fc.CategoricalColumn):
"""Produces sparse IDs and sparse weights."""
@property
@@ -1205,38 +1045,36 @@ class CrossedColumnTest(test.TestCase):
return 'test_column'
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
return {
- self.name:
- parsing_ops.VarLenFeature(dtypes.int32),
- '{}_weights'.format(self.name):
- parsing_ops.VarLenFeature(dtypes.float32),
- }
+ self.name: parsing_ops.VarLenFeature(dtypes.int32),
+ '{}_weights'.format(self.name): parsing_ops.VarLenFeature(
+ dtypes.float32),
+ }
@property
- def _num_buckets(self):
+ def num_buckets(self):
return 5
- def _transform_feature(self, inputs):
- return (inputs.get(self.name),
- inputs.get('{}_weights'.format(self.name)))
+ def transform_feature(self, transformation_cache, state_manager):
+ return (transformation_cache.get(self.name, state_manager),
+ transformation_cache.get('{}_weights'.format(self.name),
+ state_manager))
- def _get_sparse_tensors(self,
- inputs,
- weight_collections=None,
- trainable=None):
+ def get_sparse_tensors(self, transformation_cache, state_manager):
"""Populates both id_tensor and weight_tensor."""
- ids_and_weights = inputs.get(self)
- return fc_old._CategoricalColumn.IdWeightPair(
+ ids_and_weights = transformation_cache.get(self, state_manager)
+ return fc.CategoricalColumn.IdWeightPair(
id_tensor=ids_and_weights[0], weight_tensor=ids_and_weights[1])
t = _TestColumnWithWeights()
- crossed = fc_old.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
+ crossed = fc.crossed_column([t, 'c'], hash_bucket_size=5, hash_key=5)
with ops.Graph().as_default():
with self.assertRaisesRegexp(
ValueError,
'crossed_column does not support weight_tensor.*{}'.format(t.name)):
- get_keras_linear_model_predictions({
+ model = fc.LinearModel((crossed,))
+ model({
t.name:
sparse_tensor.SparseTensor(
indices=((0, 0), (1, 0), (1, 1)),
@@ -1252,37 +1090,7 @@ class CrossedColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=['cA', 'cB', 'cC'],
dense_shape=(2, 2)),
- }, (crossed,))
-
-
-def get_linear_model_bias(name='linear_model'):
- with variable_scope.variable_scope(name, reuse=True):
- return variable_scope.get_variable('bias_weights')
-
-
-def get_linear_model_column_var(column, name='linear_model'):
- return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
- name + '/' + column.name)[0]
-
-
-def get_keras_linear_model_predictions(features,
- feature_columns,
- units=1,
- sparse_combiner='sum',
- weight_collections=None,
- trainable=True,
- cols_to_vars=None):
- keras_linear_model = _LinearModel(
- feature_columns,
- units,
- sparse_combiner,
- weight_collections,
- trainable,
- name='linear_model')
- retval = keras_linear_model(features) # pylint: disable=not-callable
- if cols_to_vars is not None:
- cols_to_vars.update(keras_linear_model.cols_to_vars())
- return retval
+ })
class LinearModelTest(test.TestCase):
@@ -1290,56 +1098,50 @@ class LinearModelTest(test.TestCase):
def test_raises_if_empty_feature_columns(self):
with self.assertRaisesRegexp(ValueError,
'feature_columns must not be empty'):
- fc.linear_model(features={}, feature_columns=[])
+ fc.LinearModel(feature_columns=[])
def test_should_be_feature_column(self):
- with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
- fc.linear_model(features={'a': [[0]]}, feature_columns='NotSupported')
+ with self.assertRaisesRegexp(ValueError, 'must be a FeatureColumn'):
+ fc.LinearModel(feature_columns='NotSupported')
def test_should_be_dense_or_categorical_column(self):
- class NotSupportedColumn(fc_old._FeatureColumn):
+ class NotSupportedColumn(fc.FeatureColumn):
@property
def name(self):
return 'NotSupportedColumn'
- def _transform_feature(self, cache):
+ def transform_feature(self, transformation_cache, state_manager):
pass
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
pass
with self.assertRaisesRegexp(
- ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
- fc.linear_model(
- features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
+ ValueError, 'must be either a DenseColumn or CategoricalColumn'):
+ fc.LinearModel(feature_columns=[NotSupportedColumn()])
def test_does_not_support_dict_columns(self):
with self.assertRaisesRegexp(
ValueError, 'Expected feature_columns to be iterable, found dict.'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns={'a': fc_old.numeric_column('a')})
+ fc.LinearModel(feature_columns={'a': fc.numeric_column('a')})
def test_raises_if_duplicate_name(self):
with self.assertRaisesRegexp(
ValueError, 'Duplicate feature column name found for columns'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.numeric_column('a'),
- fc_old.numeric_column('a')
- ])
+ fc.LinearModel(
+ feature_columns=[fc.numeric_column('a'),
+ fc.numeric_column('a')])
def test_dense_bias(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
sess.run(price_var.assign([[10.]]))
@@ -1347,16 +1149,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[15.], [55.]], predictions.eval())
def test_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast])
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
@@ -1365,18 +1167,17 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[1005.], [10015.]], predictions.eval())
def test_dense_and_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- price = fc_old.numeric_column('price')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [wire_cast, price])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([wire_cast, price])
+ predictions = model(features)
+ price_var, wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
@@ -1386,38 +1187,36 @@ class LinearModelTest(test.TestCase):
def test_dense_and_sparse_column(self):
"""When the column is both dense and sparse, uses sparse tensors."""
- class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
+ class _DenseAndSparseColumn(fc.DenseColumn, fc.CategoricalColumn):
@property
def name(self):
return 'dense_and_sparse_column'
@property
- def _parse_example_spec(self):
+ def parse_example_spec(self):
return {self.name: parsing_ops.VarLenFeature(self.dtype)}
- def _transform_feature(self, inputs):
- return inputs.get(self.name)
+ def transform_feature(self, transformation_cache, state_manager):
+ return transformation_cache.get(self.name, state_manager)
@property
- def _variable_shape(self):
+ def variable_shape(self):
raise ValueError('Should not use this method.')
- def _get_dense_tensor(self, inputs, weight_collections=None,
- trainable=None):
+ def get_dense_tensor(self, transformation_cache, state_manager):
raise ValueError('Should not use this method.')
@property
- def _num_buckets(self):
+ def num_buckets(self):
return 4
- def _get_sparse_tensors(self, inputs, weight_collections=None,
- trainable=None):
+ def get_sparse_tensors(self, transformation_cache, state_manager):
sp_tensor = sparse_tensor.SparseTensor(
indices=[[0, 0], [1, 0], [1, 1]],
values=[2, 0, 3],
dense_shape=[2, 2])
- return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
+ return fc.CategoricalColumn.IdWeightPair(sp_tensor, None)
dense_and_sparse_column = _DenseAndSparseColumn()
with ops.Graph().as_default():
@@ -1426,10 +1225,9 @@ class LinearModelTest(test.TestCase):
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {dense_and_sparse_column.name: sp_tensor}
- predictions = fc.linear_model(features, [dense_and_sparse_column])
- bias = get_linear_model_bias()
- dense_and_sparse_column_var = get_linear_model_column_var(
- dense_and_sparse_column)
+ model = fc.LinearModel([dense_and_sparse_column])
+ predictions = model(features)
+ dense_and_sparse_column_var, bias = model.variables
with _initialized_session() as sess:
sess.run(dense_and_sparse_column_var.assign(
[[10.], [100.], [1000.], [10000.]]))
@@ -1437,12 +1235,12 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[1005.], [10015.]], predictions.eval())
def test_dense_multi_output(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
- predictions = fc.linear_model(features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price], units=3)
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((1, 3)), price_var.eval())
@@ -1452,16 +1250,16 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_sparse_multi_output(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast], units=3)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast], units=3)
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
@@ -1474,18 +1272,19 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_dense_multi_dimension(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- predictions = fc.linear_model(features, [price])
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, _ = model.variables
with _initialized_session() as sess:
self.assertAllClose([[0.], [0.]], price_var.eval())
sess.run(price_var.assign([[10.], [100.]]))
self.assertAllClose([[210.], [650.]], predictions.eval())
def test_sparse_multi_rank(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = array_ops.sparse_placeholder(dtypes.string)
wire_value = sparse_tensor.SparseTensorValue(
@@ -1493,8 +1292,9 @@ class LinearModelTest(test.TestCase):
indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
dense_shape=[2, 2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(features, [wire_cast])
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast])
+ predictions = model(features)
+ wire_cast_var, _ = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
self.assertAllClose(
@@ -1506,25 +1306,24 @@ class LinearModelTest(test.TestCase):
predictions.eval(feed_dict={wire_tensor: wire_value}))
def test_sparse_combiner(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
indices=[[0, 0], [1, 0], [1, 1]],
dense_shape=[2, 2])
features = {'wire_cast': wire_tensor}
- predictions = fc.linear_model(
- features, [wire_cast], sparse_combiner='mean')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast], sparse_combiner='mean')
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[1005.], [5010.]], predictions.eval())
def test_sparse_combiner_with_negative_weights(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- wire_cast_weights = fc_old.weighted_categorical_column(wire_cast, 'weights')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast_weights = fc.weighted_categorical_column(wire_cast, 'weights')
with ops.Graph().as_default():
wire_tensor = sparse_tensor.SparseTensor(
@@ -1535,22 +1334,21 @@ class LinearModelTest(test.TestCase):
'wire_cast': wire_tensor,
'weights': constant_op.constant([[1., 1., -1.0]])
}
- predictions = fc.linear_model(
- features, [wire_cast_weights], sparse_combiner='sum')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ model = fc.LinearModel([wire_cast_weights], sparse_combiner='sum')
+ predictions = model(features)
+ wire_cast_var, bias = model.variables
with _initialized_session() as sess:
sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(bias.assign([5.]))
self.assertAllClose([[1005.], [-9985.]], predictions.eval())
def test_dense_multi_dimension_multi_output(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1., 2.], [5., 6.]]}
- predictions = fc.linear_model(features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price], units=3)
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose(np.zeros((3,)), bias.eval())
self.assertAllClose(np.zeros((2, 3)), price_var.eval())
@@ -1560,21 +1358,22 @@ class LinearModelTest(test.TestCase):
predictions.eval())
def test_raises_if_shape_mismatch(self):
- price = fc_old.numeric_column('price', shape=2)
+ price = fc.numeric_column('price', shape=2)
with ops.Graph().as_default():
features = {'price': [[1.], [5.]]}
with self.assertRaisesRegexp(
Exception,
r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ model(features)
def test_dense_reshaping(self):
- price = fc_old.numeric_column('price', shape=[1, 2])
+ price = fc.numeric_column('price', shape=[1, 2])
with ops.Graph().as_default():
features = {'price': [[[1., 2.]], [[5., 6.]]]}
- predictions = fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ predictions = model(features)
+ price_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.]], price_var.eval())
@@ -1583,17 +1382,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[210.], [650.]], predictions.eval())
def test_dense_multi_column(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1', shape=2)
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1., 2.], [5., 6.]],
'price2': [[3.], [4.]]
}
- predictions = fc.linear_model(features, [price1, price2])
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
+ price1_var, price2_var, bias = model.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias.eval())
self.assertAllClose([[0.], [0.]], price1_var.eval())
@@ -1604,115 +1402,55 @@ class LinearModelTest(test.TestCase):
sess.run(bias.assign([7.]))
self.assertAllClose([[3217.], [4657.]], predictions.eval())
- def test_fills_cols_to_vars(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- cols_to_vars = {}
- fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- self.assertAllEqual(cols_to_vars['bias'], [bias])
- self.assertAllEqual(cols_to_vars[price1], [price1_var])
- self.assertAllEqual(cols_to_vars[price2], [price2_var])
-
- def test_fills_cols_to_vars_partitioned_variables(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2', shape=3)
- with ops.Graph().as_default():
- features = {
- 'price1': [[1., 2.], [6., 7.]],
- 'price2': [[3., 4., 5.], [8., 9., 10.]]
- }
- cols_to_vars = {}
- with variable_scope.variable_scope(
- 'linear',
- partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
- fc.linear_model(features, [price1, price2], cols_to_vars=cols_to_vars)
- with _initialized_session():
- self.assertEqual([0.], cols_to_vars['bias'][0].eval())
- # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
- self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
- # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
- # a [1, 1] Variable.
- self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
-
- def test_dense_collection(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- self.assertIn(bias, my_vars)
- self.assertIn(price_var, my_vars)
-
- def test_sparse_collection(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- fc.linear_model(
- features, [wire_cast], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, my_vars)
- self.assertIn(wire_cast_var, my_vars)
-
def test_dense_trainable_default(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default() as g:
features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
+ model = fc.LinearModel([price])
+ model(features)
+ price_var, bias = model.variables
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertIn(bias, trainable_vars)
self.assertIn(price_var, trainable_vars)
def test_sparse_trainable_default(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default() as g:
wire_tensor = sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
features = {'wire_cast': wire_tensor}
- fc.linear_model(features, [wire_cast])
+ model = fc.LinearModel([wire_cast])
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
+ wire_cast_var, bias = model.variables
self.assertIn(bias, trainable_vars)
self.assertIn(wire_cast_var, trainable_vars)
def test_dense_trainable_false(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default() as g:
features = {'price': [[1.], [5.]]}
- fc.linear_model(features, [price], trainable=False)
+ model = fc.LinearModel([price], trainable=False)
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual([], trainable_vars)
def test_sparse_trainable_false(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
with ops.Graph().as_default() as g:
wire_tensor = sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
features = {'wire_cast': wire_tensor}
- fc.linear_model(features, [wire_cast], trainable=False)
+ model = fc.LinearModel([wire_cast], trainable=False)
+ model(features)
trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
self.assertEqual([], trainable_vars)
def test_column_order(self):
- price_a = fc_old.numeric_column('price_a')
- price_b = fc_old.numeric_column('price_b')
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
+ price_a = fc.numeric_column('price_a')
+ price_b = fc.numeric_column('price_b')
+ wire_cast = fc.categorical_column_with_hash_bucket('wire_cast', 4)
+ with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
@@ -1720,15 +1458,15 @@ class LinearModelTest(test.TestCase):
sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
}
- fc.linear_model(
- features, [price_a, wire_cast, price_b],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
+ model = fc.LinearModel([price_a, wire_cast, price_b])
+ model(features)
+
+ my_vars = model.variables
self.assertIn('price_a', my_vars[0].name)
self.assertIn('price_b', my_vars[1].name)
self.assertIn('wire_cast', my_vars[2].name)
- with ops.Graph().as_default() as g:
+ with ops.Graph().as_default():
features = {
'price_a': [[1.]],
'price_b': [[3.]],
@@ -1736,17 +1474,45 @@ class LinearModelTest(test.TestCase):
sparse_tensor.SparseTensor(
values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
}
- fc.linear_model(
- features, [wire_cast, price_b, price_a],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
+ model = fc.LinearModel([wire_cast, price_b, price_a])
+ model(features)
+
+ my_vars = model.variables
self.assertIn('price_a', my_vars[0].name)
self.assertIn('price_b', my_vars[1].name)
self.assertIn('wire_cast', my_vars[2].name)
+ def test_variable_names(self):
+ price1 = fc.numeric_column('price1')
+ dense_feature = fc.numeric_column('dense_feature')
+ dense_feature_bucketized = fc.bucketized_column(
+ dense_feature, boundaries=[0.])
+ some_sparse_column = fc.categorical_column_with_hash_bucket(
+ 'sparse_feature', hash_bucket_size=5)
+ some_embedding_column = fc.embedding_column(
+ some_sparse_column, dimension=10)
+ all_cols = [price1, dense_feature_bucketized, some_embedding_column]
+
+ with ops.Graph().as_default():
+ model = fc.LinearModel(all_cols)
+ features = {
+ 'price1': [[3.], [4.]],
+ 'dense_feature': [[-1.], [4.]],
+ 'sparse_feature': [['a'], ['x']],
+ }
+ model(features)
+ variable_names = [var.name for var in model.variables]
+ self.assertItemsEqual([
+ 'linear_model/dense_feature_bucketized/weights:0',
+ 'linear_model/price1/weights:0',
+ 'linear_model/sparse_feature_embedding/embedding_weights:0',
+ 'linear_model/sparse_feature_embedding/weights:0',
+ 'linear_model/bias_weights:0',
+ ], variable_names)
+
def test_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': [[1.], [5.], [7.]], # batchsize = 3
@@ -1755,12 +1521,13 @@ class LinearModelTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ model(features)
def test_subset_of_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- price3 = fc_old.numeric_column('price3')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
+ price3 = fc.numeric_column('price3')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
@@ -1770,17 +1537,19 @@ class LinearModelTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- fc.linear_model(features, [price1, price2, price3])
+ model = fc.LinearModel([price1, price2, price3])
+ model(features)
def test_runtime_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
'price2': [[3.], [4.]] # batchsize = 2
}
- predictions = fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
with _initialized_session() as sess:
with self.assertRaisesRegexp(errors.OpError,
'must have the same size and shape'):
@@ -1788,14 +1557,15 @@ class LinearModelTest(test.TestCase):
predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
def test_runtime_batch_size_matches(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
+ price1 = fc.numeric_column('price1')
+ price2 = fc.numeric_column('price2')
with ops.Graph().as_default():
features = {
'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
}
- predictions = fc.linear_model(features, [price1, price2])
+ model = fc.LinearModel([price1, price2])
+ predictions = model(features)
with _initialized_session() as sess:
sess.run(
predictions,
@@ -1805,14 +1575,14 @@ class LinearModelTest(test.TestCase):
})
def test_with_numpy_input_fn(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
input_fn = numpy_io.numpy_input_fn(
@@ -1823,15 +1593,14 @@ class LinearModelTest(test.TestCase):
batch_size=2,
shuffle=False)
features = input_fn()
- net = fc.linear_model(features, [price_buckets, body_style])
+ model = fc.LinearModel([price_buckets, body_style])
+ net = model(features)
# self.assertEqual(1 + 3 + 5, net.shape[1])
with _initialized_session() as sess:
coord = coordinator.Coordinator()
threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ body_style_var, price_buckets_var, bias = model.variables
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1843,14 +1612,14 @@ class LinearModelTest(test.TestCase):
coord.join(threads)
def test_with_1d_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
# Provides 1-dim tensor and dense tensor.
@@ -1864,11 +1633,10 @@ class LinearModelTest(test.TestCase):
self.assertEqual(1, features['price'].shape.ndims)
self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
- net = fc.linear_model(features, [price_buckets, body_style])
+ model = fc.LinearModel([price_buckets, body_style])
+ net = model(features)
with _initialized_session() as sess:
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ body_style_var, price_buckets_var, bias = model.variables
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1877,16 +1645,16 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
def test_with_1d_unknown_shape_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
+ price = fc.numeric_column('price')
+ price_buckets = fc.bucketized_column(
price, boundaries=[
0.,
10.,
100.,
])
- body_style = fc_old.categorical_column_with_vocabulary_list(
+ body_style = fc.categorical_column_with_vocabulary_list(
'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- country = fc_old.categorical_column_with_vocabulary_list(
+ country = fc.categorical_column_with_vocabulary_list(
'country', vocabulary_list=['US', 'JP', 'CA'])
# Provides 1-dim tensor and dense tensor.
@@ -1905,10 +1673,9 @@ class LinearModelTest(test.TestCase):
dense_shape=(2,))
country_data = np.array(['US', 'CA'])
- net = fc.linear_model(features, [price_buckets, body_style, country])
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
+ model = fc.LinearModel([price_buckets, body_style, country])
+ net = model(features)
+ body_style_var, _, price_buckets_var, bias = model.variables
with _initialized_session() as sess:
sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
@@ -1924,7 +1691,7 @@ class LinearModelTest(test.TestCase):
}))
def test_with_rank_0_feature(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
features = {
'price': constant_op.constant(0),
}
@@ -1932,29 +1699,31 @@ class LinearModelTest(test.TestCase):
# Static rank 0 should fail
with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ model(features)
# Dynamic rank 0 should fail
features = {
'price': array_ops.placeholder(dtypes.float32),
}
- net = fc.linear_model(features, [price])
+ model = fc.LinearModel([price])
+ net = model(features)
self.assertEqual(1, net.shape[1])
with _initialized_session() as sess:
with self.assertRaisesOpError('Feature .* cannot have rank 0'):
sess.run(net, feed_dict={features['price']: np.array(1)})
def test_multiple_linear_models(self):
- price = fc_old.numeric_column('price')
+ price = fc.numeric_column('price')
with ops.Graph().as_default():
features1 = {'price': [[1.], [5.]]}
features2 = {'price': [[2.], [10.]]}
- predictions1 = fc.linear_model(features1, [price])
- predictions2 = fc.linear_model(features2, [price])
- bias1 = get_linear_model_bias(name='linear_model')
- bias2 = get_linear_model_bias(name='linear_model_1')
- price_var1 = get_linear_model_column_var(price, name='linear_model')
- price_var2 = get_linear_model_column_var(price, name='linear_model_1')
+ model1 = fc.LinearModel([price])
+ model2 = fc.LinearModel([price])
+ predictions1 = model1(features1)
+ predictions2 = model2(features2)
+ price_var1, bias1 = model1.variables
+ price_var2, bias2 = model2.variables
with _initialized_session() as sess:
self.assertAllClose([0.], bias1.eval())
sess.run(price_var1.assign([[10.]]))
@@ -1966,664 +1735,6 @@ class LinearModelTest(test.TestCase):
self.assertAllClose([[25.], [105.]], predictions2.eval())
-class _LinearModelTest(test.TestCase):
-
- def test_raises_if_empty_feature_columns(self):
- with self.assertRaisesRegexp(ValueError,
- 'feature_columns must not be empty'):
- get_keras_linear_model_predictions(features={}, feature_columns=[])
-
- def test_should_be_feature_column(self):
- with self.assertRaisesRegexp(ValueError, 'must be a _FeatureColumn'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]}, feature_columns='NotSupported')
-
- def test_should_be_dense_or_categorical_column(self):
-
- class NotSupportedColumn(fc_old._FeatureColumn):
-
- @property
- def name(self):
- return 'NotSupportedColumn'
-
- def _transform_feature(self, cache):
- pass
-
- @property
- def _parse_example_spec(self):
- pass
-
- with self.assertRaisesRegexp(
- ValueError, 'must be either a _DenseColumn or _CategoricalColumn'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]}, feature_columns=[NotSupportedColumn()])
-
- def test_does_not_support_dict_columns(self):
- with self.assertRaisesRegexp(
- ValueError, 'Expected feature_columns to be iterable, found dict.'):
- fc.linear_model(
- features={'a': [[0]]},
- feature_columns={'a': fc_old.numeric_column('a')})
-
- def test_raises_if_duplicate_name(self):
- with self.assertRaisesRegexp(
- ValueError, 'Duplicate feature column name found for columns'):
- get_keras_linear_model_predictions(
- features={'a': [[0]]},
- feature_columns=[
- fc_old.numeric_column('a'),
- fc_old.numeric_column('a')
- ])
-
- def test_dense_bias(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- sess.run(price_var.assign([[10.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[15.], [55.]], predictions.eval())
-
- def test_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(features, [wire_cast])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.], [0.], [0.]], wire_cast_var.eval())
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [10015.]], predictions.eval())
-
- def test_dense_and_sparse_bias(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor, 'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(features,
- [wire_cast, price])
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- sess.run(price_var.assign([[10.]]))
- self.assertAllClose([[1015.], [10065.]], predictions.eval())
-
- def test_dense_and_sparse_column(self):
- """When the column is both dense and sparse, uses sparse tensors."""
-
- class _DenseAndSparseColumn(fc_old._DenseColumn, fc_old._CategoricalColumn):
-
- @property
- def name(self):
- return 'dense_and_sparse_column'
-
- @property
- def _parse_example_spec(self):
- return {self.name: parsing_ops.VarLenFeature(self.dtype)}
-
- def _transform_feature(self, inputs):
- return inputs.get(self.name)
-
- @property
- def _variable_shape(self):
- raise ValueError('Should not use this method.')
-
- def _get_dense_tensor(self,
- inputs,
- weight_collections=None,
- trainable=None):
- raise ValueError('Should not use this method.')
-
- @property
- def _num_buckets(self):
- return 4
-
- def _get_sparse_tensors(self,
- inputs,
- weight_collections=None,
- trainable=None):
- sp_tensor = sparse_tensor.SparseTensor(
- indices=[[0, 0], [1, 0], [1, 1]],
- values=[2, 0, 3],
- dense_shape=[2, 2])
- return fc_old._CategoricalColumn.IdWeightPair(sp_tensor, None)
-
- dense_and_sparse_column = _DenseAndSparseColumn()
- with ops.Graph().as_default():
- sp_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'],
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {dense_and_sparse_column.name: sp_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [dense_and_sparse_column])
- bias = get_linear_model_bias()
- dense_and_sparse_column_var = get_linear_model_column_var(
- dense_and_sparse_column)
- with _initialized_session() as sess:
- sess.run(
- dense_and_sparse_column_var.assign([[10.], [100.], [1000.],
- [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [10015.]], predictions.eval())
-
- def test_dense_multi_output(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- predictions = get_keras_linear_model_predictions(
- features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((1, 3)), price_var.eval())
- sess.run(price_var.assign([[10., 100., 1000.]]))
- sess.run(bias.assign([5., 6., 7.]))
- self.assertAllClose([[15., 106., 1007.], [55., 506., 5007.]],
- predictions.eval())
-
- def test_sparse_multi_output(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [wire_cast], units=3)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((4, 3)), wire_cast_var.eval())
- sess.run(
- wire_cast_var.assign([[10., 11., 12.], [100., 110., 120.],
- [1000., 1100.,
- 1200.], [10000., 11000., 12000.]]))
- sess.run(bias.assign([5., 6., 7.]))
- self.assertAllClose([[1005., 1106., 1207.], [10015., 11017., 12019.]],
- predictions.eval())
-
- def test_dense_multi_dimension(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1., 2.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([[0.], [0.]], price_var.eval())
- sess.run(price_var.assign([[10.], [100.]]))
- self.assertAllClose([[210.], [650.]], predictions.eval())
-
- def test_sparse_multi_rank(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = array_ops.sparse_placeholder(dtypes.string)
- wire_value = sparse_tensor.SparseTensorValue(
- values=['omar', 'stringer', 'marlo', 'omar'], # hashed = [2, 0, 3, 2]
- indices=[[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 0, 1]],
- dense_shape=[2, 2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(features, [wire_cast])
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((4, 1)), wire_cast_var.eval())
- self.assertAllClose(
- np.zeros((2, 1)),
- predictions.eval(feed_dict={wire_tensor: wire_value}))
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- self.assertAllClose(
- [[1010.], [11000.]],
- predictions.eval(feed_dict={wire_tensor: wire_value}))
-
- def test_sparse_combiner(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default():
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar', 'stringer', 'marlo'], # hashed to = [2, 0, 3]
- indices=[[0, 0], [1, 0], [1, 1]],
- dense_shape=[2, 2])
- features = {'wire_cast': wire_tensor}
- predictions = get_keras_linear_model_predictions(
- features, [wire_cast], sparse_combiner='mean')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- with _initialized_session() as sess:
- sess.run(wire_cast_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(bias.assign([5.]))
- self.assertAllClose([[1005.], [5010.]], predictions.eval())
-
- def test_dense_multi_dimension_multi_output(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1., 2.], [5., 6.]]}
- predictions = get_keras_linear_model_predictions(
- features, [price], units=3)
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose(np.zeros((3,)), bias.eval())
- self.assertAllClose(np.zeros((2, 3)), price_var.eval())
- sess.run(price_var.assign([[1., 2., 3.], [10., 100., 1000.]]))
- sess.run(bias.assign([2., 3., 4.]))
- self.assertAllClose([[23., 205., 2007.], [67., 613., 6019.]],
- predictions.eval())
-
- def test_raises_if_shape_mismatch(self):
- price = fc_old.numeric_column('price', shape=2)
- with ops.Graph().as_default():
- features = {'price': [[1.], [5.]]}
- with self.assertRaisesRegexp(
- Exception,
- r'Cannot reshape a tensor with 2 elements to shape \[2,2\]'):
- get_keras_linear_model_predictions(features, [price])
-
- def test_dense_reshaping(self):
- price = fc_old.numeric_column('price', shape=[1, 2])
- with ops.Graph().as_default():
- features = {'price': [[[1., 2.]], [[5., 6.]]]}
- predictions = get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.]], price_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price_var.assign([[10.], [100.]]))
- self.assertAllClose([[210.], [650.]], predictions.eval())
-
- def test_dense_multi_column(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- with _initialized_session() as sess:
- self.assertAllClose([0.], bias.eval())
- self.assertAllClose([[0.], [0.]], price1_var.eval())
- self.assertAllClose([[0.]], price2_var.eval())
- self.assertAllClose([[0.], [0.]], predictions.eval())
- sess.run(price1_var.assign([[10.], [100.]]))
- sess.run(price2_var.assign([[1000.]]))
- sess.run(bias.assign([7.]))
- self.assertAllClose([[3217.], [4657.]], predictions.eval())
-
- def test_fills_cols_to_vars(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {'price1': [[1., 2.], [5., 6.]], 'price2': [[3.], [4.]]}
- cols_to_vars = {}
- get_keras_linear_model_predictions(
- features, [price1, price2], cols_to_vars=cols_to_vars)
- bias = get_linear_model_bias()
- price1_var = get_linear_model_column_var(price1)
- price2_var = get_linear_model_column_var(price2)
- self.assertAllEqual(cols_to_vars['bias'], [bias])
- self.assertAllEqual(cols_to_vars[price1], [price1_var])
- self.assertAllEqual(cols_to_vars[price2], [price2_var])
-
- def test_fills_cols_to_vars_partitioned_variables(self):
- price1 = fc_old.numeric_column('price1', shape=2)
- price2 = fc_old.numeric_column('price2', shape=3)
- with ops.Graph().as_default():
- features = {
- 'price1': [[1., 2.], [6., 7.]],
- 'price2': [[3., 4., 5.], [8., 9., 10.]]
- }
- cols_to_vars = {}
- with variable_scope.variable_scope(
- 'linear',
- partitioner=partitioned_variables.fixed_size_partitioner(2, axis=0)):
- get_keras_linear_model_predictions(
- features, [price1, price2], cols_to_vars=cols_to_vars)
- with _initialized_session():
- self.assertEqual([0.], cols_to_vars['bias'][0].eval())
- # Partitioning shards the [2, 1] price1 var into 2 [1, 1] Variables.
- self.assertAllEqual([[0.]], cols_to_vars[price1][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price1][1].eval())
- # Partitioning shards the [3, 1] price2 var into a [2, 1] Variable and
- # a [1, 1] Variable.
- self.assertAllEqual([[0.], [0.]], cols_to_vars[price2][0].eval())
- self.assertAllEqual([[0.]], cols_to_vars[price2][1].eval())
-
- def test_dense_collection(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(
- features, [price], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- self.assertIn(bias, my_vars)
- self.assertIn(price_var, my_vars)
-
- def test_sparse_collection(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(
- features, [wire_cast], weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, my_vars)
- self.assertIn(wire_cast_var, my_vars)
-
- def test_dense_trainable_default(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(features, [price])
- bias = get_linear_model_bias()
- price_var = get_linear_model_column_var(price)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertIn(bias, trainable_vars)
- self.assertIn(price_var, trainable_vars)
-
- def test_sparse_trainable_default(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(features, [wire_cast])
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- bias = get_linear_model_bias()
- wire_cast_var = get_linear_model_column_var(wire_cast)
- self.assertIn(bias, trainable_vars)
- self.assertIn(wire_cast_var, trainable_vars)
-
- def test_dense_trainable_false(self):
- price = fc_old.numeric_column('price')
- with ops.Graph().as_default() as g:
- features = {'price': [[1.], [5.]]}
- get_keras_linear_model_predictions(features, [price], trainable=False)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertEqual([], trainable_vars)
-
- def test_sparse_trainable_false(self):
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- wire_tensor = sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- features = {'wire_cast': wire_tensor}
- get_keras_linear_model_predictions(features, [wire_cast], trainable=False)
- trainable_vars = g.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- self.assertEqual([], trainable_vars)
-
- def test_column_order(self):
- price_a = fc_old.numeric_column('price_a')
- price_b = fc_old.numeric_column('price_b')
- wire_cast = fc_old.categorical_column_with_hash_bucket('wire_cast', 4)
- with ops.Graph().as_default() as g:
- features = {
- 'price_a': [[1.]],
- 'price_b': [[3.]],
- 'wire_cast':
- sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- }
- get_keras_linear_model_predictions(
- features, [price_a, wire_cast, price_b],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- self.assertIn('price_a', my_vars[0].name)
- self.assertIn('price_b', my_vars[1].name)
- self.assertIn('wire_cast', my_vars[2].name)
-
- with ops.Graph().as_default() as g:
- features = {
- 'price_a': [[1.]],
- 'price_b': [[3.]],
- 'wire_cast':
- sparse_tensor.SparseTensor(
- values=['omar'], indices=[[0, 0]], dense_shape=[1, 1])
- }
- get_keras_linear_model_predictions(
- features, [wire_cast, price_b, price_a],
- weight_collections=['my-vars'])
- my_vars = g.get_collection('my-vars')
- self.assertIn('price_a', my_vars[0].name)
- self.assertIn('price_b', my_vars[1].name)
- self.assertIn('wire_cast', my_vars[2].name)
-
- def test_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': [[1.], [5.], [7.]], # batchsize = 3
- 'price2': [[3.], [4.]] # batchsize = 2
- }
- with self.assertRaisesRegexp(
- ValueError,
- 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- get_keras_linear_model_predictions(features, [price1, price2])
-
- def test_subset_of_static_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- price3 = fc_old.numeric_column('price3')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
- 'price2': [[3.], [4.]], # batchsize = 2
- 'price3': [[3.], [4.], [5.]] # batchsize = 3
- }
- with self.assertRaisesRegexp(
- ValueError,
- 'Batch size \(first dimension\) of each feature must be same.'): # pylint: disable=anomalous-backslash-in-string
- get_keras_linear_model_predictions(features, [price1, price2, price3])
-
- def test_runtime_batch_size_mismatch(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 3
- 'price2': [[3.], [4.]] # batchsize = 2
- }
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- with _initialized_session() as sess:
- with self.assertRaisesRegexp(errors.OpError,
- 'must have the same size and shape'):
- sess.run(
- predictions, feed_dict={features['price1']: [[1.], [5.], [7.]]})
-
- def test_runtime_batch_size_matches(self):
- price1 = fc_old.numeric_column('price1')
- price2 = fc_old.numeric_column('price2')
- with ops.Graph().as_default():
- features = {
- 'price1': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
- 'price2': array_ops.placeholder(dtype=dtypes.int64), # batchsize = 2
- }
- predictions = get_keras_linear_model_predictions(features,
- [price1, price2])
- with _initialized_session() as sess:
- sess.run(
- predictions,
- feed_dict={
- features['price1']: [[1.], [5.]],
- features['price2']: [[1.], [5.]],
- })
-
- def test_with_numpy_input_fn(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-
- input_fn = numpy_io.numpy_input_fn(
- x={
- 'price': np.array([-1., 2., 13., 104.]),
- 'body-style': np.array(['sedan', 'hardtop', 'wagon', 'sedan']),
- },
- batch_size=2,
- shuffle=False)
- features = input_fn()
- net = get_keras_linear_model_predictions(features,
- [price_buckets, body_style])
- # self.assertEqual(1 + 3 + 5, net.shape[1])
- with _initialized_session() as sess:
- coord = coordinator.Coordinator()
- threads = queue_runner_impl.start_queue_runners(sess, coord=coord)
-
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
-
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [100 - 10 + 5.]], sess.run(net))
-
- coord.request_stop()
- coord.join(threads)
-
- def test_with_1d_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
-
- # Provides 1-dim tensor and dense tensor.
- features = {
- 'price':
- constant_op.constant([
- -1.,
- 12.,
- ]),
- 'body-style':
- sparse_tensor.SparseTensor(
- indices=((0,), (1,)),
- values=('sedan', 'hardtop'),
- dense_shape=(2,)),
- }
- self.assertEqual(1, features['price'].shape.ndims)
- self.assertEqual(1, features['body-style'].dense_shape.get_shape()[0])
-
- net = get_keras_linear_model_predictions(features,
- [price_buckets, body_style])
- with _initialized_session() as sess:
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
-
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]], sess.run(net))
-
- def test_with_1d_unknown_shape_sparse_tensor(self):
- price = fc_old.numeric_column('price')
- price_buckets = fc_old.bucketized_column(
- price, boundaries=[
- 0.,
- 10.,
- 100.,
- ])
- body_style = fc_old.categorical_column_with_vocabulary_list(
- 'body-style', vocabulary_list=['hardtop', 'wagon', 'sedan'])
- country = fc_old.categorical_column_with_vocabulary_list(
- 'country', vocabulary_list=['US', 'JP', 'CA'])
-
- # Provides 1-dim tensor and dense tensor.
- features = {
- 'price': array_ops.placeholder(dtypes.float32),
- 'body-style': array_ops.sparse_placeholder(dtypes.string),
- 'country': array_ops.placeholder(dtypes.string),
- }
- self.assertIsNone(features['price'].shape.ndims)
- self.assertIsNone(features['body-style'].get_shape().ndims)
-
- price_data = np.array([-1., 12.])
- body_style_data = sparse_tensor.SparseTensorValue(
- indices=((0,), (1,)), values=('sedan', 'hardtop'), dense_shape=(2,))
- country_data = np.array(['US', 'CA'])
-
- net = get_keras_linear_model_predictions(
- features, [price_buckets, body_style, country])
- bias = get_linear_model_bias()
- price_buckets_var = get_linear_model_column_var(price_buckets)
- body_style_var = get_linear_model_column_var(body_style)
- with _initialized_session() as sess:
- sess.run(price_buckets_var.assign([[10.], [100.], [1000.], [10000.]]))
- sess.run(body_style_var.assign([[-10.], [-100.], [-1000.]]))
- sess.run(bias.assign([5.]))
-
- self.assertAllClose([[10 - 1000 + 5.], [1000 - 10 + 5.]],
- sess.run(
- net,
- feed_dict={
- features['price']: price_data,
- features['body-style']: body_style_data,
- features['country']: country_data
- }))
-
- def test_with_rank_0_feature(self):
- price = fc_old.numeric_column('price')
- features = {
- 'price': constant_op.constant(0),
- }
- self.assertEqual(0, features['price'].shape.ndims)
-
- # Static rank 0 should fail
- with self.assertRaisesRegexp(ValueError, 'Feature .* cannot have rank 0'):
- get_keras_linear_model_predictions(features, [price])
-
- # Dynamic rank 0 should fail
- features = {
- 'price': array_ops.placeholder(dtypes.float32),
- }
- net = get_keras_linear_model_predictions(features, [price])
- self.assertEqual(1, net.shape[1])
- with _initialized_session() as sess:
- with self.assertRaisesOpError('Feature .* cannot have rank 0'):
- sess.run(net, feed_dict={features['price']: np.array(1)})
-
-
class FeatureLayerTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
@@ -3723,47 +2834,22 @@ class VocabularyFileCategoricalColumnTest(test.TestCase):
id_weight_pair.id_tensor.eval())
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_file(
- key='wire',
- vocabulary_file=self._wire_vocabulary_file_name,
- vocabulary_size=self._wire_vocabulary_size,
- num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 2: wire_var[2] = 3
- # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
- self.assertAllClose(((3.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_file(
+ wire_column = fc.categorical_column_with_vocabulary_file(
key='wire',
vocabulary_file=self._wire_vocabulary_file_name,
vocabulary_size=self._wire_vocabulary_size,
num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -4124,45 +3210,21 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
id_weight_pair.id_tensor.eval())
def test_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_list(
- key='aaa',
- vocabulary_list=('omar', 'stringer', 'marlo'),
- num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- wire_column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=('marlo', 'skywalker', 'omar'),
- dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- wire_var.assign(((1.,), (2.,), (3.,), (4.,))).eval()
- # 'marlo' -> 2: wire_var[2] = 3
- # 'skywalker' -> 3, 'omar' -> 0: wire_var[3] + wire_var[0] = 4+1 = 5
- self.assertAllClose(((3.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- wire_column = fc_old.categorical_column_with_vocabulary_list(
+ wire_column = fc.categorical_column_with_vocabulary_list(
key='aaa',
vocabulary_list=('omar', 'stringer', 'marlo'),
num_oov_buckets=1)
- self.assertEqual(4, wire_column._num_buckets)
+ self.assertEqual(4, wire_column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((wire_column,))
+ predictions = model({
wire_column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=('marlo', 'skywalker', 'omar'),
dense_shape=(2, 2))
- }, (wire_column,))
- bias = get_linear_model_bias()
- wire_var = get_linear_model_column_var(wire_column)
+ })
+ wire_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,), (0.,)), wire_var.eval())
@@ -4382,39 +3444,18 @@ class IdentityCategoricalColumnTest(test.TestCase):
}))
def test_linear_model(self):
- column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
- self.assertEqual(3, column.num_buckets)
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- column.name: sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] = 1
- # weight_var[2] + weight_var[1] = 3+2 = 5
- self.assertAllClose(((1.,), (5.,)), predictions.eval())
-
- def test_keras_linear_model(self):
- column = fc_old.categorical_column_with_identity(key='aaa', num_buckets=3)
+ column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
self.assertEqual(3, column.num_buckets)
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,))
+ predictions = model({
column.name:
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 2, 1),
dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
@@ -4640,27 +3681,8 @@ class IndicatorColumnTest(test.TestCase):
self.assertAllEqual([[0., 1., 1.]], indicator_tensor.eval())
def test_linear_model(self):
- animal = fc_old.indicator_column(
- fc_old.categorical_column_with_identity('animal', num_buckets=4))
- with ops.Graph().as_default():
- features = {
- 'animal':
- sparse_tensor.SparseTensor(
- indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
- }
-
- predictions = fc.linear_model(features, [animal])
- weight_var = get_linear_model_column_var(animal)
- with _initialized_session():
- # All should be zero-initialized.
- self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
- self.assertAllClose([[0.]], predictions.eval())
- weight_var.assign([[1.], [2.], [3.], [4.]]).eval()
- self.assertAllClose([[2. + 3.]], predictions.eval())
-
- def test_keras_linear_model(self):
- animal = fc_old.indicator_column(
- fc_old.categorical_column_with_identity('animal', num_buckets=4))
+ animal = fc.indicator_column(
+ fc.categorical_column_with_identity('animal', num_buckets=4))
with ops.Graph().as_default():
features = {
'animal':
@@ -4668,8 +3690,9 @@ class IndicatorColumnTest(test.TestCase):
indices=[[0, 0], [0, 1]], values=[1, 2], dense_shape=[1, 2])
}
- predictions = get_keras_linear_model_predictions(features, [animal])
- weight_var = get_linear_model_column_var(animal)
+ model = fc.LinearModel([animal])
+ predictions = model(features)
+ weight_var, _ = model.variables
with _initialized_session():
# All should be zero-initialized.
self.assertAllClose([[0.], [0.], [0.], [0.]], weight_var.eval())
@@ -5121,17 +4144,16 @@ class EmbeddingColumnTest(test.TestCase):
return zeros_embedding_values
# Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
+ categorical_column = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
+ embedding_column = fc.embedding_column(
categorical_column,
dimension=embedding_dimension,
initializer=_initializer)
with ops.Graph().as_default():
- predictions = fc.linear_model({
- categorical_column.name: sparse_input
- }, (embedding_column,))
+ model = fc.LinearModel((embedding_column,))
+ predictions = model({categorical_column.name: sparse_input})
expected_var_names = (
'linear_model/bias_weights:0',
'linear_model/aaa_embedding/weights:0',
@@ -5173,82 +4195,6 @@ class EmbeddingColumnTest(test.TestCase):
# = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
- def test_keras_linear_model(self):
- # Inputs.
- batch_size = 4
- vocabulary_size = 3
- sparse_input = sparse_tensor.SparseTensorValue(
- # example 0, ids [2]
- # example 1, ids [0, 1]
- # example 2, ids []
- # example 3, ids [1]
- indices=((0, 0), (1, 0), (1, 4), (3, 0)),
- values=(2, 0, 1, 1),
- dense_shape=(batch_size, 5))
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_shape = (vocabulary_size, embedding_dimension)
- zeros_embedding_values = np.zeros(embedding_shape)
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual(embedding_shape, shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return zeros_embedding_values
-
- # Build columns.
- categorical_column = fc_old.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- embedding_column = fc_old.embedding_column(
- categorical_column,
- dimension=embedding_dimension,
- initializer=_initializer)
-
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
- categorical_column.name: sparse_input
- }, (embedding_column,))
- expected_var_names = (
- 'linear_model/bias_weights:0',
- 'linear_model/aaa_embedding/weights:0',
- 'linear_model/aaa_embedding/embedding_weights:0',
- )
- self.assertItemsEqual(
- expected_var_names,
- [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
- trainable_vars = {
- v.name: v
- for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- }
- self.assertItemsEqual(expected_var_names, trainable_vars.keys())
- bias = trainable_vars['linear_model/bias_weights:0']
- embedding_weights = trainable_vars[
- 'linear_model/aaa_embedding/embedding_weights:0']
- linear_weights = trainable_vars['linear_model/aaa_embedding/weights:0']
- with _initialized_session():
- # Predictions with all zero weights.
- self.assertAllClose(np.zeros((1,)), bias.eval())
- self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights.eval())
- self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
-
- # Predictions with all non-zero weights.
- embedding_weights.assign((
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )).eval()
- linear_weights.assign(((4.,), (6.,))).eval()
- # example 0, ids [2], embedding[0] = [7, 11]
- # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
- # example 2, ids [], embedding[2] = [0, 0]
- # example 3, ids [1], embedding[3] = [3, 5]
- # sum(embeddings * linear_weights)
- # = [4*7 + 6*11, 4*2 + 6*3.5, 4*0 + 6*0, 4*3 + 6*5] = [94, 29, 0, 42]
- self.assertAllClose(((94.,), (29.,), (0.,), (42.,)), predictions.eval())
-
def test_feature_layer(self):
# Inputs.
vocabulary_size = 3
@@ -5749,27 +4695,31 @@ class SharedEmbeddingColumnTest(test.TestCase):
return zeros_embedding_values
# Build columns.
- categorical_column_a = fc_old.categorical_column_with_identity(
+ categorical_column_a = fc.categorical_column_with_identity(
key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc_old.categorical_column_with_identity(
+ categorical_column_b = fc.categorical_column_with_identity(
key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
+ embedding_column_a, embedding_column_b = fc.shared_embedding_columns_v2(
[categorical_column_a, categorical_column_b],
dimension=embedding_dimension,
initializer=_initializer)
with ops.Graph().as_default():
- predictions = fc.linear_model({
+ model = fc.LinearModel(
+ (embedding_column_a, embedding_column_b),
+ shared_state_manager=fc.SharedEmbeddingStateManager())
+ predictions = model({
categorical_column_a.name: input_a,
- categorical_column_b.name: input_b,
- }, (embedding_column_a, embedding_column_b))
+ categorical_column_b.name: input_b
+ })
+
# Linear weights do not follow the column name. But this is a rare use
# case, and fixing it would add too much complexity to the code.
expected_var_names = (
'linear_model/bias_weights:0',
- 'linear_model/aaa_bbb_shared_embedding/weights:0',
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
+ 'linear_model/aaa_shared_embedding/weights:0',
+ 'shared_embedding_state_manager/aaa_bbb_shared_embedding:0',
+ 'linear_model/bbb_shared_embedding/weights:0',
)
self.assertItemsEqual(
expected_var_names,
@@ -5781,102 +4731,11 @@ class SharedEmbeddingColumnTest(test.TestCase):
self.assertItemsEqual(expected_var_names, trainable_vars.keys())
bias = trainable_vars['linear_model/bias_weights:0']
embedding_weights = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
- linear_weights_a = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/weights:0']
- linear_weights_b = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
- with _initialized_session():
- # Predictions with all zero weights.
- self.assertAllClose(np.zeros((1,)), bias.eval())
- self.assertAllClose(zeros_embedding_values, embedding_weights.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights_a.eval())
- self.assertAllClose(
- np.zeros((embedding_dimension, 1)), linear_weights_b.eval())
- self.assertAllClose(np.zeros((batch_size, 1)), predictions.eval())
-
- # Predictions with all non-zero weights.
- embedding_weights.assign((
- (1., 2.), # id 0
- (3., 5.), # id 1
- (7., 11.) # id 2
- )).eval()
- linear_weights_a.assign(((4.,), (6.,))).eval()
- # example 0, ids [2], embedding[0] = [7, 11]
- # example 1, ids [0, 1], embedding[1] = mean([1, 2] + [3, 5]) = [2, 3.5]
- # sum(embeddings * linear_weights)
- # = [4*7 + 6*11, 4*2 + 6*3.5] = [94, 29]
- linear_weights_b.assign(((3.,), (5.,))).eval()
- # example 0, ids [0], embedding[0] = [1, 2]
- # example 1, ids [], embedding[1] = 0, 0]
- # sum(embeddings * linear_weights)
- # = [3*1 + 5*2, 3*0 +5*0] = [13, 0]
- self.assertAllClose([[94. + 13.], [29.]], predictions.eval())
-
- def test_keras_linear_model(self):
- # Inputs.
- batch_size = 2
- vocabulary_size = 3
- # -1 values are ignored.
- input_a = np.array([
- [2, -1, -1], # example 0, ids [2]
- [0, 1, -1]
- ]) # example 1, ids [0, 1]
- input_b = np.array([
- [0, -1, -1], # example 0, ids [0]
- [-1, -1, -1]
- ]) # example 1, ids []
-
- # Embedding variable.
- embedding_dimension = 2
- embedding_shape = (vocabulary_size, embedding_dimension)
- zeros_embedding_values = np.zeros(embedding_shape)
-
- def _initializer(shape, dtype, partition_info):
- self.assertAllEqual(embedding_shape, shape)
- self.assertEqual(dtypes.float32, dtype)
- self.assertIsNone(partition_info)
- return zeros_embedding_values
-
- # Build columns.
- categorical_column_a = fc_old.categorical_column_with_identity(
- key='aaa', num_buckets=vocabulary_size)
- categorical_column_b = fc_old.categorical_column_with_identity(
- key='bbb', num_buckets=vocabulary_size)
- embedding_column_a, embedding_column_b = fc_old.shared_embedding_columns(
- [categorical_column_a, categorical_column_b],
- dimension=embedding_dimension,
- initializer=_initializer)
-
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
- categorical_column_a.name: input_a,
- categorical_column_b.name: input_b,
- }, (embedding_column_a, embedding_column_b))
- # Linear weights do not follow the column name. But this is a rare use
- # case, and fixing it would add too much complexity to the code.
- expected_var_names = (
- 'linear_model/bias_weights:0',
- 'linear_model/aaa_bbb_shared_embedding/weights:0',
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0',
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0',
- )
- self.assertItemsEqual(
- expected_var_names,
- [v.name for v in ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)])
- trainable_vars = {
- v.name: v
- for v in ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
- }
- self.assertItemsEqual(expected_var_names, trainable_vars.keys())
- bias = trainable_vars['linear_model/bias_weights:0']
- embedding_weights = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/embedding_weights:0']
+ 'shared_embedding_state_manager/aaa_bbb_shared_embedding:0']
linear_weights_a = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding/weights:0']
+ 'linear_model/aaa_shared_embedding/weights:0']
linear_weights_b = trainable_vars[
- 'linear_model/aaa_bbb_shared_embedding_1/weights:0']
+ 'linear_model/bbb_shared_embedding/weights:0']
with _initialized_session():
# Predictions with all zero weights.
self.assertAllClose(np.zeros((1,)), bias.eval())
@@ -6275,13 +5134,14 @@ class WeightedCategoricalColumnTest(test.TestCase):
dense_shape=(2, 2)),
weight_tensor.eval())
- def test_keras_linear_model(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,))
+ predictions = model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
@@ -6292,9 +5152,8 @@ class WeightedCategoricalColumnTest(test.TestCase):
indices=((0, 0), (1, 0), (1, 1)),
values=(.5, 1., .1),
dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
@@ -6305,15 +5164,16 @@ class WeightedCategoricalColumnTest(test.TestCase):
# = 3*1 + 2*.1 = 3+.2 = 3.2
self.assertAllClose(((.5,), (3.2,)), predictions.eval())
- def test_keras_linear_model_mismatched_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model_mismatched_shape(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
with self.assertRaisesRegexp(ValueError,
r'Dimensions.*are not compatible'):
- get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,))
+ model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
@@ -6324,122 +5184,23 @@ class WeightedCategoricalColumnTest(test.TestCase):
indices=((0, 0), (0, 1), (1, 0), (1, 1)),
values=(.5, 11., 1., .1),
dense_shape=(2, 2))
- }, (column,))
+ })
- def test_keras_linear_model_mismatched_dense_values(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions(
- {
- 'ids':
- sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': ((.5,), (1.,))
- }, (column,),
- sparse_combiner='mean')
- # Disabling the constant folding optimizer here since it changes the
- # error message differently on CPU and GPU.
- config = config_pb2.ConfigProto()
- config.graph_options.rewrite_options.constant_folding = (
- rewriter_config_pb2.RewriterConfig.OFF)
- with _initialized_session(config):
- with self.assertRaisesRegexp(errors.OpError, 'Incompatible shapes'):
- predictions.eval()
-
- def test_keras_linear_model_mismatched_dense_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ def test_linear_model_mismatched_dense_values(self):
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = get_keras_linear_model_predictions({
+ model = fc.LinearModel((column,), sparse_combiner='mean')
+ predictions = model({
'ids':
sparse_tensor.SparseTensorValue(
indices=((0, 0), (1, 0), (1, 1)),
values=(0, 2, 1),
dense_shape=(2, 2)),
- 'values': ((.5,), (1.,), (.1,))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] * weights[0, 0] = 1 * .5 = .5
- # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
- # = 3*1 + 2*.1 = 3+.2 = 3.2
- self.assertAllClose(((.5,), (3.2,)), predictions.eval())
-
- def test_linear_model(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(.5, 1., .1),
- dense_shape=(2, 2))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
- with _initialized_session():
- self.assertAllClose((0.,), bias.eval())
- self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
- self.assertAllClose(((0.,), (0.,)), predictions.eval())
- weight_var.assign(((1.,), (2.,), (3.,))).eval()
- # weight_var[0] * weights[0, 0] = 1 * .5 = .5
- # weight_var[2] * weights[1, 0] + weight_var[1] * weights[1, 1]
- # = 3*1 + 2*.1 = 3+.2 = 3.2
- self.assertAllClose(((.5,), (3.2,)), predictions.eval())
-
- def test_linear_model_mismatched_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- with self.assertRaisesRegexp(
- ValueError, r'Dimensions.*are not compatible'):
- fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (0, 1), (1, 0), (1, 1)),
- values=(.5, 11., 1., .1),
- dense_shape=(2, 2))
- }, (column,))
-
- def test_linear_model_mismatched_dense_values(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
- key='ids', num_buckets=3),
- weight_feature_key='values')
- with ops.Graph().as_default():
- predictions = fc.linear_model(
- {
- 'ids':
- sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
- 'values': ((.5,), (1.,))
- }, (column,),
- sparse_combiner='mean')
+ 'values': ((.5,), (1.,))
+ })
# Disabling the constant folding optimizer here since it changes the
# error message differently on CPU and GPU.
config = config_pb2.ConfigProto()
@@ -6450,20 +5211,21 @@ class WeightedCategoricalColumnTest(test.TestCase):
predictions.eval()
def test_linear_model_mismatched_dense_shape(self):
- column = fc_old.weighted_categorical_column(
- categorical_column=fc_old.categorical_column_with_identity(
+ column = fc.weighted_categorical_column(
+ categorical_column=fc.categorical_column_with_identity(
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- predictions = fc.linear_model({
- 'ids': sparse_tensor.SparseTensorValue(
- indices=((0, 0), (1, 0), (1, 1)),
- values=(0, 2, 1),
- dense_shape=(2, 2)),
+ model = fc.LinearModel((column,))
+ predictions = model({
+ 'ids':
+ sparse_tensor.SparseTensorValue(
+ indices=((0, 0), (1, 0), (1, 1)),
+ values=(0, 2, 1),
+ dense_shape=(2, 2)),
'values': ((.5,), (1.,), (.1,))
- }, (column,))
- bias = get_linear_model_bias()
- weight_var = get_linear_model_column_var(column)
+ })
+ weight_var, bias = model.variables
with _initialized_session():
self.assertAllClose((0.,), bias.eval())
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py
index 06c653097a..7f6e0a75a5 100644
--- a/tensorflow/python/framework/device.py
+++ b/tensorflow/python/framework/device.py
@@ -87,6 +87,7 @@ class DeviceSpec(object):
else:
self.device_type = device_type
self.device_index = device_index
+ self._hash = hash(self.to_string())
def _clear(self):
self._job = None
@@ -234,7 +235,7 @@ class DeviceSpec(object):
return self.to_string() == other.to_string()
def __hash__(self):
- return hash(self.to_string())
+ return self._hash
def check_valid(spec):
@@ -266,6 +267,7 @@ def canonical_name(device):
# possible to compare the device function stacks belonging to different
# graphs in a meaningful way.
_cached_device_functions = {}
+_cached_device_specs = {}
_cache_lock = threading.Lock()
@@ -297,7 +299,13 @@ def merge_device(spec):
"""
with _cache_lock:
if not isinstance(spec, DeviceSpec):
- spec = DeviceSpec.from_string(spec or "")
+ cached_device_spec = _cached_device_specs.get(spec, None)
+ if cached_device_spec is None:
+ device_spec = DeviceSpec.from_string(spec or "")
+ _cached_device_specs[spec] = device_spec
+ spec = device_spec
+ else:
+ spec = cached_device_spec
cached_function = _cached_device_functions.get(spec, None)
if cached_function is not None:
return cached_function
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index c3f70df7d8..64d3b42d89 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -26,7 +26,7 @@ from tensorflow.python.util.tf_export import tf_export
_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
-@tf_export("DType")
+@tf_export("dtypes.DType", "DType")
class DType(object):
"""Represents the type of the elements in a `Tensor`.
@@ -658,7 +658,7 @@ _PYTHON_TO_TF = {
}
-@tf_export("as_dtype")
+@tf_export("dtypes.as_dtype", "as_dtype")
def as_dtype(type_value):
"""Converts the given `type_value` to a `DType`.
diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py
index 5af71f2cfb..8b303fa8a9 100644
--- a/tensorflow/python/framework/errors_impl.py
+++ b/tensorflow/python/framework/errors_impl.py
@@ -25,11 +25,13 @@ from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.framework import c_api_util
from tensorflow.python.util import compat
+from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
-@tf_export("OpError", "errors.OpError")
+@tf_export("errors.OpError", "OpError")
+@deprecation.deprecated_endpoints("OpError")
class OpError(Exception):
"""A generic error that is raised when TensorFlow execution fails.
@@ -72,7 +74,7 @@ class OpError(Exception):
or `Recv` op, there will be no corresponding
`tf.Operation`
object. In that case, this will return `None`, and you should
- instead use the `tf.OpError.node_def` to
+ instead use the `tf.errors.OpError.node_def` to
discover information about the op.
Returns:
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index f287289bd0..225208944e 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -134,7 +134,7 @@ class Defun(object):
# Func should not use kwargs and defaults.
argspec = tf_inspect.getargspec(func)
if argspec.keywords or argspec.defaults:
- raise ValueError("Functions with argument defaults or keyword "
+ raise ValueError("Functions with argument defaults or keywords "
"arguments are not supported.")
# Computes how many arguments 'func' has.
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index f740e5cfaa..87f567db0e 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -113,7 +113,7 @@ class FunctionTest(test.TestCase):
return a
with ops.Graph().as_default():
- var = variables.Variable([18.0])
+ var = variables.VariableV1([18.0])
call = MyIdentityFunc(var._ref()) # pylint: disable=protected-access
self.assertEqual("MyIdentity", call.op.name)
for cfg in _OptimizerOptions():
diff --git a/tensorflow/python/framework/graph_io.py b/tensorflow/python/framework/graph_io.py
index be30b16f5f..47e1344eae 100644
--- a/tensorflow/python/framework/graph_io.py
+++ b/tensorflow/python/framework/graph_io.py
@@ -27,7 +27,7 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.util.tf_export import tf_export
-@tf_export('train.write_graph')
+@tf_export('io.write_graph', 'train.write_graph')
def write_graph(graph_or_graph_def, logdir, name, as_text=True):
"""Writes a graph proto to a file.
diff --git a/tensorflow/python/framework/graph_util_test.py b/tensorflow/python/framework/graph_util_test.py
index 2dafb94ba7..563a177dd0 100644
--- a/tensorflow/python/framework/graph_util_test.py
+++ b/tensorflow/python/framework/graph_util_test.py
@@ -104,13 +104,13 @@ class DeviceFunctionsTest(test.TestCase):
def testNestedDeviceFunctions(self):
with ops.Graph().as_default():
- var_0 = variables.Variable(0)
+ var_0 = variables.VariableV1(0)
with ops.device(test_device_func_pin_variable_to_cpu):
- var_1 = variables.Variable(1)
+ var_1 = variables.VariableV1(1)
with ops.device(lambda op: "/device:GPU:0"):
- var_2 = variables.Variable(2)
+ var_2 = variables.VariableV1(2)
with ops.device("/device:GPU:0"): # Implicit merging device function.
- var_3 = variables.Variable(3)
+ var_3 = variables.VariableV1(3)
self.assertDeviceEqual(var_0.device, None)
self.assertDeviceEqual(var_1.device, "/device:CPU:0")
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index e48e67c8a1..c9ac27e788 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -329,7 +329,7 @@ def _SetDefaultAttrValues(node_def, op_def):
node_def.attr[key].CopyFrom(attr_def.default_value)
-@tf_export('import_graph_def')
+@tf_export('graph_util.import_graph_def', 'import_graph_def')
@deprecated_args(None, 'Please file an issue at '
'https://github.com/tensorflow/tensorflow/issues if you depend'
' on this feature.', 'op_dict')
@@ -370,7 +370,8 @@ def import_graph_def(graph_def,
Returns:
A list of `Operation` and/or `Tensor` objects from the imported graph,
- corresponding to the names in `return_elements`.
+ corresponding to the names in `return_elements`,
+ and None if `returns_elements` is None.
Raises:
TypeError: If `graph_def` is not a `GraphDef` proto,
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 8bb177939e..77c2bc930e 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -4140,10 +4140,7 @@ class Graph(object):
if op is None and not ignore_existing:
raise ValueError("Trying to reset colocation (op is None) but "
"ignore_existing is not True")
-
- if op is not None and not isinstance(op, Operation):
- # We always want to colocate with the reference op.
- op = internal_convert_to_tensor_or_indexed_slices(op, as_ref=True).op
+ op = _op_to_colocate_with(op)
# By default, colocate_with resets the device function stack,
# since colocate_with is typically used in specific internal
@@ -6168,4 +6165,27 @@ def _operation_conversion_error(op, dtype=None, name=None, as_ref=False):
name, as_ref))
+def _op_to_colocate_with(v):
+ """Operation object corresponding to v to use for colocation constraints."""
+ if v is None:
+ return None
+ if isinstance(v, Operation):
+ return v
+ # We always want to colocate with the reference op.
+ # When 'v' is a ResourceVariable, the reference op is the handle creating op.
+ #
+ # What this should be is:
+ # if isinstance(v, ResourceVariable):
+ # return v.handle.op
+ # However, that would require a circular import dependency.
+ # As of October 2018, there were attempts underway to remove
+ # colocation constraints altogether. Assuming that will
+ # happen soon, perhaps this hack to work around the circular
+ # import dependency is acceptable.
+ if hasattr(v, "handle") and hasattr(v.handle, "op") and isinstance(
+ v.handle.op, Operation):
+ return v.handle.op
+ return internal_convert_to_tensor_or_indexed_slices(v, as_ref=True).op
+
+
register_tensor_conversion_function(Operation, _operation_conversion_error)
diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py
index 2f9504889a..6f9f347a99 100644
--- a/tensorflow/python/framework/random_seed.py
+++ b/tensorflow/python/framework/random_seed.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -33,7 +34,8 @@ def _truncate_seed(seed):
return seed % _MAXINT32 # Truncate to fit into 32-bit integer
-@tf_export('get_seed')
+@tf_export('random.get_seed', 'get_seed')
+@deprecation.deprecated_endpoints('get_seed')
def get_seed(op_seed):
"""Returns the local seeds an operation should use given an op-specific seed.
@@ -80,7 +82,7 @@ def get_seed(op_seed):
return seeds
-@tf_export('set_random_seed')
+@tf_export('random.set_random_seed', 'set_random_seed')
def set_random_seed(seed):
"""Sets the graph-level random seed.
diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py
index d1bdd9b80a..440e3a0968 100644
--- a/tensorflow/python/framework/sparse_tensor.py
+++ b/tensorflow/python/framework/sparse_tensor.py
@@ -33,7 +33,7 @@ _override_helper = ops._override_helper
# pylint: enable=protected-access
-@tf_export("SparseTensor")
+@tf_export("sparse.SparseTensor", "SparseTensor")
class SparseTensor(_TensorLike):
"""Represents a sparse tensor.
@@ -245,7 +245,7 @@ class SparseTensor(_TensorLike):
SparseTensorValue = collections.namedtuple(
"SparseTensorValue", ["indices", "values", "dense_shape"])
tf_export("SparseTensorValue")(SparseTensorValue)
-pywrap_tensorflow.RegisterSparseTensorValueClass(SparseTensorValue)
+pywrap_tensorflow.RegisterType("SparseTensorValue", SparseTensorValue)
@tf_export("convert_to_tensor_or_sparse_tensor")
diff --git a/tensorflow/python/framework/subscribe_test.py b/tensorflow/python/framework/subscribe_test.py
index 1d594e4078..cab426844d 100644
--- a/tensorflow/python/framework/subscribe_test.py
+++ b/tensorflow/python/framework/subscribe_test.py
@@ -212,8 +212,8 @@ class SubscribeTest(test_util.TensorFlowTestCase):
def testSubscribeVariable(self):
"""Confirm that variables can be subscribed."""
- v1 = variables.Variable(0.0)
- v2 = variables.Variable(4.0)
+ v1 = variables.VariableV1(0.0)
+ v2 = variables.VariableV1(4.0)
add = math_ops.add(v1, v2)
assign_v1 = v1.assign(3.0)
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index cd0b03be43..95925bb471 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -24,8 +24,8 @@ from collections import OrderedDict
import contextlib
import gc
import itertools
-import os
import math
+import os
import random
import re
import tempfile
@@ -402,11 +402,14 @@ def with_c_shapes(cls):
return cls
-def enable_cond_v2(fn):
- """Decorator for enabling CondV2 on a test.
+def enable_control_flow_v2(fn):
+ """Decorator for enabling CondV2 and WhileV2 on a test.
- Note this enables using CondV2 after running the test class's setup/teardown
- methods.
+ Note this enables using CondV2 and WhileV2 after running the test class's
+ setup/teardown methods.
+
+ In addition to this, callers must import the while_v2 module in order to set
+ the _while_v2 module in control_flow_ops.
Args:
fn: the function to be wrapped
@@ -416,21 +419,56 @@ def enable_cond_v2(fn):
"""
def wrapper(*args, **kwargs):
- prev_value = control_flow_ops.ENABLE_COND_V2
+ enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2
+ enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2
control_flow_ops.ENABLE_COND_V2 = True
+ control_flow_ops.ENABLE_WHILE_V2 = True
try:
fn(*args, **kwargs)
finally:
- control_flow_ops.ENABLE_COND_V2 = prev_value
+ control_flow_ops.ENABLE_COND_V2 = enable_cond_v2_old
+ control_flow_ops.ENABLE_WHILE_V2 = enable_while_v2_old
return wrapper
-def with_cond_v2(cls):
- """Adds methods that call original methods but with CondV2 enabled.
+def with_control_flow_v2(cls):
+ """Adds methods that call original methods with WhileV2 and CondV2 enabled.
- Note this enables CondV2 in new methods after running the test class's
- setup method.
+ Note this enables CondV2 and WhileV2 in new methods after running the test
+ class's setup method.
+
+ In addition to this, callers must import the while_v2 module in order to set
+ the _while_v2 module in control_flow_ops.
+
+ If a test function has _disable_control_flow_v2 attr set to True (using the
+ @disable_control_flow_v2 decorator), the v2 function is not generated for it.
+
+ Example:
+
+ @test_util.with_control_flow_v2
+ class ControlFlowTest(test.TestCase):
+
+ def testEnabledForV2(self):
+ ...
+
+ @test_util.disable_control_flow_v2("b/xyzabc")
+ def testDisabledForV2(self):
+ ...
+
+ Generated class:
+ class ControlFlowTest(test.TestCase):
+
+ def testEnabledForV2(self):
+ ...
+
+ def testEnabledForV2WithControlFlowV2(self):
+ // Enable V2 flags.
+ testEnabledForV2(self)
+ // Restore V2 flags.
+
+ def testDisabledForV2(self):
+ ...
Args:
cls: class to decorate
@@ -438,21 +476,39 @@ def with_cond_v2(cls):
Returns:
cls with new test methods added
"""
- if control_flow_ops.ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2:
return cls
for name, value in cls.__dict__.copy().items():
- if callable(value) and name.startswith("test"):
- setattr(cls, name + "WithCondV2", enable_cond_v2(value))
+ if (callable(value) and name.startswith("test") and
+ not getattr(value, "_disable_control_flow_v2", False)):
+ setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value))
return cls
+def disable_control_flow_v2(unused_msg):
+ """Decorator for a function in a with_control_flow_v2 enabled test class.
+
+ Blocks the function from being run with v2 control flow ops.
+
+ Args:
+ unused_msg: Reason for disabling.
+
+ Returns:
+ The wrapped function with _disable_control_flow_v2 attr set to True.
+ """
+ def wrapper(func):
+ func._disable_control_flow_v2 = True
+ return func
+ return wrapper
+
+
def assert_no_new_pyobjects_executing_eagerly(f):
"""Decorator for asserting that no new Python objects persist after a test.
- Runs the test multiple times executing eagerly, first as a warmup and then
- several times to let objects accumulate. The warmup helps ignore caches which
- do not grow as the test is run repeatedly.
+ Runs the test multiple times executing eagerly, first as a warmup and then to
+ let objects accumulate. The warmup helps ignore caches which do not grow as
+ the test is run repeatedly.
Useful for checking that there are no missing Py_DECREFs in the C exercised by
a bit of Python.
@@ -462,7 +518,14 @@ def assert_no_new_pyobjects_executing_eagerly(f):
"""Warms up, gets an object count, runs the test, checks for new objects."""
with context.eager_mode():
gc.disable()
- f(self, **kwargs)
+ # Run the test 2 times as warmup, in an attempt to fill up caches, which
+ # should not grow as the test is run repeatedly below.
+ #
+ # TODO(b/117156879): Running warmup twice is black magic; we have seen
+ # tests that fail with 1 warmup run, and pass with 2, on various versions
+ # of python2.7.x.
+ for _ in range(2):
+ f(self, **kwargs)
gc.collect()
previous_count = len(gc.get_objects())
if ops.has_default_graph():
@@ -1936,10 +1999,12 @@ class TensorFlowTestCase(googletest.TestCase):
# Don't perform optimizations for tests so we don't inadvertently run
# gpu ops on cpu
config.graph_options.optimizer_options.opt_level = -1
+ # Disable Grappler constant folding since some tests & benchmarks
+ # use constant input and become meaningless after constant folding.
+ # DO NOT DISABLE GRAPPLER OPTIMIZERS WITHOUT CONSULTING WITH THE
+ # GRAPPLER TEAM.
config.graph_options.rewrite_options.constant_folding = (
rewriter_config_pb2.RewriterConfig.OFF)
- config.graph_options.rewrite_options.arithmetic_optimization = (
- rewriter_config_pb2.RewriterConfig.OFF)
config.graph_options.rewrite_options.pin_to_host_optimization = (
rewriter_config_pb2.RewriterConfig.OFF)
return config
diff --git a/tensorflow/python/grappler/item_test.py b/tensorflow/python/grappler/item_test.py
index c40de9da0a..d3d96c646c 100644
--- a/tensorflow/python/grappler/item_test.py
+++ b/tensorflow/python/grappler/item_test.py
@@ -110,7 +110,7 @@ class ItemTest(test.TestCase):
def testColocationContraints(self):
with ops.Graph().as_default() as g:
c = constant_op.constant([10])
- v = variables.Variable([3], dtype=dtypes.int32)
+ v = variables.VariableV1([3], dtype=dtypes.int32)
i = gen_array_ops.ref_identity(v)
a = state_ops.assign(i, c)
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
diff --git a/tensorflow/python/grappler/memory_optimizer_test.py b/tensorflow/python/grappler/memory_optimizer_test.py
index b658edff2d..03b42f6453 100644
--- a/tensorflow/python/grappler/memory_optimizer_test.py
+++ b/tensorflow/python/grappler/memory_optimizer_test.py
@@ -39,8 +39,8 @@ class MemoryOptimizerSwapTest(test.TestCase):
def testNoSwapping(self):
"""Make sure the graph is preserved when there is nothing to swap."""
- a = variables.Variable(10, name='a')
- b = variables.Variable(20, name='b')
+ a = variables.VariableV1(10, name='a')
+ b = variables.VariableV1(20, name='b')
c = math_ops.add_n([a, b], name='c')
d = math_ops.add_n([b, c], name='d')
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
@@ -60,8 +60,8 @@ class MemoryOptimizerSwapTest(test.TestCase):
def testSimpleSwap(self):
"""Check that the swap annotations are followed."""
- a = variables.Variable(10, name='a')
- b = variables.Variable(20, name='b')
+ a = variables.VariableV1(10, name='a')
+ b = variables.VariableV1(20, name='b')
c = math_ops.add_n([a, b], name='c')
d = math_ops.add_n([b, c], name='d')
train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
@@ -244,7 +244,7 @@ class MemoryOptimizerRecomputeTest(test.TestCase):
init_op_name=init_op_name,
train_op_name=train_op_name,
loss_op_name=loss_op_name)
- self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-4)
+ self.assertAllClose(original_loss, memory_optimized_loss, rtol=1e-2)
def _annotated_graph(self):
graph = ops.Graph()
diff --git a/tensorflow/python/grappler/tf_optimizer_test.py b/tensorflow/python/grappler/tf_optimizer_test.py
index 5a9afe7257..eca0f67982 100644
--- a/tensorflow/python/grappler/tf_optimizer_test.py
+++ b/tensorflow/python/grappler/tf_optimizer_test.py
@@ -57,7 +57,7 @@ class PyWrapOptimizeGraphTest(test.TestCase):
def testKeepNodes(self):
g = ops.Graph()
with g.as_default():
- a1 = variables.Variable(
+ a1 = variables.VariableV1(
1.0) # Must be preserved since it's in the collection 'variables'.
a2 = constant_op.constant(0, shape=[50, 50], name='keep')
ops.add_to_collection('a2', a2) # Explicitly add to collection.
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index ac011a2940..c4d23f117f 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -7,7 +7,6 @@ exports_files(["LICENSE"])
package(default_visibility = ["//visibility:public"])
-load("@pip_deps//:requirements.bzl", "requirement")
load("//tensorflow:tensorflow.bzl", "py_test")
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
@@ -63,7 +62,7 @@ py_library(
":backend",
":engine",
":layers",
- requirement("keras_applications"),
+ ":optimizer_v2",
"//tensorflow/python/saved_model",
"//tensorflow/python:training",
],
@@ -191,6 +190,30 @@ py_library(
],
)
+py_library(
+ name = "optimizer_v2",
+ srcs = [
+ "optimizer_v2/adadelta.py",
+ "optimizer_v2/adagrad.py",
+ "optimizer_v2/adam.py",
+ "optimizer_v2/optimizer_v2.py",
+ "optimizer_v2/rmsprop.py",
+ "optimizer_v2/sgd.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:distribute",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ ],
+)
+
py_test(
name = "integration_test",
size = "medium",
@@ -829,3 +852,133 @@ py_library(
"//third_party/py/numpy",
],
)
+
+cuda_py_test(
+ name = "adadelta_test",
+ size = "medium",
+ srcs = ["optimizer_v2/adadelta_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "adagrad_test",
+ size = "small",
+ srcs = ["optimizer_v2/adagrad_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "adam_test",
+ size = "small",
+ srcs = ["optimizer_v2/adam_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "checkpointable_utils_test",
+ srcs = ["optimizer_v2/checkpointable_utils_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "@six_archive//:six",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:layers",
+ "//tensorflow/python:layers_base",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/eager:test",
+ "//tensorflow/python/keras",
+ ],
+ tags = ["notsan"],
+)
+
+cuda_py_test(
+ name = "sgd_test",
+ size = "medium",
+ srcs = ["optimizer_v2/sgd_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:resources",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+cuda_py_test(
+ name = "optimizer_v2_test",
+ size = "medium",
+ srcs = ["optimizer_v2/optimizer_v2_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:clip_ops",
+ "//tensorflow/python:gradients",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:variables",
+ ],
+)
+
+cuda_py_test(
+ name = "rmsprop_test",
+ size = "small",
+ srcs = ["optimizer_v2/rmsprop_test.py"],
+ additional_deps = [
+ ":optimizer_v2",
+ "@absl_py//absl/testing:parameterized",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+ tags = ["optonly"],
+)
diff --git a/tensorflow/python/keras/activations.py b/tensorflow/python/keras/activations.py
index 99645de736..d69791ce8d 100644
--- a/tensorflow/python/keras/activations.py
+++ b/tensorflow/python/keras/activations.py
@@ -160,6 +160,11 @@ def sigmoid(x):
return nn.sigmoid(x)
+@tf_export('keras.activations.exponential')
+def exponential(x):
+ return math_ops.exp(x)
+
+
@tf_export('keras.activations.hard_sigmoid')
def hard_sigmoid(x):
"""Hard sigmoid activation function.
diff --git a/tensorflow/python/keras/activations_test.py b/tensorflow/python/keras/activations_test.py
index dd0bbcff39..ad238cb0a9 100644
--- a/tensorflow/python/keras/activations_test.py
+++ b/tensorflow/python/keras/activations_test.py
@@ -169,6 +169,16 @@ class KerasActivationsTest(test.TestCase):
expected = np.tanh(test_values)
self.assertAllClose(result, expected, rtol=1e-05)
+ def test_exponential(self):
+ with self.cached_session():
+ test_values = np.random.random((2, 5))
+ x = keras.backend.placeholder(ndim=2)
+ exp = keras.activations.exponential(x)
+ f = keras.backend.function([x], [exp])
+ result = f([test_values])[0]
+ expected = np.exp(test_values)
+ self.assertAllClose(result, expected, rtol=1e-05)
+
def test_linear(self):
x = np.random.random((10, 5))
self.assertAllClose(x, keras.activations.linear(x))
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 4589c821e5..13f52fbae7 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -653,6 +653,7 @@ def variable(value, dtype=None, name=None, constraint=None):
Examples:
```python
+ >>> import numpy as np
>>> from keras import backend as K
>>> val = np.array([[1, 2], [3, 4]])
>>> kvar = K.variable(value=val, dtype='float64', name='example_var')
@@ -773,6 +774,8 @@ def is_keras_tensor(x):
Examples:
```python
+ >>> import tensorflow as tf
+ >>> import numpy
>>> from keras import backend as K
>>> from keras.layers import Input, Dense
>>> np_var = numpy.array([1, 2])
@@ -1511,12 +1514,8 @@ def batch_dot(x, y, axes=None):
out = math_ops.reduce_sum(
math_ops.multiply(array_ops.transpose(x, [1, 0]), y), axes[1])
else:
- if axes is not None:
- adj_x = None if axes[0] == ndim(x) - 1 else True
- adj_y = True if axes[1] == ndim(y) - 1 else None
- else:
- adj_x = None
- adj_y = None
+ adj_x = None if axes[0] == ndim(x) - 1 else True
+ adj_y = True if axes[1] == ndim(y) - 1 else None
out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
if diff:
if x_ndim > y_ndim:
@@ -2224,7 +2223,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
@tf_export('keras.backend.batch_normalization')
-def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
+def batch_normalization(x, mean, var, beta, gamma, axis=-1, epsilon=1e-3):
"""Applies batch normalization on x given mean, var, beta and gamma.
I.e. returns:
@@ -2236,11 +2235,49 @@ def batch_normalization(x, mean, var, beta, gamma, epsilon=1e-3):
var: Variance of batch.
beta: Tensor with which to center the input.
gamma: Tensor by which to scale the input.
+ axis: Integer, the axis that should be normalized.
+ (typically the features axis).
epsilon: Fuzz factor.
Returns:
A tensor.
"""
+ if ndim(x) == 4:
+ # The CPU implementation of `fused_batch_norm` only supports NHWC
+ if axis == 1 or axis == -3:
+ tf_data_format = 'NCHW'
+ elif axis == 3 or axis == -1:
+ tf_data_format = 'NHWC'
+ else:
+ tf_data_format = None
+
+ if (tf_data_format == 'NHWC' or
+ tf_data_format == 'NCHW' and _has_nchw_support()):
+ # The mean / var / beta / gamma tensors may be broadcasted
+ # so they may have extra axes of size 1, which should be squeezed.
+ if ndim(mean) > 1:
+ mean = array_ops.reshape(mean, [-1])
+ if ndim(var) > 1:
+ var = array_ops.reshape(var, [-1])
+ if beta is None:
+ beta = zeros_like(mean)
+ elif ndim(beta) > 1:
+ beta = array_ops.reshape(beta, [-1])
+ if gamma is None:
+ gamma = ones_like(mean)
+ elif ndim(gamma) > 1:
+ gamma = array_ops.reshape(gamma, [-1])
+ y, _, _ = nn.fused_batch_norm(
+ x,
+ gamma,
+ beta,
+ epsilon=epsilon,
+ mean=mean,
+ variance=var,
+ data_format=tf_data_format,
+ is_training=False
+ )
+ return y
return nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
@@ -2881,7 +2918,7 @@ class Function(object):
if session_kwargs:
raise ValueError('Some keys in session_kwargs are not supported at this '
- 'time: %s', session_kwargs.keys())
+ 'time: %s', (session_kwargs.keys(),))
self._callable_fn = None
self._feed_arrays = None
@@ -3062,7 +3099,8 @@ def rnn(step_function,
mask=None,
constants=None,
unroll=False,
- input_length=None):
+ input_length=None,
+ time_major=False):
"""Iterates over the time dimension of a tensor.
Arguments:
@@ -3091,6 +3129,13 @@ def rnn(step_function,
constants: List of constant values passed at each step.
unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
input_length: If specified, assume time dimension is of this length.
+ time_major: Boolean. If true, the inputs and outputs will be in shape
+ `(timesteps, batch, ...)`, whereas in the False case, it will be
+ `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
+ efficient because it avoids transposes at the beginning and end of the
+ RNN calculation. However, most TensorFlow data is batch-major, so by
+ default this function accepts input and emits output in batch-major
+ form.
Returns:
A tuple, `(last_output, outputs, new_states)`.
@@ -3112,15 +3157,17 @@ def rnn(step_function,
if ndim < 3:
raise ValueError('Input should be at least 3D.')
inputs_shape = inputs.shape
- axes = [1, 0] + list(range(2, ndim))
- inputs = array_ops.transpose(inputs, (axes))
+ if not time_major:
+ axes = [1, 0] + list(range(2, ndim))
+ inputs = array_ops.transpose(inputs, axes)
if mask is not None:
if mask.dtype != dtypes_module.bool:
mask = math_ops.cast(mask, dtypes_module.bool)
if len(mask.shape) == ndim - 1:
mask = expand_dims(mask)
- mask = array_ops.transpose(mask, axes)
+ if not time_major:
+ mask = array_ops.transpose(mask, axes)
if constants is None:
constants = []
@@ -3301,10 +3348,11 @@ def rnn(step_function,
outputs = output_ta.stack()
last_output = output_ta.read(last_time - 1)
- axes = [1, 0] + list(range(2, len(outputs.shape)))
- outputs = array_ops.transpose(outputs, axes)
+ if not time_major:
+ axes = [1, 0] + list(range(2, len(outputs.shape)))
+ outputs = array_ops.transpose(outputs, axes)
- # Static shape inference: (samples, time, ...)
+ # Static shape inference: (samples, time, ...) or (time, sample, ...)
outputs_shape = outputs.shape.as_list()
outputs_shape[0] = inputs_shape[0]
outputs_shape[1] = inputs_shape[1]
@@ -3788,19 +3836,23 @@ def _preprocess_conv1d_input(x, data_format):
return x, tf_data_format
-def _preprocess_conv2d_input(x, data_format):
+def _preprocess_conv2d_input(x, data_format, force_transpose=False):
"""Transpose and cast the input before the conv2d.
Arguments:
x: input tensor.
data_format: string, `"channels_last"` or `"channels_first"`.
+ force_transpose: Boolean. If True, the input will always be transposed
+ from NCHW to NHWC if `data_format` is `"channels_first"`.
+ If False, the transposition only occurs on CPU (GPU ops are
+ assumed to support NCHW).
Returns:
A tensor.
"""
tf_data_format = 'NHWC'
if data_format == 'channels_first':
- if not _has_nchw_support():
+ if not _has_nchw_support() or force_transpose:
x = array_ops.transpose(x, (0, 2, 3, 1)) # NCHW -> NHWC
else:
tf_data_format = 'NCHW'
@@ -3948,7 +4000,8 @@ def conv2d_transpose(x,
output_shape,
strides=(1, 1),
padding='valid',
- data_format=None):
+ data_format=None,
+ dilation_rate=(1, 1)):
"""2D deconvolution (i.e.
transposed convolution).
@@ -3962,6 +4015,7 @@ def conv2d_transpose(x,
data_format: string, `"channels_last"` or `"channels_first"`.
Whether to use Theano or TensorFlow/CNTK data format
for inputs/kernels/outputs.
+ dilation_rate: Tuple of 2 integers.
Returns:
A tensor, result of transposed 2D convolution.
@@ -3977,7 +4031,13 @@ def conv2d_transpose(x,
if isinstance(output_shape, (tuple, list)):
output_shape = array_ops.stack(output_shape)
- x, tf_data_format = _preprocess_conv2d_input(x, data_format)
+ # `atrous_conv2d_transpose` only supports NHWC format, even on GPU.
+ if data_format == 'channels_first' and dilation_rate != (1, 1):
+ force_transpose = True
+ else:
+ force_transpose = False
+
+ x, tf_data_format = _preprocess_conv2d_input(x, data_format, force_transpose)
if data_format == 'channels_first' and tf_data_format == 'NHWC':
output_shape = (output_shape[0], output_shape[2], output_shape[3],
@@ -3992,13 +4052,18 @@ def conv2d_transpose(x,
else:
strides = (1, 1) + strides
- x = nn.conv2d_transpose(
- x,
- kernel,
- output_shape,
- strides,
- padding=padding,
- data_format=tf_data_format)
+ if dilation_rate == (1, 1):
+ x = nn.conv2d_transpose(x, kernel, output_shape, strides,
+ padding=padding,
+ data_format=tf_data_format)
+ else:
+ assert dilation_rate[0] == dilation_rate[1]
+ x = nn.atrous_conv2d_transpose(
+ x,
+ kernel,
+ output_shape,
+ rate=dilation_rate[0],
+ padding=padding)
if data_format == 'channels_first' and tf_data_format == 'NHWC':
x = array_ops.transpose(x, (0, 3, 1, 2)) # NHWC -> NCHW
return x
diff --git a/tensorflow/python/keras/backend_test.py b/tensorflow/python/keras/backend_test.py
index ab71589940..0834448699 100644
--- a/tensorflow/python/keras/backend_test.py
+++ b/tensorflow/python/keras/backend_test.py
@@ -26,6 +26,7 @@ from tensorflow.python import keras
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import nn
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.util import tf_inspect
@@ -1381,6 +1382,36 @@ class BackendNNOpsTest(test.TestCase, parameterized.TestCase):
self.assertEqual(mean.get_shape().as_list(), [3,])
self.assertEqual(var.get_shape().as_list(), [3,])
+ def test_batch_normalization(self):
+ g_val = np.random.random((3,))
+ b_val = np.random.random((3,))
+ gamma = keras.backend.variable(g_val)
+ beta = keras.backend.variable(b_val)
+
+ # 3D NHC case
+ val = np.random.random((10, 5, 3))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 1), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 5, 3])
+
+ # 4D NHWC case
+ val = np.random.random((10, 5, 5, 3))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 1, 2), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=-1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 5, 5, 3])
+
+ # 4D NCHW case
+ val = np.random.random((10, 3, 5, 5))
+ x = keras.backend.variable(val)
+ mean, var = nn.moments(x, (0, 2, 3), None, None, False)
+ normed = keras.backend.batch_normalization(
+ x, mean, var, beta, gamma, axis=1, epsilon=1e-3)
+ self.assertEqual(normed.shape.as_list(), [10, 3, 5, 5])
+
class TestCTC(test.TestCase):
@@ -1506,12 +1537,13 @@ class TestRandomOps(test.TestCase):
self.assertAllClose(np.min(y), -2., atol=0.1)
def test_string_input(self):
- seq = keras.Sequential([
- keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string),
- keras.layers.Lambda(lambda x: x[0])
- ])
- preds = seq.predict([['tensorflow eager']])
- self.assertEqual(preds.shape, (1,))
+ with self.cached_session():
+ seq = keras.Sequential([
+ keras.layers.InputLayer(input_shape=(1,), dtype=dtypes.string),
+ keras.layers.Lambda(lambda x: x[0])
+ ])
+ preds = seq.predict([['tensorflow eager']])
+ self.assertEqual(preds.shape, (1,))
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/callbacks.py b/tensorflow/python/keras/callbacks.py
index 6dfbbf3694..3d6000f223 100644
--- a/tensorflow/python/keras/callbacks.py
+++ b/tensorflow/python/keras/callbacks.py
@@ -781,6 +781,10 @@ class LearningRateScheduler(Callback):
print('\nEpoch %05d: LearningRateScheduler reducing learning '
'rate to %s.' % (epoch + 1, lr))
+ def on_epoch_end(self, epoch, logs=None):
+ logs = logs or {}
+ logs['lr'] = K.get_value(self.model.optimizer.lr)
+
@tf_export('keras.callbacks.TensorBoard')
class TensorBoard(Callback):
diff --git a/tensorflow/python/keras/callbacks_test.py b/tensorflow/python/keras/callbacks_test.py
index b6fae19823..467bc4cdc4 100644
--- a/tensorflow/python/keras/callbacks_test.py
+++ b/tensorflow/python/keras/callbacks_test.py
@@ -30,6 +30,7 @@ import numpy as np
from tensorflow.core.framework import summary_pb2
from tensorflow.python import keras
+from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
@@ -1222,6 +1223,45 @@ class KerasCallbacksTest(test.TestCase):
callbacks=cbks,
epochs=1)
+ def test_fit_generator_with_callback(self):
+
+ class TestCallback(keras.callbacks.Callback):
+ def set_model(self, model):
+ # Check the model operations for the optimizer operations that
+ # the _make_train_function adds under a named scope for the
+ # optimizer. This ensurs the full model is populated before the
+ # set_model callback is called.
+ optimizer_name_scope = 'training/' + model.optimizer.__class__.__name__
+ graph_def = ops.get_default_graph().as_graph_def()
+ for node in graph_def.node:
+ if node.name.startswith(optimizer_name_scope):
+ return
+ raise RuntimeError('The optimizer operations are not present in the '
+ 'model graph when the Callback.set_model function '
+ 'is called')
+ np.random.seed(1337)
+
+ def generator():
+ x = np.random.randn(10, 100).astype(np.float32)
+ y = np.random.randn(10, 10).astype(np.float32)
+ while True:
+ yield x, y
+
+ with self.cached_session():
+ model = testing_utils.get_small_sequential_mlp(
+ num_hidden=10, num_classes=10, input_dim=100)
+ model.compile(
+ loss='categorical_crossentropy',
+ optimizer='sgd',
+ metrics=['accuracy'])
+ model.fit_generator(
+ generator(),
+ steps_per_epoch=2,
+ epochs=1,
+ validation_data=generator(),
+ validation_steps=2,
+ callbacks=[TestCallback()],
+ verbose=0)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index cb19a412a2..a75ce30d31 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import collections as collections_lib
import enum # pylint: disable=g-bad-import-order
+import functools
import inspect # Necessary supplement to tf_inspect to deal with variadic args.
import numpy as np
@@ -160,9 +161,13 @@ class Layer(checkpointable.CheckpointableBase):
self._trainable_weights = []
self._non_trainable_weights = []
self._updates = []
- # When executing eagerly, _losses is a list of zero-argument lambdas which
- # return tensors. When using graph execution, _losses is a list of ops.
+ # A list of zero-argument lambdas which return Tensors, used for variable
+ # regularizers.
+ self._callable_losses = []
+ # A list of Tensors containing activity regularizers and losses manually
+ # added through `add_loss`. Empty when executing eagerly.
self._losses = []
+ self._in_call = False # Flag for error checking in add_loss
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
self._call_fn_args = function_utils.fn_args(self.call)
self._compute_previous_mask = ('mask' in self._call_fn_args or
@@ -359,20 +364,20 @@ class Layer(checkpointable.CheckpointableBase):
def losses(self):
"""Losses which are associated with this `Layer`.
- Note that when executing eagerly, getting this property evaluates
- regularizers. When using graph execution, variable regularization ops have
- already been created and are simply returned here.
+ Variable regularization tensors are created when this property is accessed,
+ so it is eager safe: accessing `losses` under a `tf.GradientTape` will
+ propagate gradients back to the corresponding variables.
Returns:
A list of tensors.
"""
- if context.executing_eagerly():
- # _losses may only contain variable regularization losses when executing
- # eagerly, and they have been saved as lambdas to be executed when
- # requested.
- return [regularizer() for regularizer in self._losses]
- else:
- return self._losses
+ collected_losses = []
+ collected_losses.extend(self._losses)
+ for regularizer in self._callable_losses:
+ loss_tensor = regularizer()
+ if loss_tensor is not None:
+ collected_losses.append(loss_tensor)
+ return collected_losses
@doc_controls.for_subclass_implementers
def add_loss(self, losses, inputs=None):
@@ -393,7 +398,9 @@ class Layer(checkpointable.CheckpointableBase):
from `Layer.call()`).
Arguments:
- losses: Loss tensor, or list/tuple of tensors.
+ losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
+ may also be zero-argument callables which create a loss tensor. Only
+ callable losses are supported when executing eagerly.
inputs: If anything other than None is passed, it signals the losses
are conditional on some of the layer's inputs,
and thus they should only be run where these inputs are available.
@@ -403,29 +410,45 @@ class Layer(checkpointable.CheckpointableBase):
(e.g. weight regularization losses).
Raises:
- RuntimeError: If called in Eager mode.
+ RuntimeError: If called in Eager mode with a `Tensor` rather than a
+ callable, or if `inputs` is not None.
"""
- if context.executing_eagerly():
- # TODO(fchollet): it should be possible (and highly desirable) to support
- # `add_loss` in eager mode. This allows great convenience and flexibility
- # in defining custom losses on the fly (e.g. in VAEs).
- # Simply appending the loss value to `self._losses`
- # is the correct behavior.
- # The only caveat is that we need to force the user to only call
- # `add_loss` from inside a model or Layer's `call` method
- # (otherwise the loss computation cannot be backproped through).
- raise RuntimeError('Layer.add_loss not supported in Eager mode.')
-
+ executing_eagerly = context.executing_eagerly()
+ if executing_eagerly:
+ if inputs is not None:
+ raise RuntimeError(
+ 'Activity regularization (via the "inputs" argument to '
+ 'Layer.add_loss) is not supported when executing eagerly. Consider '
+ 'returning activity regularization losses from a Model\'s call() '
+ 'method.')
+ if getattr(self, '_in_call', False):
+ # TODO(psv): Support activity regularization and a way to reset losses.
+ raise RuntimeError(
+ 'Adding losses inside a Layer\'s call() method is not currently '
+ 'supported when executing eagerly. Please file a feature request '
+ 'if you need this limitation lifted.')
losses = generic_utils.to_list(losses)
- losses = [ops.convert_to_tensor(loss, dtype=backend.floatx())
- if not tensor_util.is_tensor(loss) else loss for loss in losses]
- self._losses += losses
- if inputs is None:
- for loss in losses:
- loss._unconditional_loss = True # pylint: disable=protected-access
- else:
- for loss in losses:
- loss._unconditional_loss = False # pylint: disable=protected-access
+
+ def _tag_unconditional(loss):
+ if callable(loss):
+ loss = loss()
+ if loss is None:
+ return None # Will be filtered out when computing the .losses property
+ if not tensor_util.is_tensor(loss):
+ loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
+ loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access
+ return loss
+
+ for loss in losses:
+ if callable(loss):
+ self._callable_losses.append(
+ functools.partial(_tag_unconditional, loss))
+ else:
+ if executing_eagerly:
+ raise RuntimeError(
+ 'Layer.add_loss only supported for zero-argument lambdas when '
+ 'executing eagerly.')
+ self._losses.append(_tag_unconditional(loss))
def get_losses_for(self, inputs):
"""Retrieves losses relevant to a specific set of inputs.
@@ -599,56 +622,20 @@ class Layer(checkpointable.CheckpointableBase):
return variable
def _handle_weight_regularization(self, name, variable, regularizer):
- # `init_graph` should point to the graph in which variable initialization
- # will occur; it should be None if and only if initialization will take
- # place in the eager context.
- init_graph = None
- if not context.executing_eagerly():
- default_graph = ops.get_default_graph()
- if default_graph.building_function:
- with ops.init_scope():
- # Retrieve the variables from the graph into which variables
- # will be lifted; if initialization ops will be lifted into
- # the eager context, then there is nothing to retrieve, since variable
- # collections are not supported when eager execution is enabled.
- if not context.executing_eagerly():
- init_graph = ops.get_default_graph()
- else:
- # Initialization ops will not be lifted out of the default graph.
- init_graph = default_graph
-
- if init_graph is not None: # pylint: disable=protected-access
- # The variable was created and initialized in a graph.
- if regularizer:
- if isinstance(variable, tf_variables.PartitionedVariable):
- for v in variable:
- with ops.colocate_with(v.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(v)
- if regularization is not None:
- self.add_loss(regularization)
- else:
- with ops.colocate_with(variable.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(variable)
- if regularization is not None:
- self.add_loss(regularization)
- elif regularizer: # initialization took place in an eager context
- if isinstance(variable, tf_variables.PartitionedVariable):
- raise RuntimeError(
- 'Partitioned variable regularization is not yet '
- 'supported when executing eagerly. File a feature request'
- 'if this is important to you.')
- # Save a zero-argument lambda which runs the regularizer on the
- # variable, to be executed when `Layer.losses` is requested.
- # This makes losses responsive to variable updates when executing
- # eagerly.
- #
- # TODO(akshayka): Do the same for graphs as well, so that losses
- # collected in a while_loop can be run outside its control flow
- # context and so that losses won't be swallowed up by graph functions
- # (i.e., `.losses()` should always create regularizers).
- self._losses.append(lambda: regularizer(variable))
+ """Create lambdas which compute regularization losses."""
+
+ def _loss_for_variable(v):
+ """Creates a regularization loss `Tensor` for variable `v`."""
+ with ops.colocate_with(v):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(v)
+ return regularization
+
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for v in variable:
+ self.add_loss(functools.partial(_loss_for_variable, v))
+ else:
+ self.add_loss(functools.partial(_loss_for_variable, variable))
def _handle_activity_regularization(self, inputs, outputs):
# Apply activity regularization.
@@ -766,7 +753,9 @@ class Layer(checkpointable.CheckpointableBase):
self._assert_input_compatibility(inputs)
if not in_deferred_mode:
+ self._in_call = True
outputs = self.call(inputs, *args, **kwargs)
+ self._in_call = False
if outputs is None:
raise ValueError('A layer\'s `call` method should return a Tensor '
'or a list of Tensors, not None (layer: ' +
@@ -1972,7 +1961,9 @@ def make_variable(name,
if use_resource is None:
use_resource = True
- v = tf_variables.Variable(
+ # TODO(apassos,rohanj) figure out how to remove collections from here so we
+ # can remove the V1.
+ v = tf_variables.VariableV1(
initial_value=init_val,
name=name,
trainable=trainable,
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index 39341a931b..050602868a 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -17,12 +17,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.python.client import session as session_module
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest
@@ -304,23 +310,19 @@ def validate_inputs(x, y, distribution_strategy):
compiled.
Raises:
- ValueError: if input is not a Dataset or a numpy array.
+ ValueError: if input is not a Dataset or a numpy array(when we use
+ MirroredStrategy).
"""
- if isinstance(x, list) or isinstance(y, list):
- raise ValueError('DistributionStrategy does not support lists of numpy'
- 'arrays. You must pass a Dataset object or a numpy array '
- 'as input.')
-
if isinstance(x, dict) or isinstance(y, dict):
- raise ValueError('DistributionStrategy does not support inputs of type '
- 'dict. You must pass a Dataset object or a numpy array as '
- 'input.')
+ raise ValueError('`DistributionStrategy` does not support inputs of type '
+ 'dict. You must pass a `tf.data.Dataset` object or a '
+ 'numpy array as input.')
- if isinstance(x, iterator_ops.Iterator) or \
- isinstance(y, iterator_ops.Iterator):
- raise ValueError('DistributionStrategy does not support inputs of type '
- 'Iterator. You must pass a Dataset object or a numpy '
- 'array as input.')
+ if (isinstance(x, iterator_ops.Iterator) or
+ isinstance(y, iterator_ops.Iterator)):
+ raise ValueError('`DistributionStrategy` does not support inputs of type '
+ 'Iterator. You must pass a `tf.data.Dataset` object or a '
+ 'numpy array as input.')
if distribution_strategy.__class__.__name__ == 'TPUStrategy':
for i in [x, y]:
@@ -334,14 +336,14 @@ def validate_inputs(x, y, distribution_strategy):
'Found unknown shape {} in input {}.'.format(s, i))
-def get_input_batch_params(first_x_value, batch_size, current_strategy):
+def get_input_batch_params(first_x_value, batch_size, distribution_strategy):
"""Calculate the number of batches and steps/steps_per_epoch.
Args:
first_x_value: This is the first input numpy array that is passed in as the
model input.
batch_size: The specified batch_size or the default batch_size of 32.
- current_strategy: The current DistributionStrategy used to compile the
+ distribution_strategy: The current DistributionStrategy used to compile the
model.
Returns:
@@ -359,14 +361,14 @@ def get_input_batch_params(first_x_value, batch_size, current_strategy):
# TODO(anjalisridhar): TPU currently supports using the num_towers property.
# We might want to look into implementing worker_devices. In multi worker
# strategy, perhaps num_towers works better?
- steps = num_batches // current_strategy.num_towers
+ steps = num_batches // distribution_strategy.num_towers
if not steps:
# TODO(anjalisridhar): Number of towers in the error message may not convey
# what we want to the user. Is there another terminology that we can use
# that is consistent across different strategies.
raise ValueError('The number of batches %d is smaller than the number '
'of towers %d used for DistributionStrategy. ' %
- num_batches, current_strategy.num_towers)
+ (num_batches, distribution_strategy.num_towers))
return steps
@@ -376,3 +378,99 @@ def get_batch_dimension(iterator):
# all.
dims = shapes[0].dims
return dims[0] if dims else None
+
+
+def get_cpu_device(distribution_strategy):
+ """Returns the CPU device of the TPU host or the default CPU device string.
+
+ Args:
+ distribution_strategy: The DistributionStrategy used to compile the model.
+
+ Returns:
+ A device string which is the TPU host's CPU device in case of
+ TPUDistributionStrategy or the default CPU device string in all other
+ cases.
+
+ Raises:
+ NotImplementedError: We currently don't support copying numpy data to
+ multiple hosts in the case of Cloud TPU pods.
+ """
+ if distribution_strategy.__class__.__name__ == 'TPUStrategy':
+ if distribution_strategy.num_hosts > 1:
+ raise NotImplementedError('TPUDistributionStrategy does not '
+ 'support numpy inputs when running on Cloud'
+ 'TPU pods.')
+ return distribution_strategy.get_host_cpu_device(0)
+ else:
+ # For all strategies except TPUDistributionStrategy
+ # TODO(anjalisridhar): We may need to modify this when we add support for
+ # multi-worker strategy.
+ return '/CPU:0'
+
+
+def get_var_for_numpy(distribution_strategy, x):
+ if isinstance(x, list):
+ var_x = tuple([_get_var_for_numpy(distribution_strategy, single_input)
+ for single_input in x])
+ else:
+ var_x = _get_var_for_numpy(distribution_strategy, x)
+ return var_x
+
+
+def _get_var_for_numpy(distribution_strategy, input_array):
+ """Creates a variable and assigns the value of the numpy array to it.
+
+ Args:
+ distribution_strategy: The DistributionStrategy used to compile the model.
+ input_array: The input numpy array whose value will be assigned to the
+ variable we create.
+
+ Returns:
+ The variable to which we will copy the value of the input numpy array.
+
+ """
+ with ops.device(get_cpu_device(distribution_strategy)):
+ # Create and initialize a variable on the CPU device. This is the CPU
+ # device of the host in the case of TPUDistributionStrategy.
+ input_var = variables.VariableV1(array_ops.zeros(input_array.shape,
+ input_array.dtype),
+ trainable=False, use_resource=True)
+ K.get_session().run(input_var.initializer)
+
+ # Create a placeholder for the numpy array input slices. We copy the value
+ # of the input numpy array to the variable in slices of size 64 MB to avoid
+ # running into memory issues or RPC message limits.
+ start_placeholder = array_ops.placeholder(dtypes.int64, ())
+ end_placeholder = array_ops.placeholder(dtypes.int64, ())
+ slice_placeholder = array_ops.placeholder(input_var.dtype)
+ assign_slice_op = input_var[start_placeholder:end_placeholder].assign(
+ slice_placeholder)
+
+ # If each batch element is > 64 MB, then we copy each batch element
+ # individually. Otherwise, the slices will be < 128 MB. There might be padding
+ # which might mean that the slices are 128 MB even if the size of the
+ # tensor allocated is less than 128 MB.
+ # This formula gives slices with size:
+ # ceil(64 MB / byte size per batch element) bytes.
+ # Using ceil() guarantees we get a number >= 1.
+
+ # Calculate the size of each batch element.
+ byte_size_per_batch_element = np.prod(input_array.shape[1:]) * \
+ input_var.dtype.size
+
+ # Calculate number of elements we want to copy per slice.
+ batch_size_per_slice = np.ceil((64 << 20) / byte_size_per_batch_element)
+
+ # Copy slices of the above size starting at 0, except the last slice will be
+ # smaller.
+ start = 0
+ limit = input_array.shape[0]
+ while start < limit:
+ end = min(start + batch_size_per_slice, limit)
+ K.get_session().run(assign_slice_op, feed_dict={
+ start_placeholder: start,
+ end_placeholder: end,
+ slice_placeholder: input_array[start:end]})
+ start = end
+
+ return input_var
diff --git a/tensorflow/python/keras/engine/input_layer.py b/tensorflow/python/keras/engine/input_layer.py
index 8a4018a0df..6a69d0ed90 100644
--- a/tensorflow/python/keras/engine/input_layer.py
+++ b/tensorflow/python/keras/engine/input_layer.py
@@ -82,6 +82,7 @@ class InputLayer(base_layer.Layer):
self.built = True
self.sparse = sparse
self.batch_size = batch_size
+ self.supports_masking = True
if isinstance(input_shape, tensor_shape.TensorShape):
input_shape = tuple(input_shape.as_list())
diff --git a/tensorflow/python/keras/engine/network.py b/tensorflow/python/keras/engine/network.py
index 5ef8d13487..5969fea2b2 100644
--- a/tensorflow/python/keras/engine/network.py
+++ b/tensorflow/python/keras/engine/network.py
@@ -1028,7 +1028,10 @@ class Network(base_layer.Layer):
output_tensors, output_masks = layer._call_and_compute_mask(
computed_tensor, **kwargs)
else:
- output_tensors = layer.call(computed_tensor, **kwargs)
+ if context.executing_eagerly():
+ output_tensors = layer(computed_tensor, **kwargs)
+ else:
+ output_tensors = layer.call(computed_tensor, **kwargs)
if hasattr(layer, 'compute_mask'):
output_masks = layer.compute_mask(computed_tensor,
computed_mask)
@@ -1049,7 +1052,10 @@ class Network(base_layer.Layer):
output_tensors, output_masks = layer._call_and_compute_mask(
computed_tensors, **kwargs)
else:
- output_tensors = layer.call(computed_tensors, **kwargs)
+ if context.executing_eagerly():
+ output_tensors = layer(computed_tensors, **kwargs)
+ else:
+ output_tensors = layer.call(computed_tensors, **kwargs)
if hasattr(layer, 'compute_mask'):
output_masks = layer.compute_mask(computed_tensors,
computed_masks)
@@ -1526,6 +1532,7 @@ class Network(base_layer.Layer):
# Restore existing variables (if any) immediately, and set up a
# streaming restore for any variables created in the future.
checkpointable_utils.streaming_restore(status=status, session=session)
+ status.assert_nontrivial_match()
return status
if h5py is None:
raise ImportError(
@@ -1634,10 +1641,11 @@ class Network(base_layer.Layer):
ValueError: if `summary()` is called before the model is built.
"""
if not self.built:
- raise ValueError('This model has never been called, thus its weights '
- 'have not yet been created, so no summary can be '
- 'displayed. Build the model first '
- '(e.g. by calling it on some data).')
+ raise ValueError('This model has not yet been built. '
+ 'Build the model first by calling `build()` or calling '
+ '`fit()` with some data, or specify '
+ 'an `input_shape` argument in the first layer(s) for '
+ 'automatic build.')
layer_utils.print_summary(self,
line_length=line_length,
positions=positions,
diff --git a/tensorflow/python/keras/engine/saving_test.py b/tensorflow/python/keras/engine/saving_test.py
index 02d99d5d69..f5045be907 100644
--- a/tensorflow/python/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/engine/saving_test.py
@@ -38,6 +38,7 @@ from tensorflow.python.platform import test
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
from tensorflow.python.training import training as training_module
+from tensorflow.python.training.checkpointable import util as checkpointable
try:
import h5py # pylint:disable=g-import-not-at-top
@@ -922,6 +923,18 @@ class TestWeightSavingAndLoadingTFFormat(test.TestCase):
SubclassedModel, SubclassedModelRestore,
_restore_init_fn)
+ @test_util.run_in_graph_and_eager_modes
+ def test_incompatible_checkpoint(self):
+ save_path = checkpointable.Checkpoint().save(
+ os.path.join(self.get_temp_dir(), 'ckpt'))
+ m = keras.Model()
+ with self.assertRaisesRegexp(AssertionError, 'Nothing to load'):
+ m.load_weights(save_path)
+ m.dense = keras.layers.Dense(2)
+ m.dense(constant_op.constant([[1.]]))
+ with self.assertRaisesRegexp(
+ AssertionError, 'Nothing except the root object matched'):
+ m.load_weights(save_path)
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/topology_test.py b/tensorflow/python/keras/engine/topology_test.py
index a0da96334b..b4488033cd 100644
--- a/tensorflow/python/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/engine/topology_test.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import test
+from tensorflow.python.training import rmsprop
try:
import yaml # pylint:disable=g-import-not-at-top
@@ -1182,6 +1183,36 @@ class DefaultShapeInferenceBehaviorTest(test.TestCase):
output = model(sample_input)
self.assertEqual(output.shape, (1, 3))
+ @test_util.run_in_graph_and_eager_modes()
+ def test_sequential_as_downstream_of_masking_layer(self):
+ inputs = keras.layers.Input(shape=(3, 4))
+ x = keras.layers.Masking(mask_value=0., input_shape=(3, 4))(inputs)
+
+ s = keras.Sequential()
+ s.add(keras.layers.Dense(5, input_shape=(4,)))
+
+ x = keras.layers.wrappers.TimeDistributed(s)(x)
+ model = keras.Model(inputs=inputs, outputs=x)
+ model.compile(optimizer=rmsprop.RMSPropOptimizer(1e-3), loss='mse')
+
+ model_input = np.random.randint(
+ low=1, high=5, size=(10, 3, 4)).astype('float32')
+ for i in range(4):
+ model_input[i, i:, :] = 0.
+ model.fit(model_input,
+ np.random.random((10, 3, 5)), epochs=1, batch_size=6)
+
+ if not context.executing_eagerly():
+ # Note: this doesn't work in eager due to DeferredTensor/ops compatibility
+ # issue.
+ mask_outputs = [model.layers[1].compute_mask(model.layers[1].input)]
+ mask_outputs += [model.layers[2].compute_mask(
+ model.layers[2].input, mask_outputs[-1])]
+ func = keras.backend.function([model.input], mask_outputs)
+ mask_outputs_val = func([model_input])
+ self.assertAllClose(mask_outputs_val[0], np.any(model_input, axis=-1))
+ self.assertAllClose(mask_outputs_val[1], np.any(model_input, axis=-1))
+
class GraphUtilsTest(test.TestCase):
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index ade8a4b32d..ff2ae54ad4 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -20,11 +20,9 @@ from __future__ import print_function
import weakref
import numpy as np
-import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -565,9 +563,11 @@ class Model(Network):
for name in self.output_names:
tmp_target_tensors.append(target_tensors.get(name, None))
target_tensors = tmp_target_tensors
+ elif tensor_util.is_tensor(target_tensors):
+ target_tensors = [target_tensors]
else:
- raise TypeError('Expected `target_tensors` to be '
- 'a list or dict, but got:', target_tensors)
+ raise TypeError('Expected `target_tensors` to be a list or tuple or '
+ 'dict or a single tensor, but got:', target_tensors)
for i in range(len(self.outputs)):
if i in skip_target_indices:
@@ -647,12 +647,6 @@ class Model(Network):
skip_target_indices=skip_target_indices,
sample_weights=self.sample_weights)
- # If using distribution strategy and stateful_metrics, raise an error
- # since we currently don't support stateful metrics.
- if self._distribution_strategy is not None and self.stateful_metric_names:
- raise NotImplementedError('Stateful metrics are not supported with '
- 'DistributionStrategy.')
-
# Prepare gradient updates and state updates.
self.total_loss = total_loss
@@ -820,19 +814,22 @@ class Model(Network):
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray):
x_shape = first_x_value.shape
- x_dtype = first_x_value.dtype
if batch_size is None:
batch_size = x_shape[0] // steps
+ # We need to use the drop_remainder argument to allow for a static
+ # input shape which is required for TPUs.
+ drop_remainder = self._distribution_strategy.require_static_shapes
if y is not None:
- first_y_value = nest.flatten(y)[0]
- x = Dataset.from_generator(lambda x=x, y=y: six.moves.zip(x, y),
- output_types=(x_dtype, first_y_value.dtype),
- output_shapes=(x_shape[1:],
- first_y_value.shape[1:]))
+ var_x = distributed_training_utils.get_var_for_numpy(
+ self._distribution_strategy, x)
+ var_y = distributed_training_utils.get_var_for_numpy(
+ self._distribution_strategy, y)
+
+ x = dataset_ops.Dataset.from_tensor_slices((var_x, var_y))
# TODO(anjalisridhar): What should the buffer size be?
x = x.shuffle(10000)
x = x.repeat()
- x = x.batch(batch_size)
+ x = x.batch(batch_size, drop_remainder=drop_remainder)
y = None
else:
# This case is for the predict call where the dataset only contains
@@ -840,11 +837,11 @@ class Model(Network):
# TODO(anjalisridhar): Raise an error if we are not able to process
# all the predict samples. This can happen if the number of batches is
# not evenly divisible by the number of worker devices.
- x = Dataset.from_generator(lambda x=x: x,
- output_types=x_dtype,
- output_shapes=x_shape[1:])
+ var_x = distributed_training_utils.get_var_for_numpy(
+ self._distribution_strategy, x)
+ x = dataset_ops.Dataset.from_tensor_slices(var_x)
x = x.repeat()
- x = x.batch(batch_size)
+ x = x.batch(batch_size, drop_remainder=drop_remainder)
# TODO(anjalisridhar): Can we use the iterator and getnext op cache?
# We require users to pass Datasets since we distribute the dataset across
@@ -857,7 +854,8 @@ class Model(Network):
# able to clone a Dataset on multiple workers we can remove this lambda.
result = self._distribution_strategy.distribute_dataset(lambda: x)
iterator = result.make_initializable_iterator()
- K.get_session().run(iterator.initializer)
+ with self._distribution_strategy.scope():
+ K.get_session().run(iterator.initializer)
training_utils.validate_iterator_input(x, y, sample_weight,
validation_split)
@@ -983,16 +981,18 @@ class Model(Network):
'Make sure that your dataset can generate '
'required number of samples.')
- if (not isinstance(next_element, (list, tuple)) or
- len(next_element) not in [2, 3]):
- raise ValueError(
- 'Please provide model inputs as a list or tuple of 2 or 3'
- 'elements: (input, target) or (input, target, sample_weights)'
- 'Received %s' % next_element)
- if len(next_element) == 2:
- x, y = next_element
+ if isinstance(next_element, (list, tuple)):
+ if len(next_element) not in [2, 3]:
+ raise ValueError(
+ 'Please provide model inputs as a list or tuple of 2 or 3'
+ 'elements: (input, target) or (input, target, sample_weights)'
+ 'Received %s' % next_element)
+ if len(next_element) == 2:
+ x, y = next_element
+ else:
+ x, y, sample_weight = next_element
else:
- x, y, sample_weight = next_element
+ x = next_element
x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
class_weight, batch_size)
return x, y, sample_weights
@@ -1420,6 +1420,8 @@ class Model(Network):
- tuple `(x_val, y_val)` of Numpy arrays or tensors
- tuple `(x_val, y_val, val_sample_weights)` of Numpy arrays
- dataset or a dataset iterator
+ For the first two cases, `batch_size` must be provided.
+ For the last case, `validation_steps` must be provided.
shuffle: Boolean (whether to shuffle the training data
before each epoch) or str (for 'batch').
'batch' is a special option for dealing with the
@@ -1455,9 +1457,10 @@ class Model(Network):
TensorFlow data tensors, the default `None` is equal to
the number of samples in your dataset divided by
the batch size, or 1 if that cannot be determined.
- validation_steps: Only relevant if `steps_per_epoch`
- is specified. Total number of steps (batches of samples)
- to validate before stopping.
+ validation_steps: Only relevant if `validation_data` is provided and
+ is a dataset or dataset iterator. Total number of steps (batches of
+ samples) to draw before stopping when performing validation
+ at the end of every epoch.
max_queue_size: Integer. Used for generator or `keras.utils.Sequence`
input only. Maximum size for the generator queue.
If unspecified, `max_queue_size` will default to 10.
@@ -2361,6 +2364,6 @@ class DistributedCallbackModel(Model):
# Whitelisted atttributes of the model that can be accessed by the user
# during a callback.
if item not in ['_setattr_tracking']:
- logging.warning('You are accessing attribute ' + item + 'of the '
+ logging.warning('You are accessing attribute ' + item + ' of the '
'DistributedCallbackModel that may not have been set '
'correctly.')
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index 8b434ca444..ac759ef3aa 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -26,11 +26,13 @@ from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import distributed_training_utils
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras.utils.generic_utils import Progbar
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.util import nest
# TODO(priyag, sourabhbajaj): Refactor this file to address code duplication.
@@ -111,96 +113,99 @@ def fit_loop(
dataset_targets = distributed_training_utils.flatten_perdevice_values(
current_strategy, targets)
- # Create a train function that is composed of all the parameters above.
- distributed_train_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_train_function',
- **all_session_args)
-
- # We need to set sample_weights to None since there are sample weight
- # placeholders that are created with default values.
- sample_weights = [None for _ in range(len(model.outputs) *
- current_strategy.num_towers)]
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + dataset_targets + sample_weights + [1]
- else:
- ins = dataset_inputs + dataset_targets
+ # Create a train function that is composed of all the parameters above.
+ distributed_train_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_train_function',
+ **all_session_args)
- do_validation = False
- if validation_steps:
- do_validation = True
+ # We need to set sample_weights to None since there are sample weight
+ # placeholders that are created with default values.
+ sample_weights = [None for _ in range(len(model.outputs) *
+ current_strategy.num_towers)]
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + dataset_targets + sample_weights + [1]
+ else:
+ ins = dataset_inputs + dataset_targets
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ do_validation = False
+ if validation_steps:
+ do_validation = True
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- callbacks = cbks.configure_callbacks(
- callbacks,
- model,
- do_validation=do_validation,
- val_inputs=None,
- val_targets=None,
- epochs=epochs,
- steps_per_epoch=steps_per_epoch,
- verbose=verbose)
- out_labels = model.metrics_names or []
- callbacks.on_train_begin()
-
- assert steps_per_epoch is not None
-
- for epoch in range(initial_epoch, epochs):
- callbacks.on_epoch_begin(epoch)
- epoch_logs = {}
- for step_index in range(steps_per_epoch):
- batch_logs = {'batch': step_index, 'size': 1}
- callbacks.on_batch_begin(step_index, batch_logs)
- try:
- outs = distributed_train_function(ins)
- except errors.OutOfRangeError:
- logging.warning('Your dataset iterator ran out of data; '
- 'interrupting training. Make sure that your dataset '
- 'can generate at least `steps_per_epoch * epochs` '
- 'batches (in this case, %d batches).' %
- steps_per_epoch * epochs)
- break
-
- if not isinstance(outs, list):
- outs = [outs]
-
- outs = _aggregate_metrics_across_towers(
- current_strategy.num_towers, out_labels, outs)
- for l, o in zip(out_labels, outs):
- batch_logs[l] = o
- callbacks.on_batch_end(step_index, batch_logs)
+ callbacks = cbks.configure_callbacks(
+ callbacks,
+ model,
+ do_validation=do_validation,
+ val_inputs=None,
+ val_targets=None,
+ epochs=epochs,
+ steps_per_epoch=steps_per_epoch,
+ verbose=verbose)
+ out_labels = model.metrics_names or []
+ callbacks.on_train_begin()
+
+ assert steps_per_epoch is not None
+
+ for epoch in range(initial_epoch, epochs):
+ # Reset stateful metrics
+ for m in model.stateful_metric_functions:
+ m.reset_states()
+ callbacks.on_epoch_begin(epoch)
+ epoch_logs = {}
+ for step_index in range(steps_per_epoch):
+ batch_logs = {'batch': step_index, 'size': 1}
+ callbacks.on_batch_begin(step_index, batch_logs)
+ try:
+ outs = distributed_train_function(ins)
+ except errors.OutOfRangeError:
+ logging.warning('Your dataset iterator ran out of data; '
+ 'interrupting training. Make sure that your dataset '
+ 'can generate at least `steps_per_epoch * epochs` '
+ 'batches (in this case, %d batches).' %
+ steps_per_epoch * epochs)
+ break
+
+ if not isinstance(outs, list):
+ outs = [outs]
+
+ outs = _aggregate_metrics_across_towers(current_strategy.num_towers,
+ out_labels,
+ model.stateful_metric_names,
+ outs)
+ for l, o in zip(out_labels, outs):
+ batch_logs[l] = o
+ callbacks.on_batch_end(step_index, batch_logs)
+ if callbacks.model.stop_training:
+ break
+ if do_validation:
+ val_outs = test_loop(
+ model,
+ val_iterator,
+ steps=validation_steps,
+ verbose=0)
+ if not isinstance(val_outs, list):
+ val_outs = [val_outs]
+ # Same labels assumed.
+ for l, o in zip(out_labels, val_outs):
+ epoch_logs['val_' + l] = o
+
+ callbacks.on_epoch_end(epoch, epoch_logs)
if callbacks.model.stop_training:
break
- if do_validation:
- val_outs = test_loop(
- model,
- val_iterator,
- steps=validation_steps,
- verbose=0)
- if not isinstance(val_outs, list):
- val_outs = [val_outs]
- # Same labels assumed.
- for l, o in zip(out_labels, val_outs):
- epoch_logs['val_' + l] = o
-
- callbacks.on_epoch_end(epoch, epoch_logs)
- if callbacks.model.stop_training:
- break
- callbacks.on_train_end()
+ callbacks.on_train_end()
- # Copy the weights back from the replicated model to the original model.
- with current_strategy.scope():
+ # Copy the weights back from the replicated model to the original model.
updated_weights = current_strategy.unwrap(
model._grouped_model)[0].get_weights()
model.set_weights(updated_weights)
- return model.history
+ return model.history
def _experimental_fit_loop(
@@ -292,15 +297,16 @@ def _experimental_fit_loop(
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
if steps_per_epoch is None:
- raise ValueError('steps_per_epoch should be specified in the fit call.')
- steps_per_run_var = K.variable(
+ raise ValueError('`steps_per_epoch` should be specified when calling '
+ '`fit` on the model.')
+ steps_per_run = K.variable(
value=min(steps_per_epoch, current_strategy.steps_per_run),
dtype='int32',
- name='steps_per_run_var')
+ name='steps_per_run')
with current_strategy.scope():
ctx = current_strategy.run_steps_on_dataset(
- step_fn, iterator, iterations=steps_per_run_var,
+ step_fn, iterator, iterations=steps_per_run,
initial_loop_values=initial_loop_values)
train_op = ctx.run_op
@@ -340,7 +346,7 @@ def _experimental_fit_loop(
batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
callbacks.on_batch_begin(step_index, batch_logs)
if prev_step_count is None or step_count != prev_step_count:
- steps_per_run_var.load(step_count, K.get_session())
+ steps_per_run.load(step_count, K.get_session())
prev_step_count = step_count
try:
_, outputs = K.get_session().run([train_op, output_tensors])
@@ -422,54 +428,65 @@ def test_loop(model, iterator, verbose=0, steps=None):
dataset_targets = distributed_training_utils.flatten_perdevice_values(
current_strategy, targets)
- distributed_test_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_test_function',
- **all_session_args)
-
- # We need to set sample_weights to None since there are sample weight
- # placeholders that are created with default values.
- sample_weights = [None for _ in range(len(model.outputs) *
- current_strategy.num_towers)]
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + dataset_targets + sample_weights + [0]
- else:
- ins = dataset_inputs + dataset_targets
+ distributed_test_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_test_function',
+ **all_session_args)
- outs = []
- if verbose == 1:
- progbar = Progbar(target=steps)
+ # We need to set sample_weights to None since there are sample weight
+ # placeholders that are created with default values.
+ sample_weights = [None for _ in range(len(model.outputs) *
+ current_strategy.num_towers)]
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + dataset_targets + sample_weights + [0]
+ else:
+ ins = dataset_inputs + dataset_targets
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ for m in model.stateful_metric_functions:
+ m.reset_states()
+ stateful_metric_indices = [
+ i for i, name in enumerate(model.metrics_names)
+ if str(name) in model.stateful_metric_names
+ ]
+
+ outs = []
+ if verbose == 1:
+ progbar = Progbar(target=steps)
+
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- assert steps is not None
- for step in range(steps):
- batch_outs = distributed_test_function(ins)
- batch_outs = _aggregate_metrics_across_towers(
- current_strategy.num_towers, model.metrics_names, batch_outs)
- if isinstance(batch_outs, list):
- if step == 0:
- outs = [0.] * len(batch_outs)
- for i, batch_out in enumerate(batch_outs):
- outs[i] += batch_out
- else:
- if step == 0:
- outs.append(0.)
- outs[0] += batch_outs
- if verbose >= 1:
- progbar.update(step + 1)
- for i in range(len(outs)):
- outs[i] /= steps
+ assert steps is not None
+ for step in range(steps):
+ batch_outs = distributed_test_function(ins)
+ batch_outs = _aggregate_metrics_across_towers(
+ current_strategy.num_towers, model.metrics_names,
+ model.stateful_metric_names, batch_outs)
+ if isinstance(batch_outs, list):
+ if step == 0:
+ outs = [0.] * len(batch_outs)
+ for i, batch_out in enumerate(batch_outs):
+ if i in stateful_metric_indices:
+ outs[i] = batch_out
+ else:
+ outs[i] += batch_out
+ else:
+ if step == 0:
+ outs.append(0.)
+ outs[0] += batch_outs
+ if verbose >= 1:
+ progbar.update(step + 1)
+ for i in range(len(outs)):
+ if i not in stateful_metric_indices:
+ outs[i] /= steps
- if len(outs) == 1:
- return outs[0]
- return outs
+ if len(outs) == 1:
+ return outs[0]
+ return outs
def _experimental_test_loop(model, iterator, verbose=0, steps=None):
@@ -630,51 +647,50 @@ def predict_loop(model, iterator, verbose=0, steps=None):
dataset_inputs = distributed_training_utils.flatten_perdevice_values(
current_strategy, inputs)
- distributed_predict_function = K.Function(
- all_inputs, all_outputs,
- updates=all_updates,
- name='distributed_predict_function',
- **all_session_args)
+ distributed_predict_function = K.Function(
+ all_inputs, all_outputs,
+ updates=all_updates,
+ name='distributed_predict_function',
+ **all_session_args)
- if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
- ins = dataset_inputs + [0]
- else:
- ins = dataset_inputs
+ if model.uses_learning_phase and not isinstance(K.learning_phase(), int):
+ ins = dataset_inputs + [0]
+ else:
+ ins = dataset_inputs
- if verbose == 1:
- progbar = Progbar(target=steps)
+ if verbose == 1:
+ progbar = Progbar(target=steps)
- # Copy the weights from the original model to each of the replicated models.
- orig_model_weights = model.get_weights()
- with current_strategy.scope():
+ # Copy the weights from the original model to each of the replicated models.
+ orig_model_weights = model.get_weights()
distributed_model = current_strategy.unwrap(model._grouped_model)[0]
distributed_training_utils.set_weights(
current_strategy, distributed_model, orig_model_weights)
- if steps is not None:
- # Since we do not know how many samples we will see, we cannot pre-allocate
- # the returned Numpy arrays. Instead, we store one array per batch seen
- # and concatenate them upon returning.
- unconcatenated_outs = []
- for step in range(steps):
- batch_outs = distributed_predict_function(ins)
- if not isinstance(batch_outs, list):
- batch_outs = [batch_outs]
- if step == 0:
- for _ in batch_outs:
- unconcatenated_outs.append([])
- # TODO(anjalisridhar): Should combine the outputs from multiple towers
- # correctly here.
- for i, batch_out in enumerate(batch_outs):
- unconcatenated_outs[i].append(batch_out)
- if verbose >= 1:
- progbar.update(step + 1)
- if len(unconcatenated_outs) == 1:
- return np.concatenate(unconcatenated_outs[0], axis=0)
- return [
- np.concatenate(unconcatenated_outs[i], axis=0)
- for i in range(len(unconcatenated_outs))
- ]
+ if steps is not None:
+ # Since we do not know how many samples we will see, we cannot
+ # pre-allocate the returned Numpy arrays. Instead, we store one array per
+ # batch seen and concatenate them upon returning.
+ unconcatenated_outs = []
+ for step in range(steps):
+ batch_outs = distributed_predict_function(ins)
+ if not isinstance(batch_outs, list):
+ batch_outs = [batch_outs]
+ if step == 0:
+ for _ in batch_outs:
+ unconcatenated_outs.append([])
+ # TODO(anjalisridhar): Should combine the outputs from multiple towers
+ # correctly here.
+ for i, batch_out in enumerate(batch_outs):
+ unconcatenated_outs[i].append(batch_out)
+ if verbose >= 1:
+ progbar.update(step + 1)
+ if len(unconcatenated_outs) == 1:
+ return np.concatenate(unconcatenated_outs[0], axis=0)
+ return [
+ np.concatenate(unconcatenated_outs[i], axis=0)
+ for i in range(len(unconcatenated_outs))
+ ]
def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
@@ -706,13 +722,9 @@ def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
model.predict_function.updates_op,
model.predict_function.session_kwargs)
- def step_fn(ctx, inputs, targets):
+ def step_fn(ctx, *inputs):
"""Clones the model and calls make_predict_function."""
- # TODO(anjalisridhar): Support predict input correctly as it will not
- # contain targets, only inputs.
- del targets
-
# TODO(priyag, sourabhbajaj): The model gets cloned every time
# fit/test/predict is called. We should look into caching this keyed on
# input shapes.
@@ -808,18 +820,15 @@ def _clone_and_build_model(model, inputs=None, targets=None):
optimizer_config = model.optimizer.get_config()
optimizer = model.optimizer.__class__.from_config(optimizer_config)
- # TODO(priyag): Is there a cleaner way to do this? The API doc suggests a
- # single tensor should be OK but it throws an error in that case.
- if (targets is not None and not isinstance(targets, list) and
- not isinstance(targets, dict)):
- targets = [targets]
+ if isinstance(targets, tuple):
+ targets = nest.flatten(targets)
cloned_model.compile(
optimizer,
model.loss,
- metrics=model.metrics,
+ metrics=metrics_module.clone_metrics(model.metrics),
loss_weights=model.loss_weights,
sample_weight_mode=model.sample_weight_mode,
- weighted_metrics=model.weighted_metrics,
+ weighted_metrics=metrics_module.clone_metrics(model.weighted_metrics),
target_tensors=targets)
return cloned_model
@@ -834,8 +843,9 @@ def clone_model_on_towers(
model._make_callback_model()
-def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
- """Aggregate metrics values across all towers.
+def _aggregate_metrics_across_towers(num_devices, out_labels,
+ stateful_metric_names, outs):
+ """Aggregates stateless metrics values across towers.
When using `MirroredStrategy`, the number of towers is equal to the
number of devices over which training is distributed. This may not always be
@@ -844,6 +854,7 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
Args:
num_devices: Number of devices over which the model is being distributed.
out_labels: The list of metric names passed to `compile`.
+ stateful_metric_names: List of stateful metric names on the model.
outs: The output from all the towers.
Returns:
@@ -858,10 +869,16 @@ def _aggregate_metrics_across_towers(num_devices, out_labels, outs):
# Each label in `out_labels` corresponds to one set of metrics. The
# number of metric values corresponds to the number of devices. We
# currently take the mean of the values.
- for _ in out_labels[1:]:
- m = np.mean(outs[current_index:current_index + num_devices])
- merged_output.append(m)
- current_index += num_devices
+ for metric_name in out_labels[1:]:
+ if metric_name in stateful_metric_names:
+ # For stateful metrics, we get one aggregated result value.
+ merged_output.append(outs[current_index])
+ current_index += 1
+ else:
+ m = np.mean(outs[current_index:current_index + num_devices])
+ merged_output.append(m)
+ current_index += num_devices
+
return merged_output
@@ -869,11 +886,12 @@ def _get_input_from_iterator(iterator, model):
"""Get elements from the iterator and verify the input shape and type."""
next_element = iterator.get_next()
- if isinstance(next_element, tuple):
- x, y = next_element
- else:
+ if len(nest.flatten(next_element)) == len(model.inputs):
x = next_element
y = None
+ else:
+ x, y = next_element
+
# Validate that all the elements in x and y are of the same type and shape.
# We can then pass the first element of x and y to `_standardize_weights`
# below and be confident of the output.
diff --git a/tensorflow/python/keras/engine/training_eager.py b/tensorflow/python/keras/engine/training_eager.py
index fb71bf2596..2a62edd698 100644
--- a/tensorflow/python/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/engine/training_eager.py
@@ -739,7 +739,8 @@ def test_loop(model, inputs, targets,
y=targets,
sample_weights=sample_weights,
batch_size=batch_size,
- steps_per_epoch=steps)
+ steps_per_epoch=steps,
+ is_validation=True)
with backend.learning_phase_scope(0):
return iterator_test_loop(model, inputs, steps, verbose=verbose)
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py
index db7ccb181f..943ede1be9 100644
--- a/tensorflow/python/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/engine/training_eager_test.py
@@ -125,6 +125,36 @@ class TrainingTest(test.TestCase):
model.train_on_batch(inputs, targets)
model.test_on_batch(inputs, targets)
+ def test_model_fit_and_validation_with_missing_arg_errors(self):
+ x = keras.layers.Input(shape=(3,), name='input')
+ y = keras.layers.Dense(4, name='dense')(x)
+ model = keras.Model(x, y)
+ model.compile(optimizer=RMSPropOptimizer(learning_rate=0.001), loss='mse')
+
+ x = keras.backend.zeros(shape=(10, 3))
+ y = keras.backend.zeros(shape=(10, 4))
+ dataset = dataset_ops.Dataset.from_tensor_slices((x, y)).repeat(10).batch(5)
+ iterator = dataset.make_one_shot_iterator()
+ validation_dataset = dataset_ops.Dataset.from_tensor_slices(
+ (x, y)).repeat(10).batch(5)
+ validation_iterator = validation_dataset.make_one_shot_iterator()
+
+ with self.assertRaisesRegexp(
+ ValueError, r'specify .* `steps_per_epoch`'):
+ model.fit(iterator, epochs=1, verbose=0)
+ with self.assertRaisesRegexp(
+ ValueError, r'provide either `batch_size` or `validation_steps`'):
+ model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
+ validation_data=(x, y))
+ with self.assertRaisesRegexp(
+ ValueError, r'provide either `batch_size` or `validation_steps`'):
+ model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
+ validation_data=validation_dataset)
+ with self.assertRaisesRegexp(
+ ValueError, r'provide either `batch_size` or `validation_steps`'):
+ model.fit(iterator, steps_per_epoch=2, epochs=1, verbose=0,
+ validation_data=validation_iterator)
+
def test_generator_methods(self):
model = keras.Sequential()
model.add(keras.layers.Dense(4, input_shape=(3,)))
@@ -192,6 +222,20 @@ class CorrectnessTest(test.TestCase):
history = model.fit(iterator, epochs=1, steps_per_epoch=10)
self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173)
+ def test_no_loss_in_call(self):
+
+ class HasLoss(keras.layers.Layer):
+
+ def call(self, x):
+ self.add_loss(x)
+ return x
+
+ layer = HasLoss()
+ with self.assertRaises(RuntimeError):
+ layer(1.)
+
+ with ops.Graph().as_default():
+ layer(1.)
if __name__ == '__main__':
ops.enable_eager_execution()
diff --git a/tensorflow/python/keras/engine/training_generator.py b/tensorflow/python/keras/engine/training_generator.py
index 413c1f4fba..2e074699da 100644
--- a/tensorflow/python/keras/engine/training_generator.py
+++ b/tensorflow/python/keras/engine/training_generator.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import numpy as np
+from tensorflow.python.eager import context
from tensorflow.python.keras import callbacks as cbks
from tensorflow.python.keras.utils.data_utils import GeneratorEnqueuer
from tensorflow.python.keras.utils.data_utils import OrderedEnqueuer
@@ -48,6 +49,10 @@ def fit_generator(model,
epoch = initial_epoch
do_validation = bool(validation_data)
+ if not context.executing_eagerly():
+ model._make_train_function()
+ if do_validation:
+ model._make_test_function()
is_sequence = isinstance(generator, Sequence)
if not is_sequence and use_multiprocessing and workers > 1:
@@ -233,6 +238,9 @@ def evaluate_generator(model,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.evaluate_generator`."""
+ if not context.executing_eagerly():
+ model._make_test_function()
+
if hasattr(model, 'metrics'):
for m in model.stateful_metric_functions:
m.reset_states()
@@ -342,6 +350,9 @@ def predict_generator(model,
use_multiprocessing=False,
verbose=0):
"""See docstring for `Model.predict_generator`."""
+ if not context.executing_eagerly():
+ model._make_test_function()
+
steps_done = 0
wait_time = 0.01
all_outs = []
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 30be4131a4..868fd1dc69 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -27,6 +27,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util as tf_test_util
@@ -1864,6 +1865,10 @@ class TestTrainingWithDataTensors(test.TestCase):
model.compile(optimizer='rmsprop', loss='mse', target_tensors=[target])
model.train_on_batch(input_val, None)
+ # single-output, as single tensor
+ model.compile(optimizer='rmsprop', loss='mse', target_tensors=target)
+ model.train_on_batch(input_val, None)
+
# single-output, as dict
model.compile(optimizer='rmsprop', loss='mse',
target_tensors={'dense': target})
@@ -2427,6 +2432,17 @@ class TestTrainingWithMetrics(test.TestCase):
scores = model.train_on_batch(x, y, sample_weight=w)
self.assertArrayNear(scores, [0.2, 0.8, 0.8], 0.1)
+ def test_losses_in_defun(self):
+ with context.eager_mode():
+ layer = keras.layers.Dense(1, kernel_regularizer='l1')
+ layer(array_ops.ones([1, 10]))
+
+ @function.defun
+ def get_losses():
+ return layer.losses
+
+ self.assertAllEqual(self.evaluate(layer.losses),
+ self.evaluate(get_losses()))
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py
index 9c303f4bed..dd2a7f16ec 100644
--- a/tensorflow/python/keras/engine/training_utils.py
+++ b/tensorflow/python/keras/engine/training_utils.py
@@ -106,7 +106,8 @@ def convert_to_iterator(x=None,
batch_size=None,
steps_per_epoch=None,
epochs=1,
- shuffle=False):
+ shuffle=False,
+ is_validation=False):
"""Converts NumPy arrays or EagerTensors to an EagerIterator.
Combines all provided data into a single EagerIterator.
@@ -124,6 +125,9 @@ def convert_to_iterator(x=None,
epoch.
epochs: Epochs to repeat iterator for.
shuffle: Whether to shuffle data after each epoch.
+ is_validation: Whether this call is for validation during a training
+ (e.g., `fit()`) call. This info is used to construct error messages
+ (if any).
Raises:
ValueError: if steps_per_epoch cannot be calculated from the data
@@ -151,9 +155,12 @@ def convert_to_iterator(x=None,
steps_per_epoch = int(math.ceil(num_samples / batch_size))
if steps_per_epoch is None:
- raise ValueError('Could not determine steps_per_epoch.'
- 'Please provide either batch_size or'
- 'steps_per_epoch.')
+ alternative_arg_name = (
+ 'validation_steps' if is_validation else 'steps_per_epoch')
+ raise ValueError(
+ 'Could not determine how to convert EagerTensors into EagerIterator. '
+ 'Please provide either `batch_size` or '
+ '`%s`.' % alternative_arg_name)
# TODO(omalleyt) for NumPy arrays in graph mode
# placeholder ops should be used
diff --git a/tensorflow/python/keras/layers/convolutional.py b/tensorflow/python/keras/layers/convolutional.py
index d00def07bb..8f5872385c 100644
--- a/tensorflow/python/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/layers/convolutional.py
@@ -645,6 +645,14 @@ class Conv2DTranspose(Conv2D):
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive).
+ output_padding: An integer or tuple/list of 2 integers,
+ specifying the amount of padding along the height and width
+ of the output tensor.
+ Can be a single integer to specify the same value for all
+ spatial dimensions.
+ The amount of output padding along a given dimension must be
+ lower than the stride along that same dimension.
+ If set to `None` (default), the output shape is inferred.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
@@ -700,7 +708,9 @@ class Conv2DTranspose(Conv2D):
kernel_size,
strides=(1, 1),
padding='valid',
+ output_padding=None,
data_format=None,
+ dilation_rate=(1, 1),
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
@@ -717,6 +727,7 @@ class Conv2DTranspose(Conv2D):
strides=strides,
padding=padding,
data_format=data_format,
+ dilation_rate=dilation_rate,
activation=activations.get(activation),
use_bias=use_bias,
kernel_initializer=initializers.get(kernel_initializer),
@@ -728,6 +739,16 @@ class Conv2DTranspose(Conv2D):
bias_constraint=constraints.get(bias_constraint),
**kwargs)
+ self.output_padding = output_padding
+ if self.output_padding is not None:
+ self.output_padding = conv_utils.normalize_tuple(
+ self.output_padding, 2, 'output_padding')
+ for stride, out_pad in zip(self.strides, self.output_padding):
+ if out_pad >= stride:
+ raise ValueError('Stride ' + str(self.strides) + ' must be '
+ 'greater than output padding ' +
+ str(self.output_padding))
+
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if len(input_shape) != 4:
@@ -769,51 +790,50 @@ class Conv2DTranspose(Conv2D):
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
if self.data_format == 'channels_first':
- c_axis, h_axis, w_axis = 1, 2, 3
+ h_axis, w_axis = 2, 3
else:
- c_axis, h_axis, w_axis = 3, 1, 2
+ h_axis, w_axis = 1, 2
height, width = inputs_shape[h_axis], inputs_shape[w_axis]
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_h = out_pad_w = None
+ else:
+ out_pad_h, out_pad_w = self.output_padding
+
# Infer the dynamic output shape:
out_height = conv_utils.deconv_output_length(height,
kernel_h,
- self.padding,
- stride_h)
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h,
+ dilation=self.dilation_rate[0])
out_width = conv_utils.deconv_output_length(width,
kernel_w,
- self.padding,
- stride_w)
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w,
+ dilation=self.dilation_rate[1])
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_height, out_width)
- strides = (1, 1, stride_h, stride_w)
else:
output_shape = (batch_size, out_height, out_width, self.filters)
- strides = (1, stride_h, stride_w, 1)
output_shape_tensor = array_ops.stack(output_shape)
- outputs = nn.conv2d_transpose(
+ outputs = backend.conv2d_transpose(
inputs,
self.kernel,
output_shape_tensor,
- strides,
- padding=self.padding.upper(),
- data_format=conv_utils.convert_data_format(self.data_format, ndim=4))
+ strides=self.strides,
+ padding=self.padding,
+ data_format=self.data_format,
+ dilation_rate=self.dilation_rate)
if not context.executing_eagerly():
# Infer the static output shape:
- out_shape = inputs.get_shape().as_list()
- out_shape[c_axis] = self.filters
- out_shape[h_axis] = conv_utils.deconv_output_length(out_shape[h_axis],
- kernel_h,
- self.padding,
- stride_h)
- out_shape[w_axis] = conv_utils.deconv_output_length(out_shape[w_axis],
- kernel_w,
- self.padding,
- stride_w)
+ out_shape = self.compute_output_shape(inputs.shape)
outputs.set_shape(out_shape)
if self.use_bias:
@@ -837,13 +857,33 @@ class Conv2DTranspose(Conv2D):
kernel_h, kernel_w = self.kernel_size
stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_h = out_pad_w = None
+ else:
+ out_pad_h, out_pad_w = self.output_padding
+
output_shape[c_axis] = self.filters
output_shape[h_axis] = conv_utils.deconv_output_length(
- output_shape[h_axis], kernel_h, self.padding, stride_h)
+ output_shape[h_axis],
+ kernel_h,
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h,
+ dilation=self.dilation_rate[0])
output_shape[w_axis] = conv_utils.deconv_output_length(
- output_shape[w_axis], kernel_w, self.padding, stride_w)
+ output_shape[w_axis],
+ kernel_w,
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w,
+ dilation=self.dilation_rate[1])
return tensor_shape.TensorShape(output_shape)
+ def get_config(self):
+ config = super(Conv2DTranspose, self).get_config()
+ config['output_padding'] = self.output_padding
+ return config
+
@tf_export('keras.layers.Conv3DTranspose',
'keras.layers.Convolution3DTranspose')
@@ -878,6 +918,14 @@ class Conv3DTranspose(Conv3D):
Specifying any stride value != 1 is incompatible with specifying
any `dilation_rate` value != 1.
padding: one of `"valid"` or `"same"` (case-insensitive).
+ output_padding: An integer or tuple/list of 3 integers,
+ specifying the amount of padding along the depth, height, and
+ width.
+ Can be a single integer to specify the same value for all
+ spatial dimensions.
+ The amount of output padding along a given dimension must be
+ lower than the stride along that same dimension.
+ If set to `None` (default), the output shape is inferred.
data_format: A string,
one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
@@ -943,6 +991,7 @@ class Conv3DTranspose(Conv3D):
kernel_size,
strides=(1, 1, 1),
padding='valid',
+ output_padding=None,
data_format=None,
activation=None,
use_bias=True,
@@ -971,6 +1020,16 @@ class Conv3DTranspose(Conv3D):
bias_constraint=constraints.get(bias_constraint),
**kwargs)
+ self.output_padding = output_padding
+ if self.output_padding is not None:
+ self.output_padding = conv_utils.normalize_tuple(
+ self.output_padding, 3, 'output_padding')
+ for stride, out_pad in zip(self.strides, self.output_padding):
+ if out_pad >= stride:
+ raise ValueError('Stride ' + str(self.strides) + ' must be '
+ 'greater than output padding ' +
+ str(self.output_padding))
+
def build(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape)
if len(input_shape) != 5:
@@ -1012,11 +1071,9 @@ class Conv3DTranspose(Conv3D):
inputs_shape = array_ops.shape(inputs)
batch_size = inputs_shape[0]
if self.data_format == 'channels_first':
- c_axis, d_axis, h_axis, w_axis = 1, 2, 3, 4
+ d_axis, h_axis, w_axis = 2, 3, 4
else:
- c_axis, d_axis, h_axis, w_axis = 4, 1, 2, 3
-
- self.input_spec = InputSpec(ndim=5, axes={c_axis: inputs_shape[c_axis]})
+ d_axis, h_axis, w_axis = 1, 2, 3
depth = inputs_shape[d_axis]
height = inputs_shape[h_axis]
@@ -1025,19 +1082,27 @@ class Conv3DTranspose(Conv3D):
kernel_d, kernel_h, kernel_w = self.kernel_size
stride_d, stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_d = out_pad_h = out_pad_w = None
+ else:
+ out_pad_d, out_pad_h, out_pad_w = self.output_padding
+
# Infer the dynamic output shape:
out_depth = conv_utils.deconv_output_length(depth,
kernel_d,
- self.padding,
- stride_d)
+ padding=self.padding,
+ output_padding=out_pad_d,
+ stride=stride_d)
out_height = conv_utils.deconv_output_length(height,
kernel_h,
- self.padding,
- stride_h)
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h)
out_width = conv_utils.deconv_output_length(width,
kernel_w,
- self.padding,
- stride_w)
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w)
if self.data_format == 'channels_first':
output_shape = (batch_size, self.filters, out_depth, out_height,
out_width)
@@ -1058,20 +1123,7 @@ class Conv3DTranspose(Conv3D):
if not context.executing_eagerly():
# Infer the static output shape:
- out_shape = inputs.get_shape().as_list()
- out_shape[c_axis] = self.filters
- out_shape[d_axis] = conv_utils.deconv_output_length(out_shape[d_axis],
- kernel_d,
- self.padding,
- stride_d)
- out_shape[h_axis] = conv_utils.deconv_output_length(out_shape[h_axis],
- kernel_h,
- self.padding,
- stride_h)
- out_shape[w_axis] = conv_utils.deconv_output_length(out_shape[w_axis],
- kernel_w,
- self.padding,
- stride_w)
+ out_shape = self.compute_output_shape(inputs.shape)
outputs.set_shape(out_shape)
if self.use_bias:
@@ -1109,15 +1161,38 @@ class Conv3DTranspose(Conv3D):
kernel_d, kernel_h, kernel_w = self.kernel_size
stride_d, stride_h, stride_w = self.strides
+ if self.output_padding is None:
+ out_pad_d = out_pad_h = out_pad_w = None
+ else:
+ out_pad_d, out_pad_h, out_pad_w = self.output_padding
+
output_shape[c_axis] = self.filters
output_shape[d_axis] = conv_utils.deconv_output_length(
- output_shape[d_axis], kernel_d, self.padding, stride_d)
+ output_shape[d_axis],
+ kernel_d,
+ padding=self.padding,
+ output_padding=out_pad_d,
+ stride=stride_d)
output_shape[h_axis] = conv_utils.deconv_output_length(
- output_shape[h_axis], kernel_h, self.padding, stride_h)
+ output_shape[h_axis],
+ kernel_h,
+ padding=self.padding,
+ output_padding=out_pad_h,
+ stride=stride_h)
output_shape[w_axis] = conv_utils.deconv_output_length(
- output_shape[w_axis], kernel_w, self.padding, stride_w)
+ output_shape[w_axis],
+ kernel_w,
+ padding=self.padding,
+ output_padding=out_pad_w,
+ stride=stride_w)
return tensor_shape.TensorShape(output_shape)
+ def get_config(self):
+ config = super(Conv3DTranspose, self).get_config()
+ config.pop('dilation_rate')
+ config['output_padding'] = self.output_padding
+ return config
+
class SeparableConv(Conv):
"""Abstract base layer for separable nD convolution.
diff --git a/tensorflow/python/keras/layers/convolutional_test.py b/tensorflow/python/keras/layers/convolutional_test.py
index 2d3d38a5ce..f88d632ab5 100644
--- a/tensorflow/python/keras/layers/convolutional_test.py
+++ b/tensorflow/python/keras/layers/convolutional_test.py
@@ -113,7 +113,7 @@ class Conv2DTest(test.TestCase):
test_kwargs[arg] = value
with self.test_session(use_gpu=True):
testing_utils.layer_test(
- keras.layers.SeparableConv2D,
+ keras.layers.Conv2D,
kwargs=test_kwargs,
input_shape=(num_samples, num_row, num_col, stack_size))
@@ -204,6 +204,9 @@ class Conv2DTransposeTest(test.TestCase):
if test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, 'data_format', ['channels_first'])
+ kwargs['strides'] = (2, 2)
+ self._run_test(kwargs, 'output_padding', [(1, 1)])
+
def test_conv2dtranspose_regularizers(self):
kwargs = {
'filters': 3,
@@ -239,6 +242,31 @@ class Conv2DTransposeTest(test.TestCase):
self.assertEqual(layer.kernel.constraint, k_constraint)
self.assertEqual(layer.bias.constraint, b_constraint)
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_conv2d_transpose_dilation(self):
+ testing_utils.layer_test(keras.layers.Conv2DTranspose,
+ kwargs={'filters': 2,
+ 'kernel_size': 3,
+ 'padding': 'same',
+ 'data_format': 'channels_last',
+ 'dilation_rate': (2, 2)},
+ input_shape=(2, 5, 6, 3))
+
+ input_data = np.arange(48).reshape((1, 4, 4, 3)).astype(np.float32)
+ expected_output = np.float32([[192, 228, 192, 228],
+ [336, 372, 336, 372],
+ [192, 228, 192, 228],
+ [336, 372, 336, 372]]).reshape((1, 4, 4, 1))
+ testing_utils.layer_test(keras.layers.Conv2DTranspose,
+ input_data=input_data,
+ kwargs={'filters': 1,
+ 'kernel_size': 3,
+ 'padding': 'same',
+ 'data_format': 'channels_last',
+ 'dilation_rate': (2, 2),
+ 'kernel_initializer': 'ones'},
+ expected_output=expected_output)
+
class Conv3DTransposeTest(test.TestCase):
@@ -270,6 +298,9 @@ class Conv3DTransposeTest(test.TestCase):
if test.is_gpu_available(cuda_only=True):
self._run_test(kwargs, 'data_format', ['channels_first'])
+ kwargs['strides'] = (2, 2, 2)
+ self._run_test(kwargs, 'output_padding', [(1, 1, 1)])
+
def test_conv3dtranspose_regularizers(self):
kwargs = {
'filters': 3,
diff --git a/tensorflow/python/keras/layers/core.py b/tensorflow/python/keras/layers/core.py
index 4032202986..efa21955e6 100644
--- a/tensorflow/python/keras/layers/core.py
+++ b/tensorflow/python/keras/layers/core.py
@@ -671,22 +671,34 @@ class Lambda(Layer):
if mask is not None:
self.supports_masking = True
self.mask = mask
- if output_shape is None:
- self._output_shape = None
- elif isinstance(output_shape, (tuple, list)):
- self._output_shape = tuple(output_shape)
- else:
- if not callable(output_shape):
- raise TypeError('In Lambda, `output_shape` '
- 'must be a list, a tuple, or a function.')
- self._output_shape = output_shape
+ if (output_shape is not None and not isinstance(output_shape,
+ (tuple, list)) and
+ not callable(output_shape)):
+ raise TypeError('In Lambda, `output_shape` '
+ 'must be a list, a tuple, or a function.')
+ # Convert a list representing a single shape into a tuple.
+ if (isinstance(output_shape, list) and isinstance(output_shape[0],
+ (int, type(None)))):
+ output_shape = tuple(output_shape)
+ self._output_shape = output_shape
@tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if self._output_shape is None:
if context.executing_eagerly():
- raise NotImplementedError
- x = K.placeholder(shape=input_shape)
+ # Make use of existing autocomputation for Eager mode but provide
+ # Lambda-specific error message.
+ try:
+ return super(Lambda, self).compute_output_shape(input_shape)
+ except NotImplementedError:
+ raise NotImplementedError('We could not automatically infer '
+ 'the static shape of the Lambda\'s output.'
+ ' Please specify the `output_shape` for'
+ ' this Lambda.')
+ if isinstance(input_shape, list):
+ x = [K.placeholder(shape=shape) for shape in input_shape]
+ else:
+ x = K.placeholder(shape=input_shape)
x = self.call(x)
if isinstance(x, list):
return [tensor_shape.TensorShape(K.int_shape(x_elem)) for x_elem in x]
@@ -697,16 +709,27 @@ class Lambda(Layer):
num_samples = input_shape[0][0]
else:
num_samples = input_shape[0] if input_shape else None
- return tensor_shape.TensorShape((num_samples,) +
- tuple(self._output_shape))
+ # List here represents multiple outputs.
+ if isinstance(self._output_shape, list):
+ return [
+ tensor_shape.TensorShape((num_samples,) + tuple(single_shape))
+ for single_shape in self._output_shape
+ ]
+ return tensor_shape.TensorShape((num_samples,) + self._output_shape)
else:
shape = self._output_shape(input_shape)
if not isinstance(shape, (list, tuple)):
raise ValueError(
'`output_shape` function must return a tuple or a list of tuples.')
+ # List here can represent multiple outputs or single output.
if isinstance(shape, list):
- if isinstance(shape[0], int) or shape[0] is None:
+ # Convert list representing single output into a tuple.
+ if isinstance(shape[0], (int, type(None))):
shape = tuple(shape)
+ else:
+ return [
+ tensor_shape.TensorShape(single_shape) for single_shape in shape
+ ]
return tensor_shape.TensorShape(shape)
def call(self, inputs, mask=None):
diff --git a/tensorflow/python/keras/layers/core_test.py b/tensorflow/python/keras/layers/core_test.py
index 1df1d575b1..f0fea1f65c 100644
--- a/tensorflow/python/keras/layers/core_test.py
+++ b/tensorflow/python/keras/layers/core_test.py
@@ -252,6 +252,51 @@ class CoreLayersTest(test.TestCase):
l(keras.backend.variable(np.ones((1, 1))))
self.assertEqual('lambda', l.get_config()['output_shape_type'])
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_autocalculate_multiple_inputs(self):
+
+ def lambda_fn(x):
+ return math_ops.matmul(x[0], x[1])
+
+ l = keras.layers.Lambda(lambda_fn)
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual((10, 20), output_shape)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_list_multiple_outputs(self):
+
+ def lambda_fn(x):
+ return x
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=[(10,), (20,)])
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual([(10, 10), (10, 20)], output_shape)
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_tuple_with_none(self):
+
+ def lambda_fn(x):
+ return x
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=(None, 10))
+ output_shape = l.compute_output_shape((5, 10, 20))
+ # Dimension(None) != Dimension(None), so check
+ # str representations for equality.
+ self.assertAllEqual(('5', '?', '10'), tuple([str(s) for s in output_shape]))
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_lambda_output_shape_function_multiple_outputs(self):
+
+ def lambda_fn(x):
+ return x
+
+ def output_shape_fn(input_shape):
+ return input_shape
+
+ l = keras.layers.Lambda(lambda_fn, output_shape=output_shape_fn)
+ output_shape = l.compute_output_shape([(10, 10), (10, 20)])
+ self.assertAllEqual([(10, 10), (10, 20)], output_shape)
+
def test_lambda_config_serialization(self):
with self.cached_session():
# test serialization with output_shape and output_shape_type
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py
index cf2b0c476c..29a09a3d71 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent.py
@@ -47,6 +47,9 @@ class _CuDNNRNN(RNN):
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
+ time_major: Boolean (default False). If true, the inputs and outputs will be
+ in shape `(timesteps, batch, ...)`, whereas in the False case, it will
+ be `(batch, timesteps, ...)`.
"""
def __init__(self,
@@ -54,6 +57,7 @@ class _CuDNNRNN(RNN):
return_state=False,
go_backwards=False,
stateful=False,
+ time_major=False,
**kwargs):
# We invoke the base layer's initializer directly here because we do not
# want to create RNN cell instance.
@@ -62,6 +66,7 @@ class _CuDNNRNN(RNN):
self.return_state = return_state
self.go_backwards = go_backwards
self.stateful = stateful
+ self.time_major = time_major
self.supports_masking = False
self.input_spec = [InputSpec(ndim=3)]
if hasattr(self.cell.state_size, '__len__'):
@@ -124,7 +129,8 @@ class _CuDNNRNN(RNN):
'return_sequences': self.return_sequences,
'return_state': self.return_state,
'go_backwards': self.go_backwards,
- 'stateful': self.stateful
+ 'stateful': self.stateful,
+ 'time_major': self.time_major,
}
base_config = super( # pylint: disable=bad-super-call
RNN, self).get_config()
@@ -267,7 +273,8 @@ class CuDNNGRU(_CuDNNRNN):
self.built = True
def _process_batch(self, inputs, initial_state):
- inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
+ if not self.time_major:
+ inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
input_h = initial_state[0]
input_h = array_ops.expand_dims(input_h, axis=0)
@@ -301,7 +308,10 @@ class CuDNNGRU(_CuDNNRNN):
if self.stateful or self.return_state:
h = h[0]
if self.return_sequences:
- output = array_ops.transpose(outputs, perm=(1, 0, 2))
+ if self.time_major:
+ output = outputs
+ else:
+ output = array_ops.transpose(outputs, perm=(1, 0, 2))
else:
output = outputs[-1]
return output, [h]
@@ -456,7 +466,8 @@ class CuDNNLSTM(_CuDNNRNN):
self.built = True
def _process_batch(self, inputs, initial_state):
- inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
+ if not self.time_major:
+ inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
input_h = initial_state[0]
input_c = initial_state[1]
input_h = array_ops.expand_dims(input_h, axis=0)
@@ -496,7 +507,10 @@ class CuDNNLSTM(_CuDNNRNN):
h = h[0]
c = c[0]
if self.return_sequences:
- output = array_ops.transpose(outputs, perm=(1, 0, 2))
+ if self.time_major:
+ output = outputs
+ else:
+ output = array_ops.transpose(outputs, perm=(1, 0, 2))
else:
output = outputs[-1]
return output, [h, c]
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
index 2ed0aa8f26..7becbfede1 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent_test.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
@@ -26,6 +26,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.training.rmsprop import RMSPropOptimizer
@@ -142,6 +143,32 @@ class CuDNNTest(test.TestCase, parameterized.TestCase):
('cudnngru', keras.layers.CuDNNGRU),
('cudnnlstm', keras.layers.CuDNNLSTM),
)
+ def test_time_major_input(self, layer_class):
+ if test.is_gpu_available(cuda_only=True):
+ with self.test_session(use_gpu=True):
+ input_size = 10
+ timesteps = 6
+ units = 2
+ num_samples = 32
+
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2])))
+ layer = layer_class(units, time_major=True, return_sequences=True)
+ model.add(layer)
+ model.add(
+ keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2])))
+ model.compile(loss='categorical_crossentropy', optimizer='adam')
+ model.fit(
+ np.ones((num_samples, timesteps, input_size)),
+ np.ones((num_samples, timesteps, units)))
+ out = model.predict(np.ones((num_samples, timesteps, input_size)))
+ self.assertEqual(out.shape, (num_samples, timesteps, units))
+
+ @parameterized.named_parameters(
+ ('cudnngru', keras.layers.CuDNNGRU),
+ ('cudnnlstm', keras.layers.CuDNNLSTM),
+ )
def test_specify_initial_state_keras_tensor(self, layer_class):
if test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py
index c6df5f2e26..824a0b069e 100644
--- a/tensorflow/python/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/layers/embeddings.py
@@ -159,13 +159,15 @@ class Embedding(Layer):
else:
in_lens = [self.input_length]
if len(in_lens) != len(input_shape) - 1:
- ValueError('"input_length" is %s, but received input has shape %s' %
- (str(self.input_length), str(input_shape)))
+ raise ValueError('"input_length" is %s, '
+ 'but received input has shape %s' % (str(
+ self.input_length), str(input_shape)))
else:
for i, (s1, s2) in enumerate(zip(in_lens, input_shape[1:])):
if s1 is not None and s2 is not None and s1 != s2:
- ValueError('"input_length" is %s, but received input has shape %s' %
- (str(self.input_length), str(input_shape)))
+ raise ValueError('"input_length" is %s, '
+ 'but received input has shape %s' % (str(
+ self.input_length), str(input_shape)))
elif s1 is None:
in_lens[i] = s2
return (input_shape[0],) + tuple(in_lens) + (self.output_dim,)
diff --git a/tensorflow/python/keras/layers/pooling.py b/tensorflow/python/keras/layers/pooling.py
index 912e8bd619..72a9c1d629 100644
--- a/tensorflow/python/keras/layers/pooling.py
+++ b/tensorflow/python/keras/layers/pooling.py
@@ -18,12 +18,15 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import functools
+
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras import backend
from tensorflow.python.keras.engine.base_layer import InputSpec
from tensorflow.python.keras.engine.base_layer import Layer
from tensorflow.python.keras.utils import conv_utils
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.util.tf_export import tf_export
@@ -41,16 +44,18 @@ class Pooling1D(Layer):
strides of the pooling operation.
padding: A string. The padding method, either 'valid' or 'same'.
Case-insensitive.
- data_format: A string, one of `channels_last` (default) or `channels_first`.
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
The ordering of the dimensions in the inputs.
`channels_last` corresponds to inputs with shape
- `(batch, length, channels)` while `channels_first` corresponds to
- inputs with shape `(batch, channels, length)`.
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
name: A string, the name of the layer.
"""
def __init__(self, pool_function, pool_size, strides,
- padding='valid', data_format=None,
+ padding='valid', data_format='channels_last',
name=None, **kwargs):
super(Pooling1D, self).__init__(name=name, **kwargs)
if data_format is None:
@@ -65,45 +70,39 @@ class Pooling1D(Layer):
self.input_spec = InputSpec(ndim=3)
def call(self, inputs):
- # There is no TF op for 1D pooling, hence we make the inputs 4D.
- if self.data_format == 'channels_last':
- # input is NWC, make it NHWC
- inputs = array_ops.expand_dims(inputs, 1)
- # pool on the W dim
- pool_shape = (1, 1) + self.pool_size + (1,)
- strides = (1, 1) + self.strides + (1,)
- data_format = 'NHWC'
- else:
- # input is NCW, make it NCHW
- inputs = array_ops.expand_dims(inputs, 2)
- # pool on the W dim
- pool_shape = (1, 1, 1) + self.pool_size
- strides = (1, 1, 1) + self.strides
- data_format = 'NCHW'
-
+ pad_axis = 2 if self.data_format == 'channels_last' else 3
+ inputs = array_ops.expand_dims(inputs, pad_axis)
outputs = self.pool_function(
inputs,
- ksize=pool_shape,
- strides=strides,
- padding=self.padding.upper(),
- data_format=data_format)
-
- if self.data_format == 'channels_last':
- return array_ops.squeeze(outputs, 1)
- else:
- return array_ops.squeeze(outputs, 2)
+ self.pool_size + (1,),
+ strides=self.strides + (1,),
+ padding=self.padding,
+ data_format=self.data_format)
+ return array_ops.squeeze(outputs, pad_axis)
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
- length = conv_utils.conv_output_length(input_shape[1], self.pool_size[0],
- self.padding, self.strides[0])
- return tensor_shape.TensorShape([input_shape[0], length, input_shape[2]])
+ if self.data_format == 'channels_first':
+ steps = input_shape[2]
+ features = input_shape[1]
+ else:
+ steps = input_shape[1]
+ features = input_shape[2]
+ length = conv_utils.conv_output_length(steps,
+ self.pool_size[0],
+ self.padding,
+ self.strides[0])
+ if self.data_format == 'channels_first':
+ return tensor_shape.TensorShape([input_shape[0], features, length])
+ else:
+ return tensor_shape.TensorShape([input_shape[0], length, features])
def get_config(self):
config = {
'strides': self.strides,
'pool_size': self.pool_size,
- 'padding': self.padding
+ 'padding': self.padding,
+ 'data_format': self.data_format,
}
base_config = super(Pooling1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
@@ -119,19 +118,36 @@ class MaxPooling1D(Pooling1D):
E.g. 2 will halve the input.
If None, it will default to `pool_size`.
padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
- 3D tensor with shape: `(batch_size, downsampled_steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, downsampled_steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, downsampled_steps)`
"""
def __init__(self, pool_size=2, strides=None,
- padding='valid', data_format=None, **kwargs):
+ padding='valid', data_format='channels_last', **kwargs):
super(MaxPooling1D, self).__init__(
- nn.max_pool,
+ functools.partial(backend.pool2d, pool_mode='max'),
pool_size=pool_size,
strides=strides,
padding=padding,
@@ -149,18 +165,35 @@ class AveragePooling1D(Pooling1D):
E.g. 2 will halve the input.
If None, it will default to `pool_size`.
padding: One of `"valid"` or `"same"` (case-insensitive).
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
- 3D tensor with shape: `(batch_size, downsampled_steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, downsampled_steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, downsampled_steps)`
"""
def __init__(self, pool_size=2, strides=None,
- padding='valid', data_format=None, **kwargs):
+ padding='valid', data_format='channels_last', **kwargs):
super(AveragePooling1D, self).__init__(
- nn.avg_pool,
+ functools.partial(backend.pool2d, pool_mode='avg'),
pool_size=pool_size,
strides=strides,
padding=padding,
@@ -561,41 +594,96 @@ class GlobalPooling1D(Layer):
"""Abstract class for different global pooling 1D layers.
"""
- def __init__(self, **kwargs):
+ def __init__(self, data_format='channels_last', **kwargs):
super(GlobalPooling1D, self).__init__(**kwargs)
self.input_spec = InputSpec(ndim=3)
+ self.data_format = conv_utils.normalize_data_format(data_format)
def compute_output_shape(self, input_shape):
input_shape = tensor_shape.TensorShape(input_shape).as_list()
- return tensor_shape.TensorShape([input_shape[0], input_shape[2]])
+ if self.data_format == 'channels_first':
+ return tensor_shape.TensorShape([input_shape[0], input_shape[1]])
+ else:
+ return tensor_shape.TensorShape([input_shape[0], input_shape[2]])
def call(self, inputs):
raise NotImplementedError
+ def get_config(self):
+ config = {'data_format': self.data_format}
+ base_config = super(GlobalPooling1D, self).get_config()
+ return dict(list(base_config.items()) + list(config.items()))
+
@tf_export('keras.layers.GlobalAveragePooling1D',
'keras.layers.GlobalAvgPool1D')
class GlobalAveragePooling1D(GlobalPooling1D):
"""Global average pooling operation for temporal data.
+ Arguments:
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
+
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
2D tensor with shape:
`(batch_size, features)`
"""
- def call(self, inputs):
- return backend.mean(inputs, axis=1)
+ def __init__(self, data_format='channels_last', **kwargs):
+ super(GlobalAveragePooling1D, self).__init__(data_format=data_format,
+ **kwargs)
+ self.supports_masking = True
+
+ def call(self, inputs, mask=None):
+ steps_axis = 1 if self.data_format == 'channels_last' else 2
+ if mask is not None:
+ mask = math_ops.cast(mask, backend.floatx())
+ input_shape = inputs.shape.as_list()
+ broadcast_shape = [-1, input_shape[steps_axis], 1]
+ mask = array_ops.reshape(mask, broadcast_shape)
+ inputs *= mask
+ return backend.sum(inputs, axis=steps_axis) / math_ops.reduce_sum(
+ mask, axis=steps_axis)
+ else:
+ return backend.mean(inputs, axis=steps_axis)
+
+ def compute_mask(self, inputs, mask=None):
+ return None
@tf_export('keras.layers.GlobalMaxPool1D', 'keras.layers.GlobalMaxPooling1D')
class GlobalMaxPooling1D(GlobalPooling1D):
"""Global max pooling operation for temporal data.
+ Arguments:
+ data_format: A string,
+ one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, steps, features)` while `channels_first`
+ corresponds to inputs with shape
+ `(batch, features, steps)`.
+
Input shape:
- 3D tensor with shape: `(batch_size, steps, features)`.
+ - If `data_format='channels_last'`:
+ 3D tensor with shape:
+ `(batch_size, steps, features)`
+ - If `data_format='channels_first'`:
+ 3D tensor with shape:
+ `(batch_size, features, steps)`
Output shape:
2D tensor with shape:
@@ -603,7 +691,8 @@ class GlobalMaxPooling1D(GlobalPooling1D):
"""
def call(self, inputs):
- return backend.max(inputs, axis=1)
+ steps_axis = 1 if self.data_format == 'channels_last' else 2
+ return backend.max(inputs, axis=steps_axis)
class GlobalPooling2D(Layer):
diff --git a/tensorflow/python/keras/layers/pooling_test.py b/tensorflow/python/keras/layers/pooling_test.py
index 2cd9939e66..936e73ecf9 100644
--- a/tensorflow/python/keras/layers/pooling_test.py
+++ b/tensorflow/python/keras/layers/pooling_test.py
@@ -18,11 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.python import keras
from tensorflow.python.eager import context
from tensorflow.python.framework import test_util as tf_test_util
from tensorflow.python.keras import testing_utils
from tensorflow.python.platform import test
+from tensorflow.python.training import rmsprop
class GlobalPoolingTest(test.TestCase):
@@ -31,8 +34,26 @@ class GlobalPoolingTest(test.TestCase):
def test_globalpooling_1d(self):
testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
input_shape=(3, 4, 5))
+ testing_utils.layer_test(keras.layers.pooling.GlobalMaxPooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 4, 5))
testing_utils.layer_test(
keras.layers.pooling.GlobalAveragePooling1D, input_shape=(3, 4, 5))
+ testing_utils.layer_test(keras.layers.pooling.GlobalAveragePooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 4, 5))
+
+ @tf_test_util.run_in_graph_and_eager_modes
+ def test_globalpooling_1d_masking_support(self):
+ model = keras.Sequential()
+ model.add(keras.layers.Masking(mask_value=0., input_shape=(3, 4)))
+ model.add(keras.layers.GlobalAveragePooling1D())
+ model.compile(loss='mae', optimizer=rmsprop.RMSPropOptimizer(0.001))
+
+ model_input = np.random.random((2, 3, 4))
+ model_input[0, 1:, :] = 0
+ output = model.predict(model_input)
+ self.assertAllClose(output[0], model_input[0, 0, :])
@tf_test_util.run_in_graph_and_eager_modes
def test_globalpooling_2d(self):
@@ -172,6 +193,10 @@ class Pooling1DTest(test.TestCase):
kwargs={'strides': stride,
'padding': padding},
input_shape=(3, 5, 4))
+ testing_utils.layer_test(
+ keras.layers.MaxPooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 2, 6))
@tf_test_util.run_in_graph_and_eager_modes
def test_averagepooling_1d(self):
@@ -183,6 +208,11 @@ class Pooling1DTest(test.TestCase):
'padding': padding},
input_shape=(3, 5, 4))
+ testing_utils.layer_test(
+ keras.layers.AveragePooling1D,
+ kwargs={'data_format': 'channels_first'},
+ input_shape=(3, 2, 6))
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index ba7498e7e6..b07ec71178 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -336,9 +336,18 @@ class RNN(Layer):
in your model, you would need to specify the input length
at the level of the first layer
(e.g. via the `input_shape` argument)
+ time_major: The shape format of the `inputs` and `outputs` tensors.
+ If True, the inputs and outputs will be in shape
+ `(timesteps, batch, ...)`, whereas in the False case, it will be
+ `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
+ efficient because it avoids transposes at the beginning and end of the
+ RNN calculation. However, most TensorFlow data is batch-major, so by
+ default this function accepts input and emits output in batch-major
+ form.
Input shape:
- N-D tensor with shape `(batch_size, timesteps, ...)`.
+ N-D tensor with shape `(batch_size, timesteps, ...)` or
+ `(timesteps, batch_size, ...)` when time_major is True.
Output shape:
- if `return_state`: a list of tensors. The first tensor is
@@ -347,7 +356,8 @@ class RNN(Layer):
be a high dimension tensor shape.
- if `return_sequences`: N-D tensor with shape
`(batch_size, timesteps, output_size)`, where `output_size` could
- be a high dimension tensor shape.
+ be a high dimension tensor shape, or
+ `(timesteps, batch_size, output_size)` when `time_major` is True.
- else, N-D tensor with shape `(batch_size, output_size)`, where
`output_size` could be a high dimension tensor shape.
@@ -448,6 +458,7 @@ class RNN(Layer):
go_backwards=False,
stateful=False,
unroll=False,
+ time_major=False,
**kwargs):
if isinstance(cell, (list, tuple)):
cell = StackedRNNCells(cell)
@@ -468,6 +479,7 @@ class RNN(Layer):
self.go_backwards = go_backwards
self.stateful = stateful
self.unroll = unroll
+ self.time_major = time_major
self.supports_masking = True
self.input_spec = [None] # The input shape is unknown yet, at least rank 3.
@@ -503,14 +515,21 @@ class RNN(Layer):
# Note that state_size[0] could be a tensor_shape or int.
output_dim = tensor_shape.as_shape(state_size[0]).as_list()
+ batch = input_shape[0]
+ time_step = input_shape[1]
+ if self.time_major:
+ batch, time_step = time_step, batch
if self.return_sequences:
- output_shape = tuple([input_shape[0], input_shape[1]] + output_dim)
+ if self.time_major:
+ output_shape = tuple([time_step, batch] + output_dim)
+ else:
+ output_shape = tuple([batch, time_step] + output_dim)
else:
- output_shape = tuple([input_shape[0]] + output_dim)
+ output_shape = tuple([batch] + output_dim)
if self.return_state:
state_shape = [
- tuple([input_shape[0]] + tensor_shape.as_shape(dim).as_list())
+ tuple([batch] + tensor_shape.as_shape(dim).as_list())
for dim in state_size
]
return [output_shape] + state_shape
@@ -539,13 +558,18 @@ class RNN(Layer):
if isinstance(input_shape, list):
input_shape = input_shape[0]
- batch_size = input_shape[0] if self.stateful else None
- input_dim = input_shape[2:]
- self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_dim)
+ input_spec_shape = list(input_shape)
+ batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
+ if not self.stateful:
+ input_spec_shape[batch_index] = None
+ input_spec_shape[time_step_index] = None
+ self.input_spec[0] = InputSpec(shape=tuple(input_spec_shape))
+ batch = input_shape[batch_index]
+ input_dim = input_shape[2:]
+ step_input_shape = (batch,) + input_dim
# allow cell (if layer) to build before we set or validate state_spec
if isinstance(self.cell, Layer):
- step_input_shape = (input_shape[0],) + input_dim
if constants_shape is not None:
self.cell.build([step_input_shape] + constants_shape)
else:
@@ -598,12 +622,16 @@ class RNN(Layer):
def get_initial_state(self, inputs):
get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
+
+ input_shape = array_ops.shape(inputs)
+ batch_size = input_shape[1] if self.time_major else input_shape[0]
+ dtype = inputs.dtype
if get_initial_state_fn:
init_state = get_initial_state_fn(
- inputs=inputs, batch_size=None, dtype=None)
+ inputs=None, batch_size=batch_size, dtype=dtype)
else:
- init_state = _generate_zero_filled_state(
- array_ops.shape(inputs)[0], self.cell.state_size, inputs.dtype)
+ init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,
+ dtype)
# Keras RNN expect the states in a list, even if it's a single state tensor.
if not nest.is_sequence(init_state):
init_state = [init_state]
@@ -696,7 +724,7 @@ class RNN(Layer):
'Layer has ' + str(len(self.states)) + ' states but was passed ' +
str(len(initial_state)) + ' initial states.')
input_shape = K.int_shape(inputs)
- timesteps = input_shape[1]
+ timesteps = input_shape[0] if self.time_major else input_shape[1]
if self.unroll and timesteps in [None, 1]:
raise ValueError('Cannot unroll a RNN if the '
'time dimension is undefined or equal to 1. \n'
@@ -747,7 +775,8 @@ class RNN(Layer):
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
- input_length=timesteps)
+ input_length=timesteps,
+ time_major=self.time_major)
if self.stateful:
updates = []
for i in range(len(states)):
@@ -777,7 +806,10 @@ class RNN(Layer):
def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
- batch_size = self.input_spec[0].shape[0]
+ if self.time_major:
+ batch_size = self.input_spec[0].shape[1]
+ else:
+ batch_size = self.input_spec[0].shape[0]
if not batch_size:
raise ValueError('If a RNN is stateful, it needs to know '
'its batch size. Specify the batch size '
@@ -839,7 +871,8 @@ class RNN(Layer):
'return_state': self.return_state,
'go_backwards': self.go_backwards,
'stateful': self.stateful,
- 'unroll': self.unroll
+ 'unroll': self.unroll,
+ 'time_major': self.time_major
}
if self._num_constants is not None:
config['num_constants'] = self._num_constants
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index b9e90095e4..d246be6b45 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -186,6 +186,96 @@ class RNNTest(test.TestCase):
y_np_2 = model.predict(x_np)
self.assertAllClose(y_np, y_np_2, atol=1e-4)
+ def test_rnn_with_time_major(self):
+ batch = 10
+ time_step = 5
+ embedding_dim = 4
+ units = 3
+
+ with self.cached_session():
+ # Test basic case.
+ x = keras.Input((time_step, embedding_dim))
+ time_major_x = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(x)
+ layer = keras.layers.SimpleRNN(
+ units, time_major=True, return_sequences=True)
+ self.assertEqual(
+ layer.compute_output_shape((time_step, None,
+ embedding_dim)).as_list(),
+ [time_step, None, units])
+ y = layer(time_major_x)
+ self.assertEqual(layer.output_shape, (time_step, None, units))
+
+ y = keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))(y)
+
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ np.zeros((batch, time_step, embedding_dim)),
+ np.zeros((batch, time_step, units)))
+
+ with self.cached_session():
+ # Test stacking.
+ x = keras.Input((time_step, embedding_dim))
+ time_major_x = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(x)
+ cell_units = [10, 8, 6]
+ cells = [keras.layers.SimpleRNNCell(cell_units[i]) for i in range(3)]
+ layer = keras.layers.RNN(cells, time_major=True, return_sequences=True)
+ y = layer(time_major_x)
+ self.assertEqual(layer.output_shape, (time_step, None, cell_units[-1]))
+
+ y = keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))(y)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ np.zeros((batch, time_step, embedding_dim)),
+ np.zeros((batch, time_step, cell_units[-1])))
+
+ with self.cached_session():
+ # Test masking.
+ x = keras.Input((time_step, embedding_dim))
+ time_major = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(x)
+ mask = keras.layers.Masking()(time_major)
+ rnn = keras.layers.SimpleRNN(
+ units, time_major=True, return_sequences=True)(mask)
+ y = keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))(rnn)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ np.zeros((batch, time_step, embedding_dim)),
+ np.zeros((batch, time_step, units)))
+
+ with self.cached_session():
+ # Test layer output
+ x = keras.Input((time_step, embedding_dim))
+ rnn_1 = keras.layers.SimpleRNN(units, return_sequences=True)
+ y = rnn_1(x)
+
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ np.zeros((batch, time_step, embedding_dim)),
+ np.zeros((batch, time_step, units)))
+
+ x_np = np.random.random((batch, time_step, embedding_dim))
+ y_np_1 = model.predict(x_np)
+
+ time_major = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(x)
+ rnn_2 = keras.layers.SimpleRNN(
+ units, time_major=True, return_sequences=True)
+ y_2 = rnn_2(time_major)
+ y_2 = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(y_2)
+
+ model_2 = keras.models.Model(x, y_2)
+ rnn_2.set_weights(rnn_1.get_weights())
+
+ y_np_2 = model_2.predict(x_np)
+ self.assertAllClose(y_np_1, y_np_2, atol=1e-4)
+
def test_rnn_cell_with_constants_layer(self):
class RNNCellWithConstants(keras.layers.Layer):
diff --git a/tensorflow/python/keras/layers/wrappers.py b/tensorflow/python/keras/layers/wrappers.py
index a1933c11b0..d19d0b5f8c 100644
--- a/tensorflow/python/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/layers/wrappers.py
@@ -587,6 +587,9 @@ class Bidirectional(Wrapper):
output = y * y_rev
elif self.merge_mode is None:
output = [y, y_rev]
+ else:
+ raise ValueError(
+ 'Unrecognized value for `merge_mode`: %s' % (self.merge_mode))
# Properly set learning phase
if (getattr(y, '_uses_learning_phase', False) or
diff --git a/tensorflow/python/keras/metrics.py b/tensorflow/python/keras/metrics.py
index e64241e5cf..d217244e2f 100644
--- a/tensorflow/python/keras/metrics.py
+++ b/tensorflow/python/keras/metrics.py
@@ -71,6 +71,22 @@ def check_is_tensor_or_operation(x, name):
name, x))
+def clone_metric(metric):
+ """Returns a clone of the metric if stateful, otherwise returns it as is."""
+ if isinstance(metric, Metric):
+ return metric.__class__.from_config(metric.get_config())
+ return metric
+
+
+def clone_metrics(metrics):
+ """Clones the given metric list/dict."""
+ if metrics is None:
+ return None
+ if isinstance(metrics, dict):
+ return {key: clone_metric(value) for key, value in metrics.items()}
+ return [clone_metric(metric) for metric in metrics]
+
+
def update_state_wrapper(update_state_fn):
"""Decorator to wrap metric `update_state()` with `add_update()`.
@@ -635,7 +651,9 @@ def categorical_accuracy(y_true, y_pred):
@tf_export('keras.metrics.sparse_categorical_accuracy')
def sparse_categorical_accuracy(y_true, y_pred):
- y_true = math_ops.reduce_max(y_true, axis=-1)
+ # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
+ if (len(K.int_shape(y_true)) == len(K.int_shape(y_pred))):
+ y_true = array_ops.squeeze(y_true, [-1])
y_pred = math_ops.argmax(y_pred, axis=-1)
# If the expected labels are float, we need to cast the int returned by
@@ -654,11 +672,11 @@ def top_k_categorical_accuracy(y_true, y_pred, k=5):
@tf_export('keras.metrics.sparse_top_k_categorical_accuracy')
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
- return K.mean(
- nn.in_top_k(y_pred,
- math_ops.cast(math_ops.reduce_max(y_true, axis=-1), 'int32'),
- k),
- axis=-1)
+ # If the shape of y_true is (num_samples, 1), squeeze to (num_samples,)
+ if (len(K.int_shape(y_true)) == len(K.int_shape(y_pred))):
+ y_true = array_ops.squeeze(y_true, [-1])
+
+ return K.mean(nn.in_top_k(y_pred, math_ops.cast(y_true, 'int32'), k), axis=-1)
# Aliases
diff --git a/tensorflow/python/keras/metrics_test.py b/tensorflow/python/keras/metrics_test.py
index 4195ea18ad..5f5565d4d5 100644
--- a/tensorflow/python/keras/metrics_test.py
+++ b/tensorflow/python/keras/metrics_test.py
@@ -54,6 +54,18 @@ class KerasMetricsTest(test.TestCase):
y_pred = K.variable(np.random.random((6, 7)))
self.assertEqual(K.eval(metric(y_true, y_pred)).shape, (6,))
+ # Test correctness if the shape of y_true is (num_samples,)
+ y_true = K.variable([1., 0., 0., 0.])
+ y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
+ print(K.eval(metric(y_true, y_pred)))
+ self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
+
+ # Test correctness if the shape of y_true is (num_samples, 1)
+ y_true = K.variable([[1.], [0.], [0.], [0.]])
+ y_pred = K.variable([[0.8, 0.2], [0.6, 0.4], [0.7, 0.3], [0.9, 0.1]])
+ print(K.eval(metric(y_true, y_pred)))
+ self.assertAllEqual(K.eval(metric(y_true, y_pred)), [0., 1., 1., 1.])
+
def test_sparse_categorical_accuracy_float(self):
with self.cached_session():
metric = metrics.sparse_categorical_accuracy
@@ -79,6 +91,7 @@ class KerasMetricsTest(test.TestCase):
def test_sparse_top_k_categorical_accuracy(self):
with self.cached_session():
+ # Test correctness if the shape of y_true is (num_samples, 1)
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
y_true = K.variable(np.array([[1], [0]]))
result = K.eval(
@@ -91,6 +104,19 @@ class KerasMetricsTest(test.TestCase):
metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
self.assertEqual(result, 0.)
+ # Test correctness if the shape of y_true is (num_samples,)
+ y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
+ y_true = K.variable(np.array([1, 0]))
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=3))
+ self.assertEqual(result, 1)
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=2))
+ self.assertEqual(result, 0.5)
+ result = K.eval(
+ metrics.sparse_top_k_categorical_accuracy(y_true, y_pred, k=1))
+ self.assertEqual(result, 0.)
+
def test_top_k_categorical_accuracy(self):
with self.cached_session():
y_pred = K.variable(np.array([[0.3, 0.2, 0.1], [0.1, 0.2, 0.7]]))
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index 41c5e3cccf..2883c9ad74 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -20,6 +20,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.keras import backend as K
+from tensorflow.python.keras import metrics as metrics_module
from tensorflow.python.keras import optimizers
from tensorflow.python.keras.engine import saving
from tensorflow.python.keras.engine import sequential
@@ -95,6 +96,8 @@ def _clone_functional_model(model, input_tensors=None):
else:
# Make sure that all input tensors come from a Keras layer.
# If tensor comes from an input layer: cache the input layer.
+ if isinstance(input_tensors, tuple):
+ input_tensors = list(input_tensors)
input_tensors = generic_utils.to_list(input_tensors)
input_tensors_ = []
for i, x in enumerate(input_tensors):
@@ -211,6 +214,9 @@ def _clone_sequential_model(model, input_tensors=None):
raise ValueError('To clone a `Sequential` model, we expect '
' at most one tensor '
'as part of `input_tensors`.')
+
+ if isinstance(input_tensors, tuple):
+ input_tensors = list(input_tensors)
x = generic_utils.to_list(input_tensors)[0]
if K.is_keras_tensor(x):
origin_layer = x._keras_history[0]
@@ -290,7 +296,9 @@ def _in_place_subclassed_model_reset(model):
if isinstance(value, Layer):
attributes_cache[name] = value
assert value in model._layers
- elif isinstance(value, (list, tuple)) and name not in ('layers', '_layers'):
+ elif isinstance(
+ value, (list, tuple)) and name not in ('layers', '_layers',
+ 'stateful_metric_functions'):
# Handle case: list/tuple of layers (also tracked by the Network API).
if value and all(isinstance(val, Layer) for val in value):
raise ValueError('We do not support the use of list-of-layers '
@@ -466,10 +474,10 @@ def clone_and_build_model(
clone.compile(
optimizer,
model.loss,
- metrics=model.metrics,
+ metrics=metrics_module.clone_metrics(model.metrics),
loss_weights=model.loss_weights,
sample_weight_mode=model.sample_weight_mode,
- weighted_metrics=model.weighted_metrics,
+ weighted_metrics=metrics_module.clone_metrics(model.weighted_metrics),
target_tensors=target_tensors)
return clone
diff --git a/tensorflow/python/keras/optimizer_v2/adadelta.py b/tensorflow/python/keras/optimizer_v2/adadelta.py
new file mode 100644
index 0000000000..d3b3c9c12e
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adadelta.py
@@ -0,0 +1,116 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adadelta for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.training import training_ops
+
+
+class Adadelta(optimizer_v2.OptimizerV2):
+ """Adadelta optimizer.
+
+ It is recommended to leave the parameters of this optimizer at their default
+ values.
+
+ See [M. D. Zeiler](http://arxiv.org/abs/1212.5701)
+ ([pdf](http://arxiv.org/pdf/1212.5701v1.pdf))
+
+ Some of the args below are hyperparameters, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate. It is recommended
+ to leave it at the default value.
+ rho: float hyperparameter >= 0. The decay rate.
+ epsilon: float hyperparameter >= 0. Fuzz factor. A constant epsilon used
+ to better condition the grad update.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to 'Adadelta'.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ rho=0.95,
+ epsilon=1e-8,
+ name="Adadelta"):
+ super(Adadelta, self).__init__(name)
+ self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("rho", rho)
+ self._set_hyper("epsilon", epsilon)
+
+ def _create_vars(self, var_list, state):
+ for v in var_list:
+ state.zeros_slot(v, "accum")
+ state.zeros_slot(v, "accum_update")
+
+ def _apply_dense(self, grad, var, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.apply_adadelta(
+ var,
+ accum,
+ accum_update,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _resource_apply_dense(self, grad, var, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.resource_apply_adadelta(
+ var.handle,
+ accum.handle,
+ accum_update.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.sparse_apply_adadelta(
+ var,
+ accum,
+ accum_update,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ accum = state.get_slot(var, "accum")
+ accum_update = state.get_slot(var, "accum_update")
+ return training_ops.resource_sparse_apply_adadelta(
+ var.handle,
+ accum.handle,
+ accum_update.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/adadelta_test.py b/tensorflow/python/keras/optimizer_v2/adadelta_test.py
new file mode 100644
index 0000000000..6e48f92e4f
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adadelta_test.py
@@ -0,0 +1,166 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Adadelta Optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.keras.optimizer_v2 import adadelta
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class AdadeltaOptimizerTest(test.TestCase):
+
+ def doTestBasic(self, use_resource=False):
+ num_updates = 4 # number of ADADELTA steps to perform
+ for dtype in [dtypes.half, dtypes.float32]:
+ for grad in [0.2, 0.1, 0.01]:
+ for lr in [1.0, 0.5, 0.1]:
+ with self.cached_session():
+ var0_init = [1.0, 2.0]
+ var1_init = [3.0, 4.0]
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_init, dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_init, dtype=dtype)
+ else:
+ var0 = variables.Variable(var0_init, dtype=dtype)
+ var1 = variables.Variable(var1_init, dtype=dtype)
+
+ grads = constant_op.constant([grad, grad], dtype=dtype)
+
+ accum = 0.0
+ accum_update = 0.0
+
+ # ADADELTA gradient optimizer
+ rho = 0.95
+ epsilon = 1e-8
+ adadelta_opt = adadelta.Adadelta(lr, rho, epsilon)
+ adadelta_update = adadelta_opt.apply_gradients(
+ zip([grads, grads], [var0, var1]))
+
+ opt_vars = adadelta_opt.variables()
+ self.assertStartsWith(opt_vars[0].name, var0._shared_name)
+ self.assertStartsWith(opt_vars[1].name, var0._shared_name)
+ self.assertStartsWith(opt_vars[2].name, var1._shared_name)
+ self.assertStartsWith(opt_vars[3].name, var1._shared_name)
+ self.assertEqual(4, len(opt_vars))
+
+ variables.global_variables_initializer().run()
+
+ # Assign slots
+ slot = [None] * 2
+ slot_update = [None] * 2
+ self.assertEqual(["accum", "accum_update"],
+ adadelta_opt.get_slot_names())
+ slot[0] = adadelta_opt.get_slot(var0, "accum")
+ self.assertEquals(slot[0].get_shape(), var0.get_shape())
+ self.assertFalse(slot[0] in variables.trainable_variables())
+
+ slot_update[0] = adadelta_opt.get_slot(var0, "accum_update")
+ self.assertEquals(slot_update[0].get_shape(), var0.get_shape())
+ self.assertFalse(slot_update[0] in variables.trainable_variables())
+
+ slot[1] = adadelta_opt.get_slot(var1, "accum")
+ self.assertEquals(slot[1].get_shape(), var1.get_shape())
+ self.assertFalse(slot[1] in variables.trainable_variables())
+
+ slot_update[1] = adadelta_opt.get_slot(var1, "accum_update")
+ self.assertEquals(slot_update[1].get_shape(), var1.get_shape())
+ self.assertFalse(slot_update[1] in variables.trainable_variables())
+
+ # Fetch params to validate initial values
+ self.assertAllClose(var0_init, var0.eval())
+ self.assertAllClose(var1_init, var1.eval())
+
+ update = [None] * num_updates
+ tot_update = 0
+ for step in range(num_updates):
+ # Run adadelta update for comparison
+ adadelta_update.run()
+
+ # Perform initial update without previous accum values
+ accum = accum * rho + (grad**2) * (1 - rho)
+ update[step] = (np.sqrt(accum_update + epsilon) *
+ (1. / np.sqrt(accum + epsilon)) * grad)
+ accum_update = (accum_update * rho + (update[step]**2) *
+ (1.0 - rho))
+ tot_update += update[step] * lr
+
+ # Check that the accumulators have been updated
+ for slot_idx in range(2):
+ self.assertAllCloseAccordingToType(
+ np.array([accum, accum], dtype=dtype.as_numpy_dtype()),
+ slot[slot_idx].eval(),
+ rtol=1e-5)
+
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [accum_update, accum_update],
+ dtype=dtype.as_numpy_dtype()),
+ slot_update[slot_idx].eval(),
+ rtol=1e-5)
+
+ # Check that the parameters have been updated
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [var0_init[0] - tot_update, var0_init[1] - tot_update],
+ dtype=dtype.as_numpy_dtype()),
+ var0.eval(),
+ rtol=1e-5)
+
+ self.assertAllCloseAccordingToType(
+ np.array(
+ [var1_init[0] - tot_update, var1_init[1] - tot_update],
+ dtype=dtype.as_numpy_dtype()),
+ var1.eval(),
+ rtol=1e-5)
+
+ def testBasic(self):
+ self.doTestBasic(use_resource=False)
+
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = adadelta.Adadelta(1.0, 1.0, 1.0).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[-111, -138]], var0.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad.py b/tensorflow/python/keras/optimizer_v2/adagrad.py
new file mode 100644
index 0000000000..2d8cec2300
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adagrad.py
@@ -0,0 +1,119 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adagrad optimizer for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import training_ops
+
+
+class Adagrad(optimizer_v2.OptimizerV2):
+ """Adagrad optimizer.
+
+ It is recommended to leave the parameters of this optimizer at their default
+ values.
+
+ See this [paper](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf)
+ or this
+ [intro](https://ppasupat.github.io/a9online/uploads/proximal_notes.pdf).
+
+ The learning_rate arg below is a hyperparameter, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate.
+ initial_accumulator_value: A floating point value. Starting value for the
+ accumulators, must be positive.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to 'Adagrad'.
+
+ Raises:
+ ValueError: If the `initial_accumulator_value` is invalid.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ initial_accumulator_value=0.1,
+ name="Adagrad"):
+ if initial_accumulator_value <= 0.0:
+ raise ValueError("initial_accumulator_value must be positive: %s" %
+ initial_accumulator_value)
+ super(Adagrad, self).__init__(name)
+ self._set_hyper("learning_rate", learning_rate)
+
+ self._initial_accumulator_value = initial_accumulator_value
+
+ def _create_vars(self, var_list, state):
+ for v in var_list:
+ dtype = v.dtype.base_dtype
+ if v.get_shape().is_fully_defined():
+ init = init_ops.constant_initializer(self._initial_accumulator_value,
+ dtype=dtype)
+ else:
+ def init(v=v, dtype=dtype):
+ # Use a Tensor instead of initializer if variable does not have
+ # static shape.
+ init_constant = gen_array_ops.fill(array_ops.shape(v),
+ self._initial_accumulator_value)
+ return math_ops.cast(init_constant, dtype)
+ state.create_slot_with_initializer(v, init, v.get_shape(), dtype,
+ "accumulator")
+
+ def _apply_dense(self, grad, var, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.apply_adagrad(
+ var,
+ acc,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _resource_apply_dense(self, grad, var, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.resource_apply_adagrad(
+ var.handle,
+ acc.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.sparse_apply_adagrad(
+ var,
+ acc,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ acc = state.get_slot(var, "accumulator")
+ return training_ops.resource_sparse_apply_adagrad(
+ var.handle,
+ acc.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/adagrad_test.py b/tensorflow/python/keras/optimizer_v2/adagrad_test.py
new file mode 100644
index 0000000000..fc4ef5c399
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adagrad_test.py
@@ -0,0 +1,276 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for aggregate operations."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.optimizer_v2 import adagrad
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class AdagradOptimizerTest(test.TestCase):
+
+ def doTestBasic(self, use_resource=False):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ else:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = adagrad.Adagrad(3.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testBasic(self):
+ self.doTestBasic()
+
+ def testBasicResource(self):
+ self.doTestBasic(use_resource=True)
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable(
+ [[1.0, 2.0], [3.0, 4.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = adagrad.Adagrad(1.0).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType(
+ [[1.0, 2.0], [3.0, 4.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[0, 1], [3, 4]], var0.eval(), atol=0.01)
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = adagrad.Adagrad(
+ constant_op.constant(3.0), initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testSparseBasic(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(
+ [0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]),
+ constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(
+ [0.01], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([2, 1]))
+ ada_opt = adagrad.Adagrad(3.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([[1.0], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0]], var1.eval())
+ # Run 3 step of sgd
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([[-1.6026098728179932], [2.0]]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([[3.0], [3.715679168701172]]), var1.eval())
+
+ def testSparseRepeatedIndices(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ repeated_index_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ grad_repeated_index = ops.IndexedSlices(
+ constant_op.constant(
+ [0.1, 0.1], shape=[2, 1], dtype=dtype),
+ constant_op.constant([1, 1]),
+ constant_op.constant([2, 1]))
+ grad_aggregated = ops.IndexedSlices(
+ constant_op.constant(
+ [0.2], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([2, 1]))
+ repeated_update = adagrad.Adagrad(3.0).apply_gradients(
+ [(grad_repeated_index, repeated_index_update_var)])
+ aggregated_update = adagrad.Adagrad(3.0).apply_gradients(
+ [(grad_aggregated, aggregated_update_var)])
+ variables.global_variables_initializer().run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+ for _ in range(3):
+ repeated_update.run()
+ aggregated_update.run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+
+ def testSparseRepeatedIndicesResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var_repeated = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype)
+ loss_repeated = math_ops.reduce_sum(
+ embedding_ops.embedding_lookup(var_repeated, [0, 0]))
+ var_aggregated = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype)
+ loss_aggregated = 2 * math_ops.reduce_sum(
+ embedding_ops.embedding_lookup(var_aggregated, [0]))
+ update_op_repeated = adagrad.Adagrad(2.0).minimize(loss_repeated)
+ update_op_aggregated = adagrad.Adagrad(2.0).minimize(loss_aggregated)
+ variables.global_variables_initializer().run()
+ self.assertAllCloseAccordingToType(
+ var_repeated.eval(), var_aggregated.eval())
+ for _ in range(3):
+ update_op_repeated.run()
+ update_op_aggregated.run()
+ self.assertAllCloseAccordingToType(
+ var_repeated.eval(), var_aggregated.eval())
+
+ def testSparseStability(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ shape = [1, 6]
+ var0 = variables.Variable(
+ [[
+ 0.00872496, -0.106952, 0.110467, 0.226505, -0.0147257,
+ -0.0105945
+ ]],
+ dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(
+ [[
+ -5.91278e-05, 5.31673e-05, -2.5779e-06, 4.29153e-05,
+ -8.4877e-05, -9.48906e-05
+ ]],
+ shape=shape,
+ dtype=dtype),
+ constant_op.constant([0]),
+ constant_op.constant(shape))
+ ada_opt = adagrad.Adagrad(1.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(zip([grads0], [var0]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ init = variables.global_variables_initializer()
+ for _ in range(100):
+ init.run()
+ ada_update.run()
+ self.assertAllCloseAccordingToType(
+ np.array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([[
+ 0.00891194, -0.10712013, 0.11047515, 0.22636929, -0.0144573,
+ -0.01029443
+ ]]), var0.eval())
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = adagrad.Adagrad(3.0)
+ # Apply the optimizer twice. Both applications will use
+ # the same accums.
+ ada_update1 = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ ada_update2 = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = ada_opt.get_slot(var1, "accumulator")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values.
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Mix the first and the second adagrad for 3 steps.
+ ada_update1.run()
+ ada_update2.run()
+ ada_update1.run()
+ # Validate updated params (the same as with only 1 Adagrad).
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testDynamicShapeVariable_Ok(self):
+ with self.cached_session():
+ v = variable_scope.get_variable("v", initializer=constant_op.constant(1.),
+ validate_shape=False)
+ self.assertFalse(v.shape.is_fully_defined())
+ # Creating optimizer should cause no exception.
+ adagrad.Adagrad(3.0, initial_accumulator_value=0.1)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/adam.py b/tensorflow/python/keras/optimizer_v2/adam.py
new file mode 100644
index 0000000000..8367228d7a
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adam.py
@@ -0,0 +1,203 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Adam optimizer for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.training import training_ops
+
+
+class Adam(optimizer_v2.OptimizerV2):
+ r"""Adam Optimizer.
+
+ Default parameters follow those provided in the original paper.
+
+ See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
+ ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
+
+ Some of the args below are hyperparameters where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Initialization:
+
+ $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
+ $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
+ $$t := 0 \text{(Initialize timestep)}$$
+ The update rule for `variable` with gradient `g` uses an optimization
+ described at the end of section2 of the paper:
+
+ $$t := t + 1$$
+ $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
+
+ $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
+ $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
+ $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
+
+ The default value of 1e-8 for epsilon might not be a good default in
+ general. For example, when training an Inception network on ImageNet a
+ current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
+ formulation just before Section 2.1 of the Kingma and Ba paper rather than
+ the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
+ hat" in the paper.
+
+ The sparse implementation of this algorithm (used when the gradient is an
+ IndexedSlices object, typically because of `tf.gather` or an embedding
+ lookup in the forward pass) does apply momentum to variable slices even if
+ they were not used in the forward pass (meaning they have a gradient equal
+ to zero). Momentum decay (beta1) is also applied to the entire momentum
+ accumulator. This means that the sparse behavior is equivalent to the dense
+ behavior (in contrast to some momentum implementations which ignore momentum
+ unless a variable slice was actually used).
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate.
+ beta_1: float hyperparameter, 0 < beta_1 < 1. Generally close to 1. The
+ exponential decay rate for the 1st moment estimates.
+ beta_2: float hyperparameter, 0 < beta_2 < 1. Generally close to 1. The
+ exponential decay rate for the 2nd moment estimates.
+ epsilon: float hyperparameter >= 0. Fuzz factor. This epsilon is "epsilon
+ hat" in the Kingma and Ba paper (in the formula just before Section
+ 2.1), not the epsilon in Algorithm 1 of the paper.
+ name: Optional name for the operations created when applying gradients.
+ Defaults to "Adam".
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ beta_1=0.9,
+ beta_2=0.999,
+ epsilon=1e-8,
+ name="Adam"):
+ super(Adam, self).__init__(name)
+
+ self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("beta_1", beta_1)
+ self._set_hyper("beta_2", beta_2)
+ self._set_hyper("epsilon", epsilon)
+
+ def _get_beta_accumulators(self, state=None):
+ if state is None:
+ state = self._get_per_graph_state()
+ return (state.get_non_slot("beta_1_power"),
+ state.get_non_slot("beta_2_power"))
+
+ def _create_vars(self, var_list, state):
+ # Non-slot variables end up on the same device(s).
+ state.create_non_slot(
+ initial_value=lambda: state.get_hyper("beta_1"), name="beta_1_power")
+ state.create_non_slot(
+ initial_value=lambda: state.get_hyper("beta_2"), name="beta_2_power")
+
+ # Create slots for the first and second moments.
+ for v in var_list:
+ state.zeros_slot(v, "m")
+ state.zeros_slot(v, "v")
+
+ def _apply_dense(self, grad, var, state):
+ m = state.get_slot(var, "m")
+ v = state.get_slot(var, "v")
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ return training_ops.apply_adam(
+ var,
+ m,
+ v,
+ math_ops.cast(beta_1_power, var.dtype.base_dtype),
+ math_ops.cast(beta_2_power, var.dtype.base_dtype),
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("beta_1", var.dtype.base_dtype),
+ state.get_hyper("beta_2", var.dtype.base_dtype),
+ state.get_hyper("epsilon", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking).op
+
+ def _resource_apply_dense(self, grad, var, state):
+ m = state.get_slot(var, "m")
+ v = state.get_slot(var, "v")
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ return training_ops.resource_apply_adam(
+ var.handle,
+ m.handle,
+ v.handle,
+ math_ops.cast(beta_1_power, grad.dtype.base_dtype),
+ math_ops.cast(beta_2_power, grad.dtype.base_dtype),
+ state.get_hyper("learning_rate", grad.dtype.base_dtype),
+ state.get_hyper("beta_1", grad.dtype.base_dtype),
+ state.get_hyper("beta_2", grad.dtype.base_dtype),
+ state.get_hyper("epsilon", grad.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse_shared(self, grad, var, indices, scatter_add, state):
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ beta_1_power = math_ops.cast(beta_1_power, var.dtype.base_dtype)
+ beta_2_power = math_ops.cast(beta_2_power, var.dtype.base_dtype)
+ lr_t = state.get_hyper("learning_rate", var.dtype.base_dtype)
+ beta_1_t = state.get_hyper("beta_1", var.dtype.base_dtype)
+ beta_2_t = state.get_hyper("beta_2", var.dtype.base_dtype)
+ epsilon_t = state.get_hyper("epsilon", var.dtype.base_dtype)
+ lr = (lr_t * math_ops.sqrt(1 - beta_2_power) / (1 - beta_1_power))
+ # m_t = beta_1 * m + (1 - beta_1) * g_t
+ m = state.get_slot(var, "m")
+ m_scaled_g_values = grad * (1 - beta_1_t)
+ m_t = state_ops.assign(m, m * beta_1_t, use_locking=self._use_locking)
+ with ops.control_dependencies([m_t]):
+ m_t = scatter_add(m, indices, m_scaled_g_values)
+ # v_t = beta_2 * v + (1 - beta_2) * (g_t * g_t)
+ v = state.get_slot(var, "v")
+ v_scaled_g_values = (grad * grad) * (1 - beta_2_t)
+ v_t = state_ops.assign(v, v * beta_2_t, use_locking=self._use_locking)
+ with ops.control_dependencies([v_t]):
+ v_t = scatter_add(v, indices, v_scaled_g_values)
+ v_sqrt = math_ops.sqrt(v_t)
+ var_update = state_ops.assign_sub(var,
+ lr * m_t / (v_sqrt + epsilon_t),
+ use_locking=self._use_locking)
+ return control_flow_ops.group(*[var_update, m_t, v_t])
+
+ def _apply_sparse(self, grad, var, state):
+ return self._apply_sparse_shared(
+ grad.values, var, grad.indices,
+ lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
+ x, i, v, use_locking=self._use_locking),
+ state)
+
+ def _resource_scatter_add(self, x, i, v):
+ with ops.control_dependencies(
+ [resource_variable_ops.resource_scatter_add(
+ x.handle, i, v)]):
+ return x.value()
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ return self._apply_sparse_shared(
+ grad, var, indices, self._resource_scatter_add, state)
+
+ def _finish(self, state):
+ # Update the power accumulators.
+ beta_1_power, beta_2_power = self._get_beta_accumulators(state)
+ update_beta_1 = beta_1_power.assign(
+ beta_1_power * state.get_hyper("beta_1"), use_locking=self._use_locking)
+ update_beta_2 = beta_2_power.assign(
+ beta_2_power * state.get_hyper("beta_2"), use_locking=self._use_locking)
+ return control_flow_ops.group(update_beta_1, update_beta_2)
diff --git a/tensorflow/python/keras/optimizer_v2/adam_test.py b/tensorflow/python/keras/optimizer_v2/adam_test.py
new file mode 100644
index 0000000000..77796317a1
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/adam_test.py
@@ -0,0 +1,333 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Adam optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+def adam_update_numpy(param,
+ g_t,
+ t,
+ m,
+ v,
+ alpha=0.001,
+ beta1=0.9,
+ beta2=0.999,
+ epsilon=1e-8):
+ alpha_t = alpha * np.sqrt(1 - beta2**t) / (1 - beta1**t)
+
+ m_t = beta1 * m + (1 - beta1) * g_t
+ v_t = beta2 * v + (1 - beta2) * g_t * g_t
+
+ param_t = param - alpha_t * m_t / (np.sqrt(v_t) + epsilon)
+ return param_t, m_t, v_t
+
+
+class AdamOptimizerTest(test.TestCase):
+
+ def doTestSparse(self, use_resource=False):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0_np_indices = np.array([0, 1], dtype=np.int32)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(grads0_np),
+ constant_op.constant(grads0_np_indices), constant_op.constant([2]))
+ grads1_np_indices = np.array([0, 1], dtype=np.int32)
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(grads1_np),
+ constant_op.constant(grads1_np_indices), constant_op.constant([2]))
+ opt = adam.Adam()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSparse(self):
+ self.doTestSparse(use_resource=False)
+
+ def testResourceSparse(self):
+ self.doTestSparse(use_resource=True)
+
+ def testSparseDevicePlacement(self):
+ for index_dtype in [dtypes.int32, dtypes.int64]:
+ with self.test_session(force_gpu=test.is_gpu_available()):
+ # If a GPU is available, tests that all optimizer ops can be placed on
+ # it (i.e. they have GPU kernels).
+ var = variables.Variable([[1.0], [2.0]])
+ indices = constant_op.constant([0, 1], dtype=index_dtype)
+ gathered_sum = math_ops.reduce_sum(array_ops.gather(var, indices))
+ optimizer = adam.Adam(3.0)
+ minimize_op = optimizer.minimize(gathered_sum)
+ variables.global_variables_initializer().run()
+ minimize_op.run()
+
+ def testSparseRepeatedIndices(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ repeated_index_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ grad_repeated_index = ops.IndexedSlices(
+ constant_op.constant(
+ [0.1, 0.1], shape=[2, 1], dtype=dtype),
+ constant_op.constant([1, 1]),
+ constant_op.constant([2, 1]))
+ grad_aggregated = ops.IndexedSlices(
+ constant_op.constant(
+ [0.2], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([2, 1]))
+ repeated_update = adam.Adam().apply_gradients(
+ [(grad_repeated_index, repeated_index_update_var)])
+ aggregated_update = adam.Adam().apply_gradients(
+ [(grad_aggregated, aggregated_update_var)])
+ variables.global_variables_initializer().run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+ for _ in range(3):
+ repeated_update.run()
+ aggregated_update.run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+
+ def doTestBasic(self, use_resource=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ with self.session(graph=ops.Graph()):
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ var0_np, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ var1_np, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+
+ opt = adam.Adam()
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ opt_variables = opt.variables()
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+ self.assertTrue(beta1_power is not None)
+ self.assertTrue(beta2_power is not None)
+ self.assertIn(beta1_power, opt_variables)
+ self.assertIn(beta2_power, opt_variables)
+
+ with ops.Graph().as_default():
+ # Shouldn't return non-slot variables from other graphs.
+ self.assertEqual(0, len(opt.variables()))
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ if not context.executing_eagerly():
+ self.evaluate(update)
+ elif t > 1:
+ opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+
+ self.assertAllCloseAccordingToType(0.9**(t + 1),
+ self.evaluate(beta1_power))
+ self.assertAllCloseAccordingToType(0.999**(t + 1),
+ self.evaluate(beta2_power))
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, self.evaluate(var0))
+ self.assertAllCloseAccordingToType(var1_np, self.evaluate(var1))
+ if use_resource:
+ self.assertEqual("var0_%d/Adam:0" % (i,),
+ opt.get_slot(var=var0, name="m").name)
+
+ def testBasic(self):
+ with self.cached_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = adam.Adam(constant_op.constant(0.001))
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Run 3 steps of Adam
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ update.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ # Initialize variables for numpy implementation.
+ m0, v0, m1, v1 = 0.0, 0.0, 0.0, 0.0
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = adam.Adam()
+ update1 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ update2 = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ beta1_power, beta2_power = opt._get_beta_accumulators()
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 3 steps of intertwined Adam1 and Adam2.
+ for t in range(1, 4):
+ self.assertAllCloseAccordingToType(0.9**t, beta1_power.eval())
+ self.assertAllCloseAccordingToType(0.999**t, beta2_power.eval())
+ if t % 2 == 0:
+ update1.run()
+ else:
+ update2.run()
+
+ var0_np, m0, v0 = adam_update_numpy(var0_np, grads0_np, t, m0, v0)
+ var1_np, m1, v1 = adam_update_numpy(var1_np, grads1_np, t, m1, v1)
+
+ # Validate updated params
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ def testTwoSessions(self):
+ optimizer = adam.Adam()
+ g = ops.Graph()
+ with g.as_default():
+ with session.Session():
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+ optimizer.apply_gradients([(grads0, var0)])
+
+ gg = ops.Graph()
+ with gg.as_default():
+ with session.Session():
+ var0 = variables.Variable(np.array([1.0, 2.0]), name="v0")
+ grads0 = constant_op.constant(np.array([0.1, 0.1]))
+
+ # If the optimizer saves any state not keyed by graph the following line
+ # fails.
+ optimizer.apply_gradients([(grads0, var0)])
+
+ def testSlotsUniqueEager(self):
+ with context.eager_mode():
+ v1 = resource_variable_ops.ResourceVariable(1.)
+ v2 = resource_variable_ops.ResourceVariable(1.)
+ opt = adam.Adam(1.)
+ opt.minimize(lambda: v1 + v2)
+ # There should be two non-slot variables, and two unique slot variables
+ # for v1 and v2 respectively.
+ self.assertEqual(6, len(set(opt.variables())))
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py b/tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py
new file mode 100644
index 0000000000..338c04148b
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/checkpointable_utils_test.py
@@ -0,0 +1,761 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# TODO(josh11b): Forked from contrib/eager/python to test OptimizerV2 the same way
+# OptimizerV1 is tested. This file should be removed once the fork is resolved.
+
+import functools
+import os
+
+import six
+
+from tensorflow.python.client import session as session_lib
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.eager import test
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.engine import training
+from tensorflow.python.keras.layers import core
+from tensorflow.python.keras.optimizer_v2 import adam
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import template
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as core_saver
+from tensorflow.python.training import training_util
+from tensorflow.python.training.checkpointable import tracking
+from tensorflow.python.training.checkpointable import util
+
+
+class NonLayerCheckpointable(tracking.Checkpointable):
+
+ def __init__(self):
+ super(NonLayerCheckpointable, self).__init__()
+ self.a_variable = util.add_variable(
+ self, name="a_variable", shape=[])
+
+
+# pylint: disable=not-callable
+class MyModel(training.Model):
+ """A concrete Model for testing."""
+
+ def __init__(self):
+ super(MyModel, self).__init__()
+ self._named_dense = core.Dense(1, use_bias=True)
+ self._second = core.Dense(1, use_bias=False)
+ # We can still track Checkpointables which aren't Layers.
+ self._non_layer = NonLayerCheckpointable()
+
+ def call(self, values):
+ ret = self._second(self._named_dense(values))
+ return ret
+
+
+class _MirroringSaveable(
+ core_saver.BaseSaverBuilder.ResourceVariableSaveable):
+
+ def __init__(self, primary_variable, mirrored_variable, name):
+ self._primary_variable = primary_variable
+ self._mirrored_variable = mirrored_variable
+ super(_MirroringSaveable, self).__init__(
+ self._primary_variable, "", name)
+
+ def restore(self, restored_tensors, restored_shapes):
+ """Restore the same value into both variables."""
+ tensor, = restored_tensors
+ return control_flow_ops.group(
+ self._primary_variable.assign(tensor),
+ self._mirrored_variable.assign(tensor))
+
+
+class CheckpointingTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def testNamingWithOptimizer(self):
+ input_value = constant_op.constant([[3.]])
+ model = MyModel()
+ # A nuisance Model using the same optimizer. Its slot variables should not
+ # go in the checkpoint, since it is never depended on.
+ other_model = MyModel()
+ optimizer = adam.Adam(0.001)
+ optimizer_step = training_util.get_or_create_global_step()
+ root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+ if context.executing_eagerly():
+ optimizer.minimize(
+ lambda: model(input_value),
+ global_step=optimizer_step)
+ optimizer.minimize(
+ lambda: other_model(input_value),
+ global_step=optimizer_step)
+ else:
+ train_op = optimizer.minimize(
+ model(input_value), global_step=optimizer_step)
+ optimizer.minimize(
+ other_model(input_value),
+ global_step=optimizer_step)
+ self.evaluate(util.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ named_variables, serialized_graph, _ = (
+ util._serialize_object_graph(
+ root_checkpointable, saveables_cache=None))
+ expected_checkpoint_names = (
+ # Created in the root node, so no prefix.
+ "optimizer_step",
+ "model/_second/kernel",
+ "model/_named_dense/kernel",
+ "model/_named_dense/bias",
+ # non-Layer dependency of the model
+ "model/_non_layer/a_variable",
+ # The optimizer creates two non-slot variables
+ "optimizer/beta_1_power",
+ "optimizer/beta_2_power",
+ # Slot variables
+ "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/m",
+ "model/_second/kernel/.OPTIMIZER_SLOT/optimizer/v",
+ "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m",
+ "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/v",
+ "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/m",
+ "model/_named_dense/bias/.OPTIMIZER_SLOT/optimizer/v",
+ )
+ suffix = "/.ATTRIBUTES/VARIABLE_VALUE"
+ expected_checkpoint_names = [
+ name + suffix for name in expected_checkpoint_names]
+ # The Dense layers also save get_config() JSON
+ expected_checkpoint_names.extend(
+ ["model/_second/.ATTRIBUTES/OBJECT_CONFIG_JSON",
+ "model/_named_dense/.ATTRIBUTES/OBJECT_CONFIG_JSON"])
+ named_variables = {v.name: v for v in named_variables}
+ six.assertCountEqual(self, expected_checkpoint_names,
+ named_variables.keys())
+ # Check that we've mapped to the right variable objects (not exhaustive)
+ self.assertEqual(
+ "global_step",
+ named_variables["optimizer_step" + suffix].full_name)
+ self.assertEqual(
+ "my_model/dense_1/kernel",
+ named_variables["model/_second/kernel" + suffix].full_name)
+ self.assertEqual(
+ "my_model/dense/kernel",
+ named_variables["model/_named_dense/kernel" + suffix].full_name)
+ self.assertEqual(
+ "beta_1_power",
+ named_variables["optimizer/beta_1_power" + suffix].full_name)
+ self.assertEqual(
+ "beta_2_power",
+ named_variables["optimizer/beta_2_power" + suffix].full_name)
+ # Spot check the generated protocol buffers.
+ self.assertEqual("optimizer",
+ serialized_graph.nodes[0].children[1].local_name)
+ optimizer_node = serialized_graph.nodes[serialized_graph.nodes[0].children[
+ 1].node_id]
+ self.assertEqual("beta_1_power", optimizer_node.children[0].local_name)
+ self.assertEqual(
+ "beta_1_power", serialized_graph.nodes[
+ optimizer_node.children[0].node_id].attributes[0].full_name)
+ self.assertEqual(
+ "my_model/dense/kernel",
+ serialized_graph.nodes[optimizer_node.slot_variables[0]
+ .original_variable_node_id]
+ .attributes[0].full_name)
+ # We strip off the :0 suffix, as variable.name-based saving does.
+ self.assertEqual(
+ "my_model/dense/kernel/Adam",
+ serialized_graph.nodes[optimizer_node.slot_variables[0]
+ .slot_variable_node_id]
+ .attributes[0].full_name)
+ self.assertEqual(
+ "my_model/dense/kernel/Adam:0",
+ optimizer.get_slot(
+ var=model._named_dense.kernel,
+ name="m").name)
+ self.assertEqual(
+ "model/_named_dense/kernel" + suffix,
+ serialized_graph.nodes[
+ optimizer_node.slot_variables[0]
+ .original_variable_node_id].attributes[0].checkpoint_key)
+ self.assertEqual("m", optimizer_node.slot_variables[0].slot_name)
+ self.assertEqual(
+ "model/_named_dense/kernel/.OPTIMIZER_SLOT/optimizer/m" + suffix,
+ serialized_graph.nodes[
+ optimizer_node.slot_variables[0]
+ .slot_variable_node_id].attributes[0].checkpoint_key)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testSaveRestore(self):
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, model=model)
+ input_value = constant_op.constant([[3.]])
+ if context.executing_eagerly():
+ optimizer.minimize(
+ lambda: model(input_value))
+ else:
+ train_op = optimizer.minimize(model(input_value))
+ # TODO(allenl): Make initialization more pleasant when graph building.
+ root_checkpointable.save_counter # pylint: disable=pointless-statement
+ self.evaluate(util.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ prefix = os.path.join(self.get_temp_dir(), "ckpt")
+ self.evaluate(state_ops.assign(model._named_dense.variables[1], [42.]))
+ m_bias_slot = optimizer.get_slot(model._named_dense.variables[1], "m")
+ self.evaluate(state_ops.assign(m_bias_slot, [1.5]))
+ save_path = root_checkpointable.save(file_prefix=prefix)
+ self.evaluate(state_ops.assign(model._named_dense.variables[1], [43.]))
+ self.evaluate(state_ops.assign(root_checkpointable.save_counter, 3))
+ optimizer_variables = self.evaluate(optimizer.variables())
+ self.evaluate(state_ops.assign(m_bias_slot, [-2.]))
+ # Immediate restoration
+ status = root_checkpointable.restore(save_path=save_path).assert_consumed()
+ status.run_restore_ops()
+ self.assertAllEqual([42.], self.evaluate(model._named_dense.variables[1]))
+ self.assertAllEqual(1, self.evaluate(root_checkpointable.save_counter))
+ self.assertAllEqual([1.5], self.evaluate(m_bias_slot))
+ if not context.executing_eagerly():
+ return # Restore-on-create is only supported when executing eagerly
+ on_create_model = MyModel()
+ on_create_optimizer = adam.Adam(
+ 0.001,
+ # Preserve beta_1_power and beta_2_power when appying gradients
+ # so we can test that they've been restored correctly.
+ beta_1=1.0,
+ beta_2=1.0)
+ on_create_root = util.Checkpoint(
+ optimizer=on_create_optimizer, model=on_create_model)
+ # Deferred restoration
+ status = on_create_root.restore(save_path=save_path)
+ on_create_model(constant_op.constant([[3.]])) # create variables
+ self.assertAllEqual(1, self.evaluate(on_create_root.save_counter))
+ self.assertAllEqual([42.],
+ self.evaluate(
+ on_create_model._named_dense.variables[1]))
+ on_create_m_bias_slot = on_create_optimizer.get_slot(
+ on_create_model._named_dense.variables[1], "m")
+ # Optimizer slot variables are created when the original variable is
+ # restored.
+ self.assertAllEqual([1.5], self.evaluate(on_create_m_bias_slot))
+ self.assertAllEqual(optimizer_variables[2:],
+ self.evaluate(on_create_optimizer.variables()))
+ dummy_var = resource_variable_ops.ResourceVariable([1.])
+ on_create_optimizer.minimize(loss=dummy_var.read_value)
+ status.assert_consumed()
+ beta_1_power, beta_2_power = on_create_optimizer._get_beta_accumulators()
+ self.assertAllEqual(optimizer_variables[0], self.evaluate(beta_1_power))
+ self.assertAllEqual(optimizer_variables[1], self.evaluate(beta_2_power))
+
+ # TODO(allenl): Debug garbage created by this test in python3.
+ def testDeferredRestorationUsageEager(self):
+ """An idiomatic eager execution example."""
+ num_training_steps = 10
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ optimizer_step=training_util.get_or_create_global_step())
+ root.restore(checkpoint_management.latest_checkpoint(
+ checkpoint_directory))
+ for _ in range(num_training_steps):
+ # TODO(allenl): Use a Dataset and serialize/checkpoint it.
+ input_value = constant_op.constant([[3.]])
+ optimizer.minimize(
+ lambda: model(input_value), # pylint: disable=cell-var-from-loop
+ global_step=root.optimizer_step)
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ root.optimizer_step.numpy())
+
+ def testUsageGraph(self):
+ """Expected usage when graph building."""
+ with context.graph_mode():
+ num_training_steps = 10
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default():
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ input_value = constant_op.constant([[3.]])
+ train_op = optimizer.minimize(
+ model(input_value),
+ global_step=root.global_step)
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
+ with self.session(graph=ops.get_default_graph()) as session:
+ status = root.restore(save_path=checkpoint_path)
+ status.initialize_or_restore(session=session)
+ if checkpoint_path is None:
+ self.assertEqual(0, training_continuation)
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
+ else:
+ status.assert_consumed()
+ for _ in range(num_training_steps):
+ session.run(train_op)
+ root.save(file_prefix=checkpoint_prefix, session=session)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ session.run(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ session.run(root.save_counter))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testAgnosticUsage(self):
+ """Graph/eager agnostic usage."""
+ # Does create garbage when executing eagerly due to ops.Graph() creation.
+ num_training_steps = 10
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default(), self.test_session(
+ graph=ops.get_default_graph()), test_util.device(use_gpu=True):
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
+ status = root.restore(save_path=checkpoint_path)
+ input_value = constant_op.constant([[3.]])
+ train_fn = functools.partial(
+ optimizer.minimize,
+ functools.partial(model, input_value),
+ global_step=root.global_step)
+ if not context.executing_eagerly():
+ train_fn = functools.partial(self.evaluate, train_fn())
+ status.initialize_or_restore()
+ for _ in range(num_training_steps):
+ train_fn()
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ self.evaluate(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ self.evaluate(root.save_counter))
+
+ # pylint: disable=cell-var-from-loop
+ @test_util.run_in_graph_and_eager_modes
+ def testWithDefun(self):
+ num_training_steps = 2
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ for training_continuation in range(3):
+ with ops.Graph().as_default(), self.test_session(
+ graph=ops.get_default_graph()), test_util.device(use_gpu=True):
+ model = MyModel()
+ # Don't actually train so we can test variable values
+ optimizer = adam.Adam(0.)
+ root = util.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ checkpoint_path = checkpoint_management.latest_checkpoint(
+ checkpoint_directory)
+ status = root.restore(save_path=checkpoint_path)
+ def train_fn():
+ @function.defun
+ def _call_model(x):
+ return model(x)
+ with backprop.GradientTape() as tape:
+ loss = _call_model(constant_op.constant([[3.]]))
+ gradients = tape.gradient(loss, model.variables)
+ return optimizer.apply_gradients(zip(gradients, model.variables),
+ global_step=root.global_step)
+ if not context.executing_eagerly():
+ train_fn = functools.partial(
+ self.evaluate, train_fn())
+ status.initialize_or_restore()
+ for _ in range(num_training_steps):
+ train_fn()
+ if training_continuation > 0:
+ status.assert_consumed()
+ self.assertAllClose([[42.]], self.evaluate(model.variables[0]))
+ else:
+ self.evaluate(model.variables[0].assign([[42.]]))
+ root.save(file_prefix=checkpoint_prefix)
+ self.assertEqual((training_continuation + 1) * num_training_steps,
+ self.evaluate(root.global_step))
+ self.assertEqual(training_continuation + 1,
+ self.evaluate(root.save_counter))
+ # pylint: enable=cell-var-from-loop
+
+ def testAnonymousVarsInInit(self):
+
+ class Model(training.Model):
+
+ def __init__(self):
+ super(Model, self).__init__()
+ self.w = resource_variable_ops.ResourceVariable(0.0)
+ self.b = resource_variable_ops.ResourceVariable(0.0)
+ self.vars = [self.w, self.b]
+
+ def call(self, x):
+ return x * self.w + self.b
+
+ with context.eager_mode():
+ model = Model()
+ optimizer = adam.Adam(learning_rate=0.05)
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ checkpoint = util.Checkpoint(
+ model=model, optimizer=optimizer)
+ for _ in range(2):
+ checkpoint.save(checkpoint_prefix)
+ with backprop.GradientTape() as tape:
+ loss = (constant_op.constant(1.)
+ - model(constant_op.constant(1.))) ** 2
+ grad = tape.gradient(loss, model.vars)
+ optimizer.apply_gradients(
+ [(g, v) for g, v in zip(grad, model.vars)])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testDeferredSlotRestoration(self):
+ checkpoint_directory = self.get_temp_dir()
+
+ root = tracking.Checkpointable()
+ root.var = util.add_variable(
+ root, name="var", initializer=0.)
+ optimizer = adam.Adam(0.1)
+ if context.executing_eagerly():
+ optimizer.minimize(root.var.read_value)
+ else:
+ train_op = optimizer.minimize(root.var)
+ # Note that `optimizer` has not been added as a dependency of
+ # `root`. Create a one-off grouping so that slot variables for `root.var`
+ # get initialized too.
+ self.evaluate(util.gather_initializers(
+ util.Checkpoint(root=root, optimizer=optimizer)))
+ self.evaluate(train_op)
+ self.evaluate(state_ops.assign(root.var, 12.))
+ no_slots_path = util.CheckpointableSaver(root).save(
+ os.path.join(checkpoint_directory, "no_slots"))
+ root.optimizer = optimizer
+ self.evaluate(state_ops.assign(root.var, 13.))
+ self.evaluate(state_ops.assign(optimizer.get_slot(name="m", var=root.var),
+ 14.))
+ slots_path = util.CheckpointableSaver(root).save(
+ os.path.join(checkpoint_directory, "with_slots"))
+ new_root = tracking.Checkpointable()
+ # Load the slot-containing checkpoint (deferred), then immediately overwrite
+ # the non-slot variable (also deferred).
+ slot_status = util.CheckpointableSaver(
+ new_root).restore(slots_path)
+ no_slot_status = util.CheckpointableSaver(
+ new_root).restore(no_slots_path)
+ with self.assertRaises(AssertionError):
+ no_slot_status.assert_consumed()
+ new_root.var = util.add_variable(
+ new_root, name="var", shape=[])
+ no_slot_status.assert_consumed()
+ no_slot_status.run_restore_ops()
+ self.assertEqual(12., self.evaluate(new_root.var))
+ new_root.optimizer = adam.Adam(0.1)
+ with self.assertRaisesRegexp(AssertionError, "beta_1_power"):
+ slot_status.assert_consumed()
+ self.assertEqual(12., self.evaluate(new_root.var))
+ if context.executing_eagerly():
+ # Slot variables are only created with restoring initializers when
+ # executing eagerly.
+ self.assertEqual(14., self.evaluate(
+ new_root.optimizer.get_slot(name="m", var=new_root.var)))
+ else:
+ self.assertIs(new_root.optimizer.get_slot(name="m", var=new_root.var),
+ None)
+ if context.executing_eagerly():
+ new_root.optimizer.minimize(new_root.var.read_value)
+ else:
+ train_op = new_root.optimizer.minimize(new_root.var)
+ # The slot variable now exists; restore() didn't create it, but we should
+ # now have a restore op for it.
+ slot_status.run_restore_ops()
+ self.assertEqual(14., self.evaluate(
+ new_root.optimizer.get_slot(name="m", var=new_root.var)))
+ self.evaluate(train_op)
+ slot_status.assert_consumed()
+
+ def testManySavesGraph(self):
+ """Saves after the first should not modify the graph."""
+ with context.graph_mode():
+ graph = ops.Graph()
+ with graph.as_default(), self.session(graph):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ obj = tracking.Checkpointable()
+ obj.var = variable_scope.get_variable(name="v", initializer=0.)
+ obj.opt = adam.Adam(0.1)
+ obj.opt.minimize(obj.var.read_value())
+ self.evaluate(util.gather_initializers(obj))
+ saver = util.CheckpointableSaver(obj)
+ saver.save(checkpoint_prefix)
+ before_ops = graph.get_operations()
+ saver.save(checkpoint_prefix)
+ self.assertEqual(before_ops, graph.get_operations())
+
+ def testManyRestoresGraph(self):
+ """Restores after the first should not modify the graph."""
+ with context.graph_mode():
+ graph = ops.Graph()
+ with graph.as_default(), self.session(graph):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ obj = tracking.Checkpointable()
+ obj.var = variable_scope.get_variable(name="v", initializer=0.)
+ obj.opt = adam.Adam(0.1)
+ obj.opt.minimize(obj.var.read_value())
+ self.evaluate(util.gather_initializers(obj))
+ saver = util.CheckpointableSaver(obj)
+ save_path = saver.save(checkpoint_prefix)
+ saver.restore(save_path)
+ before_ops = graph.get_operations()
+ saver.restore(save_path)
+ self.assertEqual(before_ops, graph.get_operations())
+
+ def testMultipleGraphsNonSlotVariables(self):
+ with context.graph_mode():
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ optimizer = adam.Adam(0.001)
+ # Construct a model in one graph
+ first_graph = ops.Graph()
+ first_session = session_lib.Session(graph=first_graph)
+ with first_graph.as_default(), first_session.as_default():
+ first_variable = resource_variable_ops.ResourceVariable([1.])
+ first_root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, variable=first_variable)
+ train_op = optimizer.minimize(first_variable.read_value)
+ self.evaluate(util.gather_initializers(
+ first_root_checkpointable))
+ self.evaluate(train_op)
+ self.evaluate(first_variable.assign([1.]))
+ self.evaluate(optimizer.get_slot(
+ var=first_variable, name="m").assign([2.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
+
+ # Save and load in a second graph
+ second_graph = ops.Graph()
+ with second_graph.as_default(), session_lib.Session(graph=second_graph):
+ second_variable = resource_variable_ops.ResourceVariable([1.])
+ second_root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, variable=second_variable)
+ train_op = optimizer.minimize(second_variable.read_value)
+ second_root_checkpointable.restore(None).initialize_or_restore()
+ self.evaluate(train_op)
+ self.evaluate(second_variable.assign([4.]))
+ self.evaluate(optimizer.get_slot(
+ var=second_variable, name="m").assign([5.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(6.))
+ save_path = second_root_checkpointable.save(checkpoint_prefix)
+ self.evaluate(second_variable.assign([7.]))
+ self.evaluate(optimizer.get_slot(
+ var=second_variable, name="m").assign([8.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
+ status = second_root_checkpointable.restore(save_path)
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual([4.], self.evaluate(second_variable))
+ self.assertAllEqual([5.], self.evaluate(optimizer.get_slot(
+ var=second_variable, name="m")))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(6., self.evaluate(beta_1_power))
+
+ # Check that the first graph is unmolested
+ with first_graph.as_default(), first_session.as_default():
+ self.assertAllEqual([1.], self.evaluate(first_variable))
+ self.assertAllEqual([2.], self.evaluate(optimizer.get_slot(
+ var=first_variable, name="m")))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
+
+
+class TemplateTests(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def test_checkpointable_save_restore(self):
+
+ def _templated():
+ v = variable_scope.get_variable(
+ "v", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ v2 = variable_scope.get_variable(
+ "v2", shape=[1], initializer=init_ops.zeros_initializer(),
+ use_resource=True)
+ return v, v + 1., v2
+
+ save_template = template.make_template("s1", _templated)
+ v1_save, _, v2_save = save_template()
+ optimizer = adam.Adam(0.0)
+ save_root = util.Checkpoint(
+ my_template=save_template, optimizer=optimizer)
+ optimizer.minimize(v1_save.read_value)
+ self.evaluate([v.initializer for v in optimizer.variables()])
+ self.evaluate(v1_save.assign([12.]))
+ self.evaluate(v2_save.assign([14.]))
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ save_path = save_root.save(checkpoint_prefix)
+
+ load_template = template.make_template("s2", _templated)
+ load_optimizer = adam.Adam(0.0)
+ load_root = util.Checkpoint(
+ my_template=load_template, optimizer=load_optimizer)
+ status = load_root.restore(save_path)
+ var, var_plus_one, var2 = load_template()
+ load_optimizer.minimize(var.read_value)
+ self.assertEqual(2, len(load_template._checkpoint_dependencies))
+ self.assertEqual("v", load_template._checkpoint_dependencies[0].name)
+ self.assertEqual("v2", load_template._checkpoint_dependencies[1].name)
+ status.assert_consumed().run_restore_ops()
+ self.assertAllEqual([12.], self.evaluate(var))
+ self.assertAllEqual([13.], self.evaluate(var_plus_one))
+ self.assertAllEqual([14.], self.evaluate(var2))
+
+
+class CheckpointCompatibilityTests(test.TestCase):
+
+ def _initialized_model(self):
+ input_value = constant_op.constant([[3.]])
+ model = MyModel()
+ optimizer = adam.Adam(0.001)
+ optimizer_step = training_util.get_or_create_global_step()
+ root_checkpointable = util.Checkpoint(
+ optimizer=optimizer, model=model, optimizer_step=optimizer_step)
+ train_op = optimizer.minimize(
+ functools.partial(model, input_value),
+ global_step=optimizer_step)
+ self.evaluate(util.gather_initializers(
+ root_checkpointable))
+ self.evaluate(train_op)
+ # A regular variable, a slot variable, and a non-slot Optimizer variable
+ # with known values to check when loading.
+ self.evaluate(model._named_dense.bias.assign([1.]))
+ self.evaluate(optimizer.get_slot(
+ var=model._named_dense.bias, name="m").assign([2.]))
+ beta_1_power, _ = optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(3.))
+ return root_checkpointable
+
+ def _set_sentinels(self, root_checkpointable):
+ self.evaluate(root_checkpointable.model._named_dense.bias.assign([101.]))
+ self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.model._named_dense.bias, name="m")
+ .assign([102.]))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.evaluate(beta_1_power.assign(103.))
+
+ def _check_sentinels(self, root_checkpointable):
+ self.assertAllEqual(
+ [1.], self.evaluate(root_checkpointable.model._named_dense.bias))
+ self.assertAllEqual([2.], self.evaluate(
+ root_checkpointable.optimizer.get_slot(
+ var=root_checkpointable.model._named_dense.bias, name="m")))
+ beta_1_power, _ = root_checkpointable.optimizer._get_beta_accumulators()
+ self.assertAllEqual(3., self.evaluate(beta_1_power))
+
+ def _write_name_based_checkpoint(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ root = self._initialized_model()
+ name_saver = core_saver.Saver()
+ return name_saver.save(
+ sess=session, save_path=checkpoint_prefix,
+ global_step=root.optimizer_step)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testLoadFromNameBasedSaver(self):
+ """Save a name-based checkpoint, load it using the object-based API."""
+ with test_util.device(use_gpu=True):
+ save_path = self._write_name_based_checkpoint()
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ with self.assertRaises(AssertionError):
+ self._check_sentinels(root)
+ object_saver = util.CheckpointableSaver(root)
+ self._set_sentinels(root)
+ status = object_saver.restore(save_path)
+ if context.executing_eagerly():
+ self._check_sentinels(root)
+ if context.executing_eagerly():
+ with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
+ status.assert_consumed()
+ else:
+ # When graph building, we haven't read any keys, so we don't know
+ # whether the restore will be complete.
+ with self.assertRaisesRegexp(AssertionError, "not restored"):
+ status.assert_consumed()
+ status.run_restore_ops()
+ self._check_sentinels(root)
+ self._set_sentinels(root)
+ status = object_saver.restore(save_path)
+ status.initialize_or_restore()
+ self._check_sentinels(root)
+
+ # TODO(allenl): Test for the core name-based saver loading object-based
+ # checkpoints once object-based checkpointing is in core.
+
+ def testSaveGraphLoadEager(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph) as session:
+ root = self._initialized_model()
+ save_path = root.save(
+ session=session, file_prefix=checkpoint_prefix)
+ with context.eager_mode():
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ root.restore(save_path).assert_consumed()
+ self._check_sentinels(root)
+
+ def testSaveEagerLoadGraph(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ with context.eager_mode():
+ root = self._initialized_model()
+ save_path = root.save(file_prefix=checkpoint_prefix)
+ with context.graph_mode():
+ save_graph = ops.Graph()
+ with save_graph.as_default(), self.test_session(
+ graph=save_graph):
+ root = self._initialized_model()
+ self._set_sentinels(root)
+ root.restore(save_path).assert_consumed().run_restore_ops()
+ self._check_sentinels(root)
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
new file mode 100644
index 0000000000..bd5557f4fd
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2.py
@@ -0,0 +1,1349 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Version 2 of class Optimizer."""
+# pylint: disable=g-bad-name
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gradients
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.training import distribution_strategy_context
+from tensorflow.python.training import optimizer as optimizer_v1
+from tensorflow.python.training import slot_creator
+from tensorflow.python.training.checkpointable import base as checkpointable
+from tensorflow.python.util import nest
+
+
+class _OptimizableVariable(object):
+ """Interface for abstracting over variables in the optimizers."""
+
+ @abc.abstractmethod
+ def target(self):
+ """Returns the optimization target for this variable."""
+ raise NotImplementedError("Calling an abstract method.")
+
+ @abc.abstractmethod
+ def update_op(self, optimizer, g, *args):
+ """Returns the update ops for updating the variable."""
+ raise NotImplementedError("Calling an abstract method.")
+
+
+class _RefVariableProcessor(_OptimizableVariable):
+ """Processor for Variable."""
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v._ref() # pylint: disable=protected-access
+
+ def update_op(self, optimizer, g, *args):
+ if isinstance(g, ops.Tensor):
+ update_op = optimizer._apply_dense(g, self._v, *args) # pylint: disable=protected-access
+ if self._v.constraint is not None:
+ with ops.control_dependencies([update_op]):
+ return self._v.assign(self._v.constraint(self._v))
+ else:
+ return update_op
+ else:
+ assert isinstance(g, ops.IndexedSlices), ("Gradient ", g, " is neither a "
+ "tensor nor IndexedSlices.")
+ if self._v.constraint is not None:
+ raise RuntimeError(
+ "Cannot use a constraint function on a sparse variable.")
+ # pylint: disable=protected-access
+ return optimizer._apply_sparse_duplicate_indices(g, self._v, *args)
+
+
+class _DenseReadResourceVariableProcessor(_OptimizableVariable):
+ """Processor for dense ResourceVariables."""
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v
+
+ def update_op(self, optimizer, g, *args):
+ # pylint: disable=protected-access
+ update_op = optimizer._resource_apply_dense(g, self._v.op.inputs[0], *args)
+ if self._v.constraint is not None:
+ with ops.control_dependencies([update_op]):
+ return self._v.assign(self._v.constraint(self._v))
+ else:
+ return update_op
+
+
+class _DenseResourceVariableProcessor(_OptimizableVariable):
+ """Processor for dense ResourceVariables."""
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v
+
+ def update_op(self, optimizer, g, *args):
+ # pylint: disable=protected-access
+ if isinstance(g, ops.IndexedSlices):
+ if self._v.constraint is not None:
+ raise RuntimeError(
+ "Cannot use a constraint function on a sparse variable.")
+ return optimizer._resource_apply_sparse_duplicate_indices(
+ g.values, self._v, g.indices, *args)
+ update_op = optimizer._resource_apply_dense(g, self._v, *args)
+ if self._v.constraint is not None:
+ with ops.control_dependencies([update_op]):
+ return self._v.assign(self._v.constraint(self._v))
+ else:
+ return update_op
+
+
+class _TensorProcessor(_OptimizableVariable):
+ """Processor for ordinary Tensors.
+
+ Even though a Tensor can't really be updated, sometimes it is useful to
+ compute the gradients with respect to a Tensor using the optimizer. Updating
+ the Tensor is, of course, unsupported.
+ """
+
+ def __init__(self, v):
+ self._v = v
+
+ def target(self):
+ return self._v
+
+ def update_op(self, optimizer, g, *args):
+ raise NotImplementedError("Trying to update a Tensor ", self._v)
+
+
+def _get_processor(v):
+ """The processor of v."""
+ if context.executing_eagerly():
+ if isinstance(v, ops.Tensor):
+ return _TensorProcessor(v)
+ else:
+ return _DenseResourceVariableProcessor(v)
+ if v.op.type == "VarHandleOp":
+ return _DenseResourceVariableProcessor(v)
+ if isinstance(v, variables.Variable):
+ return _RefVariableProcessor(v)
+ if isinstance(v, ops.Tensor):
+ return _TensorProcessor(v)
+ raise NotImplementedError("Trying to optimize unsupported type ", v)
+
+
+def _var_key_v2(var):
+ """Key for representing a primary variable, for looking up slots."""
+ # pylint: disable=protected-access
+ if hasattr(var, "_distributed_container"):
+ distributed_container = var._distributed_container()
+ assert distributed_container is not None
+ if context.executing_eagerly():
+ return distributed_container._unique_id
+ return distributed_container._shared_name
+ if context.executing_eagerly():
+ return var._unique_id
+ return var.op.name
+
+
+def _resolve(value, name):
+ if callable(value):
+ value = value()
+ return ops.convert_to_tensor(value, name=name)
+
+
+def _is_dynamic(value):
+ """Returns true if __init__ arg `value` should be re-evaluated each step."""
+ if callable(value): return True
+ # Don't need to do anything special in graph mode, since dynamic values
+ # will propagate correctly automatically.
+ # TODO(josh11b): Add per-device caching across steps using variables for
+ # truly static values once we add distributed support.
+ if context.executing_eagerly() and isinstance(
+ value, resource_variable_ops.ResourceVariable):
+ return True
+ return False
+
+
+class _OptimizerV2State(object):
+ """Holds per-graph and per-step optimizer state.
+
+ Use _init_with_static_hyper() to create the state for a graph, and then
+ _copy_with_dynamic_hyper() to convert that to state for a particular step.
+ The difference between the two is that the former only has hyper
+ parameter values that are static and the latter also has values that
+ can change every step (according to _is_dynamic()).
+ """
+
+ def __init__(self, op_name):
+ self._op_name = op_name
+
+ def _init_with_static_hyper(self, hyper):
+ """Initialize a fresh state object from hyper dict."""
+ # self._hyper contains a dict from name to a dict with the Tensor values.
+ # This dict starts with a single item with key "None" with the hyper
+ # parameter value converted to a Tensor. Other items have dtype keys
+ # with that Tensor cast to that dtype.
+ with ops.init_scope():
+ self._hyper = {name: {None: ops.convert_to_tensor(value, name=name)}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if not dynamic}
+ self._slots = {}
+ self._non_slot_dict = {}
+ # Extra state to help Optimizers implement Checkpointable. Holds information
+ # about variables which will be restored as soon as they're created.
+ self._deferred_dependencies = {} # Non-slot variables
+ self._deferred_slot_restorations = {} # Slot variables
+
+ def _copy_with_dynamic_hyper(self, hyper, distribution, non_slot_devices):
+ """Create a new state object for a particular step."""
+ ret = _OptimizerV2State(self._op_name)
+ # pylint: disable=protected-access
+ ret._slots = self._slots
+ ret._non_slot_dict = self._non_slot_dict
+ ret._deferred_dependencies = self._deferred_dependencies
+ ret._deferred_slot_restorations = self._deferred_slot_restorations
+ ret._hyper = {name: {None: _resolve(value, name)}
+ for name, (dynamic, value) in sorted(hyper.items())
+ if dynamic}
+ ret._hyper.update(self._hyper)
+ ret._non_slot_devices = non_slot_devices
+ ret._distribution = distribution
+ return ret
+
+ def _variables(self):
+ """Returns a list of all variables held by self."""
+ optimizer_variables = list(self._non_slot_dict.values())
+ for variable_dict in self._slots.values():
+ for slot_for_variable in variable_dict.values():
+ optimizer_variables.append(slot_for_variable)
+ # Sort variables by name so that the return is deterministic.
+ return sorted(optimizer_variables, key=lambda v: v.name)
+
+ def _slot_dict(self, slot_name):
+ """Returns a dict for caching slots created under the given name.
+
+ Args:
+ slot_name: Name for the slot.
+
+ Returns:
+ A dict that maps primary `Variable` objects to the slot created
+ for that variable, under the given slot name.
+ """
+ named_slots = self._slots.get(slot_name, None)
+ if named_slots is None:
+ named_slots = {}
+ self._slots[slot_name] = named_slots
+ return named_slots
+
+ def create_slot(self, var, val, slot_name, optional_op_name=None):
+ """Find or create a slot for a variable.
+
+ Args:
+ var: A `Variable` object.
+ val: A `Tensor`. The initial value of the slot.
+ slot_name: Name for the slot.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A `Variable` object.
+ """
+ named_slots = self._slot_dict(slot_name)
+ var_key = _var_key_v2(var)
+ if var_key not in named_slots:
+ new_slot_variable = slot_creator.create_slot(
+ var, val, optional_op_name or self._op_name)
+ self._restore_slot_variable(
+ slot_name=slot_name, variable=var,
+ slot_variable=new_slot_variable)
+ named_slots[var_key] = new_slot_variable
+ return named_slots[var_key]
+
+ def create_slot_with_initializer(self, var, initializer, shape, dtype,
+ slot_name, optional_op_name=None):
+ """Find or create a slot for a variable, using an Initializer.
+
+ Args:
+ var: A `Variable` object.
+ initializer: An `Initializer`. The initial value of the slot.
+ shape: Shape of the initial value of the slot.
+ dtype: Type of the value of the slot.
+ slot_name: Name for the slot.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A `Variable` object.
+ """
+ named_slots = self._slot_dict(slot_name)
+ var_key = _var_key_v2(var)
+ if var_key not in named_slots:
+ new_slot_variable = slot_creator.create_slot_with_initializer(
+ var, initializer, shape, dtype, optional_op_name or self._op_name)
+ self._restore_slot_variable(
+ slot_name=slot_name, variable=var,
+ slot_variable=new_slot_variable)
+ named_slots[var_key] = new_slot_variable
+ return named_slots[var_key]
+
+ def zeros_slot(self, var, slot_name, optional_op_name=None):
+ """Find or create a slot initialized with 0.0.
+
+ Args:
+ var: A `Variable` object.
+ slot_name: Name for the slot.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+
+ Returns:
+ A `Variable` object.
+ """
+ named_slots = self._slot_dict(slot_name)
+ var_key = _var_key_v2(var)
+ if var_key not in named_slots:
+ new_slot_variable = slot_creator.create_zeros_slot(
+ var, optional_op_name or self._op_name)
+ self._restore_slot_variable(
+ slot_name=slot_name, variable=var,
+ slot_variable=new_slot_variable)
+ named_slots[var_key] = new_slot_variable
+ return named_slots[var_key]
+
+ def _create_or_restore_slot_variable(
+ self, slot_variable_position, slot_name, variable,
+ optional_op_name=None):
+ """Restore a slot variable's value, possibly creating it.
+
+ Called when a variable which has an associated slot variable is created or
+ restored. When executing eagerly, we create the slot variable with a
+ restoring initializer.
+
+ No new variables are created when graph building. Instead,
+ _restore_slot_variable catches these after normal creation and adds restore
+ ops to the graph. This method is nonetheless important when graph building
+ for the case when a slot variable has already been created but `variable`
+ has just been added to a dependency graph (causing us to realize that the
+ slot variable needs to be restored).
+
+ Args:
+ slot_variable_position: A `checkpointable._CheckpointPosition` object
+ indicating the slot variable `Checkpointable` object to be restored.
+ slot_name: The name of this `Optimizer`'s slot to restore into.
+ variable: The variable object this slot is being created for.
+ optional_op_name: Name to use when scoping the Variable that
+ needs to be created for the slot.
+ """
+ slot_variable = self.get_slot(var=variable, name=slot_name)
+ if (slot_variable is None and context.executing_eagerly() and
+ slot_variable_position.is_simple_variable()
+ # Defer slot variable creation if there is an active variable creator
+ # scope. Generally we'd like to eagerly create/restore slot variables
+ # when possible, but this may mean that scopes intended to catch
+ # `variable` also catch its eagerly created slot variable
+ # unintentionally (specifically make_template would add a dependency on
+ # a slot variable if not for this case). Deferring is mostly harmless
+ # (aside from double initialization), and makes variable creator scopes
+ # behave the same way they do when graph building.
+ and not ops.get_default_graph()._variable_creator_stack): # pylint: disable=protected-access
+ initializer = checkpointable.CheckpointInitialValue(
+ checkpoint_position=slot_variable_position)
+ slot_variable = self.create_slot(
+ var=variable,
+ val=initializer,
+ slot_name=slot_name,
+ optional_op_name=optional_op_name)
+ # Optimizers do not have unconditional dependencies on their slot
+ # variables (nor do any other objects). They are only saved if the
+ # variables they were created for are also saved.
+ if slot_variable is not None:
+ # If we've either made this slot variable, or if we've pulled out an
+ # existing slot variable, we should restore it.
+ slot_variable_position.restore(slot_variable)
+ else:
+ # We didn't make the slot variable. Defer restoring until it gets created
+ # normally. We keep a list rather than the one with the highest restore
+ # UID in case slot variables have their own dependencies, in which case
+ # those could differ between restores.
+ variable_key = _var_key_v2(variable)
+ self._deferred_slot_restorations.setdefault(
+ slot_name, {}).setdefault(variable_key, []).append(
+ slot_variable_position)
+
+ def get_slot(self, var, name):
+ """Return a slot named `name` created for `var` by the Optimizer.
+
+ Some `Optimizer` subclasses use additional variables. For example
+ `Momentum` and `Adagrad` use variables to accumulate updates. This method
+ gives access to these `Variable` objects if for some reason you need them.
+
+ Use `get_slot_names()` to get the list of slot names created by the
+ `Optimizer`.
+
+ Args:
+ var: A variable passed to `minimize()` or `apply_gradients()`.
+ name: A string.
+
+ Returns:
+ The `Variable` for the slot if it was created, `None` otherwise.
+ """
+ named_slots = self._slots.get(name, None)
+ if not named_slots:
+ return None
+ return named_slots.get(_var_key_v2(var), None)
+
+ def get_slot_names(self):
+ """Return a list of the names of slots created by the `Optimizer`.
+
+ See `get_slot()`.
+
+ Returns:
+ A list of strings.
+ """
+ return sorted(self._slots.keys())
+
+ def create_non_slot(self, initial_value, name, colocate_with=None):
+ """Add an extra variable, not associated with a slot."""
+ v = self._non_slot_dict.get(name, None)
+ if v is None:
+ if colocate_with is None: colocate_with = self._non_slot_devices
+ with self._distribution.colocate_vars_with(colocate_with):
+ # TODO(josh11b): Use get_variable() except for the legacy Adam use case.
+ v = variable_scope.variable(initial_value, name=name, trainable=False)
+ self._non_slot_dict[name] = v
+ deferred_dependencies_list = self._deferred_dependencies.pop(name, ())
+ for checkpoint_position in sorted(
+ deferred_dependencies_list,
+ key=lambda restore: restore.checkpoint.restore_uid,
+ reverse=True):
+ checkpoint_position.restore(v)
+ return v
+
+ def _restore_slot_variable(self, slot_name, variable, slot_variable):
+ """Restore a newly created slot variable's value."""
+ variable_key = _var_key_v2(variable)
+ deferred_restorations = self._deferred_slot_restorations.get(
+ slot_name, {}).pop(variable_key, [])
+ # Iterate over restores, highest restore UID first to minimize the number
+ # of assignments.
+ deferred_restorations.sort(key=lambda position: position.restore_uid,
+ reverse=True)
+ for checkpoint_position in deferred_restorations:
+ checkpoint_position.restore(slot_variable)
+
+ def get_non_slot(self, name):
+ """Returns the non-slot variable identified by `name`."""
+ return self._non_slot_dict.get(name, None)
+
+ def get_hyper(self, name, dtype=None):
+ """Returns the `name` hyper parameter, optionally cast to `dtype`."""
+ dtype_dict = self._hyper[name]
+ # Do we have the value cast to dtype already cached? This should always
+ # succeed when dtype is None.
+ if dtype in dtype_dict:
+ return dtype_dict[dtype]
+ # Not cached, cast to dtype and save the result in the cache.
+ result = math_ops.cast(dtype_dict[None], dtype)
+ dtype_dict[dtype] = result
+ return result
+
+
+class OptimizerV2(optimizer_v1.Optimizer):
+ """Updated base class for optimizers.
+
+ This class defines the API to add Ops to train a model. You never use this
+ class directly, but instead instantiate one of its subclasses such as
+ `GradientDescentOptimizer`, `AdagradOptimizer`, or `MomentumOptimizer`.
+
+ ### Usage
+
+ ```python
+ # Create an optimizer with the desired parameters.
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+ # Add Ops to the graph to minimize a cost by updating a list of variables.
+ # "cost" is a Tensor, and the list of variables contains tf.Variable
+ # objects.
+ opt_op = opt.minimize(cost, var_list=<list of variables>)
+ ```
+
+ In the training program you will just have to run the returned Op.
+
+ ```python
+ # Execute opt_op to do one step of training:
+ opt_op.run()
+ ```
+
+ ### Processing gradients before applying them.
+
+ Calling `minimize()` takes care of both computing the gradients and
+ applying them to the variables. If you want to process the gradients
+ before applying them you can instead use the optimizer in three steps:
+
+ 1. Compute the gradients with `compute_gradients()`.
+ 2. Process the gradients as you wish.
+ 3. Apply the processed gradients with `apply_gradients()`.
+
+ Example:
+
+ ```python
+ # Create an optimizer.
+ opt = GradientDescentOptimizer(learning_rate=0.1)
+
+ # Compute the gradients for a list of variables.
+ grads_and_vars = opt.compute_gradients(loss, <list of variables>)
+
+ # grads_and_vars is a list of tuples (gradient, variable). Do whatever you
+ # need to the 'gradient' part, for example cap them, etc.
+ capped_grads_and_vars = [(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
+
+ # Ask the optimizer to apply the capped gradients.
+ opt.apply_gradients(capped_grads_and_vars)
+ ```
+
+ ### Gating Gradients
+
+ Both `minimize()` and `compute_gradients()` accept a `gate_gradients`
+ argument that controls the degree of parallelism during the application of
+ the gradients.
+
+ The possible values are: `GATE_NONE`, `GATE_OP`, and `GATE_GRAPH`.
+
+ <b>`GATE_NONE`</b>: Compute and apply gradients in parallel. This provides
+ the maximum parallelism in execution, at the cost of some non-reproducibility
+ in the results. For example the two gradients of `matmul` depend on the input
+ values: With `GATE_NONE` one of the gradients could be applied to one of the
+ inputs _before_ the other gradient is computed resulting in non-reproducible
+ results.
+
+ <b>`GATE_OP`</b>: For each Op, make sure all gradients are computed before
+ they are used. This prevents race conditions for Ops that generate gradients
+ for multiple inputs where the gradients depend on the inputs.
+
+ <b>`GATE_GRAPH`</b>: Make sure all gradients for all variables are computed
+ before any one of them is used. This provides the least parallelism but can
+ be useful if you want to process all gradients before applying any of them.
+
+ ### Slots
+
+ Some optimizer subclasses, such as `MomentumOptimizer` and `AdagradOptimizer`
+ allocate and manage additional variables associated with the variables to
+ train. These are called <i>Slots</i>. Slots have names and you can ask the
+ optimizer for the names of the slots that it uses. Once you have a slot name
+ you can ask the optimizer for the variable it created to hold the slot value.
+
+ This can be useful if you want to log debug a training algorithm, report stats
+ about the slots, etc.
+
+ ### Non-slot variables
+
+ Some optimizer subclasses, such as `AdamOptimizer` have variables that
+ are not associated with the variables to train, just the step itself.
+
+ ### Hyper parameters
+
+ These are arguments passed to the optimizer subclass constructor
+ (the `__init__` method), and then passed to `self._set_hyper()`.
+ They can be either regular Python values (like 1.0), tensors, or
+ callables. If they are callable, the callable will be called during
+ `apply_gradients()` to get the value for the hyper parameter.
+
+ ### State
+
+ Internal methods are passed a `state` argument with the correct
+ values to use for the slot and non-slot variables, and the hyper
+ parameters.
+ """
+
+ # Values for gate_gradients.
+ GATE_NONE = 0
+ GATE_OP = 1
+ GATE_GRAPH = 2
+
+ def __init__(self, name):
+ """Create a new Optimizer.
+
+ This must be called by the constructors of subclasses.
+ Note that Optimizer instances should not bind to a single graph,
+ and so shouldn't keep Tensors as member variables. Generally
+ you should be able to use the _set_hyper()/state.get_hyper()
+ facility instead.
+
+ Args:
+ name: A non-empty string. The name to use for accumulators created
+ for the optimizer.
+
+ Raises:
+ ValueError: If name is malformed.
+ RuntimeError: If _create_slots has been overridden instead of
+ _create_vars.
+ """
+ # Note: We intentionally don't call parent __init__.
+
+ # Optimizer._create_slots was replaced by _create_vars in OptimizerV2.
+ if (self.__class__._create_slots.__code__ is not # pylint: disable=protected-access
+ OptimizerV2._create_slots.__code__):
+ raise RuntimeError("Override _create_vars instead of _create_slots when "
+ "descending from OptimizerV2 (class %s)" %
+ self.__class__.__name__)
+ if not name:
+ raise ValueError("Must specify the optimizer name")
+
+ self._use_locking = False
+ self._name = name
+ # Map from graph_key to state for that graph. We use the graph_key
+ # since it works in both eager and graph mode, and gives the outer
+ # graph inside functions.
+ tower_context = distribution_strategy_context.get_tower_context()
+ if tower_context is None:
+ # In a cross-tower context for a DistributionStrategy, which means
+ # only one Optimizer will be created, not one per tower.
+ self._per_graph_state = {}
+ else:
+ # We use get_tower_context().merge_call() to get a single dict
+ # shared across all model replicas when running with a
+ # DistributionStrategy.
+ self._per_graph_state = tower_context.merge_call(lambda _: {})
+
+ # Hyper parameters, and whether they should be re-evaluated every step.
+ self._hyper = {}
+
+ def _set_hyper(self, name, value):
+ self._hyper[name] = (_is_dynamic(value), value)
+
+ def minimize(self, loss, global_step=None, var_list=None,
+ gate_gradients=GATE_OP, aggregation_method=None,
+ colocate_gradients_with_ops=False, name=None,
+ grad_loss=None, stop_gradients=None,
+ scale_loss_by_num_towers=None):
+ """Add operations to minimize `loss` by updating `var_list`.
+
+ This method simply combines calls `compute_gradients()` and
+ `apply_gradients()`. If you want to process the gradient before applying
+ them call `compute_gradients()` and `apply_gradients()` explicitly instead
+ of using this function.
+
+ Args:
+ loss: A `Tensor` containing the value to minimize.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ var_list: Optional list or tuple of `Variable` objects to update to
+ minimize `loss`. Defaults to the list of variables collected in
+ the graph under the key `GraphKeys.TRAINABLE_VARIABLES`.
+ gate_gradients: How to gate the computation of gradients. Can be
+ `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ name: Optional name for the returned operation.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+ stop_gradients: Optional. A Tensor or list of tensors not to differentiate
+ through.
+ scale_loss_by_num_towers: Optional boolean. If true, scale the loss
+ down by the number of towers. By default, auto-detects whether this
+ is needed.
+
+ Returns:
+ An Operation that updates the variables in `var_list`. If `global_step`
+ was not `None`, that operation also increments `global_step`.
+
+ Raises:
+ ValueError: If some of the variables are not `Variable` objects.
+
+ @compatibility(eager)
+ When eager execution is enabled, `loss` should be a Python function that
+ takes elements of `var_list` as arguments and computes the value to be
+ minimized. If `var_list` is None, `loss` should take no arguments.
+ Minimization (and gradient computation) is done with respect to the
+ elements of `var_list` if not None, else with respect to any trainable
+ variables created during the execution of the `loss` function.
+ `gate_gradients`, `aggregation_method`, `colocate_gradients_with_ops` and
+ `grad_loss` are ignored when eager execution is enabled.
+ @end_compatibility
+ """
+ grads_and_vars = self.compute_gradients(
+ loss, var_list=var_list, gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss, stop_gradients=stop_gradients,
+ scale_loss_by_num_towers=scale_loss_by_num_towers)
+
+ vars_with_grad = [v for g, v in grads_and_vars if g is not None]
+ if not vars_with_grad:
+ raise ValueError(
+ "No gradients provided for any variable, check your graph for ops"
+ " that do not support gradients, between variables %s and loss %s." %
+ ([str(v) for _, v in grads_and_vars], loss))
+
+ return self.apply_gradients(grads_and_vars, global_step=global_step,
+ name=name)
+
+ def compute_gradients(self, loss, var_list=None,
+ gate_gradients=GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ grad_loss=None, stop_gradients=None,
+ scale_loss_by_num_towers=None):
+ """Compute gradients of `loss` for the variables in `var_list`.
+
+ This is the first part of `minimize()`. It returns a list
+ of (gradient, variable) pairs where "gradient" is the gradient
+ for "variable". Note that "gradient" can be a `Tensor`, an
+ `IndexedSlices`, or `None` if there is no gradient for the
+ given variable.
+
+ Args:
+ loss: A Tensor containing the value to minimize or a callable taking
+ no arguments which returns the value to minimize. When eager execution
+ is enabled it must be a callable.
+ var_list: Optional list or tuple of `tf.Variable` to update to minimize
+ `loss`. Defaults to the list of variables collected in the graph
+ under the key `GraphKeys.TRAINABLE_VARIABLES`.
+ gate_gradients: How to gate the computation of gradients. Can be
+ `GATE_NONE`, `GATE_OP`, or `GATE_GRAPH`.
+ aggregation_method: Specifies the method used to combine gradient terms.
+ Valid values are defined in the class `AggregationMethod`.
+ colocate_gradients_with_ops: If True, try colocating gradients with
+ the corresponding op.
+ grad_loss: Optional. A `Tensor` holding the gradient computed for `loss`.
+ stop_gradients: Optional. A Tensor or list of tensors not to differentiate
+ through.
+ scale_loss_by_num_towers: Optional boolean. If true, scale the loss
+ down by the number of towers. By default, auto-detects whether this
+ is needed.
+
+ Returns:
+ A list of (gradient, variable) pairs. Variable is always present, but
+ gradient can be `None`.
+
+ Raises:
+ TypeError: If `var_list` contains anything else than `Variable` objects.
+ ValueError: If some arguments are invalid.
+ RuntimeError: If called with eager execution enabled and `loss` is
+ not callable.
+
+ @compatibility(eager)
+ When eager execution is enabled, `gate_gradients`, `aggregation_method`,
+ and `colocate_gradients_with_ops` are ignored.
+ @end_compatibility
+ """
+ # TODO(josh11b): Test that we handle weight decay in a reasonable way.
+ if callable(loss):
+ with backprop.GradientTape() as tape:
+ if var_list is not None:
+ tape.watch(var_list)
+ loss_value = loss()
+
+ # Scale loss for number of towers (callable-loss case). In this case,
+ # we have to be careful to call distribute_lib.get_loss_reduction()
+ # *after* loss() is evaluated, so we know what loss reduction it uses.
+ if scale_loss_by_num_towers is None:
+ scale_loss_by_num_towers = (
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
+ if scale_loss_by_num_towers:
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
+ if num_towers > 1:
+ loss_value *= 1. / num_towers
+
+ if var_list is None:
+ var_list = tape.watched_variables()
+ grads = tape.gradient(loss_value, var_list, grad_loss)
+ return list(zip(grads, var_list))
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "`loss` passed to Optimizer.compute_gradients should "
+ "be a function when eager execution is enabled.")
+
+ # Scale loss for number of towers (non-callable-loss case).
+ if scale_loss_by_num_towers is None:
+ scale_loss_by_num_towers = (
+ distribute_lib.get_loss_reduction() ==
+ variable_scope.VariableAggregation.MEAN)
+ if scale_loss_by_num_towers:
+ num_towers = distribution_strategy_context.get_distribution_strategy(
+ ).num_towers
+ if num_towers > 1:
+ loss *= 1. / num_towers
+
+ if gate_gradients not in [optimizer_v1.Optimizer.GATE_NONE,
+ optimizer_v1.Optimizer.GATE_OP,
+ optimizer_v1.Optimizer.GATE_GRAPH]:
+ raise ValueError("gate_gradients must be one of: Optimizer.GATE_NONE, "
+ "Optimizer.GATE_OP, Optimizer.GATE_GRAPH. Not %s" %
+ gate_gradients)
+ self._assert_valid_dtypes([loss])
+ if grad_loss is not None:
+ self._assert_valid_dtypes([grad_loss])
+ if var_list is None:
+ var_list = (
+ variables.trainable_variables() +
+ ops.get_collection(ops.GraphKeys.TRAINABLE_RESOURCE_VARIABLES))
+ else:
+ var_list = nest.flatten(var_list)
+ # pylint: disable=protected-access
+ var_list += ops.get_collection(ops.GraphKeys._STREAMING_MODEL_PORTS)
+ # pylint: enable=protected-access
+ processors = [_get_processor(v) for v in var_list]
+ if not var_list:
+ raise ValueError("No variables to optimize.")
+ var_refs = [p.target() for p in processors]
+ grads = gradients.gradients(
+ loss, var_refs, grad_ys=grad_loss,
+ gate_gradients=(gate_gradients == optimizer_v1.Optimizer.GATE_OP),
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ stop_gradients=stop_gradients)
+ if gate_gradients == optimizer_v1.Optimizer.GATE_GRAPH:
+ grads = control_flow_ops.tuple(grads)
+ grads_and_vars = list(zip(grads, var_list))
+ self._assert_valid_dtypes(
+ [v for g, v in grads_and_vars
+ if g is not None and v.dtype != dtypes.resource])
+ return grads_and_vars
+
+ def apply_gradients(self, grads_and_vars, global_step=None, name=None):
+ """Apply gradients to variables.
+
+ This is the second part of `minimize()`. It returns an `Operation` that
+ applies gradients.
+
+ Args:
+ grads_and_vars: List of (gradient, variable) pairs as returned by
+ `compute_gradients()`.
+ global_step: Optional `Variable` to increment by one after the
+ variables have been updated.
+ name: Optional name for the returned operation. Default to the
+ name passed to the `Optimizer` constructor.
+
+ Returns:
+ An `Operation` that applies the specified gradients. If `global_step`
+ was not None, that operation also increments `global_step`.
+
+ Raises:
+ TypeError: If `grads_and_vars` is malformed.
+ ValueError: If none of the variables have gradients.
+ """
+ # This is a default implementation of apply_gradients() that can be shared
+ # by most optimizers. It relies on the subclass implementing the following
+ # methods: _create_vars(), _prepare(), _apply_dense(), and _apply_sparse().
+
+ # Filter out variables with gradients of `None`.
+ grads_and_vars = tuple(grads_and_vars) # Make sure repeat iteration works.
+ if not grads_and_vars:
+ raise ValueError("No variables provided.")
+ filtered = tuple((g, v) for (g, v) in grads_and_vars if g is not None)
+ if not filtered:
+ raise ValueError("No gradients provided for any variable: %s." %
+ ([str(v) for _, v in grads_and_vars],))
+ return distribution_strategy_context.get_tower_context().merge_call(
+ self._distributed_apply, filtered, global_step=global_step, name=name)
+
+ def _get_or_create_state(self, var_list=None):
+ """Either looks up or creates `_OptimizerV2State`.
+
+ If any variables are available, they should be passed via the `var_list`
+ argument, and these will be used to determine the graph to create/retrieve
+ state for. Otherwise the returned state is for the current default graph.
+
+ Args:
+ var_list: A list of variables to extract a graph from.
+
+ Returns:
+ An `_OptimizerV2State` object.
+ """
+ # Determine the graph_key from the current graph.
+ eager_execution = context.executing_eagerly()
+ if eager_execution or var_list is None:
+ graph = ops.get_default_graph()
+ else:
+ graph = ops._get_graph_from_inputs(var_list) # pylint: disable=protected-access
+ assert graph is not None
+ graph_key = graph._graph_key # pylint: disable=protected-access
+
+ # Get the per graph state by looking up the graph_key.
+ if graph_key in self._per_graph_state:
+ per_graph_state = self._per_graph_state[graph_key]
+ else:
+ per_graph_state = _OptimizerV2State(self._name)
+ per_graph_state._init_with_static_hyper(self._hyper) # pylint: disable=protected-access
+ self._per_graph_state[graph_key] = per_graph_state
+ return per_graph_state
+
+ def _distributed_apply(self, distribution, grads_and_vars, global_step, name):
+ """`apply_gradients` for use with a `DistributionStrategy`."""
+ reduced_grads = distribution.batch_reduce(
+ variable_scope.VariableAggregation.SUM, grads_and_vars)
+ var_list = [v for _, v in grads_and_vars]
+ grads_and_vars = zip(reduced_grads, var_list)
+
+ unwrapped_var_list = [x for v in var_list for x in distribution.unwrap(v)]
+ eager_execution = context.executing_eagerly()
+ if eager_execution:
+ # Give a clear error in this case instead of "name not supported
+ # for Eager Tensors" when we compute non_slot_devices.
+ for v in unwrapped_var_list:
+ if isinstance(v, ops.Tensor):
+ raise NotImplementedError("Trying to update a Tensor ", v)
+
+ with ops.name_scope(name, self._name) as name:
+ per_graph_state = self._get_or_create_state(var_list=unwrapped_var_list)
+ # Include the current value of any dynamic hyper parameters in `state`.
+ non_slot_devices = distribution.non_slot_devices(var_list)
+ state = per_graph_state._copy_with_dynamic_hyper( # pylint: disable=protected-access
+ self._hyper, distribution, non_slot_devices)
+
+ # Create any slot and non-slot variables we need in `state`.
+ with ops.init_scope():
+ self._create_vars(var_list, state)
+
+ with ops.name_scope(name): # Re-enter name_scope created above
+ # Give the child class a chance to do something before we start
+ # applying gradients.
+ self._prepare(state)
+
+ def update(v, g):
+ """Update variable `v` using gradient `g`."""
+ assert v is not None
+
+ # Convert the grad to Tensor or IndexedSlices if necessary, and
+ # look up a processor for each variable's type.
+ try:
+ g = ops.convert_to_tensor_or_indexed_slices(g)
+ except TypeError:
+ raise TypeError(
+ "Gradient must be convertible to a Tensor"
+ " or IndexedSlices, or None: %s" % g)
+ if not isinstance(g, (ops.Tensor, ops.IndexedSlices)):
+ raise TypeError(
+ "Gradient must be a Tensor, IndexedSlices, or None: %s" % g)
+ processor = _get_processor(v)
+
+ # We colocate all ops created in _apply_dense or _apply_sparse
+ # on the same device as the variable.
+ # TODO(apassos): figure out how to get the variable name here.
+ scope_name = "" if eager_execution else v.op.name
+ # device_policy is set because non-mirrored tensors will be read in
+ # `update_op`.
+ # TODO(josh11b): Make different state objects for each device to
+ # avoid needing to set the device_policy.
+ with ops.name_scope("update_" + scope_name), \
+ context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ return processor.update_op(self, g, state)
+
+ # Use the processors to update the variables.
+ update_ops = []
+ for grad, var in grads_and_vars:
+ update_ops.extend(distribution.update(var, update, grad, grouped=False))
+
+ # Give the child class a chance to do something after applying
+ # gradients
+ def finish():
+ # TODO(josh11b): Make different state objects for each device to
+ # avoid needing to set the device_policy.
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ return self._finish(state)
+
+ update_ops = control_flow_ops.group(update_ops)
+ with ops.control_dependencies([update_ops]):
+ finish_updates = distribution.update_non_slot(
+ non_slot_devices, finish, grouped=False)
+ # We said grouped=False, which means finish_updates is always a list.
+ # It will be [None] when finish() returns None.
+ if finish_updates == [None]:
+ finish_updates = [update_ops]
+
+ # Update `global_step` (if any).
+ if global_step is None:
+ apply_updates = distribution.group(finish_updates, name=name)
+ else:
+ with ops.control_dependencies(finish_updates):
+
+ def update_global_step(global_step, name):
+ return global_step.assign_add(1, read_value=False, name=name)
+
+ apply_updates = distribution.update(global_step, update_global_step,
+ name)
+
+ # Add the training op to the TRAIN_OP graph collection in graph mode.
+ if not eager_execution:
+ if isinstance(apply_updates, ops.Tensor):
+ apply_updates = apply_updates.op
+ train_op = ops.get_collection_ref(ops.GraphKeys.TRAIN_OP)
+ if apply_updates not in train_op:
+ train_op.append(apply_updates)
+
+ return apply_updates
+
+ def get_slot(self, var, name):
+ """Return a slot named `name` created for `var` by the Optimizer.
+
+ Some `Optimizer` subclasses use additional variables. For example
+ `Momentum` and `Adagrad` use variables to accumulate updates. This method
+ gives access to these `Variable` objects if for some reason you need them.
+
+ Use `get_slot_names()` to get the list of slot names created by the
+ `Optimizer`.
+
+ Args:
+ var: A variable passed to `minimize()` or `apply_gradients()`.
+ name: A string.
+
+ Returns:
+ The `Variable` for the slot if it was created, `None` otherwise.
+ """
+ state = self._get_state_for_var(var)
+ return state.get_slot(var, name) if state is not None else None
+
+ def get_slot_names(self):
+ """Return a list of the names of slots created by the `Optimizer`.
+
+ See `get_slot()`.
+
+ Returns:
+ A list of strings.
+ """
+ state = self._get_per_graph_state()
+ return state.get_slot_names() if state is not None else []
+
+ def variables(self):
+ """A list of variables which encode the current state of `Optimizer`.
+
+ Includes slot variables and additional global variables created by the
+ optimizer in the current default graph.
+
+ Returns:
+ A list of variables.
+ """
+ state = self._get_per_graph_state()
+ return state._variables() if state is not None else [] # pylint: disable=protected-access
+
+ # --------------
+ # Methods to be implemented by subclasses if they want to use the
+ # inherited implementation of apply_gradients() or compute_gradients().
+ # --------------
+ def _create_vars(self, var_list, state):
+ """Create all slots needed by the variables and any non-slot variables.
+
+ Args:
+ var_list: A list of `Variable` objects.
+ state: An object with these methods:
+ `create_slot(var, val, slot_name, optional_op_name)`,
+ `create_slot_with_initializer(`
+ `var, initializer, shape, dtype, slot_name, optional_op_name)`,
+ `zeros_slot(var, slot_name, optional_op_name)`,
+ `create_non_slot_variable(initial_value, name, colocate_with)`,
+ `get_hyper(name)`
+ """
+ # No slots needed by default
+ pass
+
+ def _prepare(self, state):
+ """Code to execute before applying gradients.
+
+ Note that most uses of _prepare() in Optimizer have been subsumed
+ by explicit support for hyper parameters in OptimizerV2
+
+ Args:
+ state: An object with a `get_hyper(name)` method.
+
+ Returns:
+ Return value will be ignored.
+ """
+ pass
+
+ def _apply_dense(self, grad, var, state):
+ """Add ops to apply dense gradients to `var`.
+
+ Args:
+ grad: A `Tensor`.
+ var: A `Variable` object.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation`.
+ """
+ raise NotImplementedError()
+
+ def _resource_apply_dense(self, grad, handle, state):
+ """Add ops to apply dense gradients to the variable `handle`.
+
+ Args:
+ grad: a `Tensor` representing the gradient.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+ """
+ raise NotImplementedError()
+
+ def _resource_apply_sparse_duplicate_indices(
+ self, grad, handle, indices, state):
+ """Add ops to apply sparse gradients to `handle`, with repeated indices.
+
+ Optimizers which override this method must deal with repeated indices. See
+ the docstring of `_apply_sparse_duplicate_indices` for details. By default
+ the correct behavior, to sum non-unique indices and their associated
+ gradients, is enforced by first pre-processing `grad` and `indices` and
+ passing them on to `_resource_apply_sparse`. Optimizers which deal correctly
+ with duplicate indices may instead override this method to avoid the
+ overhead of summing.
+
+ Args:
+ grad: a `Tensor` representing the gradient for the affected indices.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+ indices: a `Tensor` of integral type representing the indices for
+ which the gradient is nonzero. Indices may be repeated.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+ """
+ # pylint: disable=protected-access
+ summed_grad, unique_indices = optimizer_v1._deduplicate_indexed_slices(
+ values=grad, indices=indices)
+ # pylint: enable=protected-access
+ return self._resource_apply_sparse(
+ summed_grad, handle, unique_indices, state)
+
+ def _resource_apply_sparse(self, grad, handle, indices, state):
+ """Add ops to apply sparse gradients to the variable `handle`.
+
+ Similar to `_apply_sparse`, the `indices` argument to this method has been
+ de-duplicated. Optimizers which deal correctly with non-unique indices may
+ instead override `_resource_apply_sparse_duplicate_indices` to avoid this
+ overhead.
+
+ Args:
+ grad: a `Tensor` representing the gradient for the affected indices.
+ handle: a `Tensor` of dtype `resource` which points to the variable
+ to be updated.
+ indices: a `Tensor` of integral type representing the indices for
+ which the gradient is nonzero. Indices are unique.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation` which updates the value of the variable.
+ """
+ raise NotImplementedError()
+
+ def _apply_sparse_duplicate_indices(self, grad, var, state):
+ """Add ops to apply sparse gradients to `var`, with repeated sparse indices.
+
+ Optimizers which override this method must deal with IndexedSlices objects
+ such as the following:
+
+ IndexedSlicesValue(values=[1, 1], indices=[0, 0], dense_shape=[1])
+
+ The correct interpretation is:
+
+ IndexedSlicesValue(values=[2], indices=[0], dense_shape=[1])
+
+ Many optimizers deal incorrectly with repeated indices when updating based
+ on sparse gradients (e.g. summing squares rather than squaring the sum, or
+ applying momentum terms multiple times). Adding first is always the correct
+ behavior, so this is enforced here by reconstructing the IndexedSlices to
+ have only unique indices, then calling _apply_sparse.
+
+ Optimizers which deal correctly with repeated indices may instead override
+ this method to avoid the overhead of summing indices.
+
+ Args:
+ grad: `IndexedSlices`.
+ var: A `Variable` object.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation`.
+ """
+ # pylint: disable=protected-access
+ summed_values, unique_indices = optimizer_v1._deduplicate_indexed_slices(
+ values=grad.values, indices=grad.indices)
+ # pylint: enable=protected-access
+ gradient_no_duplicate_indices = ops.IndexedSlices(
+ indices=unique_indices,
+ values=summed_values,
+ dense_shape=grad.dense_shape)
+ return self._apply_sparse(gradient_no_duplicate_indices, var, state)
+
+ def _apply_sparse(self, grad, var, state):
+ """Add ops to apply sparse gradients to `var`.
+
+ The IndexedSlices object passed to `grad` in this function is by default
+ pre-processed in `_apply_sparse_duplicate_indices` to remove duplicate
+ indices (see its docstring for details). Optimizers which can tolerate or
+ have correct special cases for duplicate sparse indices may override
+ `_apply_sparse_duplicate_indices` instead of this function, avoiding that
+ overhead.
+
+ Args:
+ grad: `IndexedSlices`, with no repeated indices.
+ var: A `Variable` object.
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ An `Operation`.
+ """
+ raise NotImplementedError()
+
+ def _finish(self, state):
+ """Do what is needed to finish the update.
+
+ This is called inside a scope colocated with any non-slot variables.
+
+ Args:
+ state: An object with `get_slot(var, name)`, `get_non_slot(self, name)`,
+ and `get_hyper(name)` methods.
+
+ Returns:
+ The operation to apply updates, or None if no updates.
+ """
+ return None
+
+ # --------------
+ # Utility methods for subclasses.
+ # --------------
+ def _get_per_graph_state(self):
+ # pylint: disable=protected-access
+ return self._per_graph_state.get(ops.get_default_graph()._graph_key, None)
+
+ def _get_state_for_var(self, var):
+ # pylint: disable=protected-access
+ return self._per_graph_state.get(var._graph_key, None)
+
+ # --------------
+ # Overridden methods from Checkpointable.
+ # --------------
+
+ def _track_checkpointable(self, *args, **kwargs):
+ """Optimizers may not track dependencies. Raises an error."""
+ raise NotImplementedError(
+ "Optimizers may not have dependencies. File a feature request if this "
+ "limitation bothers you.")
+
+ @property
+ def _checkpoint_dependencies(self):
+ """From Checkpointable. Gather graph-specific non-slot variables to save."""
+ current_graph_non_slot_variables = []
+ state = self._get_per_graph_state()
+ if state is not None:
+ for name, variable_object in sorted(
+ state._non_slot_dict.items(), # pylint: disable=protected-access
+ # Avoid comparing variables
+ key=lambda item: item[0]):
+ current_graph_non_slot_variables.append(
+ checkpointable.CheckpointableReference(
+ name=name, ref=variable_object))
+ # Note: ignores super(); Optimizers may not have any dependencies outside of
+ # state objects.
+ return current_graph_non_slot_variables
+
+ def _lookup_dependency(self, name):
+ """From Checkpointable. Find a non-slot variable in the current graph."""
+ state = self._get_per_graph_state()
+ if state is None:
+ return None
+ else:
+ return state.get_non_slot(name)
+
+ @property
+ def _deferred_dependencies(self):
+ """Lets Checkpointable know where non-slot variables are created.
+
+ If necessary, creates a new state object for the current default graph.
+ Checkpointable will then add entries to that state's deferred dependency
+ dictionary. The state object will check that dictionary when creating
+ non-slot variables, restoring their value if an entry is found.
+
+ Returns:
+ A dictionary which holds deferred dependencies for the current default
+ graph.
+ """
+ state = self._get_or_create_state()
+ return state._deferred_dependencies # pylint: disable=protected-access
+
+ def _create_or_restore_slot_variable(
+ self, slot_variable_position, slot_name, variable):
+ """Checkpointable: Restore a slot variable's value, possibly creating it.
+
+ Called when a variable which has an associated slot variable is created or
+ restored.
+
+ Args:
+ slot_variable_position: A `checkpointable._CheckpointPosition` object
+ indicating the slot variable `Checkpointable` object to be restored.
+ slot_name: The name of this `Optimizer`'s slot to restore into.
+ variable: The variable object this slot is being created for.
+ """
+ state = self._get_or_create_state(var_list=[variable])
+ state._create_or_restore_slot_variable( # pylint: disable=protected-access
+ slot_variable_position=slot_variable_position,
+ slot_name=slot_name,
+ variable=variable,
+ optional_op_name=self._name)
+
+ # --------------
+ # Unsupported parent methods
+ # --------------
+ def _slot_dict(self, slot_name):
+ raise NotImplementedError(
+ "_slot_dict() method unsupported in OptimizerV2")
+
+ def _get_or_make_slot(self, var, val, slot_name, op_name):
+ raise NotImplementedError(
+ "_get_or_make_slot() method unsupported in OptimizerV2")
+
+ def _get_or_make_slot_with_initializer(self, var, initializer, shape, dtype,
+ slot_name, op_name):
+ raise NotImplementedError(
+ "_get_or_make_slot_with_initializer() method unsupported in "
+ "OptimizerV2")
+
+ def _create_non_slot_variable(self, initial_value, name, colocate_with):
+ raise NotImplementedError(
+ "_create_non_slot_variable() method unsupported in OptimizerV2")
+
+ def _get_non_slot_variable(self, name, graph=None):
+ raise NotImplementedError(
+ "_get_non_slot_variable() method unsupported in OptimizerV2")
+
+ def _non_slot_variables(self):
+ raise NotImplementedError(
+ "_non_slot_variables() method unsupported in OptimizerV2")
diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
new file mode 100644
index 0000000000..a6c939393e
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py
@@ -0,0 +1,277 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional test for OptimizerV2."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import clip_ops
+from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class OptimizerTest(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testBasic(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ def loss():
+ return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
+ # Note that for eager execution, minimize expects a function instead of a
+ # Tensor.
+ global_step = resource_variable_ops.ResourceVariable(
+ array_ops.zeros([], dtypes.int64), name='global_step_%d' % i)
+ sgd_op = sgd.SGD(3.0)
+
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+ # Run 1 step of sgd through optimizer
+ opt_op = sgd_op.minimize(loss, global_step, [var0, var1])
+ self.evaluate(opt_op)
+ # Validate updated params
+ self.assertAllClose([-14., -13.], self.evaluate(var0))
+ self.assertAllClose([-6., -5.], self.evaluate(var1))
+
+ def testAggregationMethod(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ cost = 5 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(
+ cost,
+ global_step, [var0, var1],
+ aggregation_method=gradients_impl.AggregationMethod.
+ EXPERIMENTAL_ACCUMULATE_N)
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([-14., -13.], var0.eval())
+ self.assertAllClose([-6., -5.], var1.eval())
+
+ def testPrecomputedGradient(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ cost = 5 * var0 + 3 * var1
+ grad_loss = constant_op.constant([42, -42], dtype=dtype)
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(
+ cost, global_step, [var0, var1], grad_loss=grad_loss)
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([1.0 - 3 * 5 * 42.0, 2.0 - 3 * 5 * (-42.0)],
+ var0.eval())
+ self.assertAllClose([3.0 - 3 * 3 * 42.0, 4.0 - 3 * 3 * (-42.0)],
+ var1.eval())
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoVariables(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ # pylint: disable=cell-var-from-loop
+ def loss():
+ var0 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype, trainable=False, name='a')
+ var1 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtype, trainable=False, name='b')
+ return 5 * var0 + var1
+ # pylint: enable=cell-var-from-loop
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError, 'No.*variables'):
+ sgd_op.minimize(loss)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoGradients(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ # pylint: disable=cell-var-from-loop
+ def loss():
+ return 5 * var0
+ # pylint: enable=cell-var-from-loop
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError, 'No gradients'):
+ # var1 has no gradient
+ sgd_op.minimize(loss, var_list=[var1])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoGradientsForAnyVariables_Minimize(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ def loss():
+ return constant_op.constant(5.0)
+
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError,
+ 'No gradients provided for any variable'):
+ sgd_op.minimize(loss, var_list=[var0, var1])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testNoGradientsForAnyVariables_ApplyGradients(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ sgd_op = sgd.SGD(3.0)
+ with self.assertRaisesRegexp(ValueError,
+ 'No gradients provided for any variable'):
+ sgd_op.apply_gradients([(None, var0), (None, var1)])
+
+ @test_util.run_in_graph_and_eager_modes
+ def testGradientsAsVariables(self):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ def loss():
+ return 5 * var0 + 3 * var1 # pylint: disable=cell-var-from-loop
+
+ sgd_op = sgd.SGD(3.0)
+ grads_and_vars = sgd_op.compute_gradients(loss, [var0, var1])
+ # Convert gradients to tf.Variables
+ converted_grads = [
+ resource_variable_ops.ResourceVariable(array_ops.zeros([2], dtype),
+ name='c_%d_%d' % (i, j))
+ for j, gv in enumerate(grads_and_vars)
+ ]
+ convert_ops = [
+ state_ops.assign(converted_grads[j], gv[0])
+ for j, gv in enumerate(grads_and_vars)
+ ]
+
+ self.evaluate(variables.global_variables_initializer())
+ # Run convert_ops to achieve the gradietns converting
+ self.evaluate(convert_ops)
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ # Run 1 step of sgd through optimizer
+ converted_grads_and_vars = list(zip(converted_grads, [var0, var1]))
+ opt_op = sgd_op.apply_gradients(converted_grads_and_vars)
+ self.evaluate(opt_op)
+
+ # Validate updated params
+ self.assertAllClose([-14., -13.], self.evaluate(var0))
+ self.assertAllClose([-6., -5.], self.evaluate(var1))
+
+ @test_util.run_in_graph_and_eager_modes
+ def testComputeGradientsWithTensors(self):
+ x = ops.convert_to_tensor(1.0)
+ def f():
+ return x * x
+
+ sgd_op = sgd.SGD(3.0)
+ grads_and_vars = sgd_op.compute_gradients(f, [x])
+ self.assertEqual(1, len(grads_and_vars))
+ grad, x_as_var = grads_and_vars[0]
+ self.assertIs(x, x_as_var)
+ self.assertEqual(2.0, self.evaluate(grad))
+
+ with self.assertRaises(NotImplementedError):
+ sgd_op.apply_gradients(grads_and_vars)
+
+ def testTrainOp(self):
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0])
+ var1 = variables.Variable([3.0, 4.0])
+ cost = 5 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
+ self.assertTrue(opt_op in ops.get_collection(ops.GraphKeys.TRAIN_OP))
+
+ def testConstraint(self):
+ constraint_01 = lambda x: clip_ops.clip_by_value(x, -0.1, 0.)
+ constraint_0 = lambda x: clip_ops.clip_by_value(x, 0., 1.)
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0],
+ constraint=constraint_01)
+ var1 = variables.Variable([3.0, 4.0],
+ constraint=constraint_0)
+ cost = 5 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name='global_step')
+ sgd_op = sgd.SGD(3.0)
+ opt_op = sgd_op.minimize(cost, global_step, [var0, var1])
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd through optimizer
+ opt_op.run()
+ # Validate updated params
+ self.assertAllClose([-0.1, -0.1], var0.eval())
+ self.assertAllClose([0., 0.], var1.eval())
+
+ def testStopGradients(self):
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], name='var0')
+ var1 = variables.Variable([3.0, 4.0], name='var1')
+ var0_id = array_ops.identity(var0)
+ cost = 5 * var0_id + 3 * var1
+ sgd_op = sgd.SGD(3.0)
+ grads_and_vars = sgd_op.compute_gradients(cost, [var0, var1],
+ stop_gradients=[var0_id])
+ grad_dict = {var.op.name: grad for grad, var in grads_and_vars}
+ self.assertIsNone(grad_dict['var0'])
+ self.assertIsNotNone(grad_dict['var1'])
+
+ def testDoNotOverrideCreateSlots(self):
+ class ShouldNotOverrideCreateSlots(optimizer_v2.OptimizerV2):
+
+ def _create_slots(self, var_list):
+ """In OptimizerV2 _create_slots was renamed _create_vars."""
+ return var_list
+
+ with self.assertRaises(RuntimeError):
+ ShouldNotOverrideCreateSlots('name')
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop.py b/tensorflow/python/keras/optimizer_v2/rmsprop.py
new file mode 100644
index 0000000000..2748d8eff7
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/rmsprop.py
@@ -0,0 +1,239 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""RMSprop optimizer for Tensorflow.
+
+rmsprop algorithm [tieleman2012rmsprop]
+
+A detailed description of rmsprop.
+
+- maintain a moving (discounted) average of the square of gradients
+- divide gradient by the root of this average
+
+mean_square = rho * mean_square{t-1} + (1-rho) * gradient ** 2
+mom = momentum * mom{t-1} + learning_rate * g_t / sqrt(mean_square)
+delta = - mom
+
+This implementation of RMSProp uses plain momentum, not Nesterov momentum.
+
+The centered version additionally maintains a moving (discounted) average of the
+gradients, and uses that average to estimate the variance:
+
+mean_grad = rho * mean_square{t-1} + (1-rho) * gradient
+mean_square = rho * mean_square{t-1} + (1-rho) * gradient ** 2
+mom = momentum * mom{t-1} + learning_rate * g_t /
+ sqrt(mean_square - mean_grad**2)
+delta = - mom
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import array_ops
+
+from tensorflow.python.training import training_ops
+
+
+class RMSProp(optimizer_v2.OptimizerV2):
+ """RMSProp optimizer.
+
+ It is recommended to leave the parameters of this optimizer at their default
+ values (except the learning rate, which can be freely tuned).
+
+ This optimizer is usually a good choice for recurrent neural networks.
+
+ Some of the args below are hyperparameters, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Note that in the dense implementation of this algorithm, variables and their
+ corresponding accumulators (momentum, gradient moving average, square
+ gradient moving average) will be updated even if the gradient is zero
+ (i.e. accumulators will decay, momentum will be applied). The sparse
+ implementation (used when the gradient is an `IndexedSlices` object,
+ typically because of `tf.gather` or an embedding lookup in the forward pass)
+ will not update variable slices or their accumulators unless those slices
+ were used in the forward pass (nor is there an "eventual" correction to
+ account for these omitted updates). This leads to more efficient updates for
+ large embedding lookup tables (where most of the slices are not accessed in
+ a particular graph execution), but differs from the published algorithm.
+
+ Arguments:
+ learning_rate: A float hyperparameter >= 0. The learning rate.
+ rho: A float hyperparameter >= 0. Discounting factor for the
+ history/coming gradient.
+ momentum: A float hyperparameter >= 0.
+ epsilon: A float hyperparameter >= 0 . Small value to initialize the
+ average square gradient variable and avoid zero denominator.
+ centered: If True, gradients are normalized by the estimated variance of
+ the gradient; if False, by the uncentered second moment. Setting this to
+ True may help with training, but is slightly more expensive in terms of
+ computation and memory. Defaults to False.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to "RMSProp".
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ rho=0.9,
+ momentum=None,
+ epsilon=1e-10,
+ centered=False,
+ name="RMSProp"):
+ super(RMSProp, self).__init__(name)
+ # Momentum default is `None` for consistency with SGD
+ # but underlying implementation uses `momentum` hyperparameter here
+ # regardless unlike SGD. Since extneral Keras RMSProp does not have
+ # a `momentum` weight, for compatibility with external Keras h5 files,
+ # when `momentum` was set as `None` we should ignore the `momentum`
+ # variable in `get_weights` and not require it in `set_weights`.
+ if momentum is None:
+ momentum = 0.0
+ self._set_hyper("learning_rate", learning_rate)
+ self._set_hyper("rho", rho)
+ self._set_hyper("momentum", momentum)
+ self._set_hyper("epsilon", epsilon)
+
+ self._centered = centered
+
+ def _create_vars(self, var_list, state):
+ for v in var_list:
+ init_rms = state.get_hyper(
+ "epsilon", v.dtype.base_dtype) * array_ops.ones_like(v)
+ state.create_slot_with_initializer(v, init_rms, v.get_shape(),
+ v.dtype.base_dtype, "rms")
+ if self._centered:
+ state.zeros_slot(v, "mg")
+ state.zeros_slot(v, "momentum")
+
+ def _apply_dense(self, grad, var, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = state.get_slot(var, "mg")
+ return training_ops.apply_centered_rms_prop(
+ var,
+ mg,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ # epsilon is now the rms initial value and is not added to the
+ # denominator anymore, hence calling the kernel op with epsilon=0.
+ 0,
+ grad,
+ use_locking=self._use_locking).op
+ else:
+ return training_ops.apply_rms_prop(
+ var,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ use_locking=self._use_locking).op
+
+ def _resource_apply_dense(self, grad, var, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = state.get_slot(var, "mg")
+ return training_ops.resource_apply_centered_rms_prop(
+ var.handle,
+ mg.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.resource_apply_rms_prop(
+ var.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = state.get_slot(var, "mg")
+ return training_ops.sparse_apply_centered_rms_prop(
+ var,
+ mg,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.sparse_apply_rms_prop(
+ var,
+ rms,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ rms = state.get_slot(var, "rms")
+ mom = state.get_slot(var, "momentum")
+ if self._centered:
+ mg = self.get_slot(var, "mg")
+ return training_ops.resource_sparse_apply_centered_rms_prop(
+ var.handle,
+ mg.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ indices,
+ use_locking=self._use_locking)
+ else:
+ return training_ops.resource_sparse_apply_rms_prop(
+ var.handle,
+ rms.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ state.get_hyper("rho", var.dtype.base_dtype),
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ 0,
+ grad,
+ indices,
+ use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/rmsprop_test.py b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py
new file mode 100644
index 0000000000..2c5eccdc5b
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/rmsprop_test.py
@@ -0,0 +1,444 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for rmsprop optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import copy
+import math
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import rmsprop
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+_DATA_TYPES = [dtypes.half, dtypes.float32]
+
+_TEST_PARAM_VALUES = [
+ # learning_rate, rho, momentum, epsilon, centered, use_resource
+ [0.5, 0.9, 0.0, 1.0, True, False],
+ [0.5, 0.9, 0.0, 1.0, False, False],
+ [0.5, 0.9, 0.0, 1.0, True, True],
+ [0.5, 0.9, 0.0, 1.0, False, True],
+ [0.1, 0.9, 0.0, 1.0, True, False],
+ [0.5, 0.95, 0.0, 1.0, False, False],
+ [0.5, 0.8, 0.0, 1e-3, True, False],
+ [0.5, 0.8, 0.9, 1e-3, True, False],
+]
+
+
+class RMSPropOptimizerTest(test.TestCase, parameterized.TestCase):
+
+ def _rmsprop_update_numpy(self, var, g, mg, rms, mom, lr, rho, momentum,
+ centered):
+ rms_t = rms * rho + (1 - rho) * g * g
+ if centered:
+ mg_t = mg * rho + (1 - rho) * g
+ denom_t = rms_t - mg_t * mg_t
+ else:
+ mg_t = mg
+ denom_t = rms_t
+ mom_t = momentum * mom + lr * g / np.sqrt(denom_t, dtype=denom_t.dtype)
+ var_t = var - mom_t
+ return var_t, mg_t, rms_t, mom_t
+
+ def _sparse_rmsprop_update_numpy(self, var, gindexs, gvalues, mg, rms, mom,
+ lr, rho, momentum, centered):
+ mg_t = copy.deepcopy(mg)
+ rms_t = copy.deepcopy(rms)
+ mom_t = copy.deepcopy(mom)
+ var_t = copy.deepcopy(var)
+ for i in range(len(gindexs)):
+ gindex = gindexs[i]
+ gvalue = gvalues[i]
+ rms_t[gindex] = rms[gindex] * rho + (1 - rho) * gvalue * gvalue
+ denom_t = rms_t[gindex]
+ if centered:
+ mg_t[gindex] = mg_t[gindex] * rho + (1 - rho) * gvalue
+ denom_t -= mg_t[gindex] * mg_t[gindex]
+ mom_t[gindex] = momentum * mom[gindex] + lr * gvalue / np.sqrt(denom_t)
+ var_t[gindex] = var[gindex] - mom_t[gindex]
+ return var_t, mg_t, rms_t, mom_t
+
+ @parameterized.named_parameters(
+ *test_util.generate_combinations_with_testcase_name(
+ dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES))
+ def testDense(self, dtype, param_value):
+ (learning_rate, rho, momentum, epsilon, centered,
+ use_resource) = tuple(param_value)
+ with self.test_session(use_gpu=True):
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1, 0.2], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01, 0.2], dtype=dtype.as_numpy_dtype)
+
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(var0_np)
+ var1 = resource_variable_ops.ResourceVariable(var1_np)
+ else:
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0 = constant_op.constant(grads0_np)
+ grads1 = constant_op.constant(grads1_np)
+ opt = rmsprop.RMSProp(
+ learning_rate=learning_rate,
+ rho=rho,
+ momentum=momentum,
+ epsilon=epsilon,
+ centered=centered)
+
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ mg0 = opt.get_slot(var0, "mg")
+ self.assertEqual(mg0 is not None, centered)
+ mg1 = opt.get_slot(var1, "mg")
+ self.assertEqual(mg1 is not None, centered)
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ rms0_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 4 steps of RMSProp
+ for _ in range(4):
+ update.run()
+
+ var0_np, mg0_np, rms0_np, mom0_np = self._rmsprop_update_numpy(
+ var0_np, grads0_np, mg0_np, rms0_np, mom0_np, learning_rate, rho,
+ momentum, centered)
+ var1_np, mg1_np, rms1_np, mom1_np = self._rmsprop_update_numpy(
+ var1_np, grads1_np, mg1_np, rms1_np, mom1_np, learning_rate, rho,
+ momentum, centered)
+
+ # Validate updated params
+ if centered:
+ self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
+ self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(
+ var0_np, var0.eval(), half_rtol=0.01, half_atol=0.01)
+ self.assertAllCloseAccordingToType(
+ var1_np, var1.eval(), half_rtol=0.01, half_atol=0.01)
+
+ @parameterized.parameters([dtypes.float32, dtypes.float64])
+ def testMinimizeSparseResourceVariable(self, dtype):
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = rmsprop.RMSProp(
+ learning_rate=1.0, rho=0.0, momentum=0.0, epsilon=0.0,
+ centered=False).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[0., 1.]], var0.eval(), atol=0.01)
+
+ @parameterized.parameters([dtypes.float32, dtypes.float64])
+ def testMinimizeSparseResourceVariableCentered(self, dtype):
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = rmsprop.RMSProp(
+ learning_rate=1.0, rho=0.1, momentum=0.0, epsilon=1.0,
+ centered=True).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[-7/3.0, -4/3.0]], var0.eval(), atol=0.01)
+
+ @parameterized.named_parameters(
+ *test_util.generate_combinations_with_testcase_name(
+ dtype=_DATA_TYPES, param_value=_TEST_PARAM_VALUES))
+ def testSparse(self, dtype, param_value):
+ (learning_rate, rho, momentum, epsilon, centered, _) = tuple(param_value)
+ with self.test_session(use_gpu=True):
+ # Initialize variables for numpy implementation.
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ grads0_np = np.array([0.1], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ grads1_np = np.array([0.01], dtype=dtype.as_numpy_dtype)
+
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ grads0_np_indices = np.array([0], dtype=np.int32)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(grads0_np),
+ constant_op.constant(grads0_np_indices), constant_op.constant([1]))
+ grads1_np_indices = np.array([1], dtype=np.int32)
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(grads1_np),
+ constant_op.constant(grads1_np_indices), constant_op.constant([1]))
+ opt = rmsprop.RMSProp(
+ learning_rate=learning_rate,
+ rho=rho,
+ momentum=momentum,
+ epsilon=epsilon,
+ centered=centered)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ mg0 = opt.get_slot(var0, "mg")
+ self.assertEqual(mg0 is not None, centered)
+ mg1 = opt.get_slot(var1, "mg")
+ self.assertEqual(mg1 is not None, centered)
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ mg0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mg1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ rms0_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ rms1_np = np.array([epsilon, epsilon], dtype=dtype.as_numpy_dtype)
+ mom0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ mom1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+
+ # Run 4 steps of RMSProp
+ for _ in range(4):
+ update.run()
+
+ var0_np, mg0_np, rms0_np, mom0_np = self._sparse_rmsprop_update_numpy(
+ var0_np, grads0_np_indices, grads0_np, mg0_np, rms0_np, mom0_np,
+ learning_rate, rho, momentum, centered)
+ var1_np, mg1_np, rms1_np, mom1_np = self._sparse_rmsprop_update_numpy(
+ var1_np, grads1_np_indices, grads1_np, mg1_np, rms1_np, mom1_np,
+ learning_rate, rho, momentum, centered)
+
+ # Validate updated params
+ if centered:
+ self.assertAllCloseAccordingToType(mg0_np, mg0.eval())
+ self.assertAllCloseAccordingToType(mg1_np, mg1.eval())
+ self.assertAllCloseAccordingToType(rms0_np, rms0.eval())
+ self.assertAllCloseAccordingToType(rms1_np, rms1.eval())
+ self.assertAllCloseAccordingToType(mom0_np, mom0.eval())
+ self.assertAllCloseAccordingToType(mom1_np, mom1.eval())
+ self.assertAllCloseAccordingToType(var0_np, var0.eval())
+ self.assertAllCloseAccordingToType(var1_np, var1.eval())
+
+ @parameterized.parameters(_DATA_TYPES)
+ def testWithoutMomentum(self, dtype):
+ with self.test_session(use_gpu=True):
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ opt = rmsprop.RMSProp(
+ learning_rate=2.0, rho=0.9, momentum=0.0, epsilon=1.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the rms accumulators where 1. So we should see a normal
+ # update: v -= grad * learning_rate
+ update.run()
+ # Check the root mean square accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901, 0.901]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001, 0.90001]), rms1.eval())
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901))
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001))
+ ]), var1.eval())
+ # Step 2: the root mean square accumulators contain the previous update.
+ update.run()
+ # Check the rms accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval())
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))
+ ]), var1.eval())
+
+ @parameterized.parameters(_DATA_TYPES)
+ def testWithMomentum(self, dtype):
+ with self.test_session(use_gpu=True):
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+
+ opt = rmsprop.RMSProp(
+ learning_rate=2.0, rho=0.9, momentum=0.5, epsilon=1.0)
+ update = opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ rms0 = opt.get_slot(var0, "rms")
+ self.assertIsNotNone(rms0)
+ rms1 = opt.get_slot(var1, "rms")
+ self.assertIsNotNone(rms1)
+ mom0 = opt.get_slot(var0, "momentum")
+ self.assertIsNotNone(mom0)
+ mom1 = opt.get_slot(var1, "momentum")
+ self.assertIsNotNone(mom1)
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: rms = 1, mom = 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ update.run()
+ # Check the root mean square accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901, 0.901]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001, 0.90001]), rms1.eval())
+ # Check the momentum accumulators
+ self.assertAllCloseAccordingToType(
+ np.array([(0.1 * 2.0 / math.sqrt(0.901)),
+ (0.1 * 2.0 / math.sqrt(0.901))]), mom0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([(0.01 * 2.0 / math.sqrt(0.90001)),
+ (0.01 * 2.0 / math.sqrt(0.90001))]), mom1.eval())
+
+ # Check that the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901))
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001))
+ ]), var1.eval())
+
+ # Step 2: the root mean square accumulators contain the previous update.
+ update.run()
+ # Check the rms accumulators.
+ self.assertAllCloseAccordingToType(
+ np.array([0.901 * 0.9 + 0.001, 0.901 * 0.9 + 0.001]), rms0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([0.90001 * 0.9 + 1e-5, 0.90001 * 0.9 + 1e-5]), rms1.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)),
+ 0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))
+ ]), mom0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)),
+ 0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))
+ ]), mom1.eval())
+
+ # Check the parameters.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001))),
+ 2.0 - (0.1 * 2.0 / math.sqrt(0.901)) -
+ (0.5 * (0.1 * 2.0 / math.sqrt(0.901)) +
+ (0.1 * 2.0 / math.sqrt(0.901 * 0.9 + 0.001)))
+ ]), var0.eval())
+
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 3.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5))),
+ 4.0 - (0.01 * 2.0 / math.sqrt(0.90001)) -
+ (0.5 * (0.01 * 2.0 / math.sqrt(0.90001)) +
+ (0.01 * 2.0 / math.sqrt(0.90001 * 0.9 + 1e-5)))
+ ]), var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/optimizer_v2/sgd.py b/tensorflow/python/keras/optimizer_v2/sgd.py
new file mode 100644
index 0000000000..f5583691f7
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/sgd.py
@@ -0,0 +1,170 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+"""Momentum for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import ops
+from tensorflow.python.keras.optimizer_v2 import optimizer_v2
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.training import training_ops
+
+
+class SGD(optimizer_v2.OptimizerV2):
+ """Stochastic gradient descent optimizer.
+
+ Includes support for momentum and Nesterov momentum.
+
+ Computes (if `nesterov = False`):
+
+ ```
+ accumulation = momentum * accumulation + gradient
+ variable -= learning_rate * accumulation
+ ```
+
+ Some of the args below are hyperparameters, where a hyperparameter is
+ defined as a scalar Tensor, a regular Python value, or a callable (which
+ will be evaluated when `apply_gradients` is called) returning a scalar
+ Tensor or a Python value.
+
+ Note that in the dense version of this algorithm, `accumulation` is updated
+ and applied regardless of a gradient's value, whereas the sparse version (when
+ the gradient is an `IndexedSlices`, typically because of `tf.gather` or an
+ embedding) only updates variable slices and corresponding `accumulation` terms
+ when that part of the variable was used in the forward pass.
+
+ @compatibility(eager)
+ When eager execution is enabled, learning_rate and momentum can each be a
+ callable that takes no arguments and returns the actual value to use. This
+ can be useful for changing these values across different invocations of
+ optimizer functions.
+ @end_compatibility
+
+ Arguments:
+ learning_rate: float hyperparameter >= 0. Learning rate.
+ momentum: float hyperparameter >= 0 or None. Parameter that accelerates
+ SGD in the relevant direction and dampens oscillations.
+ nesterov: boolean. Whether to apply Nesterov momentum. See [Sutskever et
+ al., 2013](http://jmlr.org/proceedings/papers/v28/sutskever13.pdf). This
+ implementation always computes gradients at the value of the
+ variable(s) passed to the optimizer. Using Nesterov Momentum makes the
+ variable(s) track the values called `theta_t + mu*v_t` in the paper.
+ name: Optional name prefix for the operations created when applying
+ gradients. Defaults to 'SGD'.
+ """
+
+ def __init__(self,
+ learning_rate=0.001,
+ momentum=None,
+ nesterov=False,
+ name="SGD"):
+ super(SGD, self).__init__(name)
+ self._set_hyper("learning_rate", learning_rate)
+ # Only create momentum variables and use momentum ops if needed.
+ if momentum is not None:
+ self._set_hyper("momentum", momentum)
+ self._use_nesterov = nesterov
+ self._use_momentum = True
+ else:
+ self._use_momentum = False
+
+ def _create_vars(self, var_list, state):
+ if self._use_momentum:
+ for v in var_list:
+ state.zeros_slot(v, "momentum")
+
+ def _apply_dense(self, grad, var, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.apply_momentum(
+ var,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov).op
+ else:
+ return training_ops.apply_gradient_descent(
+ var,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking).op
+
+ def _resource_apply_dense(self, grad, var, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.resource_apply_momentum(
+ var.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov)
+ else:
+ lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
+ return training_ops.resource_apply_gradient_descent(
+ var.handle, lr, grad, use_locking=self._use_locking)
+
+ def _apply_sparse(self, grad, var, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.sparse_apply_momentum(
+ var,
+ mom,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov).op
+ else:
+ return super(SGD, self)._apply_sparse(grad, var, state)
+
+ def _resource_apply_sparse(self, grad, var, indices, state):
+ if self._use_momentum:
+ mom = state.get_slot(var, "momentum")
+ return training_ops.resource_sparse_apply_momentum(
+ var.handle,
+ mom.handle,
+ state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad,
+ indices,
+ state.get_hyper("momentum", var.dtype.base_dtype),
+ use_locking=self._use_locking,
+ use_nesterov=self._use_nesterov)
+ else:
+ return super(SGD, self)._resource_apply_sparse(grad, var, indices, state)
+
+ def _resource_apply_sparse_duplicate_indices(self, grad, var, indices, state):
+ if self._use_momentum:
+ return super(SGD, self)._resource_apply_sparse_duplicate_indices(
+ grad, var, indices, state)
+ else:
+ lr = state.get_hyper("learning_rate", grad.dtype.base_dtype)
+ return resource_variable_ops.resource_scatter_add(var.handle, indices,
+ -grad * lr)
+
+ def _apply_sparse_duplicate_indices(self, grad, var, state):
+ if self._use_momentum:
+ return super(SGD, self)._apply_sparse_duplicate_indices(grad, var, state)
+ else:
+ delta = ops.IndexedSlices(
+ grad.values * state.get_hyper("learning_rate", var.dtype.base_dtype),
+ grad.indices, grad.dense_shape)
+ return var.scatter_sub(delta, use_locking=self._use_locking)
diff --git a/tensorflow/python/keras/optimizer_v2/sgd_test.py b/tensorflow/python/keras/optimizer_v2/sgd_test.py
new file mode 100644
index 0000000000..eb39aac283
--- /dev/null
+++ b/tensorflow/python/keras/optimizer_v2/sgd_test.py
@@ -0,0 +1,759 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for Momentum."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
+from tensorflow.python.keras.optimizer_v2 import sgd
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import resources
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class GradientDescentOptimizerTest(test.TestCase):
+
+ def testBasic(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ optimizer = sgd.SGD(3.0)
+ sgd_op = optimizer.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+ self.assertEqual(0, len(optimizer.variables()))
+
+ def testBasicResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ sgd_op = sgd.SGD(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ resources.initialize_resources([var0, var1]).run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+
+ def testMinimizeResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(var0, x) + var1
+ loss = pred * pred
+ sgd_op = sgd.SGD(1.0).minimize(loss)
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ resources.initialize_resources([var0, var1]).run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0
+ np_grad = 2 * np_pred
+ self.assertAllCloseAccordingToType(
+ [[1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - np_grad], var1.eval())
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ pred += var1
+ loss = pred * pred
+ sgd_op = sgd.SGD(1.0).minimize(loss)
+ # TODO(apassos) calling initialize_resources on all resources here
+ # doesn't work because the sessions and graph are reused across unit
+ # tests and this would mean trying to reinitialize variables. Figure out
+ # a long-term solution for this.
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0
+ np_grad = 2 * np_pred
+ self.assertAllCloseAccordingToType(
+ [[1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0]], var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - np_grad], var1.eval())
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ lrate = constant_op.constant(3.0)
+ sgd_op = sgd.SGD(lrate).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+
+ def testGradWrtRef(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ opt = sgd.SGD(3.0)
+ values = [1.0, 3.0]
+ vars_ = [variables.Variable([v], dtype=dtype) for v in values]
+ grads_and_vars = opt.compute_gradients(vars_[0] + vars_[1], vars_)
+ variables.global_variables_initializer().run()
+ for grad, _ in grads_and_vars:
+ self.assertAllCloseAccordingToType([1.0], grad.eval())
+
+ def testWithGlobalStep(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ global_step = variables.Variable(0, trainable=False)
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ sgd_op = sgd.SGD(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]), global_step=global_step)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([1.0, 2.0], var0.eval())
+ self.assertAllCloseAccordingToType([3.0, 4.0], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params and global_step
+ self.assertAllCloseAccordingToType([1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1],
+ var0.eval())
+ self.assertAllCloseAccordingToType([3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01],
+ var1.eval())
+ self.assertAllCloseAccordingToType(1, global_step.eval())
+
+ def testSparseBasic(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant([0.01], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
+ sgd_op = sgd.SGD(3.0).apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0], [2.0]], var0.eval())
+ self.assertAllCloseAccordingToType([[3.0], [4.0]], var1.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType([[1.0 - 3.0 * 0.1], [2.0]],
+ var0.eval())
+ self.assertAllCloseAccordingToType([[3.0], [4.0 - 3.0 * 0.01]],
+ var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
+
+
+class MomentumOptimizerTest(test.TestCase):
+
+ def _update_nesterov_momentum_numpy(self, var, accum, g, lr, momentum):
+ var = var + accum * lr * momentum
+ accum = accum * momentum + g
+ var = var - lr * accum
+ var = var - accum * lr * momentum
+ return var, accum
+
+ def doTestBasic(self, use_resource=False, use_callable_params=False):
+ for i, dtype in enumerate([dtypes.half, dtypes.float32, dtypes.float64]):
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype, name="var0_%d" % i)
+ var1 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtype, name="var1_%d" % i)
+ else:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ learning_rate = lambda: 2.0
+ momentum = lambda: 0.9
+ if not use_callable_params:
+ learning_rate = learning_rate()
+ momentum = momentum()
+ mom_opt = sgd.SGD(learning_rate=learning_rate, momentum=momentum)
+ mom_update = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+
+ if not context.executing_eagerly():
+ self.evaluate(variables.global_variables_initializer())
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], self.evaluate(var0))
+ self.assertAllClose([3.0, 4.0], self.evaluate(var1))
+
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ if not context.executing_eagerly():
+ self.assertFalse(slot0 in variables.trainable_variables())
+ self.assertFalse(slot1 in variables.trainable_variables())
+
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ if not context.executing_eagerly():
+ self.evaluate(mom_update)
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0.1, 0.1]),
+ self.evaluate(slot0))
+ self.assertAllCloseAccordingToType(np.array([0.01, 0.01]),
+ self.evaluate(slot1))
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]),
+ self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]),
+ self.evaluate(var1))
+ # Step 2: the momentum accumulators contain the previous update.
+ if context.executing_eagerly():
+ mom_opt.apply_gradients(zip([grads0, grads1], [var0, var1]))
+ else:
+ self.evaluate(mom_update)
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]),
+ self.evaluate(slot0))
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ self.evaluate(slot1))
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+ ]), self.evaluate(var0))
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), self.evaluate(var1))
+
+ def testBasic(self):
+ with self.cached_session():
+ self.doTestBasic(use_resource=False)
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testResourceBasic(self):
+ self.doTestBasic(use_resource=True)
+
+ def testBasicCallableParams(self):
+ with context.eager_mode():
+ self.doTestBasic(use_resource=True, use_callable_params=True)
+
+ def testVariablesAcrossGraphs(self):
+ optimizer = sgd.SGD(0.01, 0.5)
+ with ops.Graph().as_default():
+ var0 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtypes.float32, name="var0")
+ var1 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtypes.float32, name="var1")
+ loss = math_ops.reduce_sum(var0 + var1)
+ optimizer.minimize(loss)
+ optimizer_variables = optimizer.variables()
+ self.assertStartsWith(optimizer_variables[0].name, "var0")
+ self.assertStartsWith(optimizer_variables[1].name, "var1")
+ self.assertEquals(2, len(optimizer_variables))
+
+ with ops.Graph().as_default():
+ var2 = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtypes.float32, name="var2")
+ var3 = resource_variable_ops.ResourceVariable(
+ [3.0, 4.0], dtype=dtypes.float32, name="var3")
+ loss = math_ops.reduce_sum(var2 + var3)
+ optimizer.minimize(loss)
+ optimizer_variables = optimizer.variables()
+ self.assertStartsWith(optimizer_variables[0].name, "var2")
+ self.assertStartsWith(optimizer_variables[1].name, "var3")
+ self.assertEquals(2, len(optimizer_variables))
+
+ def testNesterovMomentum(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ cost = 5 * var0 * var0 + 3 * var1
+ global_step = variables.Variable(
+ array_ops.zeros([], dtypes.int64), name="global_step")
+ mom_op = sgd.SGD(learning_rate=2.0, momentum=0.9, nesterov=True)
+ opt_op = mom_op.minimize(cost, global_step, [var0, var1])
+ variables.global_variables_initializer().run()
+ for t in range(1, 5):
+ opt_op.run()
+ var0_np, accum0_np = self._update_nesterov_momentum_numpy(
+ var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
+ var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
+ accum1_np,
+ 3, 2.0, 0.9)
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+ def testSparseNesterovMomentum(self):
+ for dtype in [dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ grads = []
+ for t in range(1, 5):
+ grads.append(var0_np * 10)
+ var0_np, accum0_np = self._update_nesterov_momentum_numpy(
+ var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
+ var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
+ accum1_np,
+ 3, 2.0, 0.9)
+ var0_np = np.array([1.0, 2.0], dtype=dtype.as_numpy_dtype)
+ var1_np = np.array([3.0, 4.0], dtype=dtype.as_numpy_dtype)
+ accum0_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ accum1_np = np.array([0.0, 0.0], dtype=dtype.as_numpy_dtype)
+ var0 = variables.Variable(var0_np)
+ var1 = variables.Variable(var1_np)
+ loss = 5 * var0 * var0 + 3 * var1
+ mom_op = sgd.SGD(learning_rate=2.0, momentum=0.9, nesterov=True)
+ x_feed = array_ops.placeholder(dtype)
+ y_feed = ops.IndexedSlices(
+ x_feed, constant_op.constant([0, 1]), constant_op.constant([2]))
+ grads_and_vars = [(y_feed, var0), (constant_op.constant(
+ [3.0, 3.0], dtype=dtype), var1)]
+ opt_update = mom_op.apply_gradients(grads_and_vars)
+ variables.global_variables_initializer().run()
+ for t in range(1, 5):
+ opt_update.run(feed_dict={x_feed: grads[t - 1]})
+ var0_np, accum0_np = self._update_nesterov_momentum_numpy(
+ var0_np, accum0_np, var0_np * 10, 2.0, 0.9)
+ var1_np, accum1_np = self._update_nesterov_momentum_numpy(var1_np,
+ accum1_np,
+ 3, 2.0, 0.9)
+ self.assertAllClose(var0_np, var0.eval())
+ self.assertAllClose(var1_np, var1.eval())
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ # This test invokes the ResourceSparseApplyMomentum operation, which
+ # did not have a registered GPU kernel as of April 2018. With graph
+ # execution, the placement algorithm notices this and automatically
+ # places the variable in CPU (host) memory. With eager execution,
+ # the variable would be placed in GPU memory if available, which
+ # would then conflict with the future invocation of the
+ # ResourceSparseApplyMomentum operation.
+ # To work around this discrepancy, for now we force the variable
+ # to be placed on CPU.
+ with ops.device("/cpu:0"):
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+
+ # pylint: disable=cell-var-from-loop
+ def loss():
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ return pred * pred
+ # pylint: enable=cell-var-from-loop
+
+ opt = sgd.SGD(learning_rate=1.0, momentum=0.0)
+ sgd_op = opt.minimize(loss)
+ self.evaluate(variables.global_variables_initializer())
+ # Run 1 step of sgd
+ self.evaluate(sgd_op)
+ # Validate updated params
+ self.assertAllCloseAccordingToType([[-111, -138]], self.evaluate(var0))
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testMinimizeWith2DIndiciesForEmbeddingLookup(self):
+ # This test invokes the ResourceSparseApplyMomentum operation, which
+ # did not have a registered GPU kernel as of April 2018. With graph
+ # execution, the placement algorithm notices this and automatically
+ # places the variable in CPU (host) memory. With eager execution,
+ # the variable would be placed in GPU memory if available, which
+ # would then conflict with the future invocation of the
+ # ResourceSparseApplyMomentum operation.
+ # To work around this discrepancy, for now we force the variable
+ # to be placed on CPU.
+ with ops.device("/cpu:0"):
+ var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2]))
+
+ def loss():
+ return math_ops.reduce_sum(embedding_ops.embedding_lookup(var0, [[1]]))
+
+ opt = sgd.SGD(learning_rate=1.0, momentum=0.0)
+ sgd_op = opt.minimize(loss)
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(sgd_op)
+ self.assertAllCloseAccordingToType([[1, 1], [0, 0]], self.evaluate(var0))
+
+ def testTensorLearningRateAndMomentum(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ mom_opt = sgd.SGD(
+ learning_rate=constant_op.constant(2.0),
+ momentum=constant_op.constant(0.9))
+ mom_update = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ self.assertFalse(slot0 in variables.trainable_variables())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ self.assertFalse(slot1 in variables.trainable_variables())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
+ self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
+ # Step 2: the momentum accumulators contain the previous update.
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), var1.eval())
+
+ def _dbParamsMom01(self):
+ """Return dist-belief momentum values.
+
+ Return values been generated from the dist-belief momentum unittest,
+ running with a learning rate of 0.1 and a momentum of 0.1.
+
+ These values record how a parameter vector of size 10, initialized with 0.0,
+ gets updated with 10 consecutive momentum steps. It uses random gradients.
+
+ Returns:
+ db_grad: The gradients to apply
+ db_out: The parameters after the momentum update.
+ """
+ db_grad = [[]] * 10
+ db_out = [[]] * 10
+ # pylint: disable=line-too-long
+ db_grad[0] = [
+ 0.00096264342, 0.17914793, 0.93945462, 0.41396621, 0.53037018,
+ 0.93197989, 0.78648776, 0.50036013, 0.55345792, 0.96722615
+ ]
+ db_out[0] = [
+ -9.6264346e-05, -0.017914793, -0.093945466, -0.041396622, -0.053037018,
+ -0.093197994, -0.078648776, -0.050036013, -0.055345792, -0.096722618
+ ]
+ db_grad[1] = [
+ 0.17075552, 0.88821375, 0.20873757, 0.25236958, 0.57578111, 0.15312378,
+ 0.5513742, 0.94687688, 0.16012503, 0.22159521
+ ]
+ db_out[1] = [
+ -0.017181443, -0.10852765, -0.12421377, -0.070773244, -0.11591884,
+ -0.11783017, -0.14165108, -0.14972731, -0.076892875, -0.1285544
+ ]
+ db_grad[2] = [
+ 0.35077485, 0.47304362, 0.44412705, 0.44368884, 0.078527533, 0.81223965,
+ 0.31168157, 0.43203235, 0.16792089, 0.24644311
+ ]
+ db_out[2] = [
+ -0.053967446, -0.1648933, -0.1716533, -0.1180798, -0.13005978,
+ -0.20151734, -0.17911947, -0.20289968, -0.095839672, -0.15638189
+ ]
+ db_grad[3] = [
+ 0.9694621, 0.75035888, 0.28171822, 0.83813518, 0.53807181, 0.3728098,
+ 0.81454384, 0.03848977, 0.89759839, 0.93665648
+ ]
+ db_out[3] = [
+ -0.15459226, -0.24556576, -0.20456907, -0.20662397, -0.18528105,
+ -0.24716705, -0.2643207, -0.21206589, -0.18749419, -0.2528303
+ ]
+ db_grad[4] = [
+ 0.38578293, 0.8536852, 0.88722926, 0.66276771, 0.13678469, 0.94036359,
+ 0.69107032, 0.81897682, 0.5433259, 0.67860287
+ ]
+ db_out[4] = [
+ -0.20323303, -0.33900154, -0.29658359, -0.28175515, -0.20448165,
+ -0.34576839, -0.34194785, -0.29488021, -0.25099224, -0.33033544
+ ]
+ db_grad[5] = [
+ 0.27885768, 0.76100707, 0.24625534, 0.81354135, 0.18959245, 0.48038563,
+ 0.84163809, 0.41172323, 0.83259648, 0.44941229
+ ]
+ db_out[5] = [
+ -0.23598288, -0.42444581, -0.33041057, -0.3706224, -0.22536094,
+ -0.40366709, -0.43387437, -0.34433398, -0.34060168, -0.38302717
+ ]
+ db_grad[6] = [
+ 0.27233034, 0.056316052, 0.5039115, 0.24105175, 0.35697976, 0.75913221,
+ 0.73577434, 0.16014607, 0.57500273, 0.071136251
+ ]
+ db_out[6] = [
+ -0.26649091, -0.43862185, -0.38418442, -0.40361428, -0.26314685,
+ -0.48537019, -0.51664448, -0.36529395, -0.40706289, -0.39540997
+ ]
+ db_grad[7] = [
+ 0.58697265, 0.2494842, 0.08106143, 0.39954534, 0.15892942, 0.12683646,
+ 0.74053431, 0.16033, 0.66625422, 0.73515922
+ ]
+ db_out[7] = [
+ -0.32823896, -0.46498787, -0.39766794, -0.446868, -0.28281838,
+ -0.50622416, -0.59897494, -0.38342294, -0.48033443, -0.47016418
+ ]
+ db_grad[8] = [
+ 0.8215279, 0.41994119, 0.95172721, 0.68000203, 0.79439718, 0.43384039,
+ 0.55561525, 0.22567581, 0.93331909, 0.29438227
+ ]
+ db_out[8] = [
+ -0.41656655, -0.50961858, -0.49418902, -0.51919359, -0.36422527,
+ -0.55169362, -0.6627695, -0.40780342, -0.58099347, -0.50707781
+ ]
+ db_grad[9] = [
+ 0.68297005, 0.67758518, 0.1748755, 0.13266537, 0.70697063, 0.055731893,
+ 0.68593478, 0.50580865, 0.12602448, 0.093537711
+ ]
+ db_out[9] = [
+ -0.49369633, -0.58184016, -0.52132869, -0.5396927, -0.44306302,
+ -0.56181377, -0.73774242, -0.46082234, -0.60366184, -0.52012295
+ ]
+ # pylint: enable=line-too-long
+ return db_grad, db_out
+
+ def testLikeDistBeliefMom01(self):
+ with self.cached_session():
+ db_grad, db_out = self._dbParamsMom01()
+ num_samples = len(db_grad)
+ var0 = variables.Variable([0.0] * num_samples)
+ grads0 = constant_op.constant([0.0] * num_samples)
+ mom_opt = sgd.SGD(learning_rate=0.1, momentum=0.1)
+ mom_update = mom_opt.apply_gradients(zip([grads0], [var0]))
+ variables.global_variables_initializer().run()
+ for i in xrange(num_samples):
+ mom_update.run(feed_dict={grads0: db_grad[i]})
+ self.assertAllClose(np.array(db_out[i]), var0.eval())
+
+ def testSparse(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable(array_ops.zeros([4, 2], dtype=dtype))
+ var1 = variables.Variable(constant_op.constant(1.0, dtype, [4, 2]))
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(
+ [[.1, .1]], dtype=dtype),
+ constant_op.constant([1]),
+ constant_op.constant([4, 2]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant(
+ [[.01, .01], [.01, .01]], dtype=dtype),
+ constant_op.constant([2, 3]),
+ constant_op.constant([4, 2]))
+ mom_opt = sgd.SGD(learning_rate=2.0, momentum=0.9)
+ mom_update = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ # Check we have slots
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([0, 0], var0.eval()[0])
+ self.assertAllClose([0, 0], var0.eval()[1])
+ self.assertAllClose([1, 1], var1.eval()[2])
+
+ # Step 1: the momentum accumulators are 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0, 0]), slot0.eval()[0])
+ self.assertAllCloseAccordingToType(np.array([.1, .1]), slot0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([.01, .01]), slot1.eval()[2])
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(np.array([0, 0]), var0.eval()[0])
+ self.assertAllCloseAccordingToType(
+ np.array([-(0.1 * 2.0), -(0.1 * 2.0)]), var0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.01 * 2.0), 1.0 - (0.01 * 2.0)]), var1.eval()[2])
+ # Step 2: the momentum accumulators contain the previous update.
+ mom_update.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllClose(np.array([0, 0]), slot0.eval()[0])
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]),
+ slot1.eval()[2])
+ # Check that the parameters have been updated.
+ self.assertAllClose(np.array([0, 0]), var0.eval()[0])
+ self.assertAllCloseAccordingToType(
+ np.array([
+ -(0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0), -(0.1 * 2.0) - (
+ (0.9 * 0.1 + 0.1) * 2.0)
+ ]), var0.eval()[1])
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 0.98 - ((0.9 * 0.01 + 0.01) * 2.0), 0.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), var1.eval()[2])
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.cached_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ mom_opt = sgd.SGD(learning_rate=2.0, momentum=0.9)
+ mom_update1 = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ mom_update2 = mom_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+
+ self.assertEqual(["momentum"], mom_opt.get_slot_names())
+ slot0 = mom_opt.get_slot(var0, "momentum")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = mom_opt.get_slot(var1, "momentum")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Step 1: the momentum accumulators where 0. So we should see a normal
+ # update: v -= grad * learning_rate
+ mom_update1.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(np.array([0.1, 0.1]), slot0.eval())
+ self.assertAllCloseAccordingToType(np.array([0.01, 0.01]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([1.0 - (0.1 * 2.0), 2.0 - (0.1 * 2.0)]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([3.0 - (0.01 * 2.0), 4.0 - (0.01 * 2.0)]), var1.eval())
+ # Step 2: the second momentum accumulators contain the previous update.
+ mom_update2.run()
+ # Check that the momentum accumulators have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.1 + 0.1), (0.9 * 0.1 + 0.1)]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([(0.9 * 0.01 + 0.01), (0.9 * 0.01 + 0.01)]), slot1.eval())
+ # Check that the parameters have been updated.
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 1.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0),
+ 2.0 - (0.1 * 2.0) - ((0.9 * 0.1 + 0.1) * 2.0)
+ ]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([
+ 2.98 - ((0.9 * 0.01 + 0.01) * 2.0), 3.98 - (
+ (0.9 * 0.01 + 0.01) * 2.0)
+ ]), var1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/keras/preprocessing/image_test.py b/tensorflow/python/keras/preprocessing/image_test.py
index 362cbc1dc9..4abaadfcd3 100644
--- a/tensorflow/python/keras/preprocessing/image_test.py
+++ b/tensorflow/python/keras/preprocessing/image_test.py
@@ -94,43 +94,6 @@ class TestImage(test.TestCase):
self.assertEqual(x.shape[1:], images.shape[1:])
break
- def test_image_data_generator_with_validation_split(self):
- if PIL is None:
- return # Skip test if PIL is not available.
-
- for test_images in _generate_test_images():
- img_list = []
- for im in test_images:
- img_list.append(keras.preprocessing.image.img_to_array(im)[None, ...])
-
- images = np.vstack(img_list)
- generator = keras.preprocessing.image.ImageDataGenerator(
- validation_split=0.5)
- seq = generator.flow(
- images,
- np.arange(images.shape[0]),
- shuffle=False,
- batch_size=3,
- subset='validation')
- _, y = seq[0]
- self.assertEqual(list(y), [0, 1, 2])
- seq = generator.flow(
- images,
- np.arange(images.shape[0]),
- shuffle=False,
- batch_size=3,
- subset='training')
- _, y2 = seq[0]
- self.assertEqual(list(y2), [4, 5, 6])
-
- with self.assertRaises(ValueError):
- generator.flow(
- images,
- np.arange(images.shape[0]),
- shuffle=False,
- batch_size=3,
- subset='foo')
-
def test_image_data_generator_with_split_value_error(self):
with self.assertRaises(ValueError):
keras.preprocessing.image.ImageDataGenerator(validation_split=5)
diff --git a/tensorflow/python/keras/testing_utils.py b/tensorflow/python/keras/testing_utils.py
index 501b50ba5f..2fae094a1e 100644
--- a/tensorflow/python/keras/testing_utils.py
+++ b/tensorflow/python/keras/testing_utils.py
@@ -166,8 +166,9 @@ def layer_test(layer_cls, kwargs=None, input_shape=None, input_dtype=None,
if expected_dim is not None:
if expected_dim != actual_dim:
raise AssertionError(
- 'When testing layer %s, for input %s, found output_shape='
- '%s but expected to find %s.\nFull kwargs: %s' %
+ 'When testing layer %s **after deserialization**, '
+ 'for input %s, found output_shape='
+ '%s but expected to find inferred shape %s.\nFull kwargs: %s' %
(layer_cls.__name__,
x,
actual_output_shape,
diff --git a/tensorflow/python/keras/utils/conv_utils.py b/tensorflow/python/keras/utils/conv_utils.py
index 8ebca1418d..f486e631e5 100644
--- a/tensorflow/python/keras/utils/conv_utils.py
+++ b/tensorflow/python/keras/utils/conv_utils.py
@@ -137,26 +137,49 @@ def conv_input_length(output_length, filter_size, padding, stride):
return (output_length - 1) * stride - 2 * pad + filter_size
-def deconv_output_length(input_length, filter_size, padding, stride):
+def deconv_output_length(input_length, filter_size, padding,
+ output_padding=None, stride=0, dilation=1):
"""Determines output length of a transposed convolution given input length.
Arguments:
- input_length: integer.
- filter_size: integer.
- padding: one of "same", "valid", "full".
- stride: integer.
+ input_length: Integer.
+ filter_size: Integer.
+ padding: one of `"same"`, `"valid"`, `"full"`.
+ output_padding: Integer, amount of padding along the output dimension.
+ Can be set to `None` in which case the output length is inferred.
+ stride: Integer.
+ dilation: Integer.
Returns:
The output length (integer).
"""
+ assert padding in {'same', 'valid', 'full'}
if input_length is None:
return None
- input_length *= stride
- if padding == 'valid':
- input_length += max(filter_size - stride, 0)
- elif padding == 'full':
- input_length -= (stride + filter_size - 2)
- return input_length
+
+ # Get the dilated kernel size
+ filter_size = filter_size + (filter_size - 1) * (dilation - 1)
+
+ # Infer length if output padding is None, else compute the exact length
+ if output_padding is None:
+ if padding == 'valid':
+ length = input_length * stride + max(filter_size - stride, 0)
+ elif padding == 'full':
+ length = input_length * stride - (stride + filter_size - 2)
+ elif padding == 'same':
+ length = input_length * stride
+
+ else:
+ if padding == 'same':
+ pad = filter_size // 2
+ elif padding == 'valid':
+ pad = 0
+ elif padding == 'full':
+ pad = filter_size - 1
+
+ length = ((input_length - 1) * stride + filter_size - 2 * pad +
+ output_padding)
+ return length
def normalize_data_format(value):
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils.py b/tensorflow/python/keras/utils/multi_gpu_utils.py
index e1c49bc852..04b2ea8fe3 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils.py
@@ -244,9 +244,24 @@ def multi_gpu_model(model, gpus, cpu_merge=True, cpu_relocation=False):
for o in range(len(outputs)):
all_outputs[o].append(outputs[o])
+ # Deduplicate output names to handle Siamese networks.
+ occurrences = {}
+ for n in model.output_names:
+ if n not in occurrences:
+ occurrences[n] = 1
+ else:
+ occurrences[n] += 1
+ conflict_counter = {n: 0 for n, count in occurrences.items() if count > 1}
+ output_names = []
+ for n in model.output_names:
+ if n in conflict_counter:
+ conflict_counter[n] += 1
+ n += '_%d' % conflict_counter[n]
+ output_names.append(n)
+
# Merge outputs under expected scope.
with ops.device('/cpu:0' if cpu_merge else '/gpu:%d' % target_gpu_ids[0]):
merged = []
- for name, outputs in zip(model.output_names, all_outputs):
+ for name, outputs in zip(output_names, all_outputs):
merged.append(concatenate(outputs, axis=0, name=name))
return Model(model.inputs, merged)
diff --git a/tensorflow/python/keras/utils/multi_gpu_utils_test.py b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
index 3d0351a11f..1780ab6587 100644
--- a/tensorflow/python/keras/utils/multi_gpu_utils_test.py
+++ b/tensorflow/python/keras/utils/multi_gpu_utils_test.py
@@ -198,5 +198,31 @@ class TestMultiGPUModel(test.TestCase):
parallel_model.compile(loss='mean_squared_error', optimizer='adam')
parallel_model.train_on_batch(x, y)
+ def test_multi_gpu_with_siamese_network(self):
+ gpus = 2
+
+ if not check_if_compatible_devices(gpus=gpus):
+ return
+
+ with self.cached_session():
+ input_shape = (3,)
+ nested_model = keras.models.Sequential([
+ keras.layers.Dense(32, input_shape=input_shape),
+ keras.layers.Dense(1)
+ ], name='nested')
+
+ input1 = keras.Input(input_shape)
+ input2 = keras.Input(input_shape)
+ score1 = nested_model(input1)
+ score2 = nested_model(input2)
+ score_sum = keras.layers.Add(name='add')([score1, score2])
+
+ siamese = keras.models.Model(inputs=[input1, input2],
+ outputs=[score_sum, score1, score2],
+ name='siamese')
+ parallel_siamese = keras.utils.multi_gpu_model(siamese, gpus)
+ self.assertEqual(parallel_siamese.output_names,
+ ['add', 'nested_1', 'nested_2'])
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/utils/np_utils.py b/tensorflow/python/keras/utils/np_utils.py
index c24e87308b..3763999bff 100644
--- a/tensorflow/python/keras/utils/np_utils.py
+++ b/tensorflow/python/keras/utils/np_utils.py
@@ -22,7 +22,7 @@ from tensorflow.python.util.tf_export import tf_export
@tf_export('keras.utils.to_categorical')
-def to_categorical(y, num_classes=None):
+def to_categorical(y, num_classes=None, dtype='float32'):
"""Converts a class vector (integers) to binary class matrix.
E.g. for use with categorical_crossentropy.
@@ -31,6 +31,7 @@ def to_categorical(y, num_classes=None):
y: class vector to be converted into a matrix
(integers from 0 to num_classes).
num_classes: total number of classes.
+ dtype: The data type expected by the input. Default: `'float32'`.
Returns:
A binary matrix representation of the input. The classes axis is placed
@@ -44,7 +45,7 @@ def to_categorical(y, num_classes=None):
if not num_classes:
num_classes = np.max(y) + 1
n = y.shape[0]
- categorical = np.zeros((n, num_classes), dtype=np.float32)
+ categorical = np.zeros((n, num_classes), dtype=dtype)
categorical[np.arange(n), y] = 1
output_shape = input_shape + (num_classes,)
categorical = np.reshape(categorical, output_shape)
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 5183e4d30c..4e8639dfc8 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -76,6 +76,7 @@ tf_py_test(
name = "batch_gather_op_test",
srcs = ["batch_gather_op_test.py"],
additional_deps = [
+ "@absl_py//absl/testing:parameterized",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -1097,6 +1098,18 @@ tf_py_test(
],
)
+tf_py_test(
+ name = "unicode_script_op_test",
+ size = "small",
+ srcs = ["unicode_script_op_test.py"],
+ additional_deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:string_ops",
+ ],
+)
+
cuda_py_test(
name = "topk_op_test",
size = "small",
@@ -1468,7 +1481,7 @@ cuda_py_test(
name = "control_flow_ops_py_test",
# TODO(b/70473603): change this back to "small" once the C API is
# permanently enabled
- size = "medium",
+ size = "large",
srcs = ["control_flow_ops_py_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -1500,6 +1513,7 @@ cuda_py_test(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
+ "//tensorflow/python:while_v2",
],
)
@@ -2346,7 +2360,7 @@ cuda_py_test(
cuda_py_test(
name = "transpose_op_test",
- size = "large",
+ size = "medium",
srcs = ["transpose_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2354,10 +2368,11 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
],
- shard_count = 2,
+ shard_count = 10,
tags = [
"no_gpu",
"no_oss",
+ "optonly", # times out
],
)
@@ -2476,6 +2491,7 @@ cuda_py_test(
"//tensorflow/python:nn_grad",
"//tensorflow/python:nn_ops",
],
+ shard_count = 2,
tags = [
"optonly", # flaky timeouts unless optimized
],
@@ -2496,7 +2512,7 @@ cuda_py_test(
cuda_py_test(
name = "conv_ops_test",
- size = "large",
+ size = "medium",
srcs = ["conv_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2515,6 +2531,9 @@ cuda_py_test(
"//tensorflow/python:variables",
],
shard_count = 4,
+ tags = [
+ "optonly", # times out
+ ],
)
cuda_py_test(
@@ -2574,7 +2593,7 @@ cuda_py_test(
cuda_py_test(
name = "fft_ops_test",
- size = "large",
+ size = "medium",
srcs = ["fft_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2584,7 +2603,8 @@ cuda_py_test(
"//tensorflow/python:spectral_ops",
"//tensorflow/python:spectral_ops_test_util",
],
- shard_count = 3,
+ shard_count = 4,
+ tags = ["optonly"],
)
cuda_py_test(
@@ -2649,7 +2669,7 @@ cuda_py_test(
cuda_py_test(
name = "scatter_ops_test",
- size = "large", # NOTE: This is not run by default.
+ size = "medium", # NOTE: This is not run by default.
srcs = ["scatter_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2658,11 +2678,13 @@ cuda_py_test(
"//tensorflow/python:state_ops",
"//tensorflow/python:variables",
],
+ shard_count = 2,
+ tags = ["optonly"],
)
cuda_py_test(
name = "slice_op_test",
- size = "large",
+ size = "medium",
srcs = ["slice_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2978,6 +3000,10 @@ cuda_py_test(
"//tensorflow/python:math_ops",
],
shard_count = 20,
+ tags = [
+ "no_oss", # b/117185141.
+ "nomsan", # TODO(b/117236102): Re-enable in msan build.
+ ],
)
cuda_py_test(
@@ -2992,7 +3018,11 @@ cuda_py_test(
"//tensorflow/python:linalg_ops",
],
shard_count = 20,
- tags = ["no_windows_gpu"],
+ # TODO(b/117236102): Re-enable in msan build.
+ tags = [
+ "no_windows_gpu",
+ "nomsan",
+ ],
)
cuda_py_test(
@@ -3225,7 +3255,7 @@ tf_py_test(
tags = ["no_pip"],
)
-tf_py_test(
+cuda_py_test(
name = "cond_v2_test",
size = "medium",
srcs = ["cond_v2_test.py"],
@@ -3242,11 +3272,9 @@ tf_py_test(
"//tensorflow/python:training",
],
grpc_enabled = True,
- tags = ["no_gpu"], # TODO(b/111656070)
)
-# TODO(b/116053459): Replace with cuda_py_test.
-tf_py_test(
+cuda_py_test(
name = "while_v2_test",
size = "medium",
srcs = ["while_v2_test.py"],
@@ -3266,5 +3294,4 @@ tf_py_test(
"//tensorflow/python:while_v2",
],
grpc_enabled = True,
- tags = ["no_gpu"], # TODO(b/116053459)
)
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index 2fe85839d0..dcc594789e 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -615,6 +615,14 @@ class StridedSliceTest(test_util.TensorFlowTestCase):
_ = checker[:, 0]
_ = checker[:, :, 0]
+ def testBothNewAxisAndShrink(self):
+ with self.test_session(use_gpu=True):
+ ones = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int16)
+ self.assertAllEqual(
+ ones[array_ops.newaxis, :, 0].eval(
+ feed_dict={ones: [[1, 1], [1, 1]]}),
+ [[1, 1]])
+
def testTensorIndexing(self):
with self.test_session(use_gpu=True):
raw = [[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
@@ -1001,14 +1009,14 @@ class SliceAssignTest(test_util.TensorFlowTestCase):
errors.FailedPreconditionError,
"Attempting to use uninitialized value Variable"):
with self.cached_session() as sess:
- v = variables.Variable([1, 2])
+ v = variables.VariableV1([1, 2])
sess.run(v[:].assign([1, 2]))
def testTypeError(self):
init_val = constant_op.constant([1, 2], dtype=dtypes.int32)
too_small_val = constant_op.constant([3, 4], dtype=dtypes.int8)
too_large_val = constant_op.constant([3, 4], dtype=dtypes.int64)
- v = variables.Variable(init_val)
+ v = variables.VariableV1(init_val)
with self.assertRaises(TypeError):
v[:].assign(too_small_val)
with self.assertRaises(TypeError):
diff --git a/tensorflow/python/kernel_tests/batch_gather_op_test.py b/tensorflow/python/kernel_tests/batch_gather_op_test.py
index 7dd347989a..84e93b8136 100644
--- a/tensorflow/python/kernel_tests/batch_gather_op_test.py
+++ b/tensorflow/python/kernel_tests/batch_gather_op_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
import numpy as np
from tensorflow.python.framework import constant_op
@@ -29,7 +30,7 @@ _TEST_TYPES = (dtypes.int64, dtypes.float32,
dtypes.complex64, dtypes.complex128)
-class GatherTest(test.TestCase):
+class GatherTest(test.TestCase, parameterized.TestCase):
def _buildParams(self, data, dtype):
data = data.astype(dtype.as_numpy_dtype)
@@ -39,14 +40,15 @@ class GatherTest(test.TestCase):
return data + 10j * data
return data
- def testSimpleGather(self):
+ @parameterized.parameters(dtypes.int32, dtypes.int64)
+ def testSimpleGather(self, indices_dtype):
data = np.array([0, 1, 2, 3, 7, 5, 8, 9, 10, 11, 15, 13])
indices = [3, 4]
with self.test_session(use_gpu=True):
for dtype in _TEST_TYPES:
params_np = self._buildParams(data, dtype)
params = constant_op.constant(params_np)
- indices_tf = constant_op.constant(indices)
+ indices_tf = constant_op.constant(indices, dtype=indices_dtype)
gather_t = array_ops.batch_gather(params, indices_tf)
expected_result = np.array([3, 7])
np_val = self._buildParams(expected_result, dtype)
@@ -54,14 +56,15 @@ class GatherTest(test.TestCase):
self.assertAllEqual(np_val, gather_val)
self.assertEqual(np_val.shape, gather_t.get_shape())
- def test2DArray(self):
+ @parameterized.parameters(dtypes.int32, dtypes.int64)
+ def test2DArray(self, indices_dtype):
data = np.array([[0, 1, 2, 3, 7, 5], [8, 9, 10, 11, 15, 13]])
indices = [[3], [4]]
with self.test_session(use_gpu=True):
for dtype in _TEST_TYPES:
params_np = self._buildParams(data, dtype)
params = constant_op.constant(params_np)
- indices_tf = constant_op.constant(indices)
+ indices_tf = constant_op.constant(indices, dtype=indices_dtype)
gather_t = array_ops.batch_gather(params, indices_tf)
expected_result = np.array([[3], [15]])
np_val = self._buildParams(expected_result, dtype)
diff --git a/tensorflow/python/kernel_tests/bincount_op_test.py b/tensorflow/python/kernel_tests/bincount_op_test.py
index 8a58b3f97e..8177cdd454 100644
--- a/tensorflow/python/kernel_tests/bincount_op_test.py
+++ b/tensorflow/python/kernel_tests/bincount_op_test.py
@@ -22,6 +22,8 @@ import numpy as np
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import googletest
@@ -97,6 +99,22 @@ class BincountTest(test_util.TensorFlowTestCase):
with self.assertRaises(errors.InvalidArgumentError):
math_ops.bincount([1, 2, 3, -1, 6, 8]).eval()
+ def test_shape_function(self):
+ # size must be scalar.
+ with self.assertRaisesRegexp(
+ ValueError, "Shape must be rank 0 but is rank 1 for 'Bincount'"):
+ gen_math_ops.bincount([1, 2, 3, -1, 6, 8], [1], [])
+ # size must be positive.
+ with self.assertRaisesRegexp(ValueError, "must be non-negative"):
+ gen_math_ops.bincount([1, 2, 3, -1, 6, 8], -5, [])
+ # if size is a constant then the shape is known.
+ v1 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], 5, [])
+ self.assertAllEqual(v1.get_shape().as_list(), [5])
+ # if size is a placeholder then the shape is unknown.
+ s = array_ops.placeholder(dtype=dtypes.int32)
+ v2 = gen_math_ops.bincount([1, 2, 3, -1, 6, 8], s, [])
+ self.assertAllEqual(v2.get_shape().as_list(), [None])
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/kernel_tests/cond_v2_test.py b/tensorflow/python/kernel_tests/cond_v2_test.py
index 377c041675..ec875aae59 100644
--- a/tensorflow/python/kernel_tests/cond_v2_test.py
+++ b/tensorflow/python/kernel_tests/cond_v2_test.py
@@ -172,7 +172,7 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testNestedDefunInCond(self):
- self.skipTest("b/110550782")
+ self.skipTest("b/117284369")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -198,7 +198,7 @@ class CondV2Test(test.TestCase):
self._testCond(true_fn, false_fn, [y])
def testDoubleNestedDefunInCond(self):
- self.skipTest("b/110550782")
+ self.skipTest("b/117284369")
x = constant_op.constant(1.0, name="x")
y = constant_op.constant(2.0, name="y")
@@ -468,7 +468,6 @@ class CondV2Test(test.TestCase):
}), [5., 0.])
def testBuildCondAndGradientInsideDefun(self):
- self.skipTest("b/110550782")
def build_graph():
pred_outer = array_ops.placeholder(dtypes.bool, name="pred_outer")
@@ -502,29 +501,29 @@ class CondV2Test(test.TestCase):
return grads, pred_outer, pred_inner
- with ops.Graph().as_default():
+ with ops.Graph().as_default(), self.session(
+ graph=ops.get_default_graph()) as sess:
grads, pred_outer, pred_inner = build_graph()
- with self.session(graph=ops.get_default_graph()) as sess:
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: True,
- pred_inner: True
- }), [0., 0.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: True,
- pred_inner: False
- }), [0., 0.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: False,
- pred_inner: True
- }), [4., 2.])
- self.assertSequenceEqual(
- sess.run(grads, {
- pred_outer: False,
- pred_inner: False
- }), [5., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: True
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: True,
+ pred_inner: False
+ }), [0., 0.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: True
+ }), [4., 2.])
+ self.assertSequenceEqual(
+ sess.run(grads, {
+ pred_outer: False,
+ pred_inner: False
+ }), [5., 0.])
def testSecondDerivative(self):
with self.cached_session() as sess:
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index fc4d2a3809..baea5c0f6d 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import math
import time
-import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -32,6 +31,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
from tensorflow.python.eager import context
+from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -63,6 +63,7 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops import while_v2 # pylint: disable=unused-import
# pylint: disable=unused-import
import tensorflow.python.ops.tensor_array_grad
# pylint: enable=unused-import
@@ -125,12 +126,12 @@ def isum(s, maximum_iterations=None):
return r_s
-@test_util.with_cond_v2
+@test_util.with_control_flow_v2
class ControlFlowTest(test.TestCase):
def testRefIdentity(self):
with self.cached_session():
- v = variables.Variable(7)
+ v = variables.VariableV1(7)
v = control_flow_ops._Identity(v)
op = state_ops.assign(v, 9)
@@ -142,7 +143,7 @@ class ControlFlowTest(test.TestCase):
def testRefEnter(self):
with self.cached_session():
- v = variables.Variable(7)
+ v = variables.VariableV1(7)
enter_v = control_flow_ops._Enter(v, "foo_1", is_constant=True)
nine = constant_op.constant(9)
@@ -155,7 +156,7 @@ class ControlFlowTest(test.TestCase):
def testRefSwitch(self):
with self.cached_session():
- v = variables.Variable(7)
+ v = variables.VariableV1(7)
p = constant_op.constant(True)
v1 = control_flow_ops._SwitchRefOrTensor(v._ref(), p) # pylint: disable=protected-access
@@ -332,10 +333,8 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesOpError("has inputs from different frames"):
res.eval(feed_dict={data: 1.0})
+ @test_util.disable_control_flow_v2("b/113294340")
def testCondBool(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296297")
-
values = constant_op.constant(10)
fn1 = lambda: math_ops.add(values, 1)
fn2 = lambda: math_ops.subtract(values, 1)
@@ -351,6 +350,13 @@ class ControlFlowTest(test.TestCase):
grad = gradients_impl.gradients(y, [v])
self.assertAllEqual([None], grad)
+ def testCondOutputShape(self):
+ x = constant_op.constant(1.0)
+ b = control_flow_ops.cond(
+ constant_op.constant(True), lambda: math_ops.square(x),
+ lambda: math_ops.subtract(x, 1.))
+ self.assertEqual(b.shape, tensor_shape.scalar())
+
def testFetchable(self):
with self.cached_session() as sess:
x = array_ops.placeholder(dtypes.float32)
@@ -366,6 +372,7 @@ class ControlFlowTest(test.TestCase):
"has been marked as not fetchable"):
sess.run(t, feed_dict={x: 3})
+ @test_util.disable_control_flow_v2("Not relevant")
def testFeedable(self):
with self.cached_session() as sess:
c = constant_op.constant(2)
@@ -383,10 +390,8 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "may not be fed"):
sess.run(r, feed_dict={t: 3})
+ @test_util.disable_control_flow_v2("b/113296180 (IndexedSlices)")
def testCondIndexedSlices(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296180")
-
with self.cached_session():
values = constant_op.constant(10)
indices = constant_op.constant(0)
@@ -401,10 +406,8 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(11, val)
self.assertAllEqual(0, ind)
+ @test_util.disable_control_flow_v2("b/113296161 (SparseTensors)")
def testCondSparseTensor(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296161 (SparseTensors)")
-
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant(
@@ -435,10 +438,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
+ @test_util.disable_control_flow_v2("b/113293074")
def testCondIndexedSlicesDifferentTypes(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113293074")
-
with self.cached_session():
values = constant_op.constant(10)
i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
@@ -510,10 +511,8 @@ class ControlFlowTest(test.TestCase):
result = r.eval()
self.assertAllEqual(12, result)
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testCond_4(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v1 = variables.Variable(7)
v2 = variables.Variable(7)
@@ -587,10 +586,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)
self.assertAllEqual([2.0], r.eval())
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testCondWithControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/79881896")
-
with self.cached_session():
control_holder = array_ops.placeholder(dtypes.float32, shape=())
a = constant_op.constant(3)
@@ -629,10 +626,9 @@ class ControlFlowTest(test.TestCase):
merged_op = control_flow_ops.merge([assign_v, orig_v])
self.assertAllEqual([1.0], sess.run(merged_op.output))
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCondSwitchIdentity(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
# Make sure the recv identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
pred = constant_op.constant(True)
@@ -646,10 +642,9 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCondRecvIdentity(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
# Make sure the switch identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
with ops.device(test.gpu_device_name()):
@@ -666,11 +661,7 @@ class ControlFlowTest(test.TestCase):
sess.run(r)
def testCondGrad_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113346829 (gpu failure)")
-
- graph = ops.Graph()
- with graph.as_default():
+ with self.cached_session():
x = constant_op.constant(10.0, name="x")
pred = math_ops.less(1, 2)
fn1 = lambda: array_ops.identity(x)
@@ -678,8 +669,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
grad = gradients_impl.gradients(r, [x])[0]
- with self.cached_session():
- self.assertAllEqual(1.0, grad.eval())
+ self.assertAllEqual(1.0, grad.eval())
def testCondGrad_2(self):
with self.cached_session():
@@ -694,10 +684,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
+ @test_util.disable_control_flow_v2(
+ "b/110550782 (gradient w.r.t external variable)")
def testCondGrad_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/110550782 (gradient w.r.t external variable)")
-
with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
ox = constant_op.constant(10.0)
@@ -729,10 +718,8 @@ class ControlFlowTest(test.TestCase):
result = gradients_impl.gradients(z, x)[0]
self.assertEqual(1.0, result.eval())
+ @test_util.disable_control_flow_v2("b/113327884")
def testCondGrad_Gather(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113327884")
-
with self.cached_session() as sess:
v1 = variables.Variable([1.0, 42.0])
c = array_ops.placeholder(dtypes.int32, shape=[])
@@ -756,6 +743,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(dense_gv, [0.0, 2.0])
# Microbenchmark: 256,000 iterations/s.
+ @test_util.disable_control_flow_v2("b/116630618 (Times out)")
def testWhile_1(self):
with self.cached_session():
n = constant_op.constant(0)
@@ -764,6 +752,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testWhileExternalControlDependencies(self):
with self.cached_session():
v = variables.Variable(0.0)
@@ -779,6 +768,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(result.eval(), 2)
self.assertAllEqual(v.eval(), 1.0)
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testWhileExternalControlDependenciesNoInput(self):
with self.cached_session():
v = variables.Variable(0.0)
@@ -794,9 +784,10 @@ class ControlFlowTest(test.TestCase):
result.eval()
self.assertAllEqual(v.eval(), 1.0)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileWithRefs_1(self):
with self.cached_session() as sess:
- x = variables.Variable(0)._ref() # pylint: disable=protected-access
+ x = variables.VariableV1(0)._ref() # pylint: disable=protected-access
i = constant_op.constant(0)
c = lambda i, x: math_ops.less(i, 100)
@@ -824,18 +815,22 @@ class ControlFlowTest(test.TestCase):
r = isum(s)
self.assertAllEqual(45, r.eval())
+ @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
def testWhileWithMaximumIterations(self):
with self.cached_session():
s = constant_op.constant([1, 2, 3, 4, 5])
r = isum(s, maximum_iterations=3)
self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithMaximumIterationsAndSingleArgument(self):
with self.cached_session():
r = control_flow_ops.while_loop(
lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
self.assertEqual(1, r.eval())
+ @test_util.disable_control_flow_v2(
+ "b/116248044 (nested), b/115920078 (gradients)")
def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
@@ -861,6 +856,7 @@ class ControlFlowTest(test.TestCase):
# Should execute without issue.
self.assertEqual(3, self.evaluate(loop_execute))
+ @test_util.disable_control_flow_v2("b/116248044 (nested while_loop)")
def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
@@ -904,10 +900,8 @@ class ControlFlowTest(test.TestCase):
r"context '.*' \(currently defined in '.*'\)"):
_ = gradients_impl.gradients(loop_with_maxiter, v)
+ @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
-
v = constant_op.constant(1.0)
def create_while_loop():
@@ -939,6 +933,8 @@ class ControlFlowTest(test.TestCase):
r"while loop context '' \(currently defined in 'cond/.+'\)"):
_ = gradients_impl.gradients(loop, v)
+ @test_util.disable_control_flow_v2(
+ "b/116248044 (nesting), b/115776323 (max_iters)")
def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
v = constant_op.constant(1.0)
@@ -1072,6 +1068,7 @@ class ControlFlowTest(test.TestCase):
result = r[2].eval()
self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
+ @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)")
def testBufferForwarding(self):
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
@@ -1139,6 +1136,7 @@ class ControlFlowTest(test.TestCase):
r = r[1] * array_ops.ones([8, 8])
self.assertAllEqual(np.ones((8, 8)), r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithNonTensorInput_Scalar(self):
with self.cached_session():
n = 0
@@ -1147,6 +1145,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithNonTensorInput_Vector(self):
with self.cached_session():
n = np.array([0]) # Note, [0] would not work here; that is a list
@@ -1169,7 +1168,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(
c, b, [i, m],
[i.get_shape(), tensor_shape.TensorShape([None, 2])])
- self.assertTrue(r[1].get_shape()[0].value is None)
+ self.assertIsNone(r[1].get_shape()[0].value)
self.assertEqual(r[1].get_shape()[1], tensor_shape.Dimension(2))
with self.assertRaisesRegexp(
@@ -1180,6 +1179,7 @@ class ControlFlowTest(test.TestCase):
r"tf.while_loop to specify a less-specific shape."):
r = control_flow_ops.while_loop(c, b, [i, m])
+ @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
def testWhileShapeInferenceSparseTensor(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -1211,6 +1211,7 @@ class ControlFlowTest(test.TestCase):
c, b, [i, x],
[i.get_shape(), tensor_shape.TensorShape([5])])
+ @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
def testWhileShapeInferenceIndexedSlices(self):
with self.cached_session():
values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
@@ -1265,6 +1266,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n])
self.assertEqual(225, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhile_1(self):
self._testNestedWhile_1(use_gpu=False)
self._testNestedWhile_1(use_gpu=True)
@@ -1297,6 +1299,7 @@ class ControlFlowTest(test.TestCase):
outer_c, outer_b, [s0], parallel_iterations=1)
self.assertEqual(1048576.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhile_2(self):
self._testNestedWhile_2(use_gpu=False)
self._testNestedWhile_2(use_gpu=True)
@@ -1350,6 +1353,7 @@ class ControlFlowTest(test.TestCase):
lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0])
self.assertEqual(10, sess.run(r, {b: True}))
+ @test_util.disable_control_flow_v2("b/79881896 (control_deps)")
def testWhileWithControl_5(self):
with self.cached_session() as sess:
b = array_ops.placeholder(dtypes.bool)
@@ -1363,10 +1367,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(lambda x: x < 10, body, [x0])
self.assertEqual(10, sess.run(r, {b: True}))
+ @test_util.disable_control_flow_v2("b/116134862 (cond output shape)")
def testWhileCondWithControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
-
# Ensure that no control edges by an outer control dependency context are
# added to nodes inside cond/while contexts.
with self.cached_session() as sess:
@@ -1380,10 +1382,8 @@ class ControlFlowTest(test.TestCase):
(constant_op.constant(5),))
self.assertEqual(0, sess.run(loop))
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testWhileCondWithControl_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v = variable_scope.get_variable(
"v", [], initializer=init_ops.constant_initializer(2))
@@ -1405,9 +1405,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(4, r.eval())
self.assertAllClose(65536.0, v.eval())
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testWhileCondExitControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
v = variables.Variable(1)
@@ -1432,8 +1431,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(99, v.eval())
def testCondWhile_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1445,8 +1442,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testCondWhile_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -1458,9 +1453,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def _testCondWhile_3(self, use_gpu):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
-
with self.test_session(use_gpu=use_gpu) as sess:
p = array_ops.placeholder(dtypes.bool)
n = constant_op.constant(0.0)
@@ -1477,18 +1469,18 @@ class ControlFlowTest(test.TestCase):
lambda: control_flow_ops.while_loop(c, b, [n]),
lambda: math_ops.multiply(n, 2.0))
r1 = gradients_impl.gradients(r, [n])
- self.assertEqual(10, sess.run(r, {p: True}))
+ self.assertEqual(10., sess.run(r, {p: True}))
self.assertEqual([1.0], sess.run(r1, {p: True}))
self.assertEqual(0.0, sess.run(r, {p: False}))
self.assertEqual([2.0], sess.run(r1, {p: False}))
+ @test_util.disable_control_flow_v2("b/116743589")
def testCondWhile_3(self):
self._testCondWhile_3(use_gpu=False)
self._testCondWhile_3(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116134862 (cond output shape)")
def testWhileCond_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
i = ops.convert_to_tensor(0, name="i")
@@ -1504,9 +1496,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [i])
self.assertAllEqual(10, r.eval())
+ @test_util.disable_control_flow_v2("b/116134862 (cond output shape)")
def testWhileCond_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1515,9 +1506,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n])
self.assertAllEqual(10, r.eval())
+ @test_util.disable_control_flow_v2("b/116134862 (cond output shape)")
def testWhileCond_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -1532,6 +1522,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
# NOTE: It is ok to have parallel_iterations > 1
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_1(self):
with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
@@ -1554,6 +1545,7 @@ class ControlFlowTest(test.TestCase):
result = select.eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_2(self):
with self.cached_session():
select1 = variables.Variable([3.0, 4.0, 5.0])
@@ -1580,6 +1572,7 @@ class ControlFlowTest(test.TestCase):
result2 = select2.eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_3(self):
with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
@@ -1601,7 +1594,7 @@ class ControlFlowTest(test.TestCase):
result = r[1].eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
- # b/24814703
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_4(self):
with self.cached_session():
var_a = variables.Variable(0, name="a")
@@ -1629,7 +1622,7 @@ class ControlFlowTest(test.TestCase):
lpa.eval() # Run the loop
self.assertEqual(10, var_b.eval())
- # b/24736492
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_5(self):
with self.cached_session():
# Create some variables.
@@ -1659,7 +1652,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, var_a.eval())
self.assertEqual(10, var_b.eval())
- # b/24814668
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_6(self):
with self.cached_session():
# Create some variables.
@@ -1689,6 +1682,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(55, var_b.eval())
self.assertEqual(10, var_a.eval())
+ @test_util.disable_control_flow_v2("b/116742472 (resource accumulator)")
def testWhileQueue_1(self):
with self.cached_session():
q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
@@ -1707,6 +1701,7 @@ class ControlFlowTest(test.TestCase):
for i in xrange(10):
self.assertEqual([i], q.dequeue().eval())
+ @test_util.disable_control_flow_v2("b/117119329 (stack)")
def testWhileStack_1(self):
with self.cached_session():
s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
@@ -1775,6 +1770,7 @@ class ControlFlowTest(test.TestCase):
with self.session(graph=graph) as sess:
self.assertAllClose(1024.0, sess.run(r))
+ @test_util.disable_control_flow_v2("b/116351701 (colocation)")
def testWhileGrad_ColocateGradients(self):
self._testWhileGrad_ColocateGradients(colocate=False)
self._testWhileGrad_ColocateGradients(colocate=True)
@@ -1861,8 +1857,6 @@ class ControlFlowTest(test.TestCase):
self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
def _testNestedWhileCondWhileGrad(self, use_gpu):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.test_session(use_gpu=use_gpu):
v = constant_op.constant(1.0)
@@ -1885,10 +1879,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhileCondWhileGrad(self):
self._testNestedWhileCondWhileGrad(use_gpu=False)
self._testNestedWhileCondWhileGrad(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116823782")
def testWhileGrad_Variable(self):
with self.cached_session():
a = variables.Variable(3.0)
@@ -1902,8 +1898,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(216.0, r[0].eval())
def testWhileGradInCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/110550782 (gradient w.r.t external variable)")
with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
@@ -1919,6 +1913,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ @test_util.disable_control_flow_v2("b/116340060")
def testGradInWhileWrtInitialLoopVal(self):
with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
@@ -1936,6 +1931,7 @@ class ControlFlowTest(test.TestCase):
"loop invariants or wrt the input parameters to the loop body."):
control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testWhileGradInWhile(self):
with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
@@ -1952,9 +1948,8 @@ class ControlFlowTest(test.TestCase):
[tensor_shape.unknown_shape()])
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testCondGradInNestedWhiles(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113346829 (gpu failure)")
def outer_body(i, x):
_, x = control_flow_ops.while_loop(
@@ -1972,6 +1967,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(i_val, 3)
self.assertAllClose(x_val, 1.0)
+ @test_util.disable_control_flow_v2("b/116255781 (flat_args)")
def testWhile_NestedInput(self):
with self.cached_session() as sess:
named = collections.namedtuple("named", ("a", "b"))
@@ -1999,6 +1995,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0],
sess.run(r_flattened))
+ @test_util.disable_control_flow_v2("b/116255781(flat_args)")
def testWhile_NestedBadArityFails(self):
with self.cached_session():
named = collections.namedtuple("named", ("a", "b"))
@@ -2057,6 +2054,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients([rx], x)
self.assertAllClose(1024.0, r[0].eval())
+ @test_util.disable_control_flow_v2("b/116355153 (back_prop flag)")
def testWhileGrad_NoGradient(self):
with self.cached_session():
v = constant_op.constant(2.0, name="v")
@@ -2067,6 +2065,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)
self.assertAllClose(1.0, r[0].eval())
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGrad_NoDependency(self):
with self.cached_session() as sess:
variable = variables.Variable(array_ops.ones([2, 3]))
@@ -2180,10 +2179,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(8.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_Simple(self):
self._testNestedWhileGrad_Simple(use_gpu=False)
self._testNestedWhileGrad_Simple(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_SerialInner(self):
with self.cached_session():
v = constant_op.constant(1.0)
@@ -2207,6 +2208,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(256.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_ParallelInner(self):
with self.cached_session():
v = constant_op.constant(1.0)
@@ -2230,6 +2232,8 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
+ @test_util.disable_control_flow_v2(
+ "Nested loops and TensorArrays not supported")
def testNestedWhileGrad_ParallelIterations(self):
# Make sure the stack pushes and pops of an inner loop are executed in
# the sequential order of the iterations of its outer loop.
@@ -2268,13 +2272,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(1024.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116272044 (cond_in_while)")
def testWhileCondGrad_Simple(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
-
self._testWhileCondGrad_Simple(use_gpu=False)
self._testWhileCondGrad_Simple(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116272044 (cond_in_while)")
def testWhileCondGrad_UnknownShape(self):
with self.cached_session() as sess:
v = array_ops.placeholder(dtypes.float32)
@@ -2315,9 +2318,10 @@ class ControlFlowTest(test.TestCase):
sess.run(op)
self.assertAllClose([[0.98000002, 1.98000002]], sess.run(x))
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileWithRefsWithGradients_1(self):
with self.cached_session() as sess:
- x = variables.Variable(0.)._ref() # pylint: disable=protected-access
+ x = variables.VariableV1(0.)._ref() # pylint: disable=protected-access
i = constant_op.constant(0)
c = lambda i, x: math_ops.less(i, 10)
@@ -2329,7 +2333,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, body, [i, x], parallel_iterations=5)
- grad_ys = [variables.Variable(73)._ref()] # pylint: disable=protected-access
+ grad_ys = [variables.VariableV1(73)._ref()] # pylint: disable=protected-access
grad = gradients_impl.gradients([r[1]], [x], grad_ys=grad_ys)
variables.global_variables_initializer().run()
@@ -2343,6 +2347,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, value_x)
self.assertEqual(73, value_x_grad)
+ @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
def testWhileGrad_IndexedSlices(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -2364,6 +2369,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r.values, values)[0]
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+ @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
def testWhileGrad_SparseTensor(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -2386,6 +2392,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r.values, values)[0]
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+ @test_util.disable_control_flow_v2("b/115920078 (gradients)")
def testCallGradInLoop(self):
with self.cached_session() as sess:
i0 = constant_op.constant(0)
@@ -2405,6 +2412,8 @@ class ControlFlowTest(test.TestCase):
c, b, [i0, constant_op.constant(0.0)])
self.assertAllClose(600.0, sess.run(output_grad)[1])
+ @test_util.disable_control_flow_v2(
+ "b/116255781 (flat_args), b/115660901 (TensorArray)")
def testWhileAndTensorArray(self):
with self.cached_session() as sess:
param = constant_op.constant(2.0)
@@ -2509,6 +2518,7 @@ class ControlFlowTest(test.TestCase):
all_ops = x.graph.get_operations()
self.assertFalse(any([name in op.name for op in all_ops]))
+ @test_util.disable_control_flow_v2("b/116255781 (flat args)")
def testWhileGradGradFail(self):
theta = variables.Variable(initial_value=1.)
@@ -2538,6 +2548,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, y)[0]
self.assertEqual(388.0, r.eval())
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGradientWithNontrainablePath1(self):
q = variables.Variable([7., 8.])
@@ -2555,6 +2566,7 @@ class ControlFlowTest(test.TestCase):
sess.run(q.initializer)
self.assertAllClose([0., 0.], sess.run(dy_dq))
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGradientWithNontrainablePath2(self):
q = variables.Variable([7., 8.])
@@ -2572,6 +2584,7 @@ class ControlFlowTest(test.TestCase):
sess.run(q.initializer)
self.assertAllClose([1., 1.], sess.run(dy_dq))
+ @test_util.disable_control_flow_v2("b/115920078 (gradients)")
def testIssue16504(self):
c = constant_op.constant(np.arange(100), dtype=dtypes.float32)
w = variables.Variable(
@@ -2595,6 +2608,7 @@ class ControlFlowTest(test.TestCase):
grad, = gradients_impl.gradients(w, c)
self.assertIsNotNone(grad)
+ @test_util.disable_control_flow_v2("b/116270461 (resource)")
def testStopGradMultiFlows(self):
with self.cached_session():
@@ -2653,10 +2667,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCase(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
with self.cached_session():
x = constant_op.constant(1)
y = constant_op.constant(2)
@@ -2708,10 +2721,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r6.eval(), 0)
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCaseSideEffects(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
with self.cached_session() as sess:
v0 = variables.Variable(-1)
v1 = variables.Variable(-1)
@@ -2746,10 +2758,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, r0.eval())
self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1])
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testOneOpCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v = variables.Variable(0)
c = ops.convert_to_tensor(0)
@@ -2779,7 +2789,7 @@ class ControlFlowTest(test.TestCase):
def testWithOpsDependencies(self):
with self.cached_session() as sess:
- v = variables.Variable(0.0)
+ v = variables.VariableV1(0.0)
c = constant_op.constant(10)
# Fetching v directly will result in an uninitialized error
@@ -2802,7 +2812,7 @@ class ControlFlowTest(test.TestCase):
def testWithTensorDependencies(self):
with self.cached_session():
- v = variables.Variable(0.0)
+ v = variables.VariableV1(0.0)
c1 = constant_op.constant(10)
c2 = constant_op.constant(20)
@@ -2828,7 +2838,7 @@ class ControlFlowTest(test.TestCase):
def testWithIndexedSlicesDependencies(self):
with self.cached_session():
- v = variables.Variable(
+ v = variables.VariableV1(
np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(np.float32))
v_at_1 = ops.IndexedSlices(v, constant_op.constant([1]))
gather_v_at_1 = array_ops.gather(v_at_1.values, v_at_1.indices)
@@ -2851,18 +2861,18 @@ class ControlFlowTest(test.TestCase):
with ops.Graph().as_default():
# device set on tensor => same device on dep.
with ops.device("/job:ps"):
- vd = variables.Variable([0.0])
+ vd = variables.VariableV1([0.0])
with_vd_dep = control_flow_ops.with_dependencies([vd.initializer], vd)
self.assertTrue("/job:ps" in with_vd_dep.device)
# No device set on tensor => no device on dep.
- vnod = variables.Variable([0.0])
+ vnod = variables.VariableV1([0.0])
with_vnod_dep = control_flow_ops.with_dependencies([vnod.initializer],
vnod)
self.assertDeviceEqual(None, with_vnod_dep.device)
# device set on tensor, default device on graph => default device on dep.
- vdef = variables.Variable([0.0], name="vdef")
+ vdef = variables.VariableV1([0.0], name="vdef")
with ops.device("/job:worker/device:GPU:1"):
with_vdef_dep = control_flow_ops.with_dependencies([vdef.initializer],
vdef)
@@ -2872,8 +2882,8 @@ class ControlFlowTest(test.TestCase):
def testGroup(self):
with self.cached_session() as sess:
- v1 = variables.Variable([0.0])
- v2 = variables.Variable([1.0])
+ v1 = variables.VariableV1([0.0])
+ v2 = variables.VariableV1([1.0])
# Group init1 and init2 and run.
init = control_flow_ops.group(v1.initializer, v2.initializer)
@@ -2955,29 +2965,29 @@ class ControlFlowTest(test.TestCase):
p1 = array_ops.placeholder(dtypes.float32)
p2 = array_ops.placeholder(dtypes.float32)
p3 = array_ops.placeholder(dtypes.float32)
- v1 = variables.Variable(p1, validate_shape=False)
- v2 = variables.Variable(p2, validate_shape=False)
- v3 = variables.Variable(p3, validate_shape=False)
+ v1 = variables.VariableV1(p1, validate_shape=False)
+ v2 = variables.VariableV1(p2, validate_shape=False)
+ v3 = variables.VariableV1(p3, validate_shape=False)
self.assertIs(None, v1.get_shape().ndims)
s = control_flow_ops.ref_select(index, [v1, v2, v3])
self.assertIs(None, s.get_shape().ndims)
# All inputs known but different.
- v1 = variables.Variable([[1, 2]])
- v2 = variables.Variable([[2], [1]])
+ v1 = variables.VariableV1([[1, 2]])
+ v2 = variables.VariableV1([[2], [1]])
s = control_flow_ops.ref_select(index, [v1, v2])
self.assertIs(None, s.get_shape().ndims)
# All inputs known and same.
- v1 = variables.Variable([[1, 2]])
- v2 = variables.Variable([[1, 2]])
+ v1 = variables.VariableV1([[1, 2]])
+ v2 = variables.VariableV1([[1, 2]])
s = control_flow_ops.ref_select(index, [v1, v2])
self.assertEqual([1, 2], s.get_shape())
# Possibly the same but not guaranteed.
- v1 = variables.Variable([[1., 2.]])
+ v1 = variables.VariableV1([[1., 2.]])
p2 = array_ops.placeholder(dtypes.float32, shape=[None, 2])
- v2 = variables.Variable(p2, validate_shape=False)
+ v2 = variables.VariableV1(p2, validate_shape=False)
s = control_flow_ops.ref_select(index, [v1, v2])
self.assertEqual(None, s.get_shape())
@@ -3031,9 +3041,11 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, x)[0]
self.assertEqual(r.eval(), 524288.0)
- self.assertEqual(
- len([op for op in x.graph.get_operations() if op.type == "StackV2"]),
- 1)
+ # while_v2 does not have stacks.
+ if not control_flow_ops.ENABLE_WHILE_V2:
+ self.assertEqual(
+ len([op for op in x.graph.get_operations() if op.type == "StackV2"
+ ]), 1)
class ControlFlowContextCheckTest(test.TestCase):
@@ -3160,11 +3172,11 @@ class TupleTest(test.TestCase):
def testTensors(self):
for v1_first in [True, False]:
with self.cached_session():
- v1 = variables.Variable([1.0])
+ v1 = variables.VariableV1([1.0])
add1 = math_ops.add(
control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access
2.0)
- v2 = variables.Variable([10.0])
+ v2 = variables.VariableV1([10.0])
add2 = math_ops.add(
control_flow_ops.with_dependencies([v2.initializer], v2._ref()), # pylint: disable=protected-access
20.0)
@@ -3190,14 +3202,14 @@ class TupleTest(test.TestCase):
def testIndexedSlices(self):
for v1_first in [True, False]:
with self.cached_session():
- v1 = variables.Variable(
+ v1 = variables.VariableV1(
np.array([[0.0, 1.0], [10.0, 11.0], [20.0, 21.0]]).astype(
np.float32))
v1_at_1 = ops.IndexedSlices(
control_flow_ops.with_dependencies([v1.initializer], v1._ref()), # pylint: disable=protected-access
constant_op.constant([1]))
- v2 = variables.Variable(
+ v2 = variables.VariableV1(
np.array([[0.1, 1.1], [10.1, 11.1], [20.1, 21.1]]).astype(
np.float32))
v2_at_1 = ops.IndexedSlices(
@@ -3229,7 +3241,7 @@ class TupleTest(test.TestCase):
def testAcceptTensorsAsControlInputs(self):
with self.cached_session():
- var = variables.Variable(0)
+ var = variables.VariableV1(0)
assign = state_ops.assign(var, 1)
t, = control_flow_ops.tuple(
[constant_op.constant(0)], control_inputs=[assign])
@@ -3393,7 +3405,7 @@ class WhileOpBenchmark(test.Benchmark):
name="unroll_same_device", iters=iters, wall_time=duration)
-@test_util.with_cond_v2
+@test_util.with_control_flow_v2
class EagerTest(test.TestCase):
def testCond(self):
@@ -3406,6 +3418,25 @@ class EagerTest(test.TestCase):
self.assertAllEqual(r.numpy(), 10)
self.assertFalse(isinstance(r, list))
+ # TODO(b/117279927): Re-enable once msan failure is fixed.
+ def DISABLED_testCondInDefun(self):
+ with context.eager_mode():
+
+ @eager_function.defun
+ def foo(pred):
+ # TODO(b/111124878): this only needs to output one element.
+ fn1 = lambda: (constant_op.constant(10), constant_op.constant(100))
+ fn2 = lambda: (constant_op.constant(20), constant_op.constant(200))
+ return control_flow_ops.cond(constant_op.constant(pred), fn1, fn2)
+
+ r = foo(True)
+ self.assertAllEqual(r[0].numpy(), 10)
+ self.assertNotIsInstance(r, list)
+
+ r = foo(False)
+ self.assertAllEqual(r[0].numpy(), 20)
+ self.assertFalse(isinstance(r, list))
+
def testWhileLoop(self):
with context.eager_mode():
tensor = constant_op.constant([1, 2, 3, 4, 5])
diff --git a/tensorflow/python/kernel_tests/dense_update_ops_test.py b/tensorflow/python/kernel_tests/dense_update_ops_test.py
index 06c3271850..120e10314f 100644
--- a/tensorflow/python/kernel_tests/dense_update_ops_test.py
+++ b/tensorflow/python/kernel_tests/dense_update_ops_test.py
@@ -87,7 +87,7 @@ class AssignOpTest(test.TestCase):
def testAssignNonStrictShapeChecking(self):
with self.cached_session():
data = array_ops.fill([1024, 1024], 0)
- p = variables.Variable([1])
+ p = variables.VariableV1([1])
a = state_ops.assign(p, data, validate_shape=False)
a.op.run()
self.assertAllEqual(p.eval(), data.eval())
@@ -100,14 +100,14 @@ class AssignOpTest(test.TestCase):
def testInitRequiredAssignAdd(self):
with self.cached_session():
- p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
+ p = variables.VariableV1(array_ops.fill([1024, 1024], 1), dtypes.int32)
a = state_ops.assign_add(p, array_ops.fill([1024, 1024], 0))
with self.assertRaisesOpError("use uninitialized"):
a.op.run()
def testInitRequiredAssignSub(self):
with self.cached_session():
- p = variables.Variable(array_ops.fill([1024, 1024], 1), dtypes.int32)
+ p = variables.VariableV1(array_ops.fill([1024, 1024], 1), dtypes.int32)
a = state_ops.assign_sub(p, array_ops.fill([1024, 1024], 0))
with self.assertRaisesOpError("use uninitialized"):
a.op.run()
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 6d1ead20be..737a73f97a 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -131,8 +131,8 @@ class DepthwiseConv2DTest(test.TestCase):
with self.session(graph=graph, use_gpu=use_gpu) as sess:
tolerance = {
dtypes.float16: 4e-2,
- dtypes.float32: 1e-8,
- dtypes.float64: 1e-13,
+ dtypes.float32: 1e-5,
+ dtypes.float64: 1e-12,
}[data_type]
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=data_type)
diff --git a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
index 26d013bccb..37b35ba51a 100644
--- a/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
+++ b/tensorflow/python/kernel_tests/distributions/bernoulli_test.py
@@ -118,7 +118,9 @@ class BernoulliTest(test.TestCase):
self.assertEqual(dist.probs.dtype, dist.stddev().dtype)
self.assertEqual(dist.probs.dtype, dist.entropy().dtype)
self.assertEqual(dist.probs.dtype, dist.prob(0).dtype)
+ self.assertEqual(dist.probs.dtype, dist.prob(0.5).dtype)
self.assertEqual(dist.probs.dtype, dist.log_prob(0).dtype)
+ self.assertEqual(dist.probs.dtype, dist.log_prob(0.5).dtype)
dist64 = make_bernoulli([], dtypes.int64)
self.assertEqual(dist64.dtype, dtypes.int64)
@@ -181,6 +183,16 @@ class BernoulliTest(test.TestCase):
return
self._testPmf(logits=special.logit(p))
+ @test_util.run_in_graph_and_eager_modes
+ def testPmfWithFloatArgReturnsXEntropy(self):
+ p = [[0.2], [0.4], [0.3], [0.6]]
+ samps = [0, 0.1, 0.8]
+ self.assertAllClose(
+ np.float32(samps) * np.log(np.float32(p)) +
+ (1 - np.float32(samps)) * np.log(1 - np.float32(p)),
+ self.evaluate(
+ bernoulli.Bernoulli(probs=p, validate_args=False).log_prob(samps)))
+
def testBroadcasting(self):
with self.cached_session():
p = array_ops.placeholder(dtypes.float32)
diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py
index d580a415dd..42e81bd658 100644
--- a/tensorflow/python/kernel_tests/distributions/beta_test.py
+++ b/tensorflow/python/kernel_tests/distributions/beta_test.py
@@ -167,6 +167,11 @@ class BetaTest(test.TestCase):
self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
self.assertEqual((2, 2), pdf.get_shape())
+ def testLogPdfOnBoundaryIsFiniteWhenAlphaIsOne(self):
+ b = [[0.01, 0.1, 1., 2], [5., 10., 2., 3]]
+ pdf = self.evaluate(beta_lib.Beta(1., b).prob(0.))
+ self.assertAllEqual(np.ones_like(pdf, dtype=np.bool), np.isfinite(pdf))
+
def testBetaMean(self):
a = [1., 2, 3]
b = [2., 4, 1.2]
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
index cace5b3ba2..0f96382453 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
@@ -83,6 +83,23 @@ class DirichletTest(test.TestCase):
with self.assertRaisesOpError("sample last-dimension must sum to `1`"):
self.evaluate(dist.prob([.1, .2, .8]))
+ def testLogPdfOnBoundaryIsFiniteWhenAlphaIsOne(self):
+ # Test concentration = 1. for each dimension.
+ concentration = 3 * np.ones((10, 10)).astype(np.float32)
+ concentration[range(10), range(10)] = 1.
+ x = 1 / 9. * np.ones((10, 10)).astype(np.float32)
+ x[range(10), range(10)] = 0.
+ dist = dirichlet_lib.Dirichlet(concentration)
+ log_prob = self.evaluate(dist.log_prob(x))
+ self.assertAllEqual(
+ np.ones_like(log_prob, dtype=np.bool), np.isfinite(log_prob))
+
+ # Test when concentration[k] = 1., and x is zero at various dimensions.
+ dist = dirichlet_lib.Dirichlet(10 * [1.])
+ log_prob = self.evaluate(dist.log_prob(x))
+ self.assertAllEqual(
+ np.ones_like(log_prob, dtype=np.bool), np.isfinite(log_prob))
+
def testPdfZeroBatches(self):
alpha = [1., 2]
x = [.5, .5]
diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py
index 27d1291912..1600387585 100644
--- a/tensorflow/python/kernel_tests/distributions/exponential_test.py
+++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py
@@ -65,6 +65,13 @@ class ExponentialTest(test.TestCase):
self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+ def testExponentialLogPDFBoundary(self):
+ # Check that Log PDF is finite at 0.
+ rate = np.array([0.1, 0.5, 1., 2., 5., 10.], dtype=np.float32)
+ exponential = exponential_lib.Exponential(rate=rate)
+ log_pdf = exponential.log_prob(0.)
+ self.assertAllClose(np.log(rate), self.evaluate(log_pdf))
+
def testExponentialCDF(self):
batch_size = 6
lam = constant_op.constant([2.0] * batch_size)
@@ -81,6 +88,22 @@ class ExponentialTest(test.TestCase):
expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
self.assertAllClose(self.evaluate(cdf), expected_cdf)
+ def testExponentialLogSurvival(self):
+ batch_size = 7
+ lam = constant_op.constant([2.0] * batch_size)
+ lam_v = 2.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0, 10.0], dtype=np.float32)
+
+ exponential = exponential_lib.Exponential(rate=lam)
+
+ log_survival = exponential.log_survival_function(x)
+ self.assertEqual(log_survival.get_shape(), (7,))
+
+ if not stats:
+ return
+ expected_log_survival = stats.expon.logsf(x, scale=1 / lam_v)
+ self.assertAllClose(self.evaluate(log_survival), expected_log_survival)
+
def testExponentialMean(self):
lam_v = np.array([1.0, 4.0, 2.5])
exponential = exponential_lib.Exponential(rate=lam_v)
diff --git a/tensorflow/python/kernel_tests/distributions/gamma_test.py b/tensorflow/python/kernel_tests/distributions/gamma_test.py
index 4eff40b029..4c5b9c3ea3 100644
--- a/tensorflow/python/kernel_tests/distributions/gamma_test.py
+++ b/tensorflow/python/kernel_tests/distributions/gamma_test.py
@@ -77,6 +77,14 @@ class GammaTest(test.TestCase):
self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+ def testGammaLogPDFBoundary(self):
+ # When concentration = 1, we have an exponential distribution. Check that at
+ # 0 we have finite log prob.
+ rate = np.array([0.1, 0.5, 1., 2., 5., 10.], dtype=np.float32)
+ gamma = gamma_lib.Gamma(concentration=1., rate=rate)
+ log_pdf = gamma.log_prob(0.)
+ self.assertAllClose(np.log(rate), self.evaluate(log_pdf))
+
def testGammaLogPDFMultidimensional(self):
batch_size = 6
alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
diff --git a/tensorflow/python/kernel_tests/distributions/laplace_test.py b/tensorflow/python/kernel_tests/distributions/laplace_test.py
index 630c2cb424..2610ba23b8 100644
--- a/tensorflow/python/kernel_tests/distributions/laplace_test.py
+++ b/tensorflow/python/kernel_tests/distributions/laplace_test.py
@@ -275,8 +275,8 @@ class LaplaceTest(test.TestCase):
self.assertAllClose(
sample_values.var(axis=0),
stats.laplace.var(loc_bc, scale=scale_bc),
- rtol=0.10,
- atol=0.)
+ rtol=0.105,
+ atol=0.0)
fails = 0
trials = 0
for ai, a in enumerate(np.reshape(loc_v, [-1])):
diff --git a/tensorflow/python/kernel_tests/distributions/normal_test.py b/tensorflow/python/kernel_tests/distributions/normal_test.py
index de73a40b23..6625a88843 100644
--- a/tensorflow/python/kernel_tests/distributions/normal_test.py
+++ b/tensorflow/python/kernel_tests/distributions/normal_test.py
@@ -78,6 +78,14 @@ class NormalTest(test.TestCase):
self.assertEqual(expected, sigma_shape)
@test_util.run_in_graph_and_eager_modes
+ def testSampleLikeArgsGetDistDType(self):
+ dist = normal_lib.Normal(0., 1.)
+ self.assertEqual(dtypes.float32, dist.dtype)
+ for method in ("log_prob", "prob", "log_cdf", "cdf",
+ "log_survival_function", "survival_function", "quantile"):
+ self.assertEqual(dtypes.float32, getattr(dist, method)(1).dtype)
+
+ @test_util.run_in_graph_and_eager_modes
def testParamShapes(self):
sample_shape = [10, 3, 4]
self._testParamShapes(sample_shape, sample_shape)
diff --git a/tensorflow/python/kernel_tests/identity_op_py_test.py b/tensorflow/python/kernel_tests/identity_op_py_test.py
index 37f9f716f8..88ea10c22a 100644
--- a/tensorflow/python/kernel_tests/identity_op_py_test.py
+++ b/tensorflow/python/kernel_tests/identity_op_py_test.py
@@ -61,7 +61,7 @@ class IdentityOpTest(test.TestCase):
def testRefIdentityShape(self):
with self.cached_session():
shape = [2, 3]
- tensor = variables.Variable(
+ tensor = variables.VariableV1(
constant_op.constant(
[[1, 2, 3], [6, 5, 4]], dtype=dtypes.int32))
self.assertEquals(shape, tensor.get_shape())
diff --git a/tensorflow/python/kernel_tests/list_ops_test.py b/tensorflow/python/kernel_tests/list_ops_test.py
index 0f5607712b..ae413edaec 100644
--- a/tensorflow/python/kernel_tests/list_ops_test.py
+++ b/tensorflow/python/kernel_tests/list_ops_test.py
@@ -170,6 +170,32 @@ class ListOpsTest(test_util.TensorFlowTestCase):
list_ops.tensor_list_pop_back(
l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
+ @test_util.run_in_graph_and_eager_modes
+ def testCPUGPUCopyNested(self):
+ if not context.num_gpus():
+ return
+ t = constant_op.constant([1.0, 2.0])
+ child_l = list_ops.tensor_list_from_tensor(t, element_shape=scalar_shape())
+ l = list_ops.empty_tensor_list(
+ element_shape=constant_op.constant([], dtype=dtypes.int32),
+ element_dtype=dtypes.variant)
+ l = list_ops.tensor_list_push_back(l, child_l)
+ with context.device("gpu:0"):
+ l_gpu = array_ops.identity(l)
+ _, child_l_gpu = list_ops.tensor_list_pop_back(
+ l_gpu, element_dtype=dtypes.variant)
+ self.assertAllEqual(
+ self.evaluate(
+ list_ops.tensor_list_pop_back(
+ child_l_gpu, element_dtype=dtypes.float32)[1]), 2.0)
+ l_cpu = array_ops.identity(l_gpu)
+ _, child_l_cpu = list_ops.tensor_list_pop_back(
+ l_cpu, element_dtype=dtypes.variant)
+ self.assertAllEqual(
+ self.evaluate(
+ list_ops.tensor_list_pop_back(
+ child_l_cpu, element_dtype=dtypes.float32)[1]), 2.0)
+
def testGraphStack(self):
with self.cached_session():
tl = list_ops.empty_tensor_list(
diff --git a/tensorflow/python/kernel_tests/logging_ops_test.py b/tensorflow/python/kernel_tests/logging_ops_test.py
index 4beddd00bb..2f19ecc0e6 100644
--- a/tensorflow/python/kernel_tests/logging_ops_test.py
+++ b/tensorflow/python/kernel_tests/logging_ops_test.py
@@ -306,6 +306,19 @@ class PrintV2Test(test.TestCase):
logging_ops.print_v2(tensor)
self.assertTrue((expected + "\n") in printed.contents())
+ def testPrintsOrderedInDefun(self):
+ with context.eager_mode():
+
+ @function.defun
+ def prints():
+ logging_ops.print_v2("A")
+ logging_ops.print_v2("B")
+ logging_ops.print_v2("C")
+
+ with self.captureWritesToStream(sys.stderr) as printed:
+ prints()
+ self.assertTrue(("A\nB\nC\n") in printed.contents())
+
@test_util.run_in_graph_and_eager_modes()
def testPrintInDefunWithoutExplicitEvalOfPrint(self):
@function.defun
diff --git a/tensorflow/python/kernel_tests/relu_op_test.py b/tensorflow/python/kernel_tests/relu_op_test.py
index b26e944af8..672d6556f5 100644
--- a/tensorflow/python/kernel_tests/relu_op_test.py
+++ b/tensorflow/python/kernel_tests/relu_op_test.py
@@ -293,8 +293,9 @@ class LeakyReluTest(test.TestCase):
np.array([[-0.09, 0.7, -0.05, 0.3, -0.01],
[0.1, -0.03, 0.5, -0.07, 0.9]]),
self._npLeakyRelu(
- np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7, 0.9]
- ]), alpha=0.1))
+ np.array([[-0.9, 0.7, -0.5, 0.3, -0.1], [0.1, -0.3, 0.5, -0.7,
+ 0.9]]),
+ alpha=0.1))
def _testLeakyRelu(self, np_features, alpha, use_gpu=False):
np_leaky_relu = self._npLeakyRelu(np_features, alpha)
@@ -308,11 +309,13 @@ class LeakyReluTest(test.TestCase):
for t in [np.int32, np.int64, np.float16, np.float32, np.float64]:
self._testLeakyRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
- alpha=0.2, use_gpu=False)
+ alpha=0.2,
+ use_gpu=False)
if t in [np.float16, np.float32, np.float64]:
self._testLeakyRelu(
np.array([[-9, 7, -5, 3, -1], [1, -3, 5, -7, 9]]).astype(t),
- alpha=0.1, use_gpu=True)
+ alpha=0.1,
+ use_gpu=True)
# The gradient test for Leaky ReLU is a bit tricky as the derivative is not
# well defined at around zero and we want to avoid that in terms of input
diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
index f90545f84c..a9fd93e9f8 100644
--- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py
+++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py
@@ -142,7 +142,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
v = resource_variable_ops.ResourceVariable(1.0)
ops.reset_default_graph()
v.assign(2.0) # Note: this fails if we run convert_to_tensor on not the
- # variable graph.
+ # variable graph.
def testFetchHandle(self):
with self.cached_session():
@@ -290,7 +290,7 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
self.assertEqual(self.evaluate(read), [[2]])
def testUseResource(self):
- v = variables.Variable(1.0, use_resource=True)
+ v = variables.VariableV1(1.0, use_resource=True)
self.assertTrue(isinstance(v, resource_variable_ops.ResourceVariable))
def testEagerNoUseResource(self):
@@ -908,6 +908,13 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(Exception, r"shape.*2.*3"):
state_ops.scatter_update(v, [0, 1], [0, 1, 2])
+ @test_util.run_in_graph_and_eager_modes
+ def testAssignIncompatibleShape(self):
+ v = resource_variable_ops.ResourceVariable([0, 1, 2, 3])
+ self.evaluate(v.initializer)
+ with self.assertRaisesRegexp(Exception, r"hapes must be equal"):
+ self.assertAllEqual(self.evaluate(v.assign_add(1)), [1, 2, 3, 4])
+
class _MixedPrecisionVariableTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 05ad9f6336..2f6963f6b8 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -535,6 +535,45 @@ class RNNTest(test.TestCase):
self.assertAllClose(tf_out, k_out)
self.assertAllClose(tf_state, k_state)
+ def testSimpleRNNCellAndBasicRNNCellComparison(self):
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 20
+ (x_train, _), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ fix_weights_generator = keras.layers.SimpleRNNCell(output_shape)
+ fix_weights_generator.build((None, input_shape))
+ # The SimpleRNNCell contains 3 weights: kernel, recurrent_kernel, and bias
+ # The BasicRNNCell contains 2 weight: kernel and bias, where kernel is
+ # zipped [kernel, recurrent_kernel] in SimpleRNNCell.
+ keras_weights = fix_weights_generator.get_weights()
+ kernel, recurrent_kernel, bias = keras_weights
+ tf_weights = [np.concatenate((kernel, recurrent_kernel)), bias]
+
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ cell = keras.layers.SimpleRNNCell(output_shape)
+ k_out, k_state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ cell.set_weights(keras_weights)
+ [k_out, k_state] = sess.run([k_out, k_state], {inputs: x_train})
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ cell = rnn_cell_impl.BasicRNNCell(output_shape)
+ tf_out, tf_state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ cell.set_weights(tf_weights)
+ [tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train})
+
+ self.assertAllClose(tf_out, k_out)
+ self.assertAllClose(tf_state, k_state)
+
def testBasicLSTMCellInterchangeWithLSTMCell(self):
with self.session(graph=ops_lib.Graph()) as sess:
basic_cell = rnn_cell_impl.BasicLSTMCell(1)
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index 86e063cb36..4b92309e4d 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -136,7 +136,7 @@ class StatefulScatterNdTest(test.TestCase):
new = ref.copy()
np_scatter(new, indices, updates)
# Scatter via tensorflow
- ref_var = variables.Variable(ref)
+ ref_var = variables.VariableV1(ref)
ref_var.initializer.run()
tf_scatter(ref_var, indices, updates).eval()
@@ -258,7 +258,7 @@ class StatefulScatterNdTest(test.TestCase):
params = np.array([1, 2, 3, 4, 5, 6]).astype(np.float32)
updates = np.array([-3, -4, -5]).astype(np.float32)
with self.test_session(use_gpu=False):
- ref = variables.Variable(params)
+ ref = variables.VariableV1(params)
ref.initializer.run()
# Indices all in range, no problem.
diff --git a/tensorflow/python/kernel_tests/scatter_ops_test.py b/tensorflow/python/kernel_tests/scatter_ops_test.py
index 1a0fa744ae..527b7daf10 100644
--- a/tensorflow/python/kernel_tests/scatter_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_ops_test.py
@@ -178,7 +178,7 @@ class ScatterTest(test.TestCase):
np_scatter = _TF_OPS_TO_NUMPY[tf_scatter]
np_scatter(new, indices, updates)
# Scatter via tensorflow
- ref = variables.Variable(old)
+ ref = variables.VariableV1(old)
ref.initializer.run()
tf_scatter(ref, indices, updates).eval()
self.assertAllClose(ref.eval(), new)
@@ -294,7 +294,7 @@ class ScatterTest(test.TestCase):
updates = np.array([-3, -4, -5]).astype(np.float32)
if not test.is_gpu_available():
with self.test_session(use_gpu=False):
- ref = variables.Variable(params)
+ ref = variables.VariableV1(params)
ref.initializer.run()
# Indices all in range, no problem.
diff --git a/tensorflow/python/kernel_tests/softplus_op_test.py b/tensorflow/python/kernel_tests/softplus_op_test.py
index e8dc272637..636ed4747e 100644
--- a/tensorflow/python/kernel_tests/softplus_op_test.py
+++ b/tensorflow/python/kernel_tests/softplus_op_test.py
@@ -126,7 +126,7 @@ class SoftplusTest(test.TestCase):
with self.assertRaisesRegexp(
TypeError,
"'features' has DataType int32 not in list of allowed values"):
- nn_ops.softplus(constant_op.constant(7)).eval()
+ nn_ops.softplus(constant_op.constant(42)).eval()
if __name__ == "__main__":
diff --git a/tensorflow/python/kernel_tests/substr_op_test.py b/tensorflow/python/kernel_tests/substr_op_test.py
index cd3fe14883..37aa624b07 100644
--- a/tensorflow/python/kernel_tests/substr_op_test.py
+++ b/tensorflow/python/kernel_tests/substr_op_test.py
@@ -28,270 +28,448 @@ from tensorflow.python.platform import test
class SubstrOpTest(test.TestCase, parameterized.TestCase):
- def _testScalarString(self, dtype):
- test_string = b"Hello"
- position = np.array(1, dtype)
+ @parameterized.parameters(
+ (np.int32, 1, "BYTE"),
+ (np.int64, 1, "BYTE"),
+ (np.int32, -4, "BYTE"),
+ (np.int64, -4, "BYTE"),
+ (np.int32, 1, "UTF8_CHAR"),
+ (np.int64, 1, "UTF8_CHAR"),
+ (np.int32, -4, "UTF8_CHAR"),
+ (np.int64, -4, "UTF8_CHAR"),
+ )
+ def testScalarString(self, dtype, pos, unit):
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"He\xc3\xc3\U0001f604".encode("utf-8"),
+ }[unit]
+ expected_value = {
+ "BYTE": b"ell",
+ "UTF8_CHAR": u"e\xc3\xc3".encode("utf-8"),
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(3, dtype)
- expected_value = b"ell"
-
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # Negative position.
- test_string = b"Hello"
- position = np.array(-4, dtype)
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testScalarString_EdgeCases(self, dtype, unit):
+ # Empty string
+ test_string = {
+ "BYTE": b"",
+ "UTF8_CHAR": u"".encode("utf-8"),
+ }[unit]
+ expected_value = b""
+ position = np.array(0, dtype)
length = np.array(3, dtype)
- expected_value = b"ell"
-
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # Position is equal to the length of string.
- test_string = b""
+ # Full string
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
position = np.array(0, dtype)
- length = np.array(2, dtype)
- expected_value = b""
-
- substr_op = string_ops.substr(test_string, position, length)
+ length = np.array(5, dtype)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
- self.assertAllEqual(substr, expected_value)
-
- # Negative position magnitude is equal to the length of string.
- test_string = b"yo"
- position = np.array(-2, dtype)
- length = np.array(1, dtype)
- expected_value = b"y"
-
- substr_op = string_ops.substr(test_string, position, length)
+ self.assertAllEqual(substr, test_string)
+
+ # Full string (Negative)
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ position = np.array(-5, dtype)
+ length = np.array(5, dtype)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
- self.assertAllEqual(substr, expected_value)
-
- def _testVectorStrings(self, dtype):
- test_string = [b"Hello", b"World"]
- position = np.array(1, dtype)
- length = np.array(3, dtype)
- expected_value = [b"ell", b"orl"]
-
- substr_op = string_ops.substr(test_string, position, length)
+ self.assertAllEqual(substr, test_string)
+
+ # Length is larger in magnitude than a negative position
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ expected_string = {
+ "BYTE": b"ello",
+ "UTF8_CHAR": u"\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ position = np.array(-4, dtype)
+ length = np.array(5, dtype)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
- self.assertAllEqual(substr, expected_value)
-
- # Negative position.
- test_string = [b"Hello", b"World"]
- position = np.array(-4, dtype)
+ self.assertAllEqual(substr, expected_string)
+
+ @parameterized.parameters(
+ (np.int32, 1, "BYTE"),
+ (np.int64, 1, "BYTE"),
+ (np.int32, -4, "BYTE"),
+ (np.int64, -4, "BYTE"),
+ (np.int32, 1, "UTF8_CHAR"),
+ (np.int64, 1, "UTF8_CHAR"),
+ (np.int32, -4, "UTF8_CHAR"),
+ (np.int64, -4, "UTF8_CHAR"),
+ )
+ def testVectorStrings(self, dtype, pos, unit):
+ test_string = {
+ "BYTE": [b"Hello", b"World"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"H\xc3llo",
+ u"W\U0001f604rld"]],
+ }[unit]
+ expected_value = {
+ "BYTE": [b"ell", b"orl"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"\xc3ll", u"\U0001f604rl"]],
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(3, dtype)
- expected_value = [b"ell", b"orl"]
-
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testMatrixStrings(self, dtype):
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testMatrixStrings(self, dtype, unit):
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"He\xc3\xc3o",
+ u"W\U0001f604rld",
+ u"d\xfcd\xea"]]],
+ }[unit]
position = np.array(1, dtype)
length = np.array(4, dtype)
- expected_value = [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"],
- [b"ixte", b"even", b"ight"]]
-
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"en", b"leve", b"welv"], [b"hirt", b"ourt", b"ifte"],
+ [b"ixte", b"even", b"ight"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n",
+ u"\u053c\u025bv\u025b",
+ u"w\u0c1dlv"]],
+ [x.encode("utf-8") for x in [u"e\xc3\xc3o",
+ u"\U0001f604rld",
+ u"\xfcd\xea"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- # Negative position
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
- position = np.array(-2, dtype)
+ position = np.array(-3, dtype)
length = np.array(2, dtype)
- expected_value = [[b"en", b"en", b"ve"], [b"en", b"en", b"en"],
- [b"en", b"en", b"en"]]
-
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"te", b"ve", b"lv"], [b"ee", b"ee", b"ee"],
+ [b"ee", b"ee", b"ee"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227",
+ u"v\u025b", u"lv"]],
+ [x.encode("utf-8") for x in [u"\xc3\xc3", u"rl",
+ u"\xfcd"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testElementWisePosLen(self, dtype):
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testElementWisePosLen(self, dtype, unit):
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"He\xc3\xc3o",
+ u"W\U0001f604rld",
+ u"d\xfcd\xea"]],
+ [x.encode("utf-8") for x in [u"sixt\xea\xean",
+ u"se\U00010299enteen",
+ u"ei\U0001e920h\x86een"]]],
+ }[unit]
position = np.array([[1, -4, 3], [1, 2, -4], [-5, 2, 3]], dtype)
length = np.array([[2, 2, 4], [4, 3, 2], [5, 5, 5]], dtype)
- expected_value = [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"],
- [b"xteen", b"vente", b"hteen"]]
-
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"en", b"ev", b"lve"], [b"hirt", b"urt", b"te"],
+ [b"xteen", b"vente", b"hteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227n",
+ u"\u025bv",
+ u"lv\u025b"]],
+ [x.encode("utf-8") for x in [u"e\xc3\xc3o",
+ u"rld",
+ u"d\xfc"]],
+ [x.encode("utf-8") for x in [u"xt\xea\xean",
+ u"\U00010299ente",
+ u"h\x86een"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testBroadcast(self, dtype):
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testBroadcast(self, dtype, unit):
# Broadcast pos/len onto input string
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"],
- [b"nineteen", b"twenty", b"twentyone"]]
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"],
+ [b"nineteen", b"twenty", b"twentyone"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
+ u"f\U0001f604urt\xea\xean",
+ u"f\xcd\ua09ctee\ua0e4"]],
+ [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean",
+ u"se\U00010299enteen",
+ u"ei\U0001e920h\x86een"]],
+ [x.encode("utf-8") for x in [u"nineteen",
+ u"twenty",
+ u"twentyone"]]],
+ }[unit]
position = np.array([1, -4, 3], dtype)
length = np.array([1, 2, 3], dtype)
- expected_value = [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"],
- [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]]
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"e", b"ev", b"lve"], [b"h", b"te", b"tee"],
+ [b"i", b"te", b"hte"], [b"i", b"en", b"nty"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d227",
+ u"\u025bv", u"lv\u025b"]],
+ [x.encode("utf-8") for x in [u"h", u"t\xea", u"tee"]],
+ [x.encode("utf-8") for x in [u"\xcd", u"te", u"h\x86e"]],
+ [x.encode("utf-8") for x in [u"i", u"en", u"nty"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
# Broadcast input string onto pos/len
- test_string = [b"thirteen", b"fourteen", b"fifteen"]
+ test_string = {
+ "BYTE": [b"thirteen", b"fourteen", b"fifteen"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
+ u"f\U0001f604urt\xea\xean",
+ u"f\xcd\ua09ctee\ua0e4"]],
+ }[unit]
position = np.array([[1, -2, 3], [-3, 2, 1], [5, 5, -5]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- expected_value = [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"],
- [b"ee", b"ee", b"ft"]]
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [[b"hir", b"en", b"t"], [b"e", b"ur", b"ift"],
+ [b"ee", b"ee", b"ft"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"h\xcdr", u"\xean", u"t"]],
+ [x.encode("utf-8") for x in [u"\xea", u"ur",
+ u"\xcd\ua09ct"]],
+ [x.encode("utf-8") for x in [u"\xea\xea", u"\xea\xea",
+ u"\ua09ct"]]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
# Test 1D broadcast
- test_string = b"thirteen"
- position = np.array([1, -5, 7], dtype)
+ test_string = {
+ "BYTE": b"thirteen",
+ "UTF8_CHAR": u"th\xcdrt\xea\xean".encode("utf-8"),
+ }[unit]
+ position = np.array([1, -4, 7], dtype)
length = np.array([3, 2, 1], dtype)
- expected_value = [b"hir", b"rt", b"n"]
- substr_op = string_ops.substr(test_string, position, length)
+ expected_value = {
+ "BYTE": [b"hir", b"te", b"n"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"h\xcdr", u"t\xea", u"n"]],
+ }[unit]
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
substr = substr_op.eval()
self.assertAllEqual(substr, expected_value)
- def _testBadBroadcast(self, dtype):
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testBadBroadcast(self, dtype, unit):
test_string = [[b"ten", b"eleven", b"twelve"],
[b"thirteen", b"fourteen", b"fifteen"],
[b"sixteen", b"seventeen", b"eighteen"]]
position = np.array([1, 2, -3, 4], dtype)
length = np.array([1, 2, 3, 4], dtype)
with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
-
- def _testOutOfRangeError(self, dtype):
+ string_ops.substr(test_string, position, length, unit=unit)
+
+ @parameterized.parameters(
+ (np.int32, 6, "BYTE"),
+ (np.int64, 6, "BYTE"),
+ (np.int32, -6, "BYTE"),
+ (np.int64, -6, "BYTE"),
+ (np.int32, 6, "UTF8_CHAR"),
+ (np.int64, 6, "UTF8_CHAR"),
+ (np.int32, -6, "UTF8_CHAR"),
+ (np.int64, -6, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_Scalar(self, dtype, pos, unit):
# Scalar/Scalar
- test_string = b"Hello"
- position = np.array(7, dtype)
- length = np.array(3, dtype)
- substr_op = string_ops.substr(test_string, position, length)
- with self.cached_session():
- with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
- # Scalar/Scalar (with negative)
- test_string = b"Hello"
- position = np.array(-7, dtype)
+ test_string = {
+ "BYTE": b"Hello",
+ "UTF8_CHAR": u"H\xc3ll\U0001f604".encode("utf-8"),
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(3, dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, 4, "BYTE"),
+ (np.int64, 4, "BYTE"),
+ (np.int32, -4, "BYTE"),
+ (np.int64, -4, "BYTE"),
+ (np.int32, 4, "UTF8_CHAR"),
+ (np.int64, 4, "UTF8_CHAR"),
+ (np.int32, -4, "UTF8_CHAR"),
+ (np.int64, -4, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_VectorScalar(self, dtype, pos, unit):
# Vector/Scalar
- test_string = [b"good", b"good", b"bad", b"good"]
- position = np.array(4, dtype)
- length = np.array(1, dtype)
- substr_op = string_ops.substr(test_string, position, length)
- with self.cached_session():
- with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
- # Vector/Scalar (with negative)
- test_string = [b"good", b"good", b"bad", b"good"]
- position = np.array(-4, dtype)
+ test_string = {
+ "BYTE": [b"good", b"good", b"bad", b"good"],
+ "UTF8_CHAR": [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"b\xc3d",
+ u"g\xc3\xc3d"]],
+ }[unit]
+ position = np.array(pos, dtype)
length = np.array(1, dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_MatrixMatrix(self, dtype, unit):
# Matrix/Matrix
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
- [b"good", b"good", b"good"]]
+ test_string = {
+ "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
+ [b"good", b"good", b"good"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"g\xc3\xc3d"]],
+ [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"b\xc3d"]],
+ [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"g\xc3\xc3d"]]],
+ }[unit]
position = np.array([[1, 2, 3], [1, 2, 4], [1, 2, 3]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
+ substr_op.eval()
# Matrix/Matrix (with negative)
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"],
- [b"good", b"good", b"good"]]
position = np.array([[1, 2, -3], [1, 2, -4], [1, 2, -3]], dtype)
length = np.array([[3, 2, 1], [1, 2, 3], [2, 2, 2]], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testOutOfRangeError_Broadcast(self, dtype, unit):
# Broadcast
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]]
+ test_string = {
+ "BYTE": [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"g\xc3\xc3d"]],
+ [x.encode("utf-8") for x in [u"g\xc3\xc3d", u"g\xc3\xc3d",
+ u"b\xc3d"]]],
+ }[unit]
position = np.array([1, 2, 4], dtype)
length = np.array([1, 2, 3], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
+ substr_op.eval()
# Broadcast (with negative)
- test_string = [[b"good", b"good", b"good"], [b"good", b"good", b"bad"]]
position = np.array([-1, -2, -4], dtype)
length = np.array([1, 2, 3], dtype)
- substr_op = string_ops.substr(test_string, position, length)
+ substr_op = string_ops.substr(test_string, position, length, unit=unit)
with self.cached_session():
with self.assertRaises(errors_impl.InvalidArgumentError):
- substr = substr_op.eval()
-
- def _testMismatchPosLenShapes(self, dtype):
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
+ substr_op.eval()
+
+ @parameterized.parameters(
+ (np.int32, "BYTE"),
+ (np.int64, "BYTE"),
+ (np.int32, "UTF8_CHAR"),
+ (np.int64, "UTF8_CHAR"),
+ )
+ def testMismatchPosLenShapes(self, dtype, unit):
+ test_string = {
+ "BYTE": [[b"ten", b"eleven", b"twelve"],
+ [b"thirteen", b"fourteen", b"fifteen"],
+ [b"sixteen", b"seventeen", b"eighteen"]],
+ "UTF8_CHAR": [[x.encode("utf-8") for x in [u"\U0001d229\U0001d227n",
+ u"\xc6\u053c\u025bv\u025bn",
+ u"tw\u0c1dlv\u025b"]],
+ [x.encode("utf-8") for x in [u"th\xcdrt\xea\xean",
+ u"f\U0001f604urt\xea\xean",
+ u"f\xcd\ua09ctee\ua0e4"]],
+ [x.encode("utf-8") for x in [u"s\xcdxt\xea\xean",
+ u"se\U00010299enteen",
+ u"ei\U0001e920h\x86een"]]],
+ }[unit]
position = np.array([[1, 2, 3]], dtype)
length = np.array([2, 3, 4], dtype)
# Should fail: position/length have different rank
with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
+ string_ops.substr(test_string, position, length)
position = np.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]], dtype)
length = np.array([[2, 3, 4]], dtype)
# Should fail: position/length have different dimensionality
with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
-
- # Negative position.
- test_string = [[b"ten", b"eleven", b"twelve"],
- [b"thirteen", b"fourteen", b"fifteen"],
- [b"sixteen", b"seventeen", b"eighteen"]]
- position = np.array([[-1, -2, -3]], dtype)
- length = np.array([1, 2, 3], dtype)
- # Should fail: position/length have different rank
- with self.assertRaises(ValueError):
- substr_op = string_ops.substr(test_string, position, length)
-
- @parameterized.parameters(np.int32, np.int64)
- def testAll(self, dtype):
- self._testScalarString(dtype)
- self._testVectorStrings(dtype)
- self._testMatrixStrings(dtype)
- self._testElementWisePosLen(dtype)
- self._testBroadcast(dtype)
- self._testBadBroadcast(dtype)
- self._testOutOfRangeError(dtype)
- self._testMismatchPosLenShapes(dtype)
+ string_ops.substr(test_string, position, length)
def testWrongDtype(self):
with self.cached_session():
@@ -300,6 +478,11 @@ class SubstrOpTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(TypeError):
string_ops.substr(b"test", 3, 1.0)
+ def testInvalidUnit(self):
+ with self.cached_session():
+ with self.assertRaises(ValueError):
+ string_ops.substr(b"test", 3, 1, unit="UTF8")
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/transpose_op_test.py b/tensorflow/python/kernel_tests/transpose_op_test.py
index f42800226e..a825052dd2 100644
--- a/tensorflow/python/kernel_tests/transpose_op_test.py
+++ b/tensorflow/python/kernel_tests/transpose_op_test.py
@@ -39,7 +39,12 @@ class TransposeTest(test.TestCase):
return ret
def _compareCpu(self, x, p, conjugate=False):
- np_ans = self._np_transpose(x, p)
+ if p is None:
+ rank = x.ndim
+ perm = (rank - 1) - np.arange(rank)
+ else:
+ perm = p
+ np_ans = self._np_transpose(x, perm)
if conjugate:
np_ans = np.conj(np_ans)
with self.test_session(use_gpu=False):
@@ -65,7 +70,12 @@ class TransposeTest(test.TestCase):
return tf_ans, jacob_t
def _compareGpu(self, x, p, conjugate=False):
- np_ans = self._np_transpose(x, p)
+ if p is None:
+ rank = x.ndim
+ perm = (rank - 1) - np.arange(rank)
+ else:
+ perm = p
+ np_ans = self._np_transpose(x, perm)
if conjugate:
np_ans = np.conj(np_ans)
with self.test_session(use_gpu=True):
@@ -102,6 +112,11 @@ class TransposeTest(test.TestCase):
self._compareCpu(x, p, conjugate=c)
if use_gpu:
self._compareGpu(x, p, conjugate=c)
+ # Test with an empty permutation
+ for c in cs:
+ self._compareCpu(x, None, conjugate=c)
+ if use_gpu:
+ self._compareGpu(x, None, conjugate=c)
def _compare_cpu_gpu(self, x):
n = np.ndim(x)
@@ -449,6 +464,10 @@ class TransposeTest(test.TestCase):
self.assertEqual(
tensor_shape.TensorShape(None),
array_ops.transpose(array_ops.placeholder(dtypes.int32)).get_shape())
+ self.assertEqual(
+ tensor_shape.TensorShape(None),
+ array_ops.transpose(array_ops.placeholder(dtypes.int32),
+ [0]).get_shape())
def testNullTensor(self):
with self.cached_session():
@@ -456,6 +475,12 @@ class TransposeTest(test.TestCase):
xt = array_ops.transpose(x, [0, 2, 1]).eval()
self.assertAllEqual(xt.shape, (1, 0, 4))
+ def testScalar(self):
+ with self.cached_session():
+ x = constant_op.constant(42, dtype=dtypes.float32, shape=[])
+ xt = array_ops.transpose(x).eval()
+ self.assertAllEqual(xt, x)
+
def _testError(self, x, p, err):
with self.cached_session():
with self.assertRaisesOpError(err):
diff --git a/tensorflow/python/kernel_tests/unicode_script_op_test.py b/tensorflow/python/kernel_tests/unicode_script_op_test.py
new file mode 100644
index 0000000000..927e5459ed
--- /dev/null
+++ b/tensorflow/python/kernel_tests/unicode_script_op_test.py
@@ -0,0 +1,57 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#===============================================================================
+"""Functional tests for UnicodeScript op."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import string_ops
+from tensorflow.python.platform import test
+
+
+class UnicodeScriptOpTest(test.TestCase):
+
+ def testValidScripts(self):
+ inputs = [
+ ord("a"),
+ 0x0411, # CYRILLIC CAPITAL LETTER BE
+ 0x82b8, # CJK UNIFIED IDEOGRAPH-82B8
+ ord(",")
+ ]
+ with self.cached_session():
+ input_vector = constant_op.constant(inputs, dtypes.int32)
+ outputs = string_ops.unicode_script(input_vector).eval()
+ self.assertAllEqual(
+ outputs,
+ [
+ 25, # USCRIPT_LATIN (LATN)
+ 8, # USCRIPT_CYRILLIC (CYRL)
+ 17, # USCRIPT_HAN (HANI)
+ 0 # USCRIPT_COMMON (ZYYY)
+ ])
+
+ def testInvalidScript(self):
+ inputs = [-100, 0xffffff]
+ with self.cached_session():
+ input_vector = constant_op.constant(inputs, dtypes.int32)
+ outputs = string_ops.unicode_script(input_vector).eval()
+ self.assertAllEqual(outputs, [-1, -1])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/kernel_tests/variable_scope_test.py b/tensorflow/python/kernel_tests/variable_scope_test.py
index 401e1ae102..33f464fb90 100644
--- a/tensorflow/python/kernel_tests/variable_scope_test.py
+++ b/tensorflow/python/kernel_tests/variable_scope_test.py
@@ -394,10 +394,10 @@ class VariableScopeTest(test.TestCase):
old = variable_scope._DEFAULT_USE_RESOURCE
try:
variable_scope.enable_resource_variables()
- self.assertTrue(isinstance(variables_lib.Variable(1.0),
+ self.assertTrue(isinstance(variables_lib.VariableV1(1.0),
resource_variable_ops.ResourceVariable))
variable_scope.disable_resource_variables()
- self.assertFalse(isinstance(variables_lib.Variable(1.0),
+ self.assertFalse(isinstance(variables_lib.VariableV1(1.0),
resource_variable_ops.ResourceVariable))
finally:
variable_scope._DEFAULT_USE_RESOURCE = old
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 2e7975667c..c2b86089f4 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -43,14 +43,14 @@ class VariablesTestCase(test.TestCase):
def testInitialization(self):
with self.cached_session():
- var0 = variables.Variable(0.0)
+ var0 = variables.VariableV1(0.0)
self.assertEqual("Variable:0", var0.name)
self.assertEqual("Variable", var0._shared_name)
self.assertEqual([], var0.get_shape())
self.assertEqual([], var0.get_shape())
self.assertEqual([], var0.shape)
- var1 = variables.Variable(1.1)
+ var1 = variables.VariableV1(1.1)
self.assertEqual("Variable_1:0", var1.name)
self.assertEqual("Variable_1", var1._shared_name)
self.assertEqual([], var1.get_shape())
@@ -143,7 +143,7 @@ class VariablesTestCase(test.TestCase):
def testZeroSizeStringAssign(self):
with self.cached_session() as sess:
- array = variables.Variable(
+ array = variables.VariableV1(
initial_value=array_ops.zeros((0,), dtype=dtypes.string),
name="foo",
trainable=False,
@@ -192,7 +192,7 @@ class VariablesTestCase(test.TestCase):
# d get the control dep.
d = constant_op.constant(2.0)
# variables do not.
- var_x = variables.Variable(2.0)
+ var_x = variables.VariableV1(2.0)
self.assertEqual([c.op], d.op.control_inputs)
self.assertEqual([], var_x.initializer.control_inputs)
self.assertEqual([], var_x.value().op.control_inputs)
@@ -280,10 +280,10 @@ class VariablesTestCase(test.TestCase):
def testCollections(self):
with self.cached_session():
- var_x = variables.Variable(2.0)
- var_y = variables.Variable(2.0, trainable=False)
- var_z = variables.Variable(2.0, trainable=True)
- var_t = variables.Variable(
+ var_x = variables.VariableV1(2.0)
+ var_y = variables.VariableV1(2.0, trainable=False)
+ var_z = variables.VariableV1(2.0, trainable=True)
+ var_t = variables.VariableV1(
2.0,
trainable=True,
collections=[
@@ -296,9 +296,9 @@ class VariablesTestCase(test.TestCase):
def testCollectionsWithScope(self):
with self.cached_session():
with ops.name_scope("scope_1"):
- var_x = variables.Variable(2.0)
+ var_x = variables.VariableV1(2.0)
with ops.name_scope("scope_2"):
- var_y = variables.Variable(2.0)
+ var_y = variables.VariableV1(2.0)
self.assertEqual([var_x, var_y], variables.global_variables())
self.assertEqual([var_x], variables.global_variables("scope_1"))
@@ -399,7 +399,7 @@ class VariablesTestCase(test.TestCase):
def testColocation(self):
with ops.device("/job:ps"):
- var = variables.Variable(0, name="v")
+ var = variables.VariableV1(0, name="v")
with ops.device("/job:worker/task:7"):
assign_op = var.assign(1)
self.assertDeviceEqual("/job:ps", assign_op.device)
@@ -522,7 +522,7 @@ class VariablesTestCase(test.TestCase):
self.assertAllClose(np.ones((5, 5), np.float32), var.eval())
def testRepr(self):
- var = variables.Variable(np.zeros((5, 5), np.float32), name="noop")
+ var = variables.VariableV1(np.zeros((5, 5), np.float32), name="noop")
self.assertEqual(
"<tf.Variable 'noop:0' shape=(5, 5) dtype=float32_ref>",
repr(var))
@@ -556,8 +556,8 @@ class IsInitializedTest(test.TestCase):
def testVariableList(self):
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([1, 2], name="v")
- w = variables.Variable([3, 4], name="w")
+ v = variables.VariableV1([1, 2], name="v")
+ w = variables.VariableV1([3, 4], name="w")
uninited = variables.report_uninitialized_variables()
self.assertAllEqual(np.array([b"v", b"w"]), sess.run(uninited))
sess.run(w.initializer)
@@ -593,8 +593,8 @@ class ObsoleteIsInitializedTest(test.TestCase):
def testVariables(self):
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([1, 2])
- w = variables.Variable([3, 4])
+ v = variables.VariableV1([1, 2])
+ w = variables.VariableV1([3, 4])
_ = v, w
inited = variables.assert_variables_initialized()
with self.assertRaisesOpError("Attempting to use uninitialized value"):
@@ -604,8 +604,8 @@ class ObsoleteIsInitializedTest(test.TestCase):
def testVariableList(self):
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([1, 2])
- w = variables.Variable([3, 4])
+ v = variables.VariableV1([1, 2])
+ w = variables.VariableV1([3, 4])
inited = variables.assert_variables_initialized([v])
with self.assertRaisesOpError("Attempting to use uninitialized value"):
inited.op.run()
@@ -696,6 +696,48 @@ class PartitionedVariableTest(test.TestCase):
variable_list=[v0],
partitions=partitions)
+ def testPartitionedVariableAssignments(self):
+ with ops.Graph().as_default(), self.cached_session() as sess:
+ v0 = variables.Variable(initial_value=[0.0])
+ v1 = variables.Variable(initial_value=[1.0])
+ v0._set_save_slice_info(
+ variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
+ v1._set_save_slice_info(
+ variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1]))
+ partitions = [2]
+
+ # Pass variable_list as [v1, v0] to ensure they are properly
+ # re-sorted to [v0, v1] based on their slice info offsets.
+ partitioned_variable = variables.PartitionedVariable(
+ name="two_vars",
+ shape=[2],
+ dtype=v0.dtype,
+ variable_list=[v0, v1],
+ partitions=partitions)
+
+ deltas_a = constant_op.constant([1.0, 2.0])
+ deltas_b = constant_op.constant([3.0, 4.0])
+ ones = array_ops.ones([2])
+ plus_delta = partitioned_variable.assign_add(deltas_a)
+ minus_delta = partitioned_variable.assign_sub(deltas_b)
+ assign_ones = partitioned_variable.assign(ones)
+ variables.global_variables_initializer().run()
+
+ self.assertEqual([1.0], plus_delta[0].eval())
+ self.assertEqual([1.0], v0.eval())
+ self.assertEqual([3.0], plus_delta[1].eval())
+ self.assertEqual([3.0], v1.eval())
+
+ self.assertEqual([-2.0], minus_delta[0].eval())
+ self.assertEqual([-2.0], v0.eval())
+ self.assertEqual([-1.0], minus_delta[1].eval())
+ self.assertEqual([-1.0], v1.eval())
+
+ self.assertEqual([1.0], assign_ones[0].eval())
+ self.assertEqual([1.0], v0.eval())
+ self.assertEqual([1.0], assign_ones[1].eval())
+ self.assertEqual([1.0], v1.eval())
+
class VariableContainerTest(test.TestCase):
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 3ba880d7a1..e399ece232 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -131,10 +131,20 @@ class Layer(base_layer.Layer):
def add_loss(self, losses, inputs=None):
previous_losses_length = len(self._losses)
+ previous_callable_losses_length = len(self._callable_losses)
super(Layer, self).add_loss(losses, inputs=inputs)
- # TODO(fchollet): deprecate collection below.
- new_losses = self._losses[previous_losses_length:]
- _add_elements_to_collection(new_losses, ops.GraphKeys.REGULARIZATION_LOSSES)
+ if not context.executing_eagerly():
+ # TODO(fchollet): deprecate collection below.
+ new_losses = self._losses[previous_losses_length:]
+ new_callable_losses = self._callable_losses[
+ previous_callable_losses_length:]
+ for regularizer in new_callable_losses:
+ loss_tensor = regularizer()
+ if loss_tensor is not None:
+ new_losses.append(loss_tensor)
+ _add_elements_to_collection(
+ new_losses,
+ ops.GraphKeys.REGULARIZATION_LOSSES)
def _name_scope(self):
"""Determines op naming for the Layer."""
diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py
index d61d3b6dba..257fa27156 100644
--- a/tensorflow/python/layers/convolutional_test.py
+++ b/tensorflow/python/layers/convolutional_test.py
@@ -207,7 +207,8 @@ class ConvTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DBiasRegularizer(self):
height, width = 7, 9
@@ -217,7 +218,8 @@ class ConvTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DNoBias(self):
height, width = 7, 9
@@ -445,7 +447,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DPointwiseRegularizer(self):
length = 9
@@ -455,7 +458,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DBiasRegularizer(self):
length = 9
@@ -465,7 +469,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DNoBias(self):
length = 9
@@ -682,7 +687,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DPointwiseRegularizer(self):
height, width = 7, 9
@@ -692,7 +698,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DBiasRegularizer(self):
height, width = 7, 9
@@ -702,7 +709,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DNoBias(self):
height, width = 7, 9
@@ -839,7 +847,8 @@ class Conv2DTransposeTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DTransposeBiasRegularizer(self):
height, width = 7, 9
@@ -849,7 +858,8 @@ class Conv2DTransposeTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DTransposeNoBias(self):
height, width = 7, 9
@@ -1017,7 +1027,8 @@ class Conv3DTransposeTest(test.TestCase):
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv3DTransposeBiasRegularizer(self):
depth, height, width = 5, 7, 9
@@ -1027,7 +1038,8 @@ class Conv3DTransposeTest(test.TestCase):
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv3DTransposeNoBias(self):
depth, height, width = 5, 7, 9
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index 9879e5020f..e06e9aba4a 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -269,6 +269,13 @@ def dropout(inputs,
class Flatten(keras_layers.Flatten, base.Layer):
"""Flattens an input tensor while preserving the batch axis (axis 0).
+ Arguments:
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, ..., channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, ...)`.
+
Examples:
```
@@ -285,12 +292,17 @@ class Flatten(keras_layers.Flatten, base.Layer):
@tf_export('layers.flatten')
-def flatten(inputs, name=None):
+def flatten(inputs, name=None, data_format='channels_last'):
"""Flattens an input tensor while preserving the batch axis (axis 0).
Arguments:
inputs: Tensor input.
name: The name of the layer (string).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, height, width, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, height, width)`.
Returns:
Reshaped tensor.
@@ -307,7 +319,7 @@ def flatten(inputs, name=None):
# now `y` has shape `(None, None)`
```
"""
- layer = Flatten(name=name)
+ layer = Flatten(name=name, data_format=data_format)
return layer.apply(inputs)
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index 46009a30ac..0343bfa8bd 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -197,7 +197,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testKernelRegularizerWithReuse(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
@@ -218,7 +219,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testFunctionalDense(self):
with self.cached_session():
@@ -474,6 +476,40 @@ class FlattenTest(test.TestCase):
shape = core_layers.Flatten().compute_output_shape((None, 3, None))
self.assertEqual(shape.as_list(), [None, None])
+ def testDataFormat5d(self):
+ np_input_channels_last = np.arange(
+ 120, dtype='float32').reshape([1, 5, 4, 3, 2])
+
+ with self.test_session() as sess:
+ x = array_ops.placeholder(shape=(1, 5, 4, 3, 2), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_last')(x)
+ np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last})
+
+ x = array_ops.placeholder(shape=(1, 2, 5, 4, 3), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_first')(x)
+ np_input_channels_first = np.transpose(np_input_channels_last,
+ [0, 4, 1, 2, 3])
+ np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first})
+
+ self.assertAllEqual(np_output_cl, np_output_cf)
+
+ def testDataFormat4d(self):
+ np_input_channels_last = np.arange(
+ 24, dtype='float32').reshape([1, 4, 3, 2])
+
+ with self.test_session() as sess:
+ x = array_ops.placeholder(shape=(1, 4, 3, 2), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_last')(x)
+ np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last})
+
+ x = array_ops.placeholder(shape=(1, 2, 4, 3), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_first')(x)
+ np_input_channels_first = np.transpose(np_input_channels_last,
+ [0, 3, 1, 2])
+ np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first})
+
+ self.assertAllEqual(np_output_cl, np_output_cf)
+
def testFunctionalFlatten(self):
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
y = core_layers.flatten(x, name='flatten')
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index cce71a2bab..9ab683d96a 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -22,10 +22,12 @@ from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import errors
from tensorflow.python.util import compat
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
-@tf_export("python_io.TFRecordCompressionType")
+@tf_export("io.TFRecordCompressionType", "python_io.TFRecordCompressionType")
+@deprecation.deprecated_endpoints("python_io.TFRecordCompressionType")
class TFRecordCompressionType(object):
"""The type of compression for the record."""
NONE = 0
@@ -33,7 +35,8 @@ class TFRecordCompressionType(object):
GZIP = 2
-@tf_export("python_io.TFRecordOptions")
+@tf_export("io.TFRecordOptions", "python_io.TFRecordOptions")
+@deprecation.deprecated_endpoints("python_io.TFRecordOptions")
class TFRecordOptions(object):
"""Options used for manipulating TFRecord files."""
compression_type_map = {
@@ -143,7 +146,8 @@ class TFRecordOptions(object):
return options
-@tf_export("python_io.tf_record_iterator")
+@tf_export("io.tf_record_iterator", "python_io.tf_record_iterator")
+@deprecation.deprecated_endpoints("python_io.tf_record_iterator")
def tf_record_iterator(path, options=None):
"""An iterator that read the records from a TFRecords file.
@@ -175,7 +179,8 @@ def tf_record_iterator(path, options=None):
reader.Close()
-@tf_export("python_io.TFRecordWriter")
+@tf_export("io.TFRecordWriter", "python_io.TFRecordWriter")
+@deprecation.deprecated_endpoints("python_io.TFRecordWriter")
class TFRecordWriter(object):
"""A class to write records to a TFRecords file.
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index a7f57e94e3..e3e4d5f910 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1204,7 +1204,8 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
return _apply_mask_1d(tensor, mask, axis)
-@tf_export("sparse_mask")
+@tf_export("sparse.mask", "sparse_mask")
+@deprecation.deprecated_endpoints("sparse_mask")
def sparse_mask(a, mask_indices, name=None):
"""Masks elements of `IndexedSlices`.
@@ -1226,7 +1227,7 @@ def sparse_mask(a, mask_indices, name=None):
# `b` will be the subset of `a` slices at its second and third indices, so
# we want to mask its first and last indices (which are at absolute
# indices 12, 45)
- b = tf.sparse_mask(a, [12, 45])
+ b = tf.sparse.mask(a, [12, 45])
b.indices # [26, 37]
tf.shape(b.values) # [2, 10]
@@ -1382,7 +1383,7 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
[10, 11, 12]]])
# Take the transpose of the matrices in dimension-0
- # (this common operation has a shorthand `matrix_transpose`)
+ # (this common operation has a shorthand `linalg.transpose`)
tf.transpose(x, perm=[0, 2, 1]) # [[[1, 4],
# [2, 5],
# [3, 6]],
@@ -1406,8 +1407,13 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
gen_array_ops.conjugate_transpose
if (conjugate and a.dtype.is_complex) else gen_array_ops.transpose)
if perm is None:
- rank = gen_array_ops.rank(a)
- perm = (rank - 1) - gen_math_ops._range(0, rank, 1)
+ a = ops.convert_to_tensor(a, name="a")
+ if not a.get_shape().ndims:
+ rank = gen_array_ops.rank(a)
+ perm = (rank - 1) - gen_math_ops._range(0, rank, 1)
+ else:
+ rank = a.get_shape().ndims
+ perm = (rank - 1) - np.arange(rank)
ret = transpose_fn(a, perm, name=name)
# NOTE(mrry): Setting the shape explicitly because
# reverse is not handled by the shape function.
@@ -1421,7 +1427,8 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
# pylint: disable=invalid-name
-@tf_export("matrix_transpose", "linalg.transpose")
+@tf_export("linalg.transpose", "matrix_transpose")
+@deprecation.deprecated_endpoints("matrix_transpose")
def matrix_transpose(a, name="matrix_transpose", conjugate=False):
"""Transposes last two dimensions of tensor `a`.
@@ -1429,19 +1436,19 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False):
```python
x = tf.constant([[1, 2, 3], [4, 5, 6]])
- tf.matrix_transpose(x) # [[1, 4],
+ tf.linalg.transpose(x) # [[1, 4],
# [2, 5],
# [3, 6]]
x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
[4 + 4j, 5 + 5j, 6 + 6j]])
- tf.matrix_transpose(x, conjugate=True) # [[1 - 1j, 4 - 4j],
+ tf.linalg.transpose(x, conjugate=True) # [[1 - 1j, 4 - 4j],
# [2 - 2j, 5 - 5j],
# [3 - 3j, 6 - 6j]]
# Matrix with two batch dimensions.
# x.shape is [1, 2, 3, 4]
- # tf.matrix_transpose(x) is shape [1, 2, 4, 3]
+ # tf.linalg.transpose(x) is shape [1, 2, 4, 3]
```
Note that `tf.matmul` provides kwargs allowing for transpose of arguments.
@@ -1452,14 +1459,14 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False):
tf.matmul(matrix, b, transpose_b=True)
# Inefficient!
- tf.matmul(matrix, tf.matrix_transpose(b))
+ tf.matmul(matrix, tf.linalg.transpose(b))
```
@compatibility(numpy)
In `numpy` transposes are memory-efficient constant time operations as they
simply return a new view of the same data with adjusted `strides`.
- TensorFlow does not support strides, `matrix_transposes` return a new tensor
+ TensorFlow does not support strides, `linalg.transposes` return a new tensor
with the items permuted.
@end_compatibility
@@ -1467,7 +1474,7 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False):
a: A `Tensor` with `rank >= 2`.
name: A name for the operation (optional).
conjugate: Optional bool. Setting it to `True` is mathematically equivalent
- to tf.conj(tf.matrix_transpose(input)).
+ to tf.conj(tf.linalg.transpose(input)).
Returns:
A transposed batch matrix `Tensor`.
@@ -1756,7 +1763,8 @@ def _normalize_sparse_shape(shape, name):
return (ops.convert_to_tensor(shape, dtype=dtypes.int64, name=name), rank)
-@tf_export("sparse_placeholder")
+@tf_export("sparse.placeholder", "sparse_placeholder")
+@deprecation.deprecated_endpoints("sparse_placeholder")
def sparse_placeholder(dtype, shape=None, name=None):
"""Inserts a placeholder for a sparse tensor that will be always fed.
@@ -1767,8 +1775,8 @@ def sparse_placeholder(dtype, shape=None, name=None):
For example:
```python
- x = tf.sparse_placeholder(tf.float32)
- y = tf.sparse_reduce_sum(x)
+ x = tf.sparse.placeholder(tf.float32)
+ y = tf.sparse.reduce_sum(x)
with tf.Session() as sess:
print(sess.run(y)) # ERROR: will fail because x was not fed.
@@ -2250,7 +2258,8 @@ def required_space_to_batch_paddings(input_shape,
return result_paddings, result_crops
-@tf_export("space_to_batch")
+@tf_export("nn.space_to_batch", "space_to_batch")
+@deprecation.deprecated_endpoints("space_to_batch")
def space_to_batch(input, paddings, block_size, name=None): # pylint: disable=redefined-builtin
result = space_to_batch_nd(
input,
@@ -2264,7 +2273,8 @@ def space_to_batch(input, paddings, block_size, name=None): # pylint: disable=r
space_to_batch.__doc__ = gen_array_ops.space_to_batch.__doc__
-@tf_export("space_to_depth")
+@tf_export("nn.space_to_depth", "space_to_depth")
+@deprecation.deprecated_endpoints("space_to_depth")
def space_to_depth(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin
return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
@@ -2272,7 +2282,8 @@ def space_to_depth(input, block_size, name=None, data_format="NHWC"): # pylint:
space_to_depth.__doc__ = gen_array_ops.space_to_depth.__doc__
-@tf_export("depth_to_space")
+@tf_export("nn.depth_to_space", "depth_to_space")
+@deprecation.deprecated_endpoints("depth_to_space")
def depth_to_space(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin
return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
@@ -2710,16 +2721,22 @@ def batch_gather(params, indices, name=None):
params = ops.convert_to_tensor(params, name="params")
indices_shape = shape(indices)
params_shape = shape(params)
+
ndims = indices.shape.ndims
if ndims is None:
raise ValueError("batch_gather does not allow indices with unknown "
"shape.")
batch_indices = indices
- accum_dim_value = 1
+ indices_dtype = indices.dtype.base_dtype
+ accum_dim_value = ones((), dtype=indices_dtype)
+ # Use correct type for offset index computation
+ casted_params_shape = gen_math_ops.cast(params_shape, indices_dtype)
for dim in range(ndims-1, 0, -1):
- dim_value = params_shape[dim-1]
- accum_dim_value *= params_shape[dim]
- dim_indices = gen_math_ops._range(0, dim_value, 1)
+ dim_value = casted_params_shape[dim-1]
+ accum_dim_value *= casted_params_shape[dim]
+ start = zeros((), dtype=indices_dtype)
+ step = ones((), dtype=indices_dtype)
+ dim_indices = gen_math_ops._range(start, dim_value, step)
dim_indices *= accum_dim_value
dim_shape = stack([1] * (dim - 1) + [dim_value] + [1] * (ndims - dim),
axis=0)
@@ -2747,7 +2764,8 @@ def batch_gather(params, indices, name=None):
@tf_export("quantize_v2")
@deprecation.deprecated(
"2017-10-25",
- "`tf.quantize_v2` is deprecated, please use `tf.quantize` instead.")
+ "`tf.quantize_v2` is deprecated, please use `tf.quantization.quantize` "
+ "instead.") # pylint: disable=missing-docstring
def quantize_v2(input, # pylint: disable=redefined-builtin
min_range,
max_range,
@@ -2769,7 +2787,8 @@ quantize_v2.__doc__ = """Please use `tf.quantize` instead."""
# We want to expose tf.quantize instead of tf.quantize_v2; we can deprecate
# tf.quantize_v2 in next version of TensorFlow.
-@tf_export("quantize")
+@tf_export("quantization.quantize", "quantize")
+@deprecation.deprecated_endpoints("quantize")
def quantize(input, # pylint: disable=redefined-builtin
min_range,
max_range,
diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py
index 9ea1ea9c92..98dde995c9 100644
--- a/tensorflow/python/ops/candidate_sampling_ops.py
+++ b/tensorflow/python/ops/candidate_sampling_ops.py
@@ -23,10 +23,12 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops # pylint: disable=unused-import
from tensorflow.python.ops import gen_candidate_sampling_ops
from tensorflow.python.ops import math_ops # pylint: disable=unused-import
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
-@tf_export('nn.uniform_candidate_sampler')
+@tf_export('random.uniform_candidate_sampler', 'nn.uniform_candidate_sampler')
+@deprecation.deprecated_endpoints('nn.uniform_candidate_sampler')
def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
range_max, seed=None, name=None):
"""Samples a set of classes using a uniform base distribution.
@@ -82,7 +84,9 @@ def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
seed2=seed2, name=name)
-@tf_export('nn.log_uniform_candidate_sampler')
+@tf_export('random.log_uniform_candidate_sampler',
+ 'nn.log_uniform_candidate_sampler')
+@deprecation.deprecated_endpoints('nn.log_uniform_candidate_sampler')
def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
range_max, seed=None, name=None):
"""Samples a set of classes using a log-uniform (Zipfian) base distribution.
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index c3cf6e61f2..d607f1d9fb 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import compat
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
NUMERIC_TYPES = frozenset(
@@ -91,7 +92,8 @@ def _shape_and_dtype_str(tensor):
return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)
-@tf_export('assert_proper_iterable')
+@tf_export('debugging.assert_proper_iterable', 'assert_proper_iterable')
+@deprecation.deprecated_endpoints('assert_proper_iterable')
def assert_proper_iterable(values):
"""Static assert that values is a "proper" iterable.
@@ -119,7 +121,8 @@ def assert_proper_iterable(values):
'Expected argument "values" to be iterable. Found: %s' % type(values))
-@tf_export('assert_negative')
+@tf_export('debugging.assert_negative', 'assert_negative')
+@deprecation.deprecated_endpoints('assert_negative')
def assert_negative(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x < 0` holds element-wise.
@@ -160,7 +163,8 @@ def assert_negative(x, data=None, summarize=None, message=None, name=None):
return assert_less(x, zero, data=data, summarize=summarize)
-@tf_export('assert_positive')
+@tf_export('debugging.assert_positive', 'assert_positive')
+@deprecation.deprecated_endpoints('assert_positive')
def assert_positive(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x > 0` holds element-wise.
@@ -200,7 +204,8 @@ def assert_positive(x, data=None, summarize=None, message=None, name=None):
return assert_less(zero, x, data=data, summarize=summarize)
-@tf_export('assert_non_negative')
+@tf_export('debugging.assert_non_negative', 'assert_non_negative')
+@deprecation.deprecated_endpoints('assert_non_negative')
def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x >= 0` holds element-wise.
@@ -242,7 +247,8 @@ def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
return assert_less_equal(zero, x, data=data, summarize=summarize)
-@tf_export('assert_non_positive')
+@tf_export('debugging.assert_non_positive', 'assert_non_positive')
+@deprecation.deprecated_endpoints('assert_non_positive')
def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x <= 0` holds element-wise.
@@ -284,7 +290,7 @@ def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
return assert_less_equal(x, zero, data=data, summarize=summarize)
-@tf_export('assert_equal')
+@tf_export('debugging.assert_equal', 'assert_equal')
def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x == y` holds element-wise.
@@ -384,7 +390,8 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_none_equal')
+@tf_export('debugging.assert_none_equal', 'assert_none_equal')
+@deprecation.deprecated_endpoints('assert_none_equal')
def assert_none_equal(
x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x != y` holds for all elements.
@@ -435,7 +442,8 @@ def assert_none_equal(
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_near')
+@tf_export('debugging.assert_near', 'assert_near')
+@deprecation.deprecated_endpoints('assert_near')
def assert_near(
x, y, rtol=None, atol=None, data=None, summarize=None, message=None,
name=None):
@@ -513,7 +521,7 @@ def assert_near(
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_less')
+@tf_export('debugging.assert_less', 'assert_less')
def assert_less(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x < y` holds element-wise.
@@ -561,7 +569,8 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None):
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_less_equal')
+@tf_export('debugging.assert_less_equal', 'assert_less_equal')
+@deprecation.deprecated_endpoints('assert_less_equal')
def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x <= y` holds element-wise.
@@ -609,7 +618,7 @@ def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_greater')
+@tf_export('debugging.assert_greater', 'assert_greater')
def assert_greater(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x > y` holds element-wise.
@@ -657,7 +666,8 @@ def assert_greater(x, y, data=None, summarize=None, message=None, name=None):
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_greater_equal')
+@tf_export('debugging.assert_greater_equal', 'assert_greater_equal')
+@deprecation.deprecated_endpoints('assert_greater_equal')
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
name=None):
"""Assert the condition `x >= y` holds element-wise.
@@ -755,7 +765,7 @@ def _assert_rank_condition(
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_rank')
+@tf_export('debugging.assert_rank', 'assert_rank')
def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank equal to `rank`.
@@ -817,7 +827,8 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
return assert_op
-@tf_export('assert_rank_at_least')
+@tf_export('debugging.assert_rank_at_least', 'assert_rank_at_least')
+@deprecation.deprecated_endpoints('assert_rank_at_least')
def assert_rank_at_least(
x, rank, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank equal to `rank` or higher.
@@ -948,7 +959,8 @@ def _assert_ranks_condition(
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_rank_in')
+@tf_export('debugging.assert_rank_in', 'assert_rank_in')
+@deprecation.deprecated_endpoints('assert_rank_in')
def assert_rank_in(
x, ranks, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank in `ranks`.
@@ -1010,7 +1022,8 @@ def assert_rank_in(
return assert_op
-@tf_export('assert_integer')
+@tf_export('debugging.assert_integer', 'assert_integer')
+@deprecation.deprecated_endpoints('assert_integer')
def assert_integer(x, message=None, name=None):
"""Assert that `x` is of integer dtype.
@@ -1048,7 +1061,8 @@ def assert_integer(x, message=None, name=None):
return control_flow_ops.no_op('statically_determined_was_integer')
-@tf_export('assert_type')
+@tf_export('debugging.assert_type', 'assert_type')
+@deprecation.deprecated_endpoints('assert_type')
def assert_type(tensor, tf_type, message=None, name=None):
"""Statically asserts that the given `Tensor` is of the specified type.
@@ -1095,12 +1109,14 @@ def _get_diff_for_monotonic_comparison(x):
return control_flow_ops.cond(is_shorter_than_two, short_result, diff)
-@tf_export('is_numeric_tensor')
+@tf_export('debugging.is_numeric_tensor', 'is_numeric_tensor')
+@deprecation.deprecated_endpoints('is_numeric_tensor')
def is_numeric_tensor(tensor):
return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES
-@tf_export('is_non_decreasing')
+@tf_export('debugging.is_non_decreasing', 'is_non_decreasing')
+@deprecation.deprecated_endpoints('is_non_decreasing')
def is_non_decreasing(x, name=None):
"""Returns `True` if `x` is non-decreasing.
@@ -1127,7 +1143,8 @@ def is_non_decreasing(x, name=None):
return math_ops.reduce_all(math_ops.less_equal(zero, diff))
-@tf_export('is_strictly_increasing')
+@tf_export('debugging.is_strictly_increasing', 'is_strictly_increasing')
+@deprecation.deprecated_endpoints('is_strictly_increasing')
def is_strictly_increasing(x, name=None):
"""Returns `True` if `x` is strictly increasing.
@@ -1202,7 +1219,8 @@ def _assert_same_base_type(items, expected_type=None):
return expected_type
-@tf_export('assert_same_float_dtype')
+@tf_export('debugging.assert_same_float_dtype', 'assert_same_float_dtype')
+@deprecation.deprecated_endpoints('assert_same_float_dtype')
def assert_same_float_dtype(tensors=None, dtype=None):
"""Validate and return float type based on `tensors` and `dtype`.
@@ -1231,7 +1249,8 @@ def assert_same_float_dtype(tensors=None, dtype=None):
return dtype
-@tf_export('assert_scalar')
+@tf_export('debugging.assert_scalar', 'assert_scalar')
+@deprecation.deprecated_endpoints('assert_scalar')
def assert_scalar(tensor, name=None):
with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope:
tensor = ops.convert_to_tensor(tensor, name=name_scope)
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index 29468431b3..45516068f4 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import numerics
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -76,8 +77,8 @@ def clip_by_value(t, clip_value_min, clip_value_max,
return t_max
# TODO(scottzhu): switch to use new implmentation in 2 weeks.
- # return gen_math_ops.clip_by_value(
- # t, clip_value_min, clip_value_max, name=name)
+ # return gen_math_ops.clip_by_value(
+ # t, clip_value_min, clip_value_max, name=name)
# TODO(scottzhu): switch to use new implmentation in 2 weeks.
@@ -159,7 +160,8 @@ def clip_by_norm(t, clip_norm, axes=None, name=None):
return tclip
-@tf_export("global_norm")
+@tf_export("linalg.global_norm", "global_norm")
+@deprecation.deprecated_endpoints("global_norm")
def global_norm(t_list, name=None):
"""Computes the global norm of multiple tensors.
diff --git a/tensorflow/python/ops/cond_v2_impl.py b/tensorflow/python/ops/cond_v2_impl.py
index f8b1ddb140..c9aa4d4889 100644
--- a/tensorflow/python/ops/cond_v2_impl.py
+++ b/tensorflow/python/ops/cond_v2_impl.py
@@ -96,9 +96,12 @@ def cond_v2(pred, true_fn, false_fn, name="cond"):
# Create the If op.
tensors = gen_functional_ops._if( # pylint: disable=protected-access
- pred, cond_inputs, [t.dtype for t in true_graph.outputs],
+ pred,
+ cond_inputs, [t.dtype for t in true_graph.outputs],
_create_new_tf_function(true_graph),
_create_new_tf_function(false_graph),
+ output_shapes=_get_output_shapes(true_graph.outputs,
+ false_graph.outputs),
name=scope)
# Set the flag to enable lowering on the `if` op if necessary
@@ -175,9 +178,12 @@ def _IfGrad(op, *grads): # pylint: disable=invalid-name
# Create the gradient If op.
tensors = gen_functional_ops._if(
- op.inputs[0], grad_inputs, [t.dtype for t in true_grad_graph.outputs],
+ op.inputs[0],
+ grad_inputs, [t.dtype for t in true_grad_graph.outputs],
_create_new_tf_function(true_grad_graph),
- _create_new_tf_function(false_grad_graph))
+ _create_new_tf_function(false_grad_graph),
+ output_shapes=_get_output_shapes(true_grad_graph.outputs,
+ false_grad_graph.outputs))
# The predicate has no gradient.
return [None] + tensors[:num_grad_outputs]
@@ -276,9 +282,10 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
as is.
2. Tensors in the forward pass graph. These tensors may not be "live"
when the gradient is being computed. We replace such references by their
- corresponding tensor in the least common ancestor graph of `grad_graph` and
- `cond_graph`. Since we export intermediate tensors for all branch
- functions, this is always possible.
+ corresponding tensor in `cond_graph.outer_graph`. In the case of nested
+ control flow or functions, the gradient logic handling
+ `grad_graph.outer_graph` will make sure the tensor from
+ `cond_graph.outer_graph` is also correctly captured.
Args:
cond_graph: function.FuncGraph. The forward-pass function.
@@ -290,24 +297,23 @@ def _resolve_grad_inputs(cond_graph, grad_graph):
new_inputs = []
for t in grad_graph.external_captures:
+ # `t` must either be in `grad_graph.outer_graph` or in the forward
+ # `cond_graph`.
if t.graph != grad_graph.outer_graph:
- # `t` is a tensor in `cond_graph` or one of its ancestors. We bubble this
- # tensor to the least common ancestor of the `cond_graph` and
- # `grad_graph` so that it is "in-scope" for `grad_graph`.
- # TODO(srbs): `_is_ancestor` calls may be expensive. Compute the least
- # common ancestor once and re-use.
- assert _is_ancestor(cond_graph, t.graph)
- while not _is_ancestor(grad_graph, t.graph):
- assert isinstance(t.graph, _function.FuncGraph)
- if t in t.graph.internal_captures:
- # TODO(srbs): Consider building a map of internal_captures ->
- # external_captures instead of searching for `t` twice.
- t = t.graph.external_captures[t.graph.internal_captures.index(t)]
- else:
- # Note: All intermediate tensors are output by the If op.
- # TODO(srbs): .index() calls may be expensive. Optimize.
- t = t.graph._if.outputs[t.graph.outputs.index(t)]
- assert _is_ancestor(grad_graph, t.graph)
+ assert t.graph == cond_graph
+ # `internal_captures` are not treated as intermediates and hence not added
+ # to If op outputs. So we get the outer tensor corresponding to those
+ # from the list of `external_captures`.
+ try:
+ t = t.graph._if.outputs[t.graph.outputs.index(t)]
+ except ValueError:
+ index = t.graph.internal_captures.index(t)
+ t = t.graph.external_captures[index]
+
+ # Note: We rely on the capturing logic of the gradient If op graph to
+ # correctly capture the tensors in `cond_graph.outer_graph`. Both cond_v2
+ # and while_v2 handle this while building their gradient functions.
+ assert t.graph == cond_graph.outer_graph
new_inputs.append(t)
return new_inputs
@@ -480,9 +486,9 @@ def _check_same_outputs(true_graph, false_graph):
" false_fn: %s" % (true_output_types, false_output_types))
-def _is_ancestor(graph, maybe_ancestor):
- if maybe_ancestor == graph:
- return True
- if isinstance(graph, _function.FuncGraph):
- return _is_ancestor(graph.outer_graph, maybe_ancestor)
- return False
+def _get_output_shapes(true_graph_outputs, false_graph_outputs):
+ output_shapes = [
+ t_out.shape.most_specific_compatible_shape(f_out.shape)
+ for t_out, f_out in zip(true_graph_outputs, false_graph_outputs)
+ ]
+ return output_shapes
diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py
index c09154129f..8259142456 100644
--- a/tensorflow/python/ops/confusion_matrix.py
+++ b/tensorflow/python/ops/confusion_matrix.py
@@ -26,6 +26,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -89,7 +90,8 @@ def remove_squeezable_dimensions(
return labels, predictions
-@tf_export('confusion_matrix')
+@tf_export('train.confusion_matrix', 'confusion_matrix')
+@deprecation.deprecated_endpoints('confusion_matrix')
def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32,
name=None, weights=None):
"""Computes the confusion matrix from predictions and labels.
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 87f8bd85a5..5bc217d355 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -60,8 +60,17 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
+# The while_v2 module.
+_while_v2 = None
ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
+# Note: Setting this to True is not sufficient to switch to the v2 while_loop.
+# Users must also import the while_v2 module to set the _while_v2 module
+# variable above. We do this to avoid a circular dependency:
+# control_flow_ops -> while_v2 -> gradients_impl -> control_flow_ops
+# A ValueError is raised in tf.while_loop if this is set to True and the
+# `_while_v2` module is not set.
+ENABLE_WHILE_V2 = os.getenv("TF_ENABLE_WHILE_V2", "0") != "0"
# We override the 'tuple' for a control flow op, so we keep python's
@@ -97,7 +106,7 @@ def _summarize_eager(tensor, summarize=None):
# Assert and Print are special symbols in python, so we must
# use an upper-case version of them.
-@tf_export("Assert")
+@tf_export("debugging.Assert", "Assert")
@tf_should_use.should_use_result
def Assert(condition, data, summarize=None, name=None):
"""Asserts that the given condition is true.
@@ -1324,6 +1333,9 @@ class ControlFlowState(object):
"""
if util.IsLoopSwitch(op):
return None
+ if op.graph._building_function: # pylint: disable=protected-access
+ # The optimization here is tricky to apply to functions
+ return array_ops.zeros_like(op.outputs[index])
dead_branch = util.IsSwitch(op)
forward_ctxt = _GetWhileContext(op)
grad_state = self._map.get(forward_ctxt)
@@ -3211,6 +3223,14 @@ def while_loop(cond,
```
"""
+ if ENABLE_WHILE_V2 and not context.executing_eagerly():
+ if not _while_v2:
+ raise ValueError("The while_v2 module is not set. Did you forget to "
+ "import tensorflow.python.ops."
+ "while_v2?")
+ return _while_v2.while_loop(
+ cond, body, loop_vars, shape_invariants=shape_invariants, name=name)
+
with ops.name_scope(name, "while", loop_vars):
if not loop_vars:
raise ValueError("No loop variables provided")
diff --git a/tensorflow/python/ops/control_flow_ops_benchmark.py b/tensorflow/python/ops/control_flow_ops_benchmark.py
new file mode 100644
index 0000000000..9ba5ff2c0f
--- /dev/null
+++ b/tensorflow/python/ops/control_flow_ops_benchmark.py
@@ -0,0 +1,122 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Benchmark for control flow ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import time
+
+from tensorflow.python.client import session
+from tensorflow.python.eager import context
+from tensorflow.python.eager import function
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.platform import test
+
+
+class CondWithManyIntermediatesBenchmark(test.Benchmark):
+ """Checks the runtime performance of outputting all intermediates."""
+
+ NUM_INTERMEDIATES = 1000
+ NUM_ITERS = 500
+ NUM_WARM_UP_ITERS = 50
+
+ def _create_cond(self, x):
+
+ def branch_fn():
+ # Use a random value so the adds can't be constant folded.
+ return x + sum(random_ops.random_normal([])
+ for _ in range(self.NUM_INTERMEDIATES))
+
+ # Use a dynamic predicate to make sure the cond isn't constant folded.
+ return control_flow_ops.cond(math_ops.not_equal(x, -1),
+ branch_fn, lambda: 0.0)
+
+ def _benchmark_defun(self):
+ """Benchmarks cond in a defun."""
+
+ @function.defun
+ def cond_fn(x):
+ return self._create_cond(x)
+
+ # Warm up
+ for _ in range(self.NUM_WARM_UP_ITERS):
+ cond_fn(0.0)
+
+ start_time = time.time()
+
+ for _ in range(self.NUM_ITERS):
+ cond_fn(0.0)
+
+ self.report_benchmark(
+ wall_time=time.time() - start_time,
+ iters=self.NUM_ITERS)
+
+ def _benchmark_graph(self):
+ """Benchmarks cond in legacy graph mode."""
+ with context.graph_mode():
+ with ops.Graph().as_default():
+ x = array_ops.placeholder(dtypes.float32)
+ cond_val = self._create_cond(x)
+
+ with session.Session() as sess:
+ cond_fn = sess.make_callable(cond_val, [x])
+
+ # Warm up
+ for _ in range(self.NUM_WARM_UP_ITERS):
+ cond_fn(0.0)
+
+ start_time = time.time()
+
+ for _ in range(self.NUM_ITERS):
+ cond_fn(0.0)
+
+ self.report_benchmark(
+ wall_time=time.time() - start_time,
+ iters=self.NUM_ITERS)
+
+ def benchmark_cond_v1_defun(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = False
+ self._benchmark_defun()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+ def benchmark_cond_v2_defun(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = True
+ self._benchmark_defun()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+ def benchmark_cond_v1_graph(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = False
+ self._benchmark_graph()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+ def benchmark_cond_v2_graph(self):
+ old_val = control_flow_ops.ENABLE_COND_V2
+ control_flow_ops.ENABLE_COND_V2 = True
+ self._benchmark_graph()
+ control_flow_ops.ENABLE_COND_V2 = old_val
+
+if __name__ == "__main__":
+ ops.enable_eager_execution()
+ test.main()
diff --git a/tensorflow/python/ops/conv2d_benchmark.py b/tensorflow/python/ops/conv2d_benchmark.py
index 28111c2730..f40488afbe 100644
--- a/tensorflow/python/ops/conv2d_benchmark.py
+++ b/tensorflow/python/ops/conv2d_benchmark.py
@@ -63,9 +63,9 @@ def build_graph(device, dtype, data_format, input_shape, filter_shape, strides,
An array of tensors to run()
"""
with ops.device("/%s:0" % device):
- inp = variables.Variable(
+ inp = variables.VariableV1(
random_ops.truncated_normal(input_shape, dtype=dtype))
- filt = variables.Variable(
+ filt = variables.VariableV1(
random_ops.truncated_normal(filter_shape, dtype=dtype))
outputs = []
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 69c0fcbbee..97b6f3bd9c 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -39,6 +39,7 @@ from tensorflow.python.ops import resource_variable_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_data_flow_ops import *
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -112,7 +113,8 @@ def _shape_common(s1, s2):
# pylint: disable=protected-access
-@tf_export("QueueBase")
+@tf_export("io.QueueBase", "QueueBase")
+@deprecation.deprecated_endpoints("QueueBase")
class QueueBase(object):
"""Base class for queue implementations.
@@ -604,7 +606,8 @@ def _shared_name(shared_name):
return shared_name
-@tf_export("RandomShuffleQueue")
+@tf_export("io.RandomShuffleQueue", "RandomShuffleQueue")
+@deprecation.deprecated_endpoints("RandomShuffleQueue")
class RandomShuffleQueue(QueueBase):
"""A queue implementation that dequeues elements in a random order.
@@ -746,7 +749,8 @@ class FIFOQueue(QueueBase):
super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
-@tf_export("PaddingFIFOQueue")
+@tf_export("io.PaddingFIFOQueue", "PaddingFIFOQueue")
+@deprecation.deprecated_endpoints("PaddingFIFOQueue")
class PaddingFIFOQueue(QueueBase):
"""A FIFOQueue that supports batching variable-sized tensors by padding.
@@ -820,7 +824,8 @@ class PaddingFIFOQueue(QueueBase):
super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
-@tf_export("PriorityQueue")
+@tf_export("io.PriorityQueue", "PriorityQueue")
+@deprecation.deprecated_endpoints("PriorityQueue")
class PriorityQueue(QueueBase):
"""A queue implementation that dequeues elements in prioritized order.
@@ -1300,7 +1305,9 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
return out
-@tf_export("SparseConditionalAccumulator")
+@tf_export("sparse.SparseConditionalAccumulator",
+ "SparseConditionalAccumulator")
+@deprecation.deprecated_endpoints("SparseConditionalAccumulator")
class SparseConditionalAccumulator(ConditionalAccumulatorBase):
"""A conditional accumulator for aggregating sparse gradients.
diff --git a/tensorflow/python/ops/distributions/BUILD b/tensorflow/python/ops/distributions/BUILD
index e7ad028376..59ba9aee59 100644
--- a/tensorflow/python/ops/distributions/BUILD
+++ b/tensorflow/python/ops/distributions/BUILD
@@ -12,6 +12,13 @@ py_library(
["*.py"],
exclude = ["util.py"],
),
+ deprecation = ("TensorFlow Distributions has migrated to " +
+ "TensorFlow Probability " +
+ "(https://github.com/tensorflow/probability). " +
+ "Deprecated copies remaining in tf.distributions " +
+ "will not receive new features, and will be removed by " +
+ "early 2019. You should update all usage of " +
+ "`tf.distributions` to `tfp.distributions`."),
srcs_version = "PY2AND3",
deps = [
":util",
diff --git a/tensorflow/python/ops/distributions/bernoulli.py b/tensorflow/python/ops/distributions/bernoulli.py
index 84d9d40a35..baecc321d3 100644
--- a/tensorflow/python/ops/distributions/bernoulli.py
+++ b/tensorflow/python/ops/distributions/bernoulli.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -39,6 +40,14 @@ class Bernoulli(distribution.Distribution):
`1` outcome (vs a `0` outcome).
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
logits=None,
probs=None,
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py
index 2ba1ea6744..51c4f6eb3d 100644
--- a/tensorflow/python/ops/distributions/beta.py
+++ b/tensorflow/python/ops/distributions/beta.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -150,6 +151,14 @@ class Beta(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
concentration1=None,
concentration0=None,
@@ -267,8 +276,8 @@ class Beta(distribution.Distribution):
def _log_unnormalized_prob(self, x):
x = self._maybe_assert_valid_sample(x)
- return ((self.concentration1 - 1.) * math_ops.log(x)
- + (self.concentration0 - 1.) * math_ops.log1p(-x))
+ return (math_ops.xlogy(self.concentration1 - 1., x) +
+ (self.concentration0 - 1.) * math_ops.log1p(-x))
def _log_normalization(self):
return (math_ops.lgamma(self.concentration1)
@@ -341,6 +350,11 @@ class Beta(distribution.Distribution):
class BetaWithSoftplusConcentration(Beta):
"""Beta with softplus transform of `concentration1` and `concentration0`."""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "Use `tfd.Beta(tf.nn.softplus(concentration1), "
+ "tf.nn.softplus(concentration2))` instead.",
+ warn_once=True)
def __init__(self,
concentration1,
concentration0,
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index fbbacf2521..26a3da2fb6 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -149,6 +150,14 @@ class Categorical(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(
self,
logits=None,
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
index 415249a958..675c30b383 100644
--- a/tensorflow/python/ops/distributions/dirichlet.py
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -156,6 +157,14 @@ class Dirichlet(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
concentration,
validate_args=False,
@@ -236,7 +245,7 @@ class Dirichlet(distribution.Distribution):
def _log_unnormalized_prob(self, x):
x = self._maybe_assert_valid_sample(x)
- return math_ops.reduce_sum((self.concentration - 1.) * math_ops.log(x), -1)
+ return math_ops.reduce_sum(math_ops.xlogy(self.concentration - 1., x), -1)
def _log_normalization(self):
return special_math_ops.lbeta(self.concentration)
diff --git a/tensorflow/python/ops/distributions/dirichlet_multinomial.py b/tensorflow/python/ops/distributions/dirichlet_multinomial.py
index 5350c82847..2e3151a5ab 100644
--- a/tensorflow/python/ops/distributions/dirichlet_multinomial.py
+++ b/tensorflow/python/ops/distributions/dirichlet_multinomial.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -163,6 +164,14 @@ class DirichletMultinomial(distribution.Distribution):
# TODO(b/27419586) Change docstring for dtype of concentration once int
# allowed.
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
total_count,
concentration,
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 76d980679e..4741370cd8 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -25,6 +25,7 @@ import types
import numpy as np
import six
+from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
@@ -33,6 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util
+from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -127,6 +129,18 @@ def _update_docstring(old_str, append_str):
return old_str + "\n\n" + append_str
+def _convert_to_tensor(value, name=None, preferred_dtype=None):
+ """Converts to tensor avoiding an eager bug that loses float precision."""
+ # TODO(b/116672045): Remove this function.
+ if (context.executing_eagerly() and preferred_dtype is not None and
+ (preferred_dtype.is_integer or preferred_dtype.is_bool)):
+ v = ops.convert_to_tensor(value, name=name)
+ if v.dtype.is_floating:
+ return v
+ return ops.convert_to_tensor(
+ value, name=name, preferred_dtype=preferred_dtype)
+
+
class _DistributionMeta(abc.ABCMeta):
def __new__(mcs, classname, baseclasses, attrs):
@@ -216,6 +230,14 @@ class ReparameterizationType(object):
gradients / surrogate loss instead.
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self, rep_type):
self._rep_type = rep_type
@@ -392,6 +414,14 @@ class Distribution(_BaseDistribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
dtype,
reparameterization_type,
@@ -741,7 +771,8 @@ class Distribution(_BaseDistribution):
def _call_log_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._log_prob(value, **kwargs)
except NotImplementedError as original_exception:
@@ -769,7 +800,8 @@ class Distribution(_BaseDistribution):
def _call_prob(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._prob(value, **kwargs)
except NotImplementedError as original_exception:
@@ -797,7 +829,8 @@ class Distribution(_BaseDistribution):
def _call_log_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._log_cdf(value, **kwargs)
except NotImplementedError as original_exception:
@@ -835,7 +868,8 @@ class Distribution(_BaseDistribution):
def _call_cdf(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._cdf(value, **kwargs)
except NotImplementedError as original_exception:
@@ -870,7 +904,8 @@ class Distribution(_BaseDistribution):
def _call_log_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._log_survival_function(value, **kwargs)
except NotImplementedError as original_exception:
@@ -909,7 +944,8 @@ class Distribution(_BaseDistribution):
def _call_survival_function(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
try:
return self._survival_function(value, **kwargs)
except NotImplementedError as original_exception:
@@ -963,7 +999,8 @@ class Distribution(_BaseDistribution):
def _call_quantile(self, value, name, **kwargs):
with self._name_scope(name, values=[value]):
- value = ops.convert_to_tensor(value, name="value")
+ value = _convert_to_tensor(
+ value, name="value", preferred_dtype=self.dtype)
return self._quantile(value, **kwargs)
def quantile(self, value, name="quantile"):
diff --git a/tensorflow/python/ops/distributions/exponential.py b/tensorflow/python/ops/distributions/exponential.py
index 4325a14449..6a52af8c33 100644
--- a/tensorflow/python/ops/distributions/exponential.py
+++ b/tensorflow/python/ops/distributions/exponential.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import gamma
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -70,6 +71,14 @@ class Exponential(gamma.Gamma):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
rate,
validate_args=False,
@@ -114,6 +123,9 @@ class Exponential(gamma.Gamma):
def rate(self):
return self._rate
+ def _log_survival_function(self, value):
+ return self._log_prob(value) - math_ops.log(self._rate)
+
def _sample_n(self, n, seed=None):
shape = array_ops.concat([[n], array_ops.shape(self._rate)], 0)
# Uniform variates must be sampled from the open-interval `(0, 1)` rather
@@ -135,6 +147,10 @@ class Exponential(gamma.Gamma):
class ExponentialWithSoftplusRate(Exponential):
"""Exponential with softplus transform on `rate`."""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "Use `tfd.Exponential(tf.nn.softplus(rate)).",
+ warn_once=True)
def __init__(self,
rate,
validate_args=False,
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py
index 3293cda874..4a2db208d4 100644
--- a/tensorflow/python/ops/distributions/gamma.py
+++ b/tensorflow/python/ops/distributions/gamma.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -121,6 +122,14 @@ class Gamma(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
concentration,
rate,
@@ -225,7 +234,7 @@ class Gamma(distribution.Distribution):
def _log_unnormalized_prob(self, x):
x = self._maybe_assert_valid_sample(x)
- return (self.concentration - 1.) * math_ops.log(x) - self.rate * x
+ return math_ops.xlogy(self.concentration - 1., x) - self.rate * x
def _log_normalization(self):
return (math_ops.lgamma(self.concentration)
@@ -279,6 +288,11 @@ class Gamma(distribution.Distribution):
class GammaWithSoftplusConcentrationRate(Gamma):
"""`Gamma` with softplus of `concentration` and `rate`."""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "Use `tfd.Gamma(tf.nn.softplus(concentration), "
+ "tf.nn.softplus(rate))` instead.",
+ warn_once=True)
def __init__(self,
concentration,
rate,
diff --git a/tensorflow/python/ops/distributions/identity_bijector.py b/tensorflow/python/ops/distributions/identity_bijector.py
index 8628e68f96..eded96f5bc 100644
--- a/tensorflow/python/ops/distributions/identity_bijector.py
+++ b/tensorflow/python/ops/distributions/identity_bijector.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -43,6 +44,14 @@ class Identity(bijector.Bijector):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self, validate_args=False, name="identity"):
super(Identity, self).__init__(
forward_min_event_ndims=0,
diff --git a/tensorflow/python/ops/distributions/kullback_leibler.py b/tensorflow/python/ops/distributions/kullback_leibler.py
index fdeb97bf64..12743fa23d 100644
--- a/tensorflow/python/ops/distributions/kullback_leibler.py
+++ b/tensorflow/python/ops/distributions/kullback_leibler.py
@@ -22,6 +22,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -51,6 +52,14 @@ def _registered_kl(type_a, type_b):
return kl_fn
+@deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
@tf_export("distributions.kl_divergence")
def kl_divergence(distribution_a, distribution_b,
allow_nan_stats=True, name=None):
@@ -112,6 +121,14 @@ def kl_divergence(distribution_a, distribution_b,
return array_ops.identity(kl_t, name="checked_kl")
+@deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def cross_entropy(ref, other,
allow_nan_stats=True, name=None):
"""Computes the (Shannon) cross entropy.
@@ -155,6 +172,14 @@ class RegisterKL(object):
# Return KL(norm_a || norm_b)
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self, dist_cls_a, dist_cls_b):
"""Initialize the KL registrar.
diff --git a/tensorflow/python/ops/distributions/laplace.py b/tensorflow/python/ops/distributions/laplace.py
index be17cf2527..4f6a8f587d 100644
--- a/tensorflow/python/ops/distributions/laplace.py
+++ b/tensorflow/python/ops/distributions/laplace.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import special_math
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -71,6 +72,14 @@ class Laplace(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
loc,
scale,
@@ -211,6 +220,11 @@ class Laplace(distribution.Distribution):
class LaplaceWithSoftplusScale(Laplace):
"""Laplace with softplus applied to `scale`."""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "Use `tfd.Laplace(loc, tf.nn.softplus(scale)) "
+ "instead.",
+ warn_once=True)
def __init__(self,
loc,
scale,
diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py
index d0943e8eee..8397353cd5 100644
--- a/tensorflow/python/ops/distributions/multinomial.py
+++ b/tensorflow/python/ops/distributions/multinomial.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -148,6 +149,14 @@ class Multinomial(distribution.Distribution):
```
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
total_count,
logits=None,
diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py
index 2feaf806c0..9f511709b9 100644
--- a/tensorflow/python/ops/distributions/normal.py
+++ b/tensorflow/python/ops/distributions/normal.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import special_math
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -106,6 +107,14 @@ class Normal(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
loc,
scale,
@@ -240,6 +249,11 @@ class Normal(distribution.Distribution):
class NormalWithSoftplusScale(Normal):
"""Normal with softplus applied to `scale`."""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "Use `tfd.Normal(loc, tf.nn.softplus(scale)) "
+ "instead.",
+ warn_once=True)
def __init__(self,
loc,
scale,
diff --git a/tensorflow/python/ops/distributions/special_math.py b/tensorflow/python/ops/distributions/special_math.py
index 31b7a36fd3..ccc667cae3 100644
--- a/tensorflow/python/ops/distributions/special_math.py
+++ b/tensorflow/python/ops/distributions/special_math.py
@@ -12,6 +12,62 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+
+# Functions "ndtr" and "ndtri" are derived from calculations made in:
+# https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
+# In the following email exchange, the author gives his consent to redistribute
+# derived works under an Apache 2.0 license.
+#
+# From: Stephen Moshier <steve@moshier.net>
+# Date: Sat, Jun 9, 2018 at 2:36 PM
+# Subject: Re: Licensing cephes under Apache (BSD-like) license.
+# To: rif <rif@google.com>
+#
+#
+#
+# Hello Rif,
+#
+# Yes, Google may distribute Cephes files under the Apache 2 license.
+#
+# If clarification is needed, I do not favor BSD over other free licenses.
+# I would agree that Apache 2 seems to cover the concern you mentioned
+# about sublicensees.
+#
+# Best wishes for good luck with your projects!
+# Steve Moshier
+#
+#
+#
+# On Thu, 31 May 2018, rif wrote:
+#
+# > Hello Steve.
+# > My name is Rif. I work on machine learning software at Google.
+# >
+# > Your cephes software continues to be incredibly useful and widely used. I
+# > was wondering whether it would be permissible for us to use the Cephes code
+# > under the Apache 2.0 license, which is extremely similar in permissions to
+# > the BSD license (Wikipedia comparisons). This would be quite helpful to us
+# > in terms of avoiding multiple licenses on software.
+# >
+# > I'm sorry to bother you with this (I can imagine you're sick of hearing
+# > about this by now), but I want to be absolutely clear we're on the level and
+# > not misusing your important software. In former conversation with Eugene
+# > Brevdo (ebrevdo@google.com), you wrote "If your licensing is similar to BSD,
+# > the formal way that has been handled is simply to add a statement to the
+# > effect that you are incorporating the Cephes software by permission of the
+# > author." I wanted to confirm that (a) we could use the Apache license, (b)
+# > that we don't need to (and probably you don't want to) keep getting
+# > contacted about individual uses, because your intent is generally to allow
+# > this software to be reused under "BSD-like" license, and (c) you're OK
+# > letting incorporators decide whether a license is sufficiently BSD-like?
+# >
+# > Best,
+# >
+# > rif
+# >
+# >
+# >
+
"""Special Math Ops."""
from __future__ import absolute_import
@@ -135,7 +191,7 @@ def _ndtri(p):
# Constants used in piece-wise rational approximations. Taken from the cephes
# library:
- # https://github.com/scipy/scipy/blob/master/scipy/special/cephes/ndtri.c
+ # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
p0 = list(reversed([-5.99633501014107895267E1,
9.80010754185999661536E1,
-5.66762857469070293439E1,
@@ -305,7 +361,8 @@ def log_ndtr(x, series_order=3, name="log_ndtr"):
else:
raise TypeError("x.dtype=%s is not supported." % x.dtype)
- # The basic idea here was ported from py/scipy/special/cephes/ndtr.c.
+ # The basic idea here was ported from:
+ # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
# We copy the main idea, with a few changes
# * For x >> 1, and X ~ Normal(0, 1),
# Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x],
diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py
index e8d214bbe0..b69e61925c 100644
--- a/tensorflow/python/ops/distributions/student_t.py
+++ b/tensorflow/python/ops/distributions/student_t.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -140,6 +141,14 @@ class StudentT(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
df,
loc,
@@ -361,6 +370,11 @@ class StudentT(distribution.Distribution):
class StudentTWithAbsDfSoftplusScale(StudentT):
"""StudentT with `df = floor(abs(df))` and `scale = softplus(scale)`."""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "Use `tfd.StudentT(tf.floor(tf.abs(df)), loc, "
+ "tf.nn.softplus(scale)) instead.",
+ warn_once=True)
def __init__(self,
df,
loc,
diff --git a/tensorflow/python/ops/distributions/transformed_distribution.py b/tensorflow/python/ops/distributions/transformed_distribution.py
index e80bf9ee42..1becfc1877 100644
--- a/tensorflow/python/ops/distributions/transformed_distribution.py
+++ b/tensorflow/python/ops/distributions/transformed_distribution.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import identity_bijector
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
__all__ = [
"TransformedDistribution",
@@ -227,6 +228,14 @@ class TransformedDistribution(distribution_lib.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
distribution,
bijector=None,
diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py
index e66c4a37e7..b6b24187cc 100644
--- a/tensorflow/python/ops/distributions/uniform.py
+++ b/tensorflow/python/ops/distributions/uniform.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -76,6 +77,14 @@ class Uniform(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
low=0.,
high=1.,
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 60d73a1693..6263041b8d 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -550,11 +550,9 @@ def safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
- if not isinstance(embedding_weights[0],
- resource_variable_ops.ResourceVariable):
- embedding_weights = [
- ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
- ]
+ embedding_weights = [
+ ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
+ ]
with ops.name_scope(name, 'embedding_lookup',
embedding_weights + [sparse_ids,
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index 1dc666e78b..794465b10e 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -25,4 +25,5 @@ from tensorflow.python.ops.custom_gradient import custom_gradient
from tensorflow.python.ops.gradients_impl import AggregationMethod
from tensorflow.python.ops.gradients_impl import gradients
from tensorflow.python.ops.gradients_impl import hessians
+from tensorflow.python.ops.gradients_impl import UnconnectedGradients
# pylint: enable=unused-import
diff --git a/tensorflow/python/ops/gradients_impl.py b/tensorflow/python/ops/gradients_impl.py
index 056015d6b6..aac95037dc 100644
--- a/tensorflow/python/ops/gradients_impl.py
+++ b/tensorflow/python/ops/gradients_impl.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import collections
import contextlib
+import enum # pylint: disable=g-bad-import-order
import sys
import warnings
@@ -537,6 +538,26 @@ def _Consumers(t, func_graphs):
return consumers
+@tf_export("UnconnectedGradients")
+class UnconnectedGradients(enum.Enum):
+ """Controls how gradient computation behaves when y does not depend on x.
+
+ The gradient of y with respect to x can be zero in two different ways: there
+ could be no differentiable path in the graph connecting x to y (and so we can
+ statically prove that the gradient is zero) or it could be that runtime values
+ of tensors in a particular execution lead to a gradient of zero (say, if a
+ relu unit happens to not be activated). To allow you to distinguish between
+ these two cases you can choose what value gets returned for the gradient when
+ there is no path in the graph from x to y:
+
+ * `NONE`: Indicates that [None] will be returned if there is no path from x
+ to y
+ * `ZERO`: Indicates that a zero tensor will be returned in the shape of x.
+ """
+ NONE = "none"
+ ZERO = "zero"
+
+
@tf_export("gradients")
def gradients(ys,
xs,
@@ -545,7 +566,8 @@ def gradients(ys,
colocate_gradients_with_ops=False,
gate_gradients=False,
aggregation_method=None,
- stop_gradients=None):
+ stop_gradients=None,
+ unconnected_gradients=UnconnectedGradients.NONE):
"""Constructs symbolic derivatives of sum of `ys` w.r.t. x in `xs`.
`ys` and `xs` are each a `Tensor` or a list of tensors. `grad_ys`
@@ -596,6 +618,23 @@ def gradients(ys,
All integer tensors are considered constant with respect to all `xs`, as if
they were included in `stop_gradients`.
+ `unconnected_gradients` determines the value returned for each x in xs if it
+ is unconnected in the graph to ys. By default this is None to safeguard
+ against errors. MAthematically these gradients are zero which can be requested
+ using the `'zero'` option. `tf.UnconnectedGradients` provides the
+ following options and behaviors:
+
+ ```python
+ a = tf.ones([1, 2])
+ b = tf.ones([3, 1])
+ g1 = tf.gradients([b], [a], unnconnected_gradients='none')
+ sess.run(g1) # [None]
+
+ g2 = tf.gradients([b], [a], unconnected_gradients='zero')
+ sess.run(g2) # [array([[0., 0.]], dtype=float32)]
+ ```
+
+
Args:
ys: A `Tensor` or list of tensors to be differentiated.
xs: A `Tensor` or list of tensors to be used for differentiation.
@@ -611,6 +650,10 @@ def gradients(ys,
Accepted values are constants defined in the class `AggregationMethod`.
stop_gradients: Optional. A `Tensor` or list of tensors not to differentiate
through.
+ unconnected_gradients: Optional. Specifies the gradient value returned when
+ the given input tensors are unconnected. Accepted values are constants
+ defined in the class `tf.UnconnectedGradients` and the default value is
+ `none`.
Returns:
A list of `sum(dy/dx)` for each x in `xs`.
@@ -627,7 +670,8 @@ def gradients(ys,
# mutating new ops.
with ops.get_default_graph()._mutation_lock(): # pylint: disable=protected-access
return _GradientsHelper(ys, xs, grad_ys, name, colocate_gradients_with_ops,
- gate_gradients, aggregation_method, stop_gradients)
+ gate_gradients, aggregation_method, stop_gradients,
+ unconnected_gradients)
def _GradientsHelper(ys,
@@ -638,6 +682,7 @@ def _GradientsHelper(ys,
gate_gradients=False,
aggregation_method=None,
stop_gradients=None,
+ unconnected_gradients=UnconnectedGradients.NONE,
src_graph=None):
"""Implementation of gradients()."""
if context.executing_eagerly():
@@ -645,6 +690,11 @@ def _GradientsHelper(ys,
"is enabled. Use tf.GradientTape instead.")
if src_graph is None:
src_graph = ops.get_default_graph()
+ try:
+ unconnected_gradients = UnconnectedGradients(unconnected_gradients)
+ except ValueError:
+ raise ValueError(
+ "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
# If src_graph is a _FuncGraph (i.e. a function body), gather it and all
# ancestor graphs. This is necessary for correctly handling captured values.
@@ -856,7 +906,7 @@ def _GradientsHelper(ys,
if loop_state:
loop_state.PostProcessing()
- return [_GetGrad(grads, x) for x in xs]
+ return [_GetGrad(grads, x, unconnected_gradients) for x in xs]
def _HasAnyNotNoneGrads(grads, op):
@@ -924,12 +974,19 @@ def _SetGrad(grads, t, grad):
op_grads[t.value_index] = grad
-def _GetGrad(grads, t):
+def _GetGrad(grads, t, unconnected_gradients):
"""Gets gradient for tensor "t"."""
op = t.op
op_grads = grads.get(op)
if not op_grads:
- return None
+ if unconnected_gradients == UnconnectedGradients.ZERO:
+ return array_ops.zeros_like(t)
+ elif unconnected_gradients == UnconnectedGradients.NONE:
+ return None
+ else:
+ raise ValueError(
+ "Unknown value for unconnected_gradients: %r" % unconnected_gradients)
+
t_grad = op_grads[t.value_index]
assert not isinstance(
t_grad, list), ("gradients list should have been aggregated by now.")
diff --git a/tensorflow/python/ops/gradients_test.py b/tensorflow/python/ops/gradients_test.py
index 4f6e5dc473..c93e2493ee 100644
--- a/tensorflow/python/ops/gradients_test.py
+++ b/tensorflow/python/ops/gradients_test.py
@@ -273,7 +273,7 @@ class GradientsTest(test_util.TensorFlowTestCase):
def testVariableRefGradient(self):
with ops.Graph().as_default():
init = constant_op.constant(100.0)
- var = variables.Variable(init)
+ var = variables.VariableV1(init)
gradient = gradients.gradients(var._ref(), var)
self.assertIsNotNone(gradient)
@@ -350,6 +350,40 @@ class GradientsTest(test_util.TensorFlowTestCase):
for a, b in zip(npgrad1, npgrad2):
np.testing.assert_allclose(a, b)
+ def testUnconnectedGradientsNoneUnconnectedGradients(self):
+ with ops.Graph().as_default():
+ x = constant(1.0, shape=[2, 2])
+ y = constant(3.0, shape=[3, 1])
+ grad = gradients.gradients(
+ [y], [x], unconnected_gradients="none")
+ self.assertIsNone(grad[0])
+
+ def testUnconnectedGradientsZerosUnconnectedGradients(self):
+ with ops.Graph().as_default():
+ x = constant(1.0, shape=[2, 2])
+ y = constant(3.0, shape=[3, 1])
+ grads = gradients.gradients(
+ [y], [x], unconnected_gradients="zero")
+ with self.cached_session() as sess:
+ self.assertAllEqual([[0.0, 0.0], [0.0, 0.0]], sess.run(grads)[0])
+
+ def testUnconnectedGradientsZeroConnectedGradients(self):
+ with ops.Graph().as_default():
+ x = constant(1.0)
+ y = x * 3.0
+ grad = gradients.gradients(
+ [y], [x], unconnected_gradients="zero")
+ with self.cached_session() as sess:
+ self.assertEquals(3.0, sess.run(grad)[0])
+
+ def testUnknownUnconnectedGradientsValueGiven(self):
+ with ops.Graph().as_default():
+ x = constant(1.0)
+ y = constant(1.0)
+ with self.assertRaisesRegexp(
+ ValueError, "Unknown value for unconnected_gradients: 'nonsense'"):
+ gradients.gradients([y], [x], unconnected_gradients="nonsense")
+
class FunctionGradientsTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index fff3d9b930..65bb77b474 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -43,6 +43,7 @@ from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import linalg_ops_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.util import deprecation
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.deprecation import deprecated_arg_values
from tensorflow.python.util.tf_export import tf_export
@@ -341,6 +342,7 @@ class TruncatedNormal(Initializer):
@tf_export("initializers.uniform_unit_scaling",
"uniform_unit_scaling_initializer")
+@deprecation.deprecated_endpoints("uniform_unit_scaling_initializer")
class UniformUnitScaling(Initializer):
"""Initializer that generates tensors without scaling variance.
@@ -401,6 +403,7 @@ class UniformUnitScaling(Initializer):
@tf_export("keras.initializers.VarianceScaling",
"initializers.variance_scaling", "variance_scaling_initializer")
+@deprecation.deprecated_endpoints("variance_scaling_initializer")
class VarianceScaling(Initializer):
"""Initializer capable of adapting its scale to the shape of weights tensors.
@@ -494,6 +497,7 @@ class VarianceScaling(Initializer):
@tf_export("keras.initializers.Orthogonal", "initializers.orthogonal",
"orthogonal_initializer", "keras.initializers.orthogonal")
+@deprecation.deprecated_endpoints("orthogonal_initializer")
class Orthogonal(Initializer):
"""Initializer that generates an orthogonal matrix.
@@ -1149,6 +1153,7 @@ class GlorotUniform(VarianceScaling):
@tf_export("glorot_normal_initializer", "keras.initializers.glorot_normal",
"initializers.glorot_normal")
+@deprecation.deprecated_endpoints("glorot_normal_initializer")
class GlorotNormal(VarianceScaling):
"""The Glorot normal initializer, also called Xavier normal initializer.
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index f4a93560be..bf4354fa73 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -80,6 +80,7 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind):
@tf_export('cholesky_solve', 'linalg.cholesky_solve')
+@deprecation.deprecated_endpoints('cholesky_solve')
def cholesky_solve(chol, rhs, name=None):
"""Solves systems of linear eqns `A X = RHS`, given Cholesky factorizations.
@@ -167,7 +168,8 @@ def eye(num_rows,
name=name)
-@tf_export('matrix_solve_ls', 'linalg.lstsq')
+@tf_export('linalg.lstsq', 'matrix_solve_ls')
+@deprecation.deprecated_endpoints('matrix_solve_ls')
def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
r"""Solves one or more linear least-squares problems.
@@ -220,7 +222,7 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
squares sense.
Raises:
- NotImplementedError: matrix_solve_ls is currently disabled for complex128
+ NotImplementedError: linalg.lstsq is currently disabled for complex128
and l2_regularizer != 0 due to poor accuracy.
"""
@@ -303,7 +305,8 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
matrix, rhs, l2_regularizer, fast=fast, name=name)
-@tf_export('self_adjoint_eig', 'linalg.eigh')
+@tf_export('linalg.eigh', 'self_adjoint_eig')
+@deprecation.deprecated_endpoints('self_adjoint_eig')
def self_adjoint_eig(tensor, name=None):
"""Computes the eigen decomposition of a batch of self-adjoint matrices.
@@ -325,12 +328,13 @@ def self_adjoint_eig(tensor, name=None):
return e, v
-@tf_export('self_adjoint_eigvals', 'linalg.eigvalsh')
+@tf_export('linalg.eigvalsh', 'self_adjoint_eigvals')
+@deprecation.deprecated_endpoints('self_adjoint_eigvals')
def self_adjoint_eigvals(tensor, name=None):
"""Computes the eigenvalues of one or more self-adjoint matrices.
Note: If your program backpropagates through this function, you should replace
- it with a call to tf.self_adjoint_eig (possibly ignoring the second output) to
+ it with a call to tf.linalg.eigvalsh (possibly ignoring the second output) to
avoid computing the eigen decomposition twice. This is because the
eigenvectors are used to compute the gradient w.r.t. the eigenvalues. See
_SelfAdjointEigV2Grad in linalg_grad.py.
@@ -348,6 +352,7 @@ def self_adjoint_eigvals(tensor, name=None):
@tf_export('svd', 'linalg.svd')
+@deprecation.deprecated_endpoints('svd')
def svd(tensor, full_matrices=False, compute_uv=True, name=None):
r"""Computes the singular value decompositions of one or more matrices.
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index 5443699ddd..cffaa983d4 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -59,7 +59,7 @@ def initialize_all_tables(name="init_all_tables"):
return tables_initializer(name)
-@tf_export("tables_initializer")
+@tf_export("initializers.tables_initializer", "tables_initializer")
def tables_initializer(name="init_all_tables"):
"""Returns an Op that initializes all tables of the default graph.
diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py
index 6633565a64..d9d0728287 100644
--- a/tensorflow/python/ops/manip_ops.py
+++ b/tensorflow/python/ops/manip_ops.py
@@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
-@tf_export('manip.roll')
+@tf_export('roll', 'manip.roll')
+@deprecation.deprecated_endpoints('manip.roll')
def roll(input, shift, axis): # pylint: disable=redefined-builtin
return _gen_manip_ops.roll(input, shift, axis)
diff --git a/tensorflow/python/ops/math_grad.py b/tensorflow/python/ops/math_grad.py
index 8e11c4bce1..35278d9680 100644
--- a/tensorflow/python/ops/math_grad.py
+++ b/tensorflow/python/ops/math_grad.py
@@ -516,6 +516,40 @@ def _Log1pGrad(op, grad):
return grad * math_ops.reciprocal(1 + x)
+@ops.RegisterGradient("Xlogy")
+def _XLogyGrad(op, grad):
+ """Returns gradient of xlogy(x, y) with respect to x and y."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+ with ops.control_dependencies([grad]):
+ not_zero_x = math_ops.cast(
+ math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
+ partial_x = gen_math_ops.xlogy(not_zero_x, y)
+ partial_y = gen_math_ops.xdivy(x, y)
+ return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
+
+
+@ops.RegisterGradient("Xdivy")
+def _XDivyGrad(op, grad):
+ """Returns gradient of xdivy(x, y) with respect to x and y."""
+ x = op.inputs[0]
+ y = op.inputs[1]
+ sx = array_ops.shape(x)
+ sy = array_ops.shape(y)
+ rx, ry = gen_array_ops.broadcast_gradient_args(sx, sy)
+ with ops.control_dependencies([grad]):
+ not_zero_x = math_ops.cast(
+ math_ops.not_equal(x, math_ops.cast(0., dtype=x.dtype)), dtype=x.dtype)
+ partial_x = gen_math_ops.xdivy(not_zero_x, y)
+ partial_y = gen_math_ops.xdivy(math_ops.negative(x), y**2)
+ return (array_ops.reshape(math_ops.reduce_sum(partial_x * grad, rx), sx),
+ array_ops.reshape(math_ops.reduce_sum(partial_y * grad, ry), sy))
+
+
@ops.RegisterGradient("Sinh")
def _SinhGrad(op, grad):
"""Returns grad * cosh(x)."""
diff --git a/tensorflow/python/ops/math_grad_test.py b/tensorflow/python/ops/math_grad_test.py
index 7110e0958c..9cfb050942 100644
--- a/tensorflow/python/ops/math_grad_test.py
+++ b/tensorflow/python/ops/math_grad_test.py
@@ -256,5 +256,93 @@ class DivNoNanGradientTest(test.TestCase):
self.assertAllClose(dy.eval(), np.zeros(y.shape.as_list()))
+class XlogyTest(test.TestCase):
+
+ def _xlogy_gradients(self, x, y):
+ xlogy_xgrad = self.evaluate(gradients.gradients(math_ops.xlogy(x, y), x)[0])
+ xlogy_ygrad = self.evaluate(gradients.gradients(math_ops.xlogy(x, y), y)[0])
+ return xlogy_xgrad, xlogy_ygrad
+
+ def testNonZeroValuesGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ xlogy_expected_xgrad = self.evaluate(math_ops.log(y))
+ xlogy_expected_ygrad = self.evaluate(x / y)
+ self.assertAllClose(xlogy_expected_xgrad, xlogy_xgrad)
+ self.assertAllClose(xlogy_expected_ygrad, xlogy_ygrad)
+
+ def testZeroXGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xlogy_xgrad)
+ self.assertAllClose(zero, xlogy_ygrad)
+
+ def testZeroYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ self.assertAllClose(-np.inf, xlogy_xgrad)
+ self.assertAllClose(np.inf, xlogy_ygrad)
+
+ def testZeroXYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xlogy_xgrad, xlogy_ygrad = self._xlogy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xlogy_xgrad)
+ self.assertAllClose(zero, xlogy_ygrad)
+
+
+class XdivyTest(test.TestCase):
+
+ def _xdivy_gradients(self, x, y):
+ xdivy_xgrad = self.evaluate(gradients.gradients(math_ops.xdivy(x, y), x)[0])
+ xdivy_ygrad = self.evaluate(gradients.gradients(math_ops.xdivy(x, y), y)[0])
+ return xdivy_xgrad, xdivy_ygrad
+
+ def testNonZeroValuesGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ xdivy_expected_xgrad = self.evaluate(1 / y)
+ xdivy_expected_ygrad = self.evaluate(-x / y**2)
+ self.assertAllClose(xdivy_expected_xgrad, xdivy_xgrad)
+ self.assertAllClose(xdivy_expected_ygrad, xdivy_ygrad)
+
+ def testZeroXGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(3.1, dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xdivy_xgrad)
+ self.assertAllClose(zero, xdivy_ygrad)
+
+ def testZeroYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0.1, dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ self.assertAllClose(np.inf, xdivy_xgrad)
+ self.assertAllClose(-np.inf, xdivy_ygrad)
+
+ def testZeroXYGrad(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(0., dtype=dtype)
+ y = constant_op.constant(0., dtype=dtype)
+ xdivy_xgrad, xdivy_ygrad = self._xdivy_gradients(x, y)
+ zero = self.evaluate(x)
+ self.assertAllClose(zero, xdivy_xgrad)
+ self.assertAllClose(zero, xdivy_ygrad)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index f57abf6704..83b8b5a3a4 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -70,7 +70,7 @@ def _set_doc(doc):
# pylint: disable=redefined-builtin
-@tf_export("argmax")
+@tf_export("math.argmax", "argmax")
@deprecation.deprecated_args(None, "Use the `axis` argument instead",
"dimension")
@_set_doc(
@@ -88,7 +88,7 @@ def argmax(input,
return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type)
-@tf_export("argmin")
+@tf_export("math.argmin", "argmin")
@deprecation.deprecated_args(None, "Use the `axis` argument instead",
"dimension")
@_set_doc(
@@ -111,7 +111,7 @@ def argmin(input,
# pylint: disable=anomalous-backslash-in-string,protected-access
# pylint: disable=g-docstring-has-escape
-@tf_export("abs")
+@tf_export("math.abs", "abs")
def abs(x, name=None): # pylint: disable=redefined-builtin
r"""Computes the absolute value of a tensor.
@@ -186,7 +186,7 @@ class DivideDelegateWithName(object):
return _div_python2(self.x, y, self.name)
-@tf_export("divide")
+@tf_export("math.divide", "divide")
def divide(x, y, name=None):
"""Computes Python style division of `x` by `y`."""
@@ -198,7 +198,7 @@ def divide(x, y, name=None):
return x / y
-@tf_export("multiply")
+@tf_export("math.multiply", "multiply")
def multiply(x, y, name=None):
return gen_math_ops.mul(x, y, name)
@@ -218,7 +218,7 @@ _mul.__doc__ = (
gen_math_ops.mul.__doc__ + ("" if _mul.__doc__ is None else _mul.__doc__))
-@tf_export("subtract")
+@tf_export("math.subtract", "subtract")
def subtract(x, y, name=None):
return gen_math_ops.sub(x, y, name)
@@ -239,7 +239,7 @@ _sub.__doc__ = (
# pylint: disable=g-docstring-has-escape
-@tf_export("negative")
+@tf_export("math.negative", "negative")
def negative(x, name=None):
"""Computes numerical negative value element-wise.
@@ -288,7 +288,7 @@ def _neg(x, name=None):
# pylint: enable=g-docstring-has-escape
-@tf_export("sign")
+@tf_export("math.sign", "sign")
def sign(x, name=None):
"""Returns an element-wise indication of the sign of a number.
@@ -319,7 +319,7 @@ def sign(x, name=None):
return gen_math_ops.sign(x, name=name)
-@tf_export("square")
+@tf_export("math.square", "square")
def square(x, name=None):
r"""Computes square of x element-wise.
@@ -342,7 +342,7 @@ def square(x, name=None):
return gen_math_ops.square(x, name=name)
-@tf_export("sqrt")
+@tf_export("math.sqrt", "sqrt")
def sqrt(x, name=None):
r"""Computes square root of x element-wise.
@@ -365,7 +365,8 @@ def sqrt(x, name=None):
return gen_math_ops.sqrt(x, name=name)
-@tf_export("erf")
+@tf_export("math.erf", "erf")
+@deprecation.deprecated_endpoints("erf")
def erf(x, name=None):
"""Computes the Gauss error function of `x` element-wise.
@@ -386,7 +387,7 @@ def erf(x, name=None):
return gen_math_ops.erf(x, name=name)
-@tf_export("scalar_mul")
+@tf_export("math.scalar_mul", "scalar_mul")
def scalar_mul(scalar, x):
"""Multiplies a scalar times a `Tensor` or `IndexedSlices` object.
@@ -416,7 +417,7 @@ def scalar_mul(scalar, x):
raise ValueError("Only scalar multiply works, got shape %s" % shape)
-@tf_export("pow")
+@tf_export("math.pow", "pow")
def pow(x, y, name=None): # pylint: disable=redefined-builtin
r"""Computes the power of one value to another.
@@ -444,7 +445,7 @@ def pow(x, y, name=None): # pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin,redefined-outer-name
-@tf_export("complex")
+@tf_export("dtypes.complex", "complex")
def complex(real, imag, name=None):
r"""Converts two real numbers to a complex number.
@@ -486,7 +487,8 @@ def complex(real, imag, name=None):
return gen_math_ops._complex(real, imag, Tout=Tout, name=name)
-@tf_export("real")
+@tf_export("math.real", "real")
+@deprecation.deprecated_endpoints("real")
def real(input, name=None):
r"""Returns the real part of a complex (or real) tensor.
@@ -517,7 +519,8 @@ def real(input, name=None):
return input
-@tf_export("imag")
+@tf_export("math.imag", "imag")
+@deprecation.deprecated_endpoints("imag")
def imag(input, name=None):
r"""Returns the imaginary part of a complex (or real) tensor.
@@ -547,7 +550,8 @@ def imag(input, name=None):
return array_ops.zeros_like(input)
-@tf_export("angle")
+@tf_export("math.angle", "angle")
+@deprecation.deprecated_endpoints("angle")
def angle(input, name=None):
r"""Returns the element-wise argument of a complex (or real) tensor.
@@ -586,7 +590,7 @@ def angle(input, name=None):
# pylint: enable=redefined-outer-name,redefined-builtin
-@tf_export("round")
+@tf_export("math.round", "round")
def round(x, name=None): # pylint: disable=redefined-builtin
"""Rounds the values of a tensor to the nearest integer, element-wise.
@@ -613,7 +617,7 @@ def round(x, name=None): # pylint: disable=redefined-builtin
return gen_math_ops.round(x, name=name)
-@tf_export("cast")
+@tf_export("dtypes.cast", "cast")
def cast(x, dtype, name=None):
"""Casts a tensor to a new type.
@@ -676,7 +680,7 @@ def cast(x, dtype, name=None):
return x
-@tf_export("saturate_cast")
+@tf_export("dtypes.saturate_cast", "saturate_cast")
def saturate_cast(value, dtype, name=None):
"""Performs a safe saturating cast of `value` to `dtype`.
@@ -995,7 +999,7 @@ def _div_python2(x, y, name=None):
return gen_math_ops.floor_div(x, y, name=name)
-@tf_export("truediv")
+@tf_export("math.truediv", "truediv")
def truediv(x, y, name=None):
"""Divides x / y elementwise (using Python 3 division operator semantics).
@@ -1006,7 +1010,7 @@ def truediv(x, y, name=None):
arguments are cast to floating types first. This op is generated by normal
`x / y` division in Python 3 and in Python 2.7 with
`from __future__ import division`. If you want integer division that rounds
- down, use `x // y` or `tf.floordiv`.
+ down, use `x // y` or `tf.math.floordiv`.
`x` and `y` must have the same numeric type. If the inputs are floating
point, the output will have the same type. If the inputs are integral, the
@@ -1078,7 +1082,8 @@ mod = gen_math_ops.floor_mod
# TODO(aselle): Deprecate this once all internal functionality uses
# tf.truncatediv
-@tf_export("floordiv")
+@tf_export("math.floordiv", "floordiv")
+@deprecation.deprecated_endpoints("floordiv")
def floordiv(x, y, name=None):
"""Divides `x / y` elementwise, rounding toward the most negative integer.
@@ -1151,7 +1156,8 @@ _OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod")
_OverrideBinaryOperatorHelper(pow, "pow")
-@tf_export("logical_xor")
+@tf_export("math.logical_xor", "logical_xor")
+@deprecation.deprecated_endpoints("logical_xor")
def logical_xor(x, y, name="LogicalXor"):
"""x ^ y = (x | y) & ~(x & y)."""
# TODO(alemi) Make this a cwise op if people end up relying on it.
@@ -1277,7 +1283,7 @@ def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output):
return output
-@tf_export("reduce_sum")
+@tf_export("math.reduce_sum", "reduce_sum")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_sum(input_tensor,
@@ -1339,7 +1345,7 @@ def reduce_sum(input_tensor,
name=name))
-@tf_export("count_nonzero")
+@tf_export("math.count_nonzero", "count_nonzero")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def count_nonzero(input_tensor,
@@ -1417,7 +1423,7 @@ def count_nonzero(input_tensor,
dtype=dtype)
-@tf_export("reduce_mean")
+@tf_export("math.reduce_mean", "reduce_mean")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_mean(input_tensor,
@@ -1489,7 +1495,7 @@ def reduce_mean(input_tensor,
name=name))
-@tf_export("reduce_prod")
+@tf_export("math.reduce_prod", "reduce_prod")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_prod(input_tensor,
@@ -1539,7 +1545,7 @@ def reduce_prod(input_tensor,
name=name))
-@tf_export("reduce_min")
+@tf_export("math.reduce_min", "reduce_min")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_min(input_tensor,
@@ -1588,7 +1594,7 @@ def reduce_min(input_tensor,
name=name))
-@tf_export("reduce_max")
+@tf_export("math.reduce_max", "reduce_max")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_max(input_tensor,
@@ -1637,7 +1643,7 @@ def reduce_max(input_tensor,
name=name))
-@tf_export("reduce_all")
+@tf_export("math.reduce_all", "reduce_all")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_all(input_tensor,
@@ -1695,7 +1701,7 @@ def reduce_all(input_tensor,
name=name))
-@tf_export("reduce_any")
+@tf_export("math.reduce_any", "reduce_any")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_any(input_tensor,
@@ -1753,7 +1759,7 @@ def reduce_any(input_tensor,
name=name))
-@tf_export("reduce_logsumexp")
+@tf_export("math.reduce_logsumexp", "reduce_logsumexp")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_logsumexp(input_tensor,
@@ -1827,7 +1833,8 @@ def reduce_logsumexp(input_tensor,
return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result)
-@tf_export("trace", "linalg.trace")
+@tf_export("linalg.trace", "trace")
+@deprecation.deprecated_endpoints("trace")
def trace(x, name=None):
"""Compute the trace of a tensor `x`.
@@ -1841,12 +1848,12 @@ def trace(x, name=None):
```python
x = tf.constant([[1, 2], [3, 4]])
- tf.trace(x) # 5
+ tf.linalg.trace(x) # 5
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
- tf.trace(x) # 15
+ tf.linalg.trace(x) # 15
x = tf.constant([[[1, 2, 3],
[4, 5, 6],
@@ -1854,7 +1861,7 @@ def trace(x, name=None):
[[-1, -2, -3],
[-4, -5, -6],
[-7, -8, -9]]])
- tf.trace(x) # [15, -15]
+ tf.linalg.trace(x) # [15, -15]
```
Args:
@@ -1869,7 +1876,7 @@ def trace(x, name=None):
return reduce_sum(array_ops.matrix_diag_part(x), [-1], name=name)
-@tf_export("matmul")
+@tf_export("linalg.matmul", "matmul")
def matmul(a,
b,
transpose_a=False,
@@ -2131,7 +2138,7 @@ def _as_indexed_slices_list(inputs, optimize=True):
return casted_outputs
-@tf_export("add_n")
+@tf_export("math.add_n", "add_n")
def add_n(inputs, name=None):
"""Adds all input tensors element-wise.
@@ -2166,14 +2173,15 @@ def add_n(inputs, name=None):
return gen_math_ops.add_n(inputs, name=name)
-@tf_export("accumulate_n")
+@tf_export("math.accumulate_n", "accumulate_n")
+@deprecation.deprecated_endpoints("accumulate_n")
def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
"""Returns the element-wise sum of a list of tensors.
Optionally, pass `shape` and `tensor_dtype` for shape and type checking,
otherwise, these are inferred.
- `tf.accumulate_n` performs the same operation as `tf.add_n`, but does not
+ `tf.math.accumulate_n` performs the same operation as `tf.add_n`, but does not
wait for all of its inputs to be ready before beginning to sum. This can
save memory if inputs are ready at different times, since minimum temporary
storage is proportional to the output size rather than the inputs size.
@@ -2185,10 +2193,10 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
```python
a = tf.constant([[1, 2], [3, 4]])
b = tf.constant([[5, 0], [0, 6]])
- tf.accumulate_n([a, b, a]) # [[7, 4], [6, 14]]
+ tf.math.accumulate_n([a, b, a]) # [[7, 4], [6, 14]]
# Explicitly pass shape and type
- tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32)
+ tf.math.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32)
# [[7, 4],
# [6, 14]]
```
@@ -2252,7 +2260,7 @@ def _accumulate_n_grad(op, grad):
return [grad] * len(op.inputs)
-@tf_export("nn.sigmoid", "sigmoid")
+@tf_export("math.sigmoid", "nn.sigmoid", "sigmoid")
def sigmoid(x, name=None):
"""Computes sigmoid of `x` element-wise.
@@ -2275,7 +2283,8 @@ def sigmoid(x, name=None):
return gen_math_ops.sigmoid(x, name=name)
-@tf_export("log_sigmoid")
+@tf_export("math.log_sigmoid", "log_sigmoid")
+@deprecation.deprecated_endpoints("log_sigmoid")
def log_sigmoid(x, name=None):
"""Computes log sigmoid of `x` element-wise.
@@ -2294,7 +2303,7 @@ def log_sigmoid(x, name=None):
return gen_math_ops.neg(gen_nn_ops.softplus(-x), name=name)
-@tf_export("nn.tanh", "tanh")
+@tf_export("math.tanh", "nn.tanh", "tanh")
def tanh(x, name=None):
"""Computes hyperbolic tangent of `x` element-wise.
@@ -2315,7 +2324,8 @@ def tanh(x, name=None):
return gen_math_ops.tanh(x, name=name)
-@tf_export("bincount")
+@tf_export("math.bincount", "bincount")
+@deprecation.deprecated_endpoints("bincount")
def bincount(arr,
weights=None,
minlength=None,
@@ -2362,7 +2372,7 @@ def bincount(arr,
return gen_math_ops.bincount(arr, output_size, weights)
-@tf_export("cumsum")
+@tf_export("math.cumsum", "cumsum")
def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
"""Compute the cumulative sum of the tensor `x` along `axis`.
@@ -2414,7 +2424,8 @@ def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
x, axis, exclusive=exclusive, reverse=reverse, name=name)
-@tf_export("cumprod")
+@tf_export("math.cumprod", "cumprod")
+@deprecation.deprecated_endpoints("cumprod")
def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
"""Compute the cumulative product of the tensor `x` along `axis`.
@@ -2422,7 +2433,7 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
first element of the input is identical to the first element of the output:
```python
- tf.cumprod([a, b, c]) # [a, a * b, a * b * c]
+ tf.math.cumprod([a, b, c]) # [a, a * b, a * b * c]
```
By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
@@ -2430,21 +2441,21 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
instead:
```python
- tf.cumprod([a, b, c], exclusive=True) # [1, a, a * b]
+ tf.math.cumprod([a, b, c], exclusive=True) # [1, a, a * b]
```
By setting the `reverse` kwarg to `True`, the cumprod is performed in the
opposite direction:
```python
- tf.cumprod([a, b, c], reverse=True) # [a * b * c, b * c, c]
+ tf.math.cumprod([a, b, c], reverse=True) # [a * b * c, b * c, c]
```
This is more efficient than using separate `tf.reverse` ops.
The `reverse` and `exclusive` kwargs can also be combined:
```python
- tf.cumprod([a, b, c], exclusive=True, reverse=True) # [b * c, c, 1]
+ tf.math.cumprod([a, b, c], exclusive=True, reverse=True) # [b * c, c, 1]
```
Args:
@@ -2466,7 +2477,8 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
x, axis, exclusive=exclusive, reverse=reverse, name=name)
-@tf_export("conj")
+@tf_export("math.conj", "conj")
+@deprecation.deprecated_endpoints("conj")
def conj(x, name=None):
r"""Returns the complex conjugate of a complex number.
@@ -2480,7 +2492,7 @@ def conj(x, name=None):
For example:
# tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
- tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
+ tf.math.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
If `x` is real, it is returned unchanged.
@@ -2566,7 +2578,8 @@ def _unsorted_segment_N(data, segment_ids, num_segments):
return gen_math_ops.maximum(N, 1)
-@tf_export("unsorted_segment_mean")
+@tf_export("math.unsorted_segment_mean", "unsorted_segment_mean")
+@deprecation.deprecated_endpoints("unsorted_segment_mean")
def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
r"""Computes the mean along segments of a tensor.
@@ -2608,7 +2621,8 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
return summed / N
-@tf_export("unsorted_segment_sqrt_n")
+@tf_export("math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n")
+@deprecation.deprecated_endpoints("unsorted_segment_sqrt_n")
def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
r"""Computes the sum along segments of a tensor divided by the sqrt(N).
@@ -2653,7 +2667,8 @@ def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
return summed / gen_math_ops.sqrt(N)
-@tf_export("sparse_segment_sum")
+@tf_export("sparse.segment_sum", "sparse_segment_sum")
+@deprecation.deprecated_endpoints("sparse_segment_sum")
def sparse_segment_sum(data, indices, segment_ids, name=None,
num_segments=None):
r"""Computes the sum along sparse segments of a tensor.
@@ -2674,16 +2689,16 @@ def sparse_segment_sum(data, indices, segment_ids, name=None,
c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
# Select two rows, one segment.
- tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
+ tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
# => [[0 0 0 0]]
# Select two rows, two segment.
- tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
+ tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
# => [[ 1 2 3 4]
# [-1 -2 -3 -4]]
# With missing segment ids.
- tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 2]),
+ tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 2]),
num_segments=4)
# => [[ 1 2 3 4]
# [ 0 0 0 0]
@@ -2691,7 +2706,7 @@ def sparse_segment_sum(data, indices, segment_ids, name=None,
# [ 0 0 0 0]]
# Select all rows, two segments.
- tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
+ tf.sparse.segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
# => [[0 0 0 0]
# [5 6 7 8]]
@@ -2726,7 +2741,8 @@ def sparse_segment_sum(data, indices, segment_ids, name=None,
data=data, indices=indices, segment_ids=segment_ids, name=name)
-@tf_export("sparse_segment_mean")
+@tf_export("sparse.segment_mean", "sparse_segment_mean")
+@deprecation.deprecated_endpoints("sparse_segment_mean")
def sparse_segment_mean(data,
indices,
segment_ids,
@@ -2771,7 +2787,8 @@ def sparse_segment_mean(data,
data=data, indices=indices, segment_ids=segment_ids, name=name)
-@tf_export("sparse_segment_sqrt_n")
+@tf_export("sparse.segment_sqrt_n", "sparse_segment_sqrt_n")
+@deprecation.deprecated_endpoints("sparse_segment_sqrt_n")
def sparse_segment_sqrt_n(data,
indices,
segment_ids,
diff --git a/tensorflow/python/ops/math_ops_test.py b/tensorflow/python/ops/math_ops_test.py
index 1b01d1d37f..f051850d92 100644
--- a/tensorflow/python/ops/math_ops_test.py
+++ b/tensorflow/python/ops/math_ops_test.py
@@ -21,6 +21,7 @@ import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -488,5 +489,75 @@ class DivNoNanTest(test_util.TensorFlowTestCase):
self.assertAllEqual(tf_result, np_result)
+class XlogyTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXlogyNoZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.1, 0.2, 3.5], [-2., -5., 30.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [3.1, 4., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xlogy = self.evaluate(math_ops.xlogy(x, y))
+ xtimeslogy = self.evaluate(x * math_ops.log(y))
+ self.assertAllClose(xlogy, xtimeslogy)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXlogyWithZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(np.zeros((2, 3)), dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xlogy_tf_np = self.evaluate(math_ops.xlogy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y))
+ self.assertAllClose(xlogy_tf_np, zeros_np)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXlogyWithZeroBroadcast(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.], [1.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xlogy_tf_np = self.evaluate(math_ops.xlogy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y[0]))
+ xtimes_logy = self.evaluate(math_ops.log(y[1]))
+ self.assertAllClose(zeros_np, xlogy_tf_np[0])
+ self.assertAllClose(xtimes_logy, xlogy_tf_np[1])
+
+
+class XdivyTest(test_util.TensorFlowTestCase):
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXdivyNoZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.1, 0.2, 3.5], [-2., -5., 30.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [3.1, 4., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xdivy = self.evaluate(math_ops.xdivy(x, y))
+ x_over_y = self.evaluate(x / y)
+ self.assertAllClose(xdivy, x_over_y)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXdivyWithZero(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant(np.zeros((2, 3)), dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xdivy_tf_np = self.evaluate(math_ops.xdivy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y))
+ self.assertAllClose(xdivy_tf_np, zeros_np)
+
+ @test_util.run_in_graph_and_eager_modes
+ def testXdivyWithZeroBroadcast(self):
+ for dtype in [dtypes.float16, dtypes.float32, dtypes.float64]:
+ x = constant_op.constant([[0.], [1.]], dtype=dtype)
+ y = constant_op.constant([[0.1, 0.2, 3.5], [0., 1., 2.]], dtype=dtype)
+ with self.cached_session(use_gpu=True):
+ xdivy_tf_np = self.evaluate(math_ops.xdivy(x, y))
+ zeros_np = self.evaluate(array_ops.zeros_like(y[0]))
+ x_over_y = self.evaluate(1 / y[1])
+ self.assertAllClose(zeros_np, xdivy_tf_np[0])
+ self.assertAllClose(x_over_y, xdivy_tf_np[1])
+
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/ops/matmul_benchmark.py b/tensorflow/python/ops/matmul_benchmark.py
index 6e5fe74290..138149e63d 100644
--- a/tensorflow/python/ops/matmul_benchmark.py
+++ b/tensorflow/python/ops/matmul_benchmark.py
@@ -49,13 +49,13 @@ def build_graph(device, n, m, k, transpose_a, transpose_b, dtype):
"""
with ops.device('%s' % device):
if not transpose_a:
- x = variables.Variable(random_ops.random_uniform([n, m], dtype=dtype))
+ x = variables.VariableV1(random_ops.random_uniform([n, m], dtype=dtype))
else:
- x = variables.Variable(random_ops.random_uniform([m, n], dtype=dtype))
+ x = variables.VariableV1(random_ops.random_uniform([m, n], dtype=dtype))
if not transpose_b:
- y = variables.Variable(random_ops.random_uniform([m, k], dtype=dtype))
+ y = variables.VariableV1(random_ops.random_uniform([m, k], dtype=dtype))
else:
- y = variables.Variable(random_ops.random_uniform([k, m], dtype=dtype))
+ y = variables.VariableV1(random_ops.random_uniform([k, m], dtype=dtype))
z = math_ops.matmul(x, y, transpose_a=transpose_a, transpose_b=transpose_b)
return control_flow_ops.group(z)
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 2a1919e66f..453848fc00 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -328,7 +328,7 @@ def swish(features):
return features * math_ops.sigmoid(features)
-@tf_export("nn.l2_normalize")
+@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize")
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
"""Normalizes along dimension `axis` using an L2 norm.
@@ -360,7 +360,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
return math_ops.multiply(x, x_inv_norm, name=name)
-@tf_export("nn.zero_fraction")
+@tf_export("math.zero_fraction", "nn.zero_fraction")
def zero_fraction(value, name=None):
"""Returns the fraction of zeros in `value`.
@@ -689,7 +689,7 @@ def moments(
# Compute true mean while keeping the dims for proper broadcasting.
mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean")
# sample variance, not unbiased variance
- # Note: stop_gradient does not change the gradient that gets
+ # Note: stop_gradient does not change the gradient that gets
# backpropagated to the mean from the variance calculation,
# because that gradient is zero
variance = math_ops.reduce_mean(
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 78e000e458..04962da7f7 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -427,8 +427,8 @@ class _WithSpaceToBatch(object):
try:
input_shape.with_rank_at_least(expected_input_rank)
except ValueError:
- ValueError("input tensor must have rank %d at least" %
- (expected_input_rank))
+ raise ValueError(
+ "input tensor must have rank %d at least" % (expected_input_rank))
const_rate = tensor_util.constant_value(dilation_rate)
rate_or_const_rate = dilation_rate
@@ -818,12 +818,14 @@ class Convolution(object):
try:
input_shape.with_rank(num_spatial_dims + 2)
except ValueError:
- ValueError("input tensor must have rank %d" % (num_spatial_dims + 2))
+ raise ValueError(
+ "input tensor must have rank %d" % (num_spatial_dims + 2))
try:
filter_shape.with_rank(num_spatial_dims + 2)
except ValueError:
- ValueError("filter tensor must have rank %d" % (num_spatial_dims + 2))
+ raise ValueError(
+ "filter tensor must have rank %d" % (num_spatial_dims + 2))
if data_format is None or not data_format.startswith("NC"):
input_channels_dim = input_shape[num_spatial_dims + 1]
@@ -1695,7 +1697,7 @@ def _softmax(logits, compute_op, dim=-1, name=None):
return output
-@tf_export("nn.softmax")
+@tf_export("nn.softmax", "math.softmax")
@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def softmax(logits, axis=None, name=None, dim=None):
"""Computes softmax activations.
@@ -1725,7 +1727,7 @@ def softmax(logits, axis=None, name=None, dim=None):
return _softmax(logits, gen_nn_ops.softmax, axis, name)
-@tf_export("nn.log_softmax")
+@tf_export("nn.log_softmax", "math.log_softmax")
@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def log_softmax(logits, axis=None, name=None, dim=None):
"""Computes log softmax activations.
@@ -2332,7 +2334,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di
return ret
-@tf_export("nn.top_k")
+@tf_export("math.top_k", "nn.top_k")
def top_k(input, k=1, sorted=True, name=None): # pylint: disable=redefined-builtin
"""Finds values and indices of the `k` largest entries for the last dimension.
@@ -2647,7 +2649,7 @@ def erosion2d(value, kernel, strides, rates, padding, name=None):
name=name))
-@tf_export("nn.in_top_k")
+@tf_export("math.in_top_k", "nn.in_top_k")
def in_top_k(predictions, targets, k, name=None):
r"""Says whether the targets are in the top `K` predictions.
diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py
index 8fcbd7d834..002e87b411 100644
--- a/tensorflow/python/ops/numerics.py
+++ b/tensorflow/python/ops/numerics.py
@@ -24,10 +24,12 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
-@tf_export("verify_tensor_all_finite")
+@tf_export("debugging.assert_all_finite", "verify_tensor_all_finite")
+@deprecation.deprecated_endpoints("verify_tensor_all_finite")
def verify_tensor_all_finite(t, msg, name=None):
"""Assert that the tensor does not contain any NaN's or Inf's.
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index e0f6d51881..83cbe64ff2 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -1987,14 +1987,12 @@ def _convert_cast(pfor_input):
@RegisterPForWithArgs("Pow", math_ops.pow)
@RegisterPForWithArgs("RealDiv", math_ops.divide)
@RegisterPForWithArgs("Real", math_ops.real)
-@RegisterPForWithArgs("ReciprocalGrad", math_ops.reciprocal_grad)
@RegisterPForWithArgs("Reciprocal", math_ops.reciprocal)
@RegisterPForWithArgs("Relu6", nn_ops.relu6)
@RegisterPForWithArgs("Relu", nn_ops.relu)
@RegisterPForWithArgs("RightShift", bitwise_ops.right_shift)
@RegisterPForWithArgs("Rint", math_ops.rint)
@RegisterPForWithArgs("Round", math_ops.round)
-@RegisterPForWithArgs("RsqrtGrad", math_ops.rsqrt_grad)
@RegisterPForWithArgs("Rsqrt", math_ops.rsqrt)
@RegisterPForWithArgs("Selu", nn_ops.selu)
@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid)
@@ -2003,7 +2001,6 @@ def _convert_cast(pfor_input):
@RegisterPForWithArgs("Sin", math_ops.sin)
@RegisterPForWithArgs("Softplus", nn_ops.softplus)
@RegisterPForWithArgs("Softsign", nn_ops.softsign)
-@RegisterPForWithArgs("SqrtGrad", math_ops.sqrt_grad)
@RegisterPForWithArgs("Sqrt", math_ops.sqrt)
@RegisterPForWithArgs("SquaredDifference", math_ops.squared_difference)
@RegisterPForWithArgs("Square", math_ops.square)
@@ -2095,6 +2092,9 @@ def _convert_biasaddgrad(pfor_input):
@RegisterPForWithArgs("SoftplusGrad")
@RegisterPForWithArgs("SoftsignGrad")
@RegisterPForWithArgs("TanhGrad")
+@RegisterPForWithArgs("SqrtGrad")
+@RegisterPForWithArgs("RsqrtGrad")
+@RegisterPForWithArgs("ReciprocalGrad")
def _convert_grads(pfor_input, op_type, *args, **kw_args):
del args
del kw_args
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index b3e03a0135..a2da6412ed 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.gen_parsing_ops import *
# pylint: enable=wildcard-import,undefined-variable
from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -45,7 +46,7 @@ ops.NotDifferentiable("SerializeTensor")
ops.NotDifferentiable("StringToNumber")
-@tf_export("VarLenFeature")
+@tf_export("io.VarLenFeature", "VarLenFeature")
class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])):
"""Configuration for parsing a variable-length input feature.
@@ -55,7 +56,7 @@ class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])):
pass
-@tf_export("SparseFeature")
+@tf_export("io.SparseFeature", "SparseFeature")
class SparseFeature(
collections.namedtuple(
"SparseFeature",
@@ -130,7 +131,7 @@ class SparseFeature(
cls, index_key, value_key, dtype, size, already_sorted)
-@tf_export("FixedLenFeature")
+@tf_export("io.FixedLenFeature", "FixedLenFeature")
class FixedLenFeature(collections.namedtuple(
"FixedLenFeature", ["shape", "dtype", "default_value"])):
"""Configuration for parsing a fixed-length input feature.
@@ -150,7 +151,7 @@ class FixedLenFeature(collections.namedtuple(
cls, shape, dtype, default_value)
-@tf_export("FixedLenSequenceFeature")
+@tf_export("io.FixedLenSequenceFeature", "FixedLenSequenceFeature")
class FixedLenSequenceFeature(collections.namedtuple(
"FixedLenSequenceFeature",
["shape", "dtype", "allow_missing", "default_value"])):
@@ -216,21 +217,21 @@ def _features_to_raw_params(features, types):
feature = features[key]
if isinstance(feature, VarLenFeature):
if VarLenFeature not in types:
- raise ValueError("Unsupported VarLenFeature %s." % feature)
+ raise ValueError("Unsupported VarLenFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
sparse_keys.append(key)
sparse_types.append(feature.dtype)
elif isinstance(feature, SparseFeature):
if SparseFeature not in types:
- raise ValueError("Unsupported SparseFeature %s." % feature)
+ raise ValueError("Unsupported SparseFeature %s." % (feature,))
if not feature.index_key:
raise ValueError(
- "Missing index_key for SparseFeature %s." % feature)
+ "Missing index_key for SparseFeature %s." % (feature,))
if not feature.value_key:
raise ValueError(
- "Missing value_key for SparseFeature %s." % feature)
+ "Missing value_key for SparseFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
index_keys = feature.index_key
@@ -259,7 +260,7 @@ def _features_to_raw_params(features, types):
sparse_types.append(feature.dtype)
elif isinstance(feature, FixedLenFeature):
if FixedLenFeature not in types:
- raise ValueError("Unsupported FixedLenFeature %s." % feature)
+ raise ValueError("Unsupported FixedLenFeature %s." % (feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
if feature.shape is None:
@@ -280,7 +281,8 @@ def _features_to_raw_params(features, types):
dense_defaults[key] = feature.default_value
elif isinstance(feature, FixedLenSequenceFeature):
if FixedLenSequenceFeature not in types:
- raise ValueError("Unsupported FixedLenSequenceFeature %s." % feature)
+ raise ValueError("Unsupported FixedLenSequenceFeature %s." % (
+ feature,))
if not feature.dtype:
raise ValueError("Missing type for feature %s." % key)
if feature.shape is None:
@@ -360,7 +362,7 @@ def _prepend_none_dimension(features):
return features
-@tf_export("parse_example")
+@tf_export("io.parse_example", "parse_example")
def parse_example(serialized, features, name=None, example_names=None):
# pylint: disable=line-too-long
"""Parses `Example` protos into a `dict` of tensors.
@@ -761,7 +763,7 @@ def _process_raw_parameters(names, dense_defaults, sparse_keys, sparse_types,
dense_shapes_as_proto, dense_shapes)
-@tf_export("parse_single_example")
+@tf_export("io.parse_single_example", "parse_single_example")
def parse_single_example(serialized, features, name=None, example_names=None):
"""Parses a single `Example` proto.
@@ -1244,7 +1246,7 @@ def _parse_sequence_example_raw(serialized,
# TODO(sundberg): rewrite this method to call the batch version, which is more
# efficient especially for large inputs.
-@tf_export("parse_single_sequence_example")
+@tf_export("io.parse_single_sequence_example", "parse_single_sequence_example")
def parse_single_sequence_example(
serialized, context_features=None, sequence_features=None,
example_name=None, name=None):
@@ -1564,7 +1566,8 @@ def _parse_single_sequence_example_raw(serialized,
# Swap `name` and `na_value` for backward compatibility.
-@tf_export("decode_csv")
+@tf_export("io.decode_csv", "decode_csv")
+@deprecation.deprecated_endpoints("decode_csv")
def decode_csv(records,
record_defaults,
field_delim=",",
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index 4baf506385..c2eb9dfc5d 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import math_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_random_ops import *
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -43,7 +44,7 @@ def _ShapeTensor(shape):
return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
-@tf_export("random_normal")
+@tf_export("random.normal", "random_normal")
def random_normal(shape,
mean=0.0,
stddev=1.0,
@@ -136,7 +137,7 @@ def parameterized_truncated_normal(shape,
return rnd
-@tf_export("truncated_normal")
+@tf_export("random.truncated_normal", "truncated_normal")
def truncated_normal(shape,
mean=0.0,
stddev=1.0,
@@ -181,7 +182,7 @@ ops.NotDifferentiable("ParameterizedTruncatedNormal")
ops.NotDifferentiable("TruncatedNormal")
-@tf_export("random_uniform")
+@tf_export("random.uniform", "random_uniform")
def random_uniform(shape,
minval=0,
maxval=None,
@@ -246,7 +247,7 @@ def random_uniform(shape,
ops.NotDifferentiable("RandomUniform")
-@tf_export("random_shuffle")
+@tf_export("random.shuffle", "random_shuffle")
def random_shuffle(value, seed=None, name=None):
"""Randomly shuffles a tensor along its first dimension.
@@ -277,7 +278,7 @@ def random_shuffle(value, seed=None, name=None):
value, seed=seed1, seed2=seed2, name=name)
-@tf_export("random_crop")
+@tf_export("image.random_crop", "random_crop")
def random_crop(value, size, seed=None, name=None):
"""Randomly crops a tensor to a given size.
@@ -320,7 +321,7 @@ def random_crop(value, size, seed=None, name=None):
return array_ops.slice(value, offset, size, name=name)
-@tf_export("multinomial")
+@tf_export("random.multinomial", "multinomial")
def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None):
"""Draws samples from a multinomial distribution.
@@ -356,7 +357,8 @@ def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None):
ops.NotDifferentiable("Multinomial")
-@tf_export("random_gamma")
+@tf_export("random.gamma", "random_gamma")
+@deprecation.deprecated_endpoints("random_gamma")
def random_gamma(shape,
alpha,
beta=None,
@@ -439,7 +441,8 @@ def random_gamma(shape,
shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta)
-@tf_export("random_poisson")
+@tf_export("random.poisson", "random_poisson")
+@deprecation.deprecated_endpoints("random_poisson")
def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
"""Draws `shape` samples from each of the given Poisson distribution(s).
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 43cca1a498..dd4f3d7a99 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -370,7 +370,7 @@ class LayerRNNCell(RNNCell):
*args, **kwargs)
-@tf_export("nn.rnn_cell.BasicRNNCell")
+@tf_export(v1=["nn.rnn_cell.BasicRNNCell"])
class BasicRNNCell(LayerRNNCell):
"""The most basic RNN cell.
@@ -393,6 +393,8 @@ class BasicRNNCell(LayerRNNCell):
`trainable` etc when constructing the cell from configs of get_config().
"""
+ @deprecated(None, "This class is equivalent as tf.keras.layers.SimpleRNNCell,"
+ " and will be replaced by that in Tensorflow 2.0.")
def __init__(self,
num_units,
activation=None,
@@ -611,7 +613,7 @@ class LSTMStateTuple(_LSTMStateTuple):
# TODO(scottzhu): Stop exporting this class in TF 2.0.
@tf_export("nn.rnn_cell.BasicLSTMCell")
class BasicLSTMCell(LayerRNNCell):
- """DEPRECATED: Please use @{tf.nn.rnn_cell.LSTMCell} instead.
+ """DEPRECATED: Please use `tf.nn.rnn_cell.LSTMCell` instead.
Basic LSTM recurrent network cell.
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 400a42a3c0..7e3dbdbad4 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -185,7 +185,8 @@ def sparse_eye(num_rows,
# pylint: disable=protected-access
-@tf_export("sparse_concat")
+@tf_export("sparse.concat", "sparse_concat")
+@deprecation.deprecated_endpoints("sparse_concat")
@deprecation.deprecated_args(
None, "concat_dim is deprecated, use axis instead", "concat_dim")
def sparse_concat(axis,
@@ -317,7 +318,8 @@ def sparse_concat(axis,
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
-@tf_export("sparse_add")
+@tf_export("sparse.add", "sparse_add")
+@deprecation.deprecated_endpoints("sparse_add")
def sparse_add(a, b, thresh=0):
"""Adds two tensors, at least one of each is a `SparseTensor`.
@@ -557,7 +559,8 @@ def sparse_dense_cwise_add(sp_t, dense_t):
return sparse_tensor.SparseTensor(sp_t.indices, result, sp_t.dense_shape)
-@tf_export("sparse_reorder")
+@tf_export("sparse.reorder", "sparse_reorder")
+@deprecation.deprecated_endpoints("sparse_reorder")
def sparse_reorder(sp_input, name=None):
"""Reorders a `SparseTensor` into the canonical, row-major ordering.
@@ -607,7 +610,8 @@ def sparse_reorder(sp_input, name=None):
return sparse_tensor.SparseTensor(reordered_ind, reordered_val, dense_shape)
-@tf_export("sparse_reshape")
+@tf_export("sparse.reshape", "sparse_reshape")
+@deprecation.deprecated_endpoints("sparse_reshape")
def sparse_reshape(sp_input, shape, name=None):
"""Reshapes a `SparseTensor` to represent values in a new dense shape.
@@ -700,7 +704,8 @@ class KeywordRequired(object):
return "KeywordRequired()"
-@tf_export("sparse_split")
+@tf_export("sparse.split", "sparse_split")
+@deprecation.deprecated_endpoints("sparse_split")
@deprecation.deprecated_args(
None, "split_dim is deprecated, use axis instead", "split_dim")
def sparse_split(keyword_required=KeywordRequired(),
@@ -773,7 +778,8 @@ def sparse_split(keyword_required=KeywordRequired(),
return sparse_tensors
-@tf_export("sparse_slice")
+@tf_export("sparse.slice", "sparse_slice")
+@deprecation.deprecated_endpoints("sparse_slice")
def sparse_slice(sp_input, start, size, name=None):
"""Slice a `SparseTensor` based on the `start` and `size.
@@ -785,11 +791,11 @@ def sparse_slice(sp_input, start, size, name=None):
Graphically the output tensors are:
- sparse_slice([0, 0], [2, 4]) = shape = [2, 4]
+ sparse.slice([0, 0], [2, 4]) = shape = [2, 4]
[ a ]
[b c ]
- sparse_slice([0, 4], [2, 3]) = shape = [2, 3]
+ sparse.slice([0, 4], [2, 3]) = shape = [2, 3]
[ d e ]
[ ]
@@ -823,6 +829,9 @@ def sparse_slice(sp_input, start, size, name=None):
@tf_export("sparse_to_dense")
+@deprecation.deprecated(
+ None,
+ "Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.")
def sparse_to_dense(sparse_indices,
output_shape,
sparse_values,
@@ -878,7 +887,8 @@ def sparse_to_dense(sparse_indices,
name=name)
-@tf_export("sparse_reduce_max")
+@tf_export("sparse.reduce_max", "sparse_reduce_max")
+@deprecation.deprecated_endpoints("sparse_reduce_max")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def sparse_reduce_max(sp_input, axis=None, keepdims=None,
@@ -912,16 +922,16 @@ def sparse_reduce_max(sp_input, axis=None, keepdims=None,
# 'x' represents [[1, ?, 2]
# [?, 3, ?]]
# where ? is implicitly-zero.
- tf.sparse_reduce_max(x) ==> 3
- tf.sparse_reduce_max(x, 0) ==> [1, 3, 2]
- tf.sparse_reduce_max(x, 1) ==> [2, 3] # Can also use -1 as the axis.
- tf.sparse_reduce_max(x, 1, keepdims=True) ==> [[2], [3]]
- tf.sparse_reduce_max(x, [0, 1]) ==> 3
+ tf.sparse.reduce_max(x) ==> 3
+ tf.sparse.reduce_max(x, 0) ==> [1, 3, 2]
+ tf.sparse.reduce_max(x, 1) ==> [2, 3] # Can also use -1 as the axis.
+ tf.sparse.reduce_max(x, 1, keepdims=True) ==> [[2], [3]]
+ tf.sparse.reduce_max(x, [0, 1]) ==> 3
# 'y' represents [[-7, ?]
# [ 4, 3]
# [ ?, ?]
- tf.sparse_reduce_max(x, 1) ==> [-7, 4, 0]
+ tf.sparse.reduce_max(x, 1) ==> [-7, 4, 0]
```
Args:
@@ -945,7 +955,8 @@ def sparse_reduce_max(sp_input, axis=None, keepdims=None,
math_ops._ReductionDims(sp_input, axis, reduction_axes), keepdims)
-@tf_export("sparse_reduce_max_sparse")
+@tf_export("sparse.reduce_max_sparse", "sparse_reduce_max_sparse")
+@deprecation.deprecated_endpoints("sparse_reduce_max_sparse")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def sparse_reduce_max_sparse(sp_input,
@@ -995,7 +1006,8 @@ def sparse_reduce_max_sparse(sp_input,
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
-@tf_export("sparse_reduce_sum")
+@tf_export("sparse.reduce_sum", "sparse_reduce_sum")
+@deprecation.deprecated_endpoints("sparse_reduce_sum")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def sparse_reduce_sum(sp_input, axis=None, keepdims=None,
@@ -1021,11 +1033,11 @@ def sparse_reduce_sum(sp_input, axis=None, keepdims=None,
# 'x' represents [[1, ?, 1]
# [?, 1, ?]]
# where ? is implicitly-zero.
- tf.sparse_reduce_sum(x) ==> 3
- tf.sparse_reduce_sum(x, 0) ==> [1, 1, 1]
- tf.sparse_reduce_sum(x, 1) ==> [2, 1] # Can also use -1 as the axis.
- tf.sparse_reduce_sum(x, 1, keepdims=True) ==> [[2], [1]]
- tf.sparse_reduce_sum(x, [0, 1]) ==> 3
+ tf.sparse.reduce_sum(x) ==> 3
+ tf.sparse.reduce_sum(x, 0) ==> [1, 1, 1]
+ tf.sparse.reduce_sum(x, 1) ==> [2, 1] # Can also use -1 as the axis.
+ tf.sparse.reduce_sum(x, 1, keepdims=True) ==> [[2], [1]]
+ tf.sparse.reduce_sum(x, [0, 1]) ==> 3
```
Args:
@@ -1049,7 +1061,8 @@ def sparse_reduce_sum(sp_input, axis=None, keepdims=None,
math_ops._ReductionDims(sp_input, axis, reduction_axes), keepdims)
-@tf_export("sparse_reduce_sum_sparse")
+@tf_export("sparse.reduce_sum_sparse", "sparse_reduce_sum_sparse")
+@deprecation.deprecated_endpoints("sparse_reduce_sum_sparse")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def sparse_reduce_sum_sparse(sp_input,
@@ -1099,7 +1112,8 @@ def sparse_reduce_sum_sparse(sp_input,
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
-@tf_export("sparse_tensor_to_dense")
+@tf_export("sparse.to_dense", "sparse_tensor_to_dense")
+@deprecation.deprecated_endpoints("sparse_tensor_to_dense")
def sparse_tensor_to_dense(sp_input,
default_value=0,
validate_indices=True,
@@ -1151,7 +1165,8 @@ def sparse_tensor_to_dense(sp_input,
name=name)
-@tf_export("sparse_to_indicator")
+@tf_export("sparse.to_indicator", "sparse_to_indicator")
+@deprecation.deprecated_endpoints("sparse_to_indicator")
def sparse_to_indicator(sp_input, vocab_size, name=None):
"""Converts a `SparseTensor` of ids into a dense bool indicator tensor.
@@ -1214,7 +1229,8 @@ def sparse_to_indicator(sp_input, vocab_size, name=None):
sp_new, default_value=False, validate_indices=False, name=name)
-@tf_export("sparse_merge")
+@tf_export("sparse.merge", "sparse_merge")
+@deprecation.deprecated_endpoints("sparse_merge")
def sparse_merge(sp_ids, sp_values, vocab_size, name=None,
already_sorted=False):
"""Combines a batch of feature ids and values into a single `SparseTensor`.
@@ -1358,7 +1374,8 @@ def sparse_merge(sp_ids, sp_values, vocab_size, name=None,
sorted_result.indices, sorted_result.values, new_shape)
-@tf_export("sparse_retain")
+@tf_export("sparse.retain", "sparse_retain")
+@deprecation.deprecated_endpoints("sparse_retain")
def sparse_retain(sp_input, to_retain):
"""Retains specified non-empty values within a `SparseTensor`.
@@ -1402,7 +1419,8 @@ def sparse_retain(sp_input, to_retain):
array_ops.identity(sp_input.dense_shape))
-@tf_export("sparse_reset_shape")
+@tf_export("sparse.reset_shape", "sparse_reset_shape")
+@deprecation.deprecated_endpoints("sparse_reset_shape")
def sparse_reset_shape(sp_input, new_shape=None):
"""Resets the shape of a `SparseTensor` with indices and values unchanged.
@@ -1503,7 +1521,8 @@ def sparse_reset_shape(sp_input, new_shape=None):
return sparse_tensor.SparseTensor(in_indices, in_values, output_shape_tensor)
-@tf_export("sparse_fill_empty_rows")
+@tf_export("sparse.fill_empty_rows", "sparse_fill_empty_rows")
+@deprecation.deprecated_endpoints("sparse_fill_empty_rows")
def sparse_fill_empty_rows(sp_input, default_value, name=None):
"""Fills empty rows in the input 2-D `SparseTensor` with a default value.
@@ -1567,7 +1586,8 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
dense_shape=sp_input.dense_shape), empty_row_indicator)
-@tf_export("serialize_sparse")
+@tf_export("io.serialize_sparse", "serialize_sparse")
+@deprecation.deprecated_endpoints("serialize_sparse")
def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object.
@@ -1593,7 +1613,8 @@ def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
out_type=out_type)
-@tf_export("serialize_many_sparse")
+@tf_export("io.serialize_many_sparse", "serialize_many_sparse")
+@deprecation.deprecated_endpoints("serialize_many_sparse")
def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`.
@@ -1694,7 +1715,8 @@ def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None):
return sparse_tensor.SparseTensor(output_indices, output_values, output_shape)
-@tf_export("deserialize_many_sparse")
+@tf_export("io.deserialize_many_sparse", "deserialize_many_sparse")
+@deprecation.deprecated_endpoints("deserialize_many_sparse")
def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None):
"""Deserialize and concatenate `SparseTensors` from a serialized minibatch.
@@ -1712,7 +1734,7 @@ def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None):
The input `SparseTensor` objects' indices are assumed ordered in
standard lexicographic order. If this is not the case, after this
- step run `sparse_reorder` to restore index ordering.
+ step run `sparse.reorder` to restore index ordering.
For example, if the serialized input is a `[2, 3]` matrix representing two
original `SparseTensor` objects:
@@ -1764,7 +1786,8 @@ def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None):
return sparse_tensor.SparseTensor(output_indices, output_values, output_shape)
-@tf_export("sparse_tensor_dense_matmul")
+@tf_export("sparse.matmul", "sparse_tensor_dense_matmul")
+@deprecation.deprecated_endpoints("sparse_tensor_dense_matmul")
def sparse_tensor_dense_matmul(sp_a,
b,
adjoint_a=False,
@@ -1777,7 +1800,7 @@ def sparse_tensor_dense_matmul(sp_a,
following input format is recommended for optimal behavior:
* If `adjoint_a == false`: `A` should be sorted in lexicographically
- increasing order. Use `sparse_reorder` if you're not sure.
+ increasing order. Use `sparse.reorder` if you're not sure.
* If `adjoint_a == true`: `A` should be sorted in order of increasing
dimension 1 (i.e., "column major" order instead of "row major" order).
@@ -1981,7 +2004,8 @@ def sparse_tensor_dense_matmul(sp_a,
adjoint_b=adjoint_b)
-@tf_export("sparse_softmax")
+@tf_export("sparse.softmax", "sparse_softmax")
+@deprecation.deprecated_endpoints("sparse_softmax")
def sparse_softmax(sp_input, name=None):
"""Applies softmax to a batched N-D `SparseTensor`.
@@ -2036,7 +2060,8 @@ def sparse_softmax(sp_input, name=None):
sp_input.dense_shape)
-@tf_export("sparse_maximum")
+@tf_export("sparse.maximum", "sparse_maximum")
+@deprecation.deprecated_endpoints("sparse_maximum")
def sparse_maximum(sp_a, sp_b, name=None):
"""Returns the element-wise max of two SparseTensors.
@@ -2073,7 +2098,8 @@ def sparse_maximum(sp_a, sp_b, name=None):
return sparse_tensor.SparseTensor(out_indices, out_values, sp_a.dense_shape)
-@tf_export("sparse_minimum")
+@tf_export("sparse.minimum", "sparse_minimum")
+@deprecation.deprecated_endpoints("sparse_minimum")
def sparse_minimum(sp_a, sp_b, name=None):
"""Returns the element-wise min of two SparseTensors.
@@ -2110,7 +2136,8 @@ def sparse_minimum(sp_a, sp_b, name=None):
return sparse_tensor.SparseTensor(out_indices, out_values, sp_a.dense_shape)
-@tf_export("sparse_transpose")
+@tf_export("sparse.transpose", "sparse_transpose")
+@deprecation.deprecated_endpoints("sparse_transpose")
def sparse_transpose(sp_input, perm=None, name=None):
"""Transposes a `SparseTensor`
@@ -2259,7 +2286,7 @@ def _take_many_sparse_from_tensors_map(sparse_map_op,
The input `SparseTensor` objects' indices are assumed ordered in
standard lexicographic order. If this is not the case, after this
- step run `sparse_reorder` to restore index ordering.
+ step run `sparse.reorder` to restore index ordering.
For example, if the serialized input is a `[2, 3]` matrix representing two
original `SparseTensor` objects:
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py
index 9a10abfcf7..cfab943896 100644
--- a/tensorflow/python/ops/special_math_ops.py
+++ b/tensorflow/python/ops/special_math_ops.py
@@ -29,11 +29,13 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# TODO(b/27419586) Change docstring for required dtype of x once int allowed
-@tf_export('lbeta')
+@tf_export('math.lbeta', 'lbeta')
+@deprecation.deprecated_endpoints('lbeta')
def lbeta(x, name=None):
r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension.
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 046a48d192..f26388efea 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -46,6 +46,7 @@ from tensorflow.python.util.tf_export import tf_export
# pylint: disable=redefined-builtin
+@tf_export("strings.regex_full_match")
def regex_full_match(input, pattern, name=None):
r"""Match elements of `input` with regex `pattern`.
@@ -73,15 +74,14 @@ def regex_full_match(input, pattern, name=None):
regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
-# Expose regex_full_match in strings namespace
-tf_export("strings.regex_full_match")(regex_full_match)
-
-def regex_replace(source, pattern, rewrite, replace_global=True):
- r"""Replace elements of `source` matching regex `pattern` with `rewrite`.
+@tf_export("strings.regex_replace", "regex_replace")
+@deprecation.deprecated_endpoints("regex_replace")
+def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
+ r"""Replace elements of `input` matching regex `pattern` with `rewrite`.
Args:
- source: string `Tensor`, the source strings to process.
+ input: string `Tensor`, the source strings to process.
pattern: string or scalar string `Tensor`, regular expression to use,
see more details at https://github.com/google/re2/wiki/Syntax
rewrite: string or scalar string `Tensor`, value to use in match
@@ -89,9 +89,10 @@ def regex_replace(source, pattern, rewrite, replace_global=True):
text matching corresponding parenthesized group.
replace_global: `bool`, if `True` replace all non-overlapping matches,
else replace only the first match.
+ name: A name for the operation (optional).
Returns:
- string `Tensor` of the same shape as `source` with specified replacements.
+ string `Tensor` of the same shape as `input` with specified replacements.
"""
if (isinstance(pattern, util_compat.bytes_or_text_types) and
isinstance(rewrite, util_compat.bytes_or_text_types)):
@@ -99,11 +100,13 @@ def regex_replace(source, pattern, rewrite, replace_global=True):
# use a version which performs the expensive regex compilation once at
# creation time.
return gen_string_ops.static_regex_replace(
- input=source, pattern=pattern,
- rewrite=rewrite, replace_global=replace_global)
+ input=input, pattern=pattern,
+ rewrite=rewrite, replace_global=replace_global,
+ name=name)
return gen_string_ops.regex_replace(
- input=source, pattern=pattern,
- rewrite=rewrite, replace_global=replace_global)
+ input=input, pattern=pattern,
+ rewrite=rewrite, replace_global=replace_global,
+ name=name)
@tf_export("strings.format")
@@ -310,8 +313,9 @@ def _reduce_join_reduction_dims(x, axis, reduction_indices):
return math_ops.range(array_ops.rank(x) - 1, -1, -1)
-@tf_export("reduce_join")
-def reduce_join(inputs, axis=None,
+@tf_export("strings.reduce_join", "reduce_join")
+@deprecation.deprecated_endpoints("reduce_join")
+def reduce_join(inputs, axis=None, # pylint: disable=missing-docstring
keep_dims=False,
separator="",
name=None,
@@ -329,6 +333,8 @@ def reduce_join(inputs, axis=None,
reduce_join.__doc__ = deprecation.rewrite_argument_docstring(
gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis")
+reduce_join.__doc__ = reduce_join.__doc__.replace("tf.reduce_join(",
+ "tf.strings.reduce_join(")
# This wrapper provides backwards compatibility for code that predates the
@@ -341,6 +347,22 @@ def string_length(input, name=None, unit="BYTE"):
string_length.__doc__ = gen_string_ops.string_length.__doc__
+@tf_export("substr")
+@deprecation.deprecated(None, "Use `tf.strings.substr` instead of `tf.substr`.")
+def substr_deprecated(input, pos, len, name=None, unit="BYTE"):
+ return substr(input, pos, len, name=name, unit=unit)
+
+substr_deprecated.__doc__ = gen_string_ops.substr.__doc__
+
+
+@tf_export("strings.substr")
+def substr(input, pos, len, name=None, unit="BYTE"):
+ return gen_string_ops.substr(input, pos, len, unit=unit, name=name)
+
+
+substr.__doc__ = gen_string_ops.substr.__doc__
+
+
ops.NotDifferentiable("RegexReplace")
ops.NotDifferentiable("StringToHashBucket")
ops.NotDifferentiable("StringToHashBucketFast")
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index a43676cd70..5032ca79f9 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -198,7 +198,7 @@ VariableSynchronization = variables.VariableSynchronization # pylint: disable=i
VariableAggregation = variables.VariableAggregation # pylint: disable=invalid-name
AUTO_REUSE = _ReuseMode.AUTO_REUSE
-tf_export("AUTO_REUSE").export_constant(__name__, "AUTO_REUSE")
+tf_export(v1=["AUTO_REUSE"]).export_constant(__name__, "AUTO_REUSE")
AUTO_REUSE.__doc__ = """
When passed in as the value for the `reuse` flag, AUTO_REUSE indicates that
get_variable() should create the requested variable if it doesn't exist or, if
@@ -515,8 +515,10 @@ class _VariableStore(object):
"synchronization": synchronization,
"aggregation": aggregation,
}
- # `fn_args` can handle functions, `functools.partial`, `lambda`.
- if "constraint" in function_utils.fn_args(custom_getter):
+ # `fn_args` and `has_kwargs` can handle functions, `functools.partial`,
+ # `lambda`.
+ if ("constraint" in function_utils.fn_args(custom_getter) or
+ function_utils.has_kwargs(custom_getter)):
custom_getter_kwargs["constraint"] = constraint
return custom_getter(**custom_getter_kwargs)
else:
@@ -906,7 +908,7 @@ class _VariableStore(object):
if use_resource is None:
# Set the default value if unspecified.
use_resource = _DEFAULT_USE_RESOURCE
- v = variable(
+ v = variables.VariableV1(
initial_value=init_val,
name=name,
trainable=trainable,
@@ -937,7 +939,8 @@ class _VariableStore(object):
if regularizer:
with ops.colocate_with(v):
with ops.name_scope(name + "/Regularizer/"):
- loss = regularizer(v)
+ with ops.init_scope():
+ loss = regularizer(v)
if loss is not None:
if context.executing_eagerly():
v_name = "v_%s" % type(v)
@@ -992,7 +995,7 @@ def no_regularizer(_):
# TODO(alive): support caching devices and partitioned variables in Eager mode.
-@tf_export("VariableScope")
+@tf_export(v1=["VariableScope"])
class VariableScope(object):
"""Variable scope object to carry defaults to provide to `get_variable`.
@@ -1340,7 +1343,7 @@ def get_variable_scope_store():
return scope_store
-@tf_export("get_variable_scope")
+@tf_export(v1=["get_variable_scope"])
def get_variable_scope():
"""Returns the current variable scope."""
return get_variable_scope_store().current_scope
@@ -1449,7 +1452,7 @@ class EagerVariableStore(object):
# The argument list for get_variable must match arguments to get_local_variable.
# So, if you are updating the arguments, also update arguments to
# get_local_variable below.
-@tf_export("get_variable")
+@tf_export(v1=["get_variable"])
def get_variable(name,
shape=None,
dtype=None,
@@ -1594,7 +1597,7 @@ get_variable.__doc__ = get_variable_or_local_docstring % (
# The argument list for get_local_variable must match arguments to get_variable.
# So, if you are updating the arguments, also update arguments to get_variable.
-@tf_export("get_local_variable")
+@tf_export(v1=["get_local_variable"])
def get_local_variable( # pylint: disable=missing-docstring
name,
shape=None,
@@ -1939,7 +1942,7 @@ def _get_unique_variable_scope(prefix):
# Named like a function for backwards compatibility with the
# @tf_contextlib.contextmanager version, which was switched to a class to avoid
# some object creation overhead.
-@tf_export("variable_scope") # pylint: disable=invalid-name
+@tf_export(v1=["variable_scope"]) # pylint: disable=invalid-name
class variable_scope(object):
"""A context manager for defining ops that creates variables (layers).
@@ -2320,7 +2323,7 @@ class variable_scope(object):
# pylint: disable=g-doc-return-or-yield
-@tf_export("variable_op_scope")
+@tf_export(v1=["variable_op_scope"])
@tf_contextlib.contextmanager
def variable_op_scope(values,
name_or_scope,
@@ -2441,7 +2444,33 @@ def default_variable_creator(next_creator=None, **kwargs):
expected_shape=expected_shape, import_scope=import_scope)
+def default_variable_creator_v2(next_creator=None, **kwargs):
+ """Default variable creator."""
+ assert next_creator is None
+ initial_value = kwargs.get("initial_value", None)
+ trainable = kwargs.get("trainable", None)
+ validate_shape = kwargs.get("validate_shape", True)
+ caching_device = kwargs.get("caching_device", None)
+ name = kwargs.get("name", None)
+ variable_def = kwargs.get("variable_def", None)
+ dtype = kwargs.get("dtype", None)
+ import_scope = kwargs.get("import_scope", None)
+ constraint = kwargs.get("constraint", None)
+
+ # Set trainable value based on synchronization value.
+ synchronization = kwargs.get("synchronization", VariableSynchronization.AUTO)
+ trainable = _get_trainable_value(
+ synchronization=synchronization, trainable=trainable)
+
+ return resource_variable_ops.ResourceVariable(
+ initial_value=initial_value, trainable=trainable,
+ validate_shape=validate_shape, caching_device=caching_device,
+ name=name, dtype=dtype, constraint=constraint, variable_def=variable_def,
+ import_scope=import_scope)
+
+
variables.default_variable_creator = default_variable_creator
+variables.default_variable_creator_v2 = default_variable_creator_v2
def _make_getter(captured_getter, captured_previous):
@@ -2450,11 +2479,12 @@ def _make_getter(captured_getter, captured_previous):
# TODO(apassos) remove forwarding symbol
-variable = variables.Variable
+variable = variables.VariableV1
+@tf_export(v1=["variable_creator_scope"])
@tf_contextlib.contextmanager
-def variable_creator_scope(variable_creator):
+def variable_creator_scope_v1(variable_creator):
"""Scope which defines a variable creation function to be used by variable().
variable_creator is expected to be a function with the following signature:
@@ -2525,3 +2555,73 @@ def variable_creator_scope(variable_creator):
"""
with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access
yield
+
+
+# Note: only the docstrings differ between this and v1.
+@tf_export(v2=["variable_creator_scope"])
+@tf_contextlib.contextmanager
+def variable_creator_scope(variable_creator):
+ """Scope which defines a variable creation function to be used by variable().
+
+ variable_creator is expected to be a function with the following signature:
+
+ ```
+ def variable_creator(next_creator, **kwargs)
+ ```
+
+ The creator is supposed to eventually call the next_creator to create a
+ variable if it does want to create a variable and not call Variable or
+ ResourceVariable directly. This helps make creators composable. A creator may
+ choose to create multiple variables, return already existing variables, or
+ simply register that a variable was created and defer to the next creators in
+ line. Creators can also modify the keyword arguments seen by the next
+ creators.
+
+ Custom getters in the variable scope will eventually resolve down to these
+ custom creators when they do create variables.
+
+ The valid keyword arguments in kwds are:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
+ trainable: If `True`, the default, GradientTapes automatically watch
+ uses of this Variable.
+ validate_shape: If `False`, allows the variable to be initialized with a
+ value of unknown shape. If `True`, the default, the shape of
+ `initial_value` must be known.
+ caching_device: Optional device string describing where the Variable
+ should be cached for reading. Defaults to the Variable's device.
+ If not `None`, caches on another device. Typical use is to cache
+ on the device where the Ops using the Variable reside, to deduplicate
+ copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ dtype: If set, initial_value will be converted to the given type.
+ If `None`, either the datatype will be kept (if `initial_value` is
+ a Tensor), or `convert_to_tensor` will decide.
+ constraint: A constraint function to be applied to the variable after
+ updates by some algorithms.
+ synchronization: Indicates when a distributed a variable will be
+ aggregated. Accepted values are constants defined in the class
+ `tf.VariableSynchronization`. By default the synchronization is set to
+ `AUTO` and the current `DistributionStrategy` chooses
+ when to synchronize. If `synchronization` is set to `ON_READ`,
+ `trainable` must not be set to `True`.
+ aggregation: Indicates how a distributed variable will be aggregated.
+ Accepted values are constants defined in the class
+ `tf.VariableAggregation`.
+
+ This set may grow over time, so it's important the signature of creators is as
+ mentioned above.
+
+ Args:
+ variable_creator: the passed creator
+
+ Yields:
+ A scope in which the creator is active
+ """
+ with ops.get_default_graph()._variable_creator_scope(variable_creator): # pylint: disable=protected-access
+ yield
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 7a46157739..45c8618610 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -46,6 +46,11 @@ def default_variable_creator(_, **kwds):
raise NotImplementedError("variable_scope needs to be imported")
+def default_variable_creator_v2(_, **kwds):
+ del kwds
+ raise NotImplementedError("variable_scope needs to be imported")
+
+
def _make_getter(captured_getter, captured_previous):
"""To avoid capturing loop variables."""
def getter(**kwargs):
@@ -101,21 +106,21 @@ class VariableAggregation(enum.Enum):
class VariableMetaclass(type):
"""Metaclass to allow construction of tf.Variable to be overridden."""
- def _variable_call(cls,
- initial_value=None,
- trainable=None,
- collections=None,
- validate_shape=True,
- caching_device=None,
- name=None,
- variable_def=None,
- dtype=None,
- expected_shape=None,
- import_scope=None,
- constraint=None,
- use_resource=None,
- synchronization=VariableSynchronization.AUTO,
- aggregation=VariableAggregation.NONE):
+ def _variable_v1_call(cls,
+ initial_value=None,
+ trainable=None,
+ collections=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ expected_shape=None,
+ import_scope=None,
+ constraint=None,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
"""Call on Variable class. Useful to force the signature."""
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
@@ -140,14 +145,49 @@ class VariableMetaclass(type):
synchronization=synchronization,
aggregation=aggregation)
+ def _variable_v2_call(cls,
+ initial_value=None,
+ trainable=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ import_scope=None,
+ constraint=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
+ """Call on Variable class. Useful to force the signature."""
+ previous_getter = lambda **kws: default_variable_creator_v2(None, **kws)
+ for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
+ previous_getter = _make_getter(getter, previous_getter)
+
+ # Reset `aggregation` that is explicitly set as `None` to the enum NONE.
+ if aggregation is None:
+ aggregation = VariableAggregation.NONE
+ return previous_getter(
+ initial_value=initial_value,
+ trainable=trainable,
+ validate_shape=validate_shape,
+ caching_device=caching_device,
+ name=name,
+ variable_def=variable_def,
+ dtype=dtype,
+ import_scope=import_scope,
+ constraint=constraint,
+ synchronization=synchronization,
+ aggregation=aggregation)
+
def __call__(cls, *args, **kwargs):
- if cls is Variable:
- return cls._variable_call(*args, **kwargs)
+ if cls is VariableV1:
+ return cls._variable_v1_call(*args, **kwargs)
+ elif cls is Variable:
+ return cls._variable_v2_call(*args, **kwargs)
else:
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
-@tf_export("Variable")
+@tf_export(v2=["Variable"])
class Variable(six.with_metaclass(VariableMetaclass,
checkpointable.CheckpointableBase)):
"""See the [Variables Guide](https://tensorflow.org/guide/variables).
@@ -267,16 +307,13 @@ class Variable(six.with_metaclass(VariableMetaclass,
def __init__(self,
initial_value=None,
trainable=True,
- collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
- expected_shape=None,
import_scope=None,
constraint=None,
- use_resource=None,
synchronization=VariableSynchronization.AUTO,
aggregation=VariableAggregation.NONE):
"""Creates a new variable with value `initial_value`.
@@ -297,11 +334,8 @@ class Variable(six.with_metaclass(VariableMetaclass,
callable with no argument that returns the initial value when called. In
that case, `dtype` must be specified. (Note that initializer functions
from init_ops.py must first be bound to a shape before being used here.)
- trainable: If `True`, the default, also adds the variable to the graph
- collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
- the default list of variables to use by the `Optimizer` classes.
- collections: List of graph collections keys. The new variable is added to
- these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+ trainable: If `True`, the default, GradientTapes automatically watch uses
+ of this variable.
validate_shape: If `False`, allows the variable to be initialized with a
value of unknown shape. If `True`, the default, the shape of
`initial_value` must be known.
@@ -319,8 +353,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
dtype: If set, initial_value will be converted to the given type.
If `None`, either the datatype will be kept (if `initial_value` is
a Tensor), or `convert_to_tensor` will decide.
- expected_shape: A TensorShape. If set, initial_value is expected
- to have this shape.
import_scope: Optional `string`. Name scope to add to the
`Variable.` Only used when initializing from protocol buffer.
constraint: An optional projection function to be applied to the variable
@@ -330,9 +362,6 @@ class Variable(six.with_metaclass(VariableMetaclass,
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
- use_resource: if True, a ResourceVariable is created; otherwise an
- old-style ref-based variable is created. When eager execution is enabled
- a resource variable is always created.
synchronization: Indicates when a distributed a variable will be
aggregated. Accepted values are constants defined in the class
`tf.VariableSynchronization`. By default the synchronization is set to
@@ -1009,11 +1038,207 @@ class Variable(six.with_metaclass(VariableMetaclass,
raise NotImplementedError
+@tf_export(v1=["Variable"])
+class VariableV1(Variable):
+ """See the [Variables Guide](https://tensorflow.org/guide/variables).
+
+ A variable maintains state in the graph across calls to `run()`. You add a
+ variable to the graph by constructing an instance of the class `Variable`.
+
+ The `Variable()` constructor requires an initial value for the variable,
+ which can be a `Tensor` of any type and shape. The initial value defines the
+ type and shape of the variable. After construction, the type and shape of
+ the variable are fixed. The value can be changed using one of the assign
+ methods.
+
+ If you want to change the shape of a variable later you have to use an
+ `assign` Op with `validate_shape=False`.
+
+ Just like any `Tensor`, variables created with `Variable()` can be used as
+ inputs for other Ops in the graph. Additionally, all the operators
+ overloaded for the `Tensor` class are carried over to variables, so you can
+ also add nodes to the graph by just doing arithmetic on variables.
+
+ ```python
+ import tensorflow as tf
+
+ # Create a variable.
+ w = tf.Variable(<initial-value>, name=<optional-name>)
+
+ # Use the variable in the graph like any Tensor.
+ y = tf.matmul(w, ...another variable or tensor...)
+
+ # The overloaded operators are available too.
+ z = tf.sigmoid(w + y)
+
+ # Assign a new value to the variable with `assign()` or a related method.
+ w.assign(w + 1.0)
+ w.assign_add(1.0)
+ ```
+
+ When you launch the graph, variables have to be explicitly initialized before
+ you can run Ops that use their value. You can initialize a variable by
+ running its *initializer op*, restoring the variable from a save file, or
+ simply running an `assign` Op that assigns a value to the variable. In fact,
+ the variable *initializer op* is just an `assign` Op that assigns the
+ variable's initial value to the variable itself.
+
+ ```python
+ # Launch the graph in a session.
+ with tf.Session() as sess:
+ # Run the variable initializer.
+ sess.run(w.initializer)
+ # ...you now can run ops that use the value of 'w'...
+ ```
+
+ The most common initialization pattern is to use the convenience function
+ `global_variables_initializer()` to add an Op to the graph that initializes
+ all the variables. You then run that Op after launching the graph.
+
+ ```python
+ # Add an Op to initialize global variables.
+ init_op = tf.global_variables_initializer()
+
+ # Launch the graph in a session.
+ with tf.Session() as sess:
+ # Run the Op that initializes global variables.
+ sess.run(init_op)
+ # ...you can now run any Op that uses variable values...
+ ```
+
+ If you need to create a variable with an initial value dependent on another
+ variable, use the other variable's `initialized_value()`. This ensures that
+ variables are initialized in the right order.
+
+ All variables are automatically collected in the graph where they are
+ created. By default, the constructor adds the new variable to the graph
+ collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function
+ `global_variables()` returns the contents of that collection.
+
+ When building a machine learning model it is often convenient to distinguish
+ between variables holding the trainable model parameters and other variables
+ such as a `global step` variable used to count training steps. To make this
+ easier, the variable constructor supports a `trainable=<bool>` parameter. If
+ `True`, the new variable is also added to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`. The convenience function
+ `trainable_variables()` returns the contents of this collection. The
+ various `Optimizer` classes use this collection as the default list of
+ variables to optimize.
+
+ WARNING: tf.Variable objects by default have a non-intuitive memory model. A
+ Variable is represented internally as a mutable Tensor which can
+ non-deterministically alias other Tensors in a graph. The set of operations
+ which consume a Variable and can lead to aliasing is undetermined and can
+ change across TensorFlow versions. Avoid writing code which relies on the
+ value of a Variable either changing or not changing as other operations
+ happen. For example, using Variable objects or simple functions thereof as
+ predicates in a `tf.cond` is dangerous and error-prone:
+
+ ```
+ v = tf.Variable(True)
+ tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken.
+ ```
+
+ Here replacing adding `use_resource=True` when constructing the variable will
+ fix any nondeterminism issues:
+ ```
+ v = tf.Variable(True, use_resource=True)
+ tf.cond(v, lambda: v.assign(False), my_false_fn)
+ ```
+
+ To use the replacement for variables which does
+ not have these issues:
+
+ * Add `use_resource=True` when constructing `tf.Variable`;
+ * Call `tf.get_variable_scope().set_use_resource(True)` inside a
+ `tf.variable_scope` before the `tf.get_variable()` call.
+ """
+
+ def __init__(self, # pylint: disable=super-init-not-called
+ initial_value=None,
+ trainable=True,
+ collections=None,
+ validate_shape=True,
+ caching_device=None,
+ name=None,
+ variable_def=None,
+ dtype=None,
+ expected_shape=None,
+ import_scope=None,
+ constraint=None,
+ use_resource=None,
+ synchronization=VariableSynchronization.AUTO,
+ aggregation=VariableAggregation.NONE):
+ """Creates a new variable with value `initial_value`.
+
+ The new variable is added to the graph collections listed in `collections`,
+ which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+
+ If `trainable` is `True` the variable is also added to the graph collection
+ `GraphKeys.TRAINABLE_VARIABLES`.
+
+ This constructor creates both a `variable` Op and an `assign` Op to set the
+ variable to its initial value.
+
+ Args:
+ initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
+ which is the initial value for the Variable. The initial value must have
+ a shape specified unless `validate_shape` is set to False. Can also be a
+ callable with no argument that returns the initial value when called. In
+ that case, `dtype` must be specified. (Note that initializer functions
+ from init_ops.py must first be bound to a shape before being used here.)
+ trainable: If `True`, the default, also adds the variable to the graph
+ collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
+ the default list of variables to use by the `Optimizer` classes.
+ collections: List of graph collections keys. The new variable is added to
+ these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
+ validate_shape: If `False`, allows the variable to be initialized with a
+ value of unknown shape. If `True`, the default, the shape of
+ `initial_value` must be known.
+ caching_device: Optional device string describing where the Variable
+ should be cached for reading. Defaults to the Variable's device.
+ If not `None`, caches on another device. Typical use is to cache
+ on the device where the Ops using the Variable reside, to deduplicate
+ copying through `Switch` and other conditional statements.
+ name: Optional name for the variable. Defaults to `'Variable'` and gets
+ uniquified automatically.
+ variable_def: `VariableDef` protocol buffer. If not `None`, recreates
+ the Variable object with its contents, referencing the variable's nodes
+ in the graph, which must already exist. The graph is not changed.
+ `variable_def` and the other arguments are mutually exclusive.
+ dtype: If set, initial_value will be converted to the given type.
+ If `None`, either the datatype will be kept (if `initial_value` is
+ a Tensor), or `convert_to_tensor` will decide.
+ expected_shape: A TensorShape. If set, initial_value is expected
+ to have this shape.
+ import_scope: Optional `string`. Name scope to add to the
+ `Variable.` Only used when initializing from protocol buffer.
+ constraint: An optional projection function to be applied to the variable
+ after being updated by an `Optimizer` (e.g. used to implement norm
+ constraints or value constraints for layer weights). The function must
+ take as input the unprojected Tensor representing the value of the
+ variable and return the Tensor for the projected value
+ (which must have the same shape). Constraints are not safe to
+ use when doing asynchronous distributed training.
+ use_resource: whether to use resource variables.
+ synchronization: unused
+ aggregation: unused
+
+ Raises:
+ ValueError: If both `variable_def` and initial_value are specified.
+ ValueError: If the initial value is not specified, or does not have a
+ shape and `validate_shape` is `True`.
+ RuntimeError: If eager execution is enabled.
+ """
+
+ SaveSliceInfo = Variable.SaveSliceInfo
+
+
# TODO(apassos): do not repeat all comments here
-class RefVariable(Variable):
+class RefVariable(VariableV1):
"""Ref-based implementation of variables."""
- def __init__(self,
+ def __init__(self, # pylint: disable=super-init-not-called
initial_value=None,
trainable=True,
collections=None,
@@ -1873,7 +2098,7 @@ class RefVariable(Variable):
def _OverloadAllOperators(): # pylint: disable=invalid-name
"""Register overloads for all operators."""
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
- Variable._OverloadOperator(operator)
+ Variable._OverloadOperator(operator) # pylint: disable=protected-access
# For slicing, bind getitem differently than a tensor (use SliceHelperVar
# instead)
# pylint: disable=protected-access
@@ -2395,13 +2620,53 @@ class PartitionedVariable(object):
def _get_partitions(self):
return self._partitions
- def assign(self, value, use_locking=False):
- _ = value, use_locking
- raise NotImplementedError(
- "assign() has not been implemented for PartitionedVariable.")
+ def _apply_assign_fn(self, assign_fn, value):
+ partition_axes = self._partition_axes()
+ if len(partition_axes) > 1:
+ raise NotImplementedError(
+ "Cannot do assign action along more than one dimension: %s. "
+ "Multi-axis partition assign action is not supported " %
+ str(partition_axes))
+ partition_ix = partition_axes[0]
+ size_splits_list = [
+ var.shape[partition_ix].value for var in self._variable_list
+ ]
+ value_list = array_ops.split(value, size_splits_list, axis=partition_ix)
+ op_list = [
+ assign_fn(var, value_list[idx], idx)
+ for idx, var in enumerate(self._variable_list)
+ ]
+ return op_list
+
+ def assign(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
+
+ def assign_add(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign_add(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
+
+ def assign_sub(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign_sub(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
-@tf_export("global_variables")
+@tf_export(v1=["global_variables"])
def global_variables(scope=None):
"""Returns global variables.
@@ -2427,7 +2692,7 @@ def global_variables(scope=None):
return ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES, scope)
-@tf_export("all_variables")
+@tf_export(v1=["all_variables"])
@deprecated("2017-03-02", "Please use tf.global_variables instead.")
def all_variables():
"""See `tf.global_variables`."""
@@ -2452,7 +2717,7 @@ def _all_saveable_objects(scope=None):
ops.get_collection(ops.GraphKeys.SAVEABLE_OBJECTS, scope))
-@tf_export("local_variables")
+@tf_export(v1=["local_variables"])
def local_variables(scope=None):
"""Returns local variables.
@@ -2480,7 +2745,7 @@ def local_variables(scope=None):
return ops.get_collection(ops.GraphKeys.LOCAL_VARIABLES, scope)
-@tf_export("model_variables")
+@tf_export(v1=["model_variables"])
def model_variables(scope=None):
"""Returns all variables in the MODEL_VARIABLES collection.
@@ -2497,7 +2762,7 @@ def model_variables(scope=None):
return ops.get_collection(ops.GraphKeys.MODEL_VARIABLES, scope)
-@tf_export("trainable_variables")
+@tf_export(v1=["trainable_variables"])
def trainable_variables(scope=None):
"""Returns all variables created with `trainable=True`.
@@ -2519,7 +2784,7 @@ def trainable_variables(scope=None):
return ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES, scope)
-@tf_export("moving_average_variables")
+@tf_export(v1=["moving_average_variables"])
def moving_average_variables(scope=None):
"""Returns all variables that maintain their moving averages.
@@ -2541,7 +2806,7 @@ def moving_average_variables(scope=None):
return ops.get_collection(ops.GraphKeys.MOVING_AVERAGE_VARIABLES, scope)
-@tf_export("initializers.variables", "variables_initializer")
+@tf_export(v1=["initializers.variables", "variables_initializer"])
def variables_initializer(var_list, name="init"):
"""Returns an Op that initializes a list of variables.
@@ -2567,7 +2832,7 @@ def variables_initializer(var_list, name="init"):
return control_flow_ops.no_op(name=name)
-@tf_export("initialize_variables")
+@tf_export(v1=["initialize_variables"])
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.variables_initializer` instead.")
def initialize_variables(var_list, name="init"):
@@ -2575,7 +2840,7 @@ def initialize_variables(var_list, name="init"):
return variables_initializer(var_list, name=name)
-@tf_export("initializers.global_variables", "global_variables_initializer")
+@tf_export(v1=["initializers.global_variables", "global_variables_initializer"])
def global_variables_initializer():
"""Returns an Op that initializes global variables.
@@ -2589,7 +2854,7 @@ def global_variables_initializer():
return variables_initializer(global_variables())
-@tf_export("initialize_all_variables")
+@tf_export(v1=["initialize_all_variables"])
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.global_variables_initializer` instead.")
def initialize_all_variables():
@@ -2597,7 +2862,7 @@ def initialize_all_variables():
return global_variables_initializer()
-@tf_export("initializers.local_variables", "local_variables_initializer")
+@tf_export(v1=["initializers.local_variables", "local_variables_initializer"])
def local_variables_initializer():
"""Returns an Op that initializes all local variables.
@@ -2611,7 +2876,7 @@ def local_variables_initializer():
return variables_initializer(local_variables())
-@tf_export("initialize_local_variables")
+@tf_export(v1=["initialize_local_variables"])
@tf_should_use.should_use_result
@deprecated("2017-03-02", "Use `tf.local_variables_initializer` instead.")
def initialize_local_variables():
@@ -2619,7 +2884,7 @@ def initialize_local_variables():
return local_variables_initializer()
-@tf_export("is_variable_initialized")
+@tf_export(v1=["is_variable_initialized"])
@tf_should_use.should_use_result
def is_variable_initialized(variable):
"""Tests if a variable has been initialized.
@@ -2634,7 +2899,7 @@ def is_variable_initialized(variable):
return state_ops.is_variable_initialized(variable)
-@tf_export("assert_variables_initialized")
+@tf_export(v1=["assert_variables_initialized"])
@tf_should_use.should_use_result
def assert_variables_initialized(var_list=None):
"""Returns an Op to check if variables are initialized.
@@ -2677,7 +2942,7 @@ def assert_variables_initialized(var_list=None):
return array_ops.stack(ranks)
-@tf_export("report_uninitialized_variables")
+@tf_export(v1=["report_uninitialized_variables"])
@tf_should_use.should_use_result
def report_uninitialized_variables(var_list=None,
name="report_uninitialized_variables"):
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 875be31602..8e88a84d60 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import sys
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import function
@@ -31,8 +32,10 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl as cond_v2
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
@@ -41,6 +44,8 @@ from tensorflow.python.util import nest
# pylint: disable=protected-access
+control_flow_ops._while_v2 = sys.modules[__name__]
+
# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows
# control dependencies on external nodes with at least 1 output.
# Another idea is to create const nodes outside the loop and add control edges
@@ -48,8 +53,17 @@ from tensorflow.python.util import nest
# handled in the CapturingGraph itself.
-def while_loop(cond, body, loop_vars, name=None):
+def while_loop(cond, body, loop_vars, shape_invariants=None, name=None):
"""Like tf.while_loop, except emits a single While op."""
+ flattened_loop_vars = nest.flatten(loop_vars)
+ if shape_invariants is not None:
+ nest.assert_same_structure(loop_vars, shape_invariants)
+ flattened_shapes = nest.flatten(shape_invariants)
+ else:
+ flattened_shapes = [t.shape for t in flattened_loop_vars]
+
+ del shape_invariants
+
if not name:
name = "while"
@@ -58,25 +72,33 @@ def while_loop(cond, body, loop_vars, name=None):
cond_name = _get_unique_name(("%scond" % scope).replace("/", "_"))
body_name = _get_unique_name(("%sbody" % scope).replace("/", "_"))
- flattened_loop_vars = nest.flatten(loop_vars)
num_outputs = len(flattened_loop_vars)
# Add loop counter needed for computing gradients.
flattened_loop_vars = [constant_op.constant(0., name="loop_counter")
] + flattened_loop_vars
+ flattened_shapes = [tensor_shape.scalar()] + flattened_shapes
+
# Build a `cond` wrapper that can handle the extra counter loop_var.
def wrapped_cond(unused_loop_counter, *loop_vars):
return cond(*loop_vars)
- cond_graph = function.func_graph_from_py_func(cond_name, wrapped_cond,
- flattened_loop_vars, {})
+ signature = [
+ tensor_spec.TensorSpec(shape, t.dtype)
+ for shape, t in zip(flattened_shapes, flattened_loop_vars)
+ ]
+ cond_graph = function.func_graph_from_py_func(
+ cond_name, wrapped_cond, flattened_loop_vars, {}, signature=signature)
# Add external_captures of cond to the list of loop vars.
# Note that external tensors will be treated as loop invariants, i.e.,
# the value of that tensor in each iteration is the same as it was at the
# beginning of the loop execution.
flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures
+ flattened_shapes = flattened_shapes + [
+ t.shape for t in cond_graph.external_captures
+ ]
def wrapped_body(loop_counter, *args):
"""Loop body augmented with counter update.
@@ -101,8 +123,12 @@ def while_loop(cond, body, loop_vars, name=None):
# is_constant=True for inputs that are directly passed to outputs.
return [loop_counter + 1] + list(outputs) + list(args[num_outputs:])
- body_graph = function.func_graph_from_py_func(body_name, wrapped_body,
- flattened_loop_vars, {})
+ signature = [
+ tensor_spec.TensorSpec(shape, t.dtype)
+ for shape, t in zip(flattened_shapes, flattened_loop_vars)
+ ]
+ body_graph = function.func_graph_from_py_func(
+ body_name, wrapped_body, flattened_loop_vars, {}, signature=signature)
# Add external captures of body to the list of loop vars.
# Note that external tensors will be treated as loop invariants, i.e.,
# the value of that tensor in each iteration is the same as it was at the
@@ -145,10 +171,17 @@ def while_loop(cond, body, loop_vars, name=None):
# Add this modified tensor list to the list of outputs.
body_graph.outputs.append(appended_tensor_list)
+ # Make sure that the shapes of the loop outputs are compatible with the
+ # shape invariants, or the shapes of the loop vars if the invariants are not
+ # specified.
+ _check_shapes_compat(body_graph.outputs[1:1 + num_outputs],
+ flattened_shapes[1:1 + num_outputs],
+ flattened_loop_vars[1:1 + num_outputs])
outputs = gen_functional_ops._while(
flattened_loop_vars,
cond_v2._create_new_tf_function(cond_graph),
cond_v2._create_new_tf_function(body_graph),
+ output_shapes=[t.shape for t in body_graph.outputs],
name=scope)
_copy_handle_data(body_graph.outputs, outputs)
@@ -212,6 +245,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
loop_vars,
cond_v2._create_new_tf_function(cond_grad_graph),
cond_v2._create_new_tf_function(body_grad_graph),
+ output_shapes=[t.shape for t in body_grad_graph.outputs],
name=_get_unique_name("%s_grad" % op.name))
_copy_handle_data(body_grad_graph.outputs, outputs)
@@ -232,8 +266,10 @@ def _get_body_graph(while_op):
Returns:
`FuncGraph` for the while body.
"""
- extra_inputs = list(while_op.inputs)
- input_shapes = [t.shape for t in extra_inputs]
+ # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
+ input_shapes = [
+ tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
+ ]
func_name = while_op.get_attr("body").name
fdef = while_op.graph._get_function(func_name).definition
func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
@@ -531,6 +567,17 @@ class _WhileBodyGradFuncGraph(function.FuncGraph):
return captured_tensor
+def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
+ for (t, shape, input_t) in zip(output_tensors, shape_invariants,
+ input_tensors):
+ if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape):
+ raise ValueError(
+ "Input tensor '%s' enters the loop with shape %s, but has "
+ "shape %s after one iteration. To allow the shape to vary across "
+ "iterations, use the `shape_invariants` argument of tf.while_loop to "
+ "specify a less-specific shape." % (input_t.name, shape, t.shape))
+
+
def _copy_handle_data(src_tensors, tgt_tensors):
for src_t, tgt_t in zip(src_tensors, tgt_tensors):
function._copy_handle_data(src_t, tgt_t)
diff --git a/tensorflow/python/platform/tf_logging.py b/tensorflow/python/platform/tf_logging.py
index 5962d2f220..59e60856ae 100644
--- a/tensorflow/python/platform/tf_logging.py
+++ b/tensorflow/python/platform/tf_logging.py
@@ -25,6 +25,7 @@ import logging as _logging
import os as _os
import sys as _sys
import time as _time
+import traceback as _traceback
from logging import DEBUG
from logging import ERROR
from logging import FATAL
@@ -36,13 +37,49 @@ import six
from tensorflow.python.util.tf_export import tf_export
-
# Don't use this directly. Use _get_logger() instead.
_logger = None
_logger_lock = threading.Lock()
+def _get_caller(offset=3):
+ """Returns a code and frame object for the lowest non-logging stack frame."""
+ # Use sys._getframe(). This avoids creating a traceback object.
+ # pylint: disable=protected-access
+ f = _sys._getframe(offset)
+ # pylint: enable=protected-access
+ our_file = f.f_code.co_filename
+ f = f.f_back
+ while f:
+ code = f.f_code
+ if code.co_filename != our_file:
+ return code, f
+ f = f.f_back
+ return None, None
+
+
+# The definition of `findCaller` changed in Python 3.2
+if _sys.version_info.major >= 3 and _sys.version_info.minor >= 2:
+ def _logger_find_caller(stack_info=False): # pylint: disable=g-wrong-blank-lines
+ code, frame = _get_caller(4)
+ sinfo = None
+ if stack_info:
+ sinfo = '\n'.join(_traceback.format_stack())
+ if code:
+ return (code.co_filename, frame.f_lineno, code.co_name, sinfo)
+ else:
+ return '(unknown file)', 0, '(unknown function)', sinfo
+else:
+ def _logger_find_caller(): # pylint: disable=g-wrong-blank-lines
+ code, frame = _get_caller(4)
+ if code:
+ return (code.co_filename, frame.f_lineno, code.co_name)
+ else:
+ return '(unknown file)', 0, '(unknown function)'
+
+
def _get_logger():
+ """Return TF logger instance."""
global _logger
# Use double-checked locking to avoid taking lock unnecessarily.
@@ -58,6 +95,9 @@ def _get_logger():
# Scope the TensorFlow logger to not conflict with users' loggers.
logger = _logging.getLogger('tensorflow')
+ # Override findCaller on the logger to skip internal helper functions
+ logger.findCaller = _logger_find_caller
+
# Don't further configure the TensorFlow logger if the root logger is
# already configured. This prevents double logging in those cases.
if not _logging.getLogger().handlers:
@@ -216,18 +256,10 @@ def log_if(level, msg, condition, *args):
def _GetFileAndLine():
"""Returns (filename, linenumber) for the stack frame."""
- # Use sys._getframe(). This avoids creating a traceback object.
- # pylint: disable=protected-access
- f = _sys._getframe()
- # pylint: enable=protected-access
- our_file = f.f_code.co_filename
- f = f.f_back
- while f:
- code = f.f_code
- if code.co_filename != our_file:
- return (code.co_filename, f.f_lineno)
- f = f.f_back
- return ('<unknown>', 0)
+ code, f = _get_caller()
+ if not code:
+ return ('<unknown>', 0)
+ return (code.co_filename, f.f_lineno)
def google2_log_prefix(level, timestamp=None, file_and_line=None):
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index c411a58b70..61e0abbfcb 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -67,6 +67,7 @@ limitations under the License.
%rename("%s") TFE_ContextStartStep;
%rename("%s") TFE_ContextEndStep;
%rename("%s") TFE_Py_RegisterVSpace;
+%rename("%s") TFE_Py_EncodeArg;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 8e7f123a85..8bf057f69d 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -36,10 +36,13 @@ from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated_args
+from tensorflow.python.util.deprecation import deprecated_endpoints
from tensorflow.python.util.tf_export import tf_export
-@tf_export("saved_model.builder.SavedModelBuilder")
+@tf_export("saved_model.Builder",
+ "saved_model.builder.SavedModelBuilder")
+@deprecated_endpoints("saved_model.builder.SavedModelBuilder")
class SavedModelBuilder(object):
"""Builds the `SavedModel` protocol buffer and saves variables and assets.
@@ -61,7 +64,7 @@ class SavedModelBuilder(object):
Typical usage for the `SavedModelBuilder`:
```python
...
- builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
+ builder = tf.saved_model.Builder(export_dir)
with tf.Session(graph=tf.Graph()) as sess:
...
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index e8536108e8..895644a030 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -34,6 +34,7 @@ from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -144,7 +145,10 @@ def _get_main_op_tensor(
return main_op_tensor
-@tf_export("saved_model.loader.maybe_saved_model_directory")
+@tf_export("saved_model.maybe_saved_model_directory",
+ "saved_model.loader.maybe_saved_model_directory")
+@deprecation.deprecated_endpoints(
+ "saved_model.loader.maybe_saved_model_directory")
def maybe_saved_model_directory(export_dir):
"""Checks whether the provided export directory could contain a SavedModel.
@@ -165,7 +169,7 @@ def maybe_saved_model_directory(export_dir):
return file_io.file_exists(txt_path) or file_io.file_exists(pb_path)
-@tf_export("saved_model.loader.load")
+@tf_export("saved_model.load", "saved_model.loader.load")
def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
"""Loads the model from a SavedModel as specified by tags.
diff --git a/tensorflow/python/saved_model/loader_test.py b/tensorflow/python/saved_model/loader_test.py
index b7e217a35b..924b2e7c06 100644
--- a/tensorflow/python/saved_model/loader_test.py
+++ b/tensorflow/python/saved_model/loader_test.py
@@ -47,8 +47,8 @@ class SavedModelLoaderTest(test.TestCase):
def setUp(self):
"""Write test SavedModels to a temp directory."""
with session.Session(graph=ops.Graph()) as sess:
- x = variables.Variable(5, name="x")
- y = variables.Variable(11, name="y")
+ x = variables.VariableV1(5, name="x")
+ y = variables.VariableV1(11, name="y")
z = x + y
sess.run(variables.global_variables_initializer())
@@ -134,8 +134,8 @@ class SavedModelLoaderTest(test.TestCase):
def test_restore_variables(self):
loader = loader_impl.SavedModelLoader(SAVED_MODEL_WITH_MAIN_OP)
with self.session(graph=ops.Graph()) as sess:
- x = variables.Variable(0, name="x")
- y = variables.Variable(0, name="y")
+ x = variables.VariableV1(0, name="x")
+ y = variables.VariableV1(0, name="y")
z = x * y
sess.run(variables.global_variables_initializer())
@@ -186,8 +186,10 @@ class SavedModelLoaderTest(test.TestCase):
"""
path = _get_export_dir("no_variable_saved_model")
with session.Session(graph=ops.Graph()) as sess:
- x = variables.Variable(5, name="x", collections=["not_global_variable"])
- y = variables.Variable(11, name="y", collections=["not_global_variable"])
+ x = variables.VariableV1(
+ 5, name="x", collections=["not_global_variable"])
+ y = variables.VariableV1(
+ 11, name="y", collections=["not_global_variable"])
self.assertFalse(variables._all_saveable_objects())
z = x + y
sess.run(variables.variables_initializer([x, y]))
diff --git a/tensorflow/python/saved_model/main_op_impl.py b/tensorflow/python/saved_model/main_op_impl.py
index 631ee63729..ad4511b28e 100644
--- a/tensorflow/python/saved_model/main_op_impl.py
+++ b/tensorflow/python/saved_model/main_op_impl.py
@@ -22,6 +22,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -42,7 +43,9 @@ def main_op():
# TODO(sukritiramesh): Integrate with Saver for complete restore functionality.
-@tf_export('saved_model.main_op.main_op_with_restore')
+@tf_export('saved_model.main_op_with_restore',
+ 'saved_model.main_op.main_op_with_restore')
+@deprecation.deprecated_endpoints('saved_model.main_op.main_op_with_restore')
def main_op_with_restore(restore_op_name):
"""Returns a main op to init variables, tables and restore the graph.
diff --git a/tensorflow/python/saved_model/saved_model_test.py b/tensorflow/python/saved_model/saved_model_test.py
index 49d52d3bee..80b75b7ee6 100644
--- a/tensorflow/python/saved_model/saved_model_test.py
+++ b/tensorflow/python/saved_model/saved_model_test.py
@@ -60,7 +60,7 @@ class SavedModelTest(test.TestCase):
return os.path.join(test.get_temp_dir(), label)
def _init_and_validate_variable(self, sess, variable_name, variable_value):
- v = variables.Variable(variable_value, name=variable_name)
+ v = variables.VariableV1(variable_value, name=variable_name)
sess.run(variables.global_variables_initializer())
self.assertEqual(variable_value, v.eval())
@@ -458,7 +458,7 @@ class SavedModelTest(test.TestCase):
# Graph with a single variable added to a collection. SavedModel invoked to:
# - add with weights.
with self.session(graph=ops.Graph()) as sess:
- v = variables.Variable(42, name="v")
+ v = variables.VariableV1(42, name="v")
ops.add_to_collection("foo_vars", v)
sess.run(variables.global_variables_initializer())
self.assertEqual(42, v.eval())
@@ -468,7 +468,7 @@ class SavedModelTest(test.TestCase):
# SavedModel invoked to:
# - simply add the model (weights are not updated).
with self.session(graph=ops.Graph()) as sess:
- v = variables.Variable(43, name="v")
+ v = variables.VariableV1(43, name="v")
ops.add_to_collection("bar_vars", v)
sess.run(variables.global_variables_initializer())
self.assertEqual(43, v.eval())
@@ -780,13 +780,13 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
- v3 = variables.Variable(42, name="v3")
+ v3 = variables.VariableV1(42, name="v3")
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the main_op.
@@ -815,13 +815,13 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
# Initialize another variable `v3` to 42.
- v3 = variables.Variable(42, name="v3", trainable=False, collections=[])
+ v3 = variables.VariableV1(42, name="v3", trainable=False, collections=[])
ops.add_to_collection("v", v3)
# Set up an assignment op to be run as part of the legacy_init_op.
@@ -860,11 +860,11 @@ class SavedModelTest(test.TestCase):
g = ops.Graph()
with self.session(graph=g) as sess:
# Initialize variable `v1` to 1.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
# Initialize another variable `v2` to 42.
- v2 = variables.Variable(42, name="v2", trainable=False, collections=[])
+ v2 = variables.VariableV1(42, name="v2", trainable=False, collections=[])
ops.add_to_collection("v", v2)
# Set up an assignment op to be run as part of the init op.
@@ -889,9 +889,9 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
@@ -918,9 +918,9 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
@@ -947,9 +947,9 @@ class SavedModelTest(test.TestCase):
with self.session(graph=ops.Graph()) as sess:
# Add `v1` and `v2` variables to the graph.
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
ops.add_to_collection("v", v1)
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
ops.add_to_collection("v", v2)
sess.run(variables.global_variables_initializer())
@@ -1071,13 +1071,13 @@ class SavedModelTest(test.TestCase):
graph=ops.Graph(),
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v1 = variables.Variable(1, name="v1")
+ v1 = variables.VariableV1(1, name="v1")
with sess.graph.device("/cpu:1"):
- v2 = variables.Variable(2, name="v2")
+ v2 = variables.VariableV1(2, name="v2")
# v3 is an unsaved variable derived from v1 and v2. It is used to
# exercise the ability to run an init op when restoring a graph.
- v3 = variables.Variable(1, name="v3", trainable=False, collections=[])
+ v3 = variables.VariableV1(1, name="v3", trainable=False, collections=[])
assign_v3 = state_ops.assign(v3, math_ops.add(v1, v2))
init_op = control_flow_ops.group(assign_v3, name="init_op")
@@ -1140,7 +1140,7 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.session(graph=ops.Graph()) as sess:
- variables.Variable(1, name="v1")
+ variables.VariableV1(1, name="v1")
sess.run(variables.global_variables_initializer())
custom_saver = training.Saver(name="my_saver")
builder.add_meta_graph_and_variables(sess, ["tag"], saver=custom_saver)
@@ -1162,7 +1162,7 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.session(graph=ops.Graph()) as sess:
- variables.Variable(1, name="v1")
+ variables.VariableV1(1, name="v1")
sess.run(variables.global_variables_initializer())
training.Saver(name="my_saver")
builder.add_meta_graph_and_variables(sess, ["tag"])
@@ -1184,7 +1184,7 @@ class SavedModelTest(test.TestCase):
builder = saved_model_builder.SavedModelBuilder(export_dir)
with self.session(graph=ops.Graph()) as sess:
- variables.Variable(1, name="v1")
+ variables.VariableV1(1, name="v1")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(sess, ["tag_0"])
@@ -1293,8 +1293,8 @@ class SavedModelTest(test.TestCase):
# Add a graph with two float32 variables and a Complex Op composing them
# with strip_default_attrs enabled.
with session.Session(graph=ops.Graph()) as sess:
- real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
- imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(
@@ -1303,8 +1303,8 @@ class SavedModelTest(test.TestCase):
# Add a graph with the same float32 variables and a Complex Op composing
# them with strip_default_attrs disabled.
with session.Session(graph=ops.Graph()) as sess:
- real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
- imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph(["bar"], strip_default_attrs=False)
@@ -1366,7 +1366,7 @@ class SavedModelTest(test.TestCase):
# Add a graph with a single variable and a test op with a defaultless
# float32 attr, "test_attr".
with session.Session(graph=ops.Graph()) as sess:
- variables.Variable(1.0, dtype=dtypes.float64, name="var")
+ variables.VariableV1(1.0, dtype=dtypes.float64, name="var")
test_ops.test_attr(T=dtypes.float32, name="test_attr")
sess.run(variables.global_variables_initializer())
builder.add_meta_graph_and_variables(sess, ["foo"])
diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py
index 37f927f381..a1034416e9 100644
--- a/tensorflow/python/saved_model/signature_def_utils_impl.py
+++ b/tensorflow/python/saved_model/signature_def_utils_impl.py
@@ -24,10 +24,14 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import utils
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
-@tf_export('saved_model.signature_def_utils.build_signature_def')
+@tf_export('saved_model.build_signature_def',
+ 'saved_model.signature_def_utils.build_signature_def')
+@deprecation.deprecated_endpoints(
+ 'saved_model.signature_def_utils.build_signature_def')
def build_signature_def(inputs=None, outputs=None, method_name=None):
"""Utility function to build a SignatureDef protocol buffer.
@@ -53,7 +57,10 @@ def build_signature_def(inputs=None, outputs=None, method_name=None):
return signature_def
-@tf_export('saved_model.signature_def_utils.regression_signature_def')
+@tf_export('saved_model.regression_signature_def',
+ 'saved_model.signature_def_utils.regression_signature_def')
+@deprecation.deprecated_endpoints(
+ 'saved_model.signature_def_utils.regression_signature_def')
def regression_signature_def(examples, predictions):
"""Creates regression signature from given examples and predictions.
@@ -95,7 +102,10 @@ def regression_signature_def(examples, predictions):
return signature_def
-@tf_export('saved_model.signature_def_utils.classification_signature_def')
+@tf_export('saved_model.classification_signature_def',
+ 'saved_model.signature_def_utils.classification_signature_def')
+@deprecation.deprecated_endpoints(
+ 'saved_model.signature_def_utils.classification_signature_def')
def classification_signature_def(examples, classes, scores):
"""Creates classification signature from given examples and predictions.
@@ -148,7 +158,10 @@ def classification_signature_def(examples, classes, scores):
return signature_def
-@tf_export('saved_model.signature_def_utils.predict_signature_def')
+@tf_export('saved_model.predict_signature_def',
+ 'saved_model.signature_def_utils.predict_signature_def')
+@deprecation.deprecated_endpoints(
+ 'saved_model.signature_def_utils.predict_signature_def')
def predict_signature_def(inputs, outputs):
"""Creates prediction signature from given inputs and outputs.
@@ -239,7 +252,10 @@ def _supervised_signature_def(
return signature_def
-@tf_export('saved_model.signature_def_utils.is_valid_signature')
+@tf_export('saved_model.is_valid_signature',
+ 'saved_model.signature_def_utils.is_valid_signature')
+@deprecation.deprecated_endpoints(
+ 'saved_model.signature_def_utils.is_valid_signature')
def is_valid_signature(signature_def):
"""Determine whether a SignatureDef can be served by TensorFlow Serving."""
if signature_def is None:
@@ -313,4 +329,3 @@ def _is_valid_classification_signature(signature_def):
return False
return True
-
diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py
index 06d09325c8..0bba7b6fac 100644
--- a/tensorflow/python/saved_model/utils_impl.py
+++ b/tensorflow/python/saved_model/utils_impl.py
@@ -27,13 +27,16 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import constants
from tensorflow.python.util import compat
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# TensorInfo helpers.
-@tf_export("saved_model.utils.build_tensor_info")
+@tf_export("saved_model.build_tensor_info",
+ "saved_model.utils.build_tensor_info")
+@deprecation.deprecated_endpoints("saved_model.utils.build_tensor_info")
def build_tensor_info(tensor):
"""Utility function to build TensorInfo proto.
@@ -57,7 +60,10 @@ def build_tensor_info(tensor):
return tensor_info
-@tf_export("saved_model.utils.get_tensor_from_tensor_info")
+@tf_export("saved_model.get_tensor_from_tensor_info",
+ "saved_model.utils.get_tensor_from_tensor_info")
+@deprecation.deprecated_endpoints(
+ "saved_model.utils.get_tensor_from_tensor_info")
def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
"""Returns the Tensor or SparseTensor described by a TensorInfo proto.
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 75824d83e6..384c7a82d2 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -8,6 +8,7 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "py_test")
+load("//tensorflow:tensorflow.bzl", "py_binary")
# Transitive dependencies of this target will be included in the pip package.
py_library(
@@ -21,6 +22,13 @@ py_library(
":saved_model_cli",
":saved_model_utils",
":strip_unused",
+ # The following py_library are needed because
+ # py_binary may not depend on them when --define=no_tensorflow_py_deps=true
+ # is specified. See https://github.com/tensorflow/tensorflow/issues/22390
+ ":freeze_graph_lib",
+ ":optimize_for_inference_lib",
+ ":selective_registration_header_lib",
+ ":strip_unused_lib",
],
)
diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl
index 92446e2f8f..533a138a39 100644
--- a/tensorflow/python/tools/api/generator/api_init_files.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files.bzl
@@ -8,6 +8,7 @@ TENSORFLOW_API_INIT_FILES = [
"bitwise/__init__.py",
"compat/__init__.py",
"data/__init__.py",
+ "data/experimental/__init__.py",
"debugging/__init__.py",
"distributions/__init__.py",
"dtypes/__init__.py",
@@ -69,6 +70,7 @@ TENSORFLOW_API_INIT_FILES = [
"profiler/__init__.py",
"python_io/__init__.py",
"quantization/__init__.py",
+ "random/__init__.py",
"resource_loader/__init__.py",
"strings/__init__.py",
"saved_model/__init__.py",
diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
index bc2f3516d1..0747424eab 100644
--- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
@@ -8,6 +8,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"bitwise/__init__.py",
"compat/__init__.py",
"data/__init__.py",
+ "data/experimental/__init__.py",
"debugging/__init__.py",
"distributions/__init__.py",
"dtypes/__init__.py",
@@ -69,6 +70,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"profiler/__init__.py",
"python_io/__init__.py",
"quantization/__init__.py",
+ "random/__init__.py",
"resource_loader/__init__.py",
"strings/__init__.py",
"saved_model/__init__.py",
diff --git a/tensorflow/python/tools/freeze_graph_test.py b/tensorflow/python/tools/freeze_graph_test.py
index e38945fabc..5dc14a6961 100644
--- a/tensorflow/python/tools/freeze_graph_test.py
+++ b/tensorflow/python/tools/freeze_graph_test.py
@@ -60,7 +60,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
# We'll create an input graph that has a single variable containing 1.0,
# and that then multiplies it by 2.
with ops.Graph().as_default():
- variable_node = variables.Variable(1.0, name="variable_node")
+ variable_node = variables.VariableV1(1.0, name="variable_node")
output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
sess = session.Session()
init = variables.global_variables_initializer()
@@ -138,7 +138,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
features = parsing_ops.parse_example(examples, feature_configs)
feature = features[feature_name]
- variable_node = variables.Variable(1.0, name="variable_node")
+ variable_node = variables.VariableV1(1.0, name="variable_node")
scores = math_ops.multiply(variable_node, feature, name="output_node")
class_feature = array_ops.fill(array_ops.shape(feature),
"class_%s" % feature_name)
@@ -174,7 +174,7 @@ class FreezeGraphTest(test_util.TensorFlowTestCase):
output_graph_filename = os.path.join(tmp_dir, "output_graph.pb")
with ops.Graph().as_default():
- variable_node = variables.Variable(1.0, name="variable_node")
+ variable_node = variables.VariableV1(1.0, name="variable_node")
output_node = math_ops.multiply(variable_node, 2.0, name="output_node")
sess = session.Session()
init = variables.global_variables_initializer()
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 3dbccd1409..2fcb0fa029 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -267,7 +267,8 @@ def scan_meta_graph_def(meta_graph_def):
def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
input_tensor_key_feed_dict, outdir,
- overwrite_flag, worker=None, tf_debug=False):
+ overwrite_flag, worker=None, init_tpu=False,
+ tf_debug=False):
"""Runs SavedModel and fetch all outputs.
Runs the input dictionary through the MetaGraphDef within a SavedModel
@@ -287,6 +288,8 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
the same name exists.
worker: If provided, the session will be run on the worker. Valid worker
specification is a bns or gRPC path.
+ init_tpu: If true, the TPU system will be initialized after the session
+ is created.
tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
intermediate Tensor values and runtime GraphDefs while running the
SavedModel.
@@ -328,6 +331,12 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
]
with session.Session(worker, graph=ops_lib.Graph()) as sess:
+ if init_tpu:
+ print('Initializing TPU System ...')
+ # This is needed for freshly started worker, or if the job
+ # restarts after a preemption.
+ sess.run(tf.contrib.tpu.initialize_system())
+
loader.load(sess, tag_set.split(','), saved_model_dir)
if tf_debug:
@@ -632,7 +641,7 @@ def run(args):
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
tensor_key_feed_dict, args.outdir,
args.overwrite, worker=args.worker,
- tf_debug=args.tf_debug)
+ init_tpu=args.init_tpu, tf_debug=args.tf_debug)
def scan(args):
@@ -775,6 +784,12 @@ def create_parser():
default=None,
help='if specified, a Session will be run on the worker. '
'Valid worker specification is a bns or gRPC path.')
+ parser_run.add_argument(
+ '--init_tpu',
+ action='store_true',
+ default=None,
+ help='if specified, tpu.initialize_system will be called on the Session. '
+ 'This option should be only used if the worker is a TPU job.')
parser_run.set_defaults(func=run)
# scan command
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index 3bd4bd75bd..1efabcd854 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -344,7 +344,7 @@ class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook):
raise ValueError("steps_per_run should be greater than 0")
self._num_steps = num_steps
self._last_step = last_step
- self._steps_per_run = steps_per_run
+ self._steps_per_run_initial_value = steps_per_run
def begin(self):
self._global_step_tensor = training_util.get_global_step()
@@ -353,7 +353,8 @@ class _MultiStepStopAtStepHook(session_run_hook.SessionRunHook):
self._steps_per_run_variable = get_or_create_steps_per_run_variable()
def _update_steps_per_run_variable(self, global_step, session):
- steps = min(self._last_step - global_step, self._steps_per_run)
+ steps = min(self._last_step - global_step,
+ self._steps_per_run_initial_value)
self._steps_per_run_variable.load(steps, session=session)
def after_create_session(self, session, coord):
diff --git a/tensorflow/python/training/checkpointable/util.py b/tensorflow/python/training/checkpointable/util.py
index 56c4043d9d..edab6cc6eb 100644
--- a/tensorflow/python/training/checkpointable/util.py
+++ b/tensorflow/python/training/checkpointable/util.py
@@ -247,7 +247,7 @@ def _default_getter(name, shape, dtype, initializer=None,
def initial_value():
return initializer(
shape_object.as_list(), dtype=dtype, partition_info=partition_info)
- return variables.Variable(
+ return variables.VariableV1(
initial_value=initial_value,
name=name,
dtype=variable_dtype,
@@ -854,6 +854,11 @@ class _LoadStatus(object):
pass
@abc.abstractmethod
+ def assert_nontrivial_match(self):
+ """Raises an exception if only the root object matched."""
+ pass
+
+ @abc.abstractmethod
def run_restore_ops(self, session=None):
"""Runs restore ops from the checkpoint. Requires a valid checkpoint."""
pass
@@ -975,6 +980,26 @@ class CheckpointLoadStatus(_LoadStatus):
% (list(unused_python_objects),))
return self
+ def assert_nontrivial_match(self):
+ """Raises an exception if only the root object matched."""
+ for checkpointable_object in list_objects(self._root_checkpointable):
+ self._checkpoint.all_python_objects.add(checkpointable_object)
+ if len(self._checkpoint.object_by_proto_id) <= 1:
+ unused_python_objects = (
+ _ObjectIdentitySet(self._checkpoint.all_python_objects)
+ - _ObjectIdentitySet(self._checkpoint.object_by_proto_id.values()))
+ if unused_python_objects:
+ raise AssertionError(
+ ("Nothing except the root object matched a checkpointed value. "
+ "Typically this means that the checkpoint does not match the "
+ "Python program. The following objects have no matching "
+ "checkpointed value: %s") % (list(unused_python_objects),))
+ else:
+ raise AssertionError(
+ "Nothing to load. No dependencies have been added to %s yet." % (
+ self._root_checkpointable,))
+ return self
+
def run_restore_ops(self, session=None):
"""Run operations to restore objects in the dependency graph."""
if context.executing_eagerly():
@@ -1039,6 +1064,11 @@ class InitializationOnlyStatus(_LoadStatus):
raise AssertionError(
"No checkpoint specified (save_path=None); nothing is being restored.")
+ def assert_nontrivial_match(self):
+ """Assertion for consistency with `CheckpointLoadStatus`. Always fails."""
+ raise AssertionError(
+ "No checkpoint specified (save_path=None); nothing is being restored.")
+
def run_restore_ops(self, session=None):
"""For consistency with `CheckpointLoadStatus`.
@@ -1122,6 +1152,14 @@ class NameBasedSaverStatus(_LoadStatus):
# useful since we don't touch Python objects or Python state).
return self.assert_consumed()
+ def assert_nontrivial_match(self):
+ """Raises an exception if currently created objects are unmatched."""
+ # For name-based checkpoints there's no object information in the
+ # checkpoint, so there's no distinction between
+ # assert_nontrivial_match and assert_consumed (and both are less
+ # useful since we don't touch Python objects or Python state).
+ return self.assert_consumed()
+
def _gather_saveable_objects(self):
"""Walk the object graph, using global names for SaveableObjects."""
objects = list_objects(self._root_checkpointable)
@@ -1779,13 +1817,15 @@ class Checkpoint(tracking.Checkpointable):
status of a checkpoint restoration and run initialization/restore ops.
The returned status object has the following methods:
- - `assert_consumed()`:
+
+ * `assert_consumed()`:
Raises an exception if any variables/objects are unmatched: either
checkpointed values which don't have a matching Python object or
Python objects in the dependency graph with no values in the
checkpoint. This method returns the status object, and so may be
chained with `initialize_or_restore` or `run_restore_ops`.
- - `assert_existing_objects_matched()`:
+
+ * `assert_existing_objects_matched()`:
Raises an exception if any existing Python objects in the dependency
graph are unmatched. Unlike `assert_consumed`, this assertion will
pass if values in the checkpoint have no corresponding Python
@@ -1796,12 +1836,20 @@ class Checkpoint(tracking.Checkpointable):
a `tf.train.Optimizer` was saved but only the state required for
inference is being loaded. This method returns the status object, and
so may be chained with `initialize_or_restore` or `run_restore_ops`.
- - `initialize_or_restore(session=None)`:
+
+ * `assert_nontrivial_match()`: Asserts that something aside from the root
+ object was matched. This is a very weak assertion, but is useful for
+ sanity checking in library code where objects may exist in the
+ checkpoint which haven't been created in Python and some Python
+ objects may not have a checkpointed value.
+
+ * `initialize_or_restore(session=None)`:
When graph building, runs variable initializers if `save_path` is
`None`, but otherwise runs restore operations. If no `session` is
explicitly specified, the default session is used. No effect when
executing eagerly (variables are initialized or restored eagerly).
- - `run_restore_ops(session=None)`:
+
+ * `run_restore_ops(session=None)`:
When graph building, runs restore operations. If no `session` is
explicitly specified, the default session is used. No effect when
executing eagerly (restore operations are run eagerly). May only be
diff --git a/tensorflow/python/training/checkpointable/util_test.py b/tensorflow/python/training/checkpointable/util_test.py
index f8b5bd8501..14b47a1940 100644
--- a/tensorflow/python/training/checkpointable/util_test.py
+++ b/tensorflow/python/training/checkpointable/util_test.py
@@ -437,6 +437,7 @@ class CheckpointingTests(test.TestCase):
optimizer=on_create_optimizer, model=on_create_model)
# Deferred restoration
status = on_create_root.restore(save_path=save_path)
+ status.assert_nontrivial_match()
status.assert_existing_objects_matched()
with self.assertRaises(AssertionError):
status.assert_consumed()
@@ -1509,6 +1510,8 @@ class CheckpointCompatibilityTests(test.TestCase):
status.assert_consumed()
with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
status.assert_existing_objects_matched()
+ with self.assertRaisesRegexp(AssertionError, "OBJECT_CONFIG_JSON"):
+ status.assert_nontrivial_match()
else:
# When graph building, we haven't read any keys, so we don't know
# whether the restore will be complete.
@@ -1516,6 +1519,8 @@ class CheckpointCompatibilityTests(test.TestCase):
status.assert_consumed()
with self.assertRaisesRegexp(AssertionError, "not restored"):
status.assert_existing_objects_matched()
+ with self.assertRaisesRegexp(AssertionError, "not restored"):
+ status.assert_nontrivial_match()
status.run_restore_ops()
self._check_sentinels(root)
self._set_sentinels(root)
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 419a9ec12b..144b167170 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -26,7 +26,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
@@ -437,6 +436,9 @@ class DistributionStrategy(object):
def __init__(self):
self._default_device = None
+ # This property is used to determine if we should set drop_remainder=True
+ # when creating Datasets from numpy array inputs.
+ self._require_static_shapes = False
def scope(self):
"""Returns a context manager selecting this DistributionStrategy as current.
@@ -631,7 +633,7 @@ class DistributionStrategy(object):
Args:
fn: function to run using this distribution strategy. The function must
- have the following signature: def fn(context, *inputs).
+ have the following signature: `def fn(context, *inputs)`.
`context` is an instance of `MultiStepContext` that will be passed when
`fn` is run. `context` can be used to specify the outputs to be returned
from `fn` by calling `context.set_last_step_output`. It can also be used
@@ -797,9 +799,9 @@ class DistributionStrategy(object):
return merged(results)
```
- Otherwise this returns `fn(var, *args, **kwargs)` colocated with `var`.'
+ Otherwise this returns `fn(var, *args, **kwargs)` colocated with `var`.
- Neither *args nor **kwargs may contain per-device values.
+ Neither `*args` nor `**kwargs` may contain per-device values.
If they contain mirrored values, they will be unwrapped before
calling `fn`.
@@ -807,15 +809,22 @@ class DistributionStrategy(object):
var: Variable, possibly mirrored to multiple devices, to operate on.
fn: Function to call. Should take the variable as the first argument.
*args: Additional positional arguments to pass to `fn()`.
- **kwargs: Keyword arguments to pass to `fn()`.
+ **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is
+ specified, the return value will be unwrapped.
Returns:
- Merged return value of `fn` across all towers.
+ By default, the merged return value of `fn` across all towers. The merged
+ result has dependencies to make sure that if it is evaluated at all, the
+ side effects (updates) will happen on every tower. If instead
+ "grouped=False" is specified, this function will return a nest of lists
+ where each list has an element per tower, and the caller is responsible
+ for ensuring all elements are executed.
"""
_require_cross_tower_context(self)
- return self._update(var, fn, *args, **kwargs)
+ options = {"grouped": kwargs.pop("grouped", True)}
+ return self._update(var, options, fn, *args, **kwargs)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
def update_non_slot(self, colocate_with, fn, *args, **kwargs):
@@ -825,15 +834,18 @@ class DistributionStrategy(object):
colocate_with: The return value of `non_slot_devices()`.
fn: Function to execute.
*args: Positional arguments to pass to `fn()`.
- **kwargs: Keyword arguments to pass to `fn()`.
+ **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is
+ specified, the return value will be unwrapped and the caller is
+ responsible for ensuring all elements are executed.
Returns:
Return value of `fn`, possibly merged across devices.
"""
_require_cross_tower_context(self)
- return self._update_non_slot(colocate_with, fn, *args, **kwargs)
+ options = {"grouped": kwargs.pop("grouped", True)}
+ return self._update_non_slot(colocate_with, options, fn, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
def unwrap(self, value):
@@ -890,6 +902,10 @@ class DistributionStrategy(object):
raise NotImplementedError("must be implemented in descendants")
@property
+ def require_static_shapes(self):
+ return self._require_static_shapes
+
+ @property
def num_towers(self):
"""Returns number of towers, for purposes of averaging across towers."""
raise NotImplementedError("must be implemented in descendants")
@@ -1134,17 +1150,22 @@ class _DefaultDistributionStrategy(DistributionStrategy):
del aggregation, destinations
return value
- def _update(self, var, fn, *args, **kwargs):
- # TODO(josh11b): Figure out what we should be passing to UpdateContext()
- # once that value is used for something.
- with ops.colocate_with(var), UpdateContext(var):
- return fn(var, *args, **kwargs)
+ def _update(self, var, options, fn, *args, **kwargs):
+ # The implementations of _update() and _update_non_slot() are identical
+ # except _update() passes `var` as the first argument to `fn()`.
+ return self._update_non_slot(var, options, fn, var, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
# TODO(josh11b): Figure out what we should be passing to UpdateContext()
# once that value is used for something.
with ops.colocate_with(colocate_with), UpdateContext(colocate_with):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def read_var(self, tower_local_var):
return array_ops.identity(tower_local_var)
@@ -1193,13 +1214,10 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def increment_var(v, amount=1):
"""`v += amount`, distributed-aware version."""
def update(vu):
- if isinstance(vu, resource_variable_ops.ResourceVariable):
- return vu.assign_add(amount, read_value=False)
- else:
- return state_ops.assign_add(vu, amount)
+ return vu.assign_add(amount, read_value=False)
def merge_fn(dist, vm):
- return dist.group(dist.update(vm, update))
+ return dist.update(vm, update)
tower_context = distribution_strategy_context.get_tower_context()
return tower_context.merge_call(merge_fn, v)
diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py
index 998b5c35ce..ce580a406f 100644
--- a/tensorflow/python/training/distribution_strategy_context.py
+++ b/tensorflow/python/training/distribution_strategy_context.py
@@ -89,6 +89,7 @@ def get_tower_context():
"""Returns the current TowerContext or None if in a cross-tower context.
Note that execution:
+
1. starts in the default (single-tower) tower context (this function
will return the default TowerContext object);
2. switches to cross-tower context (in which case this will return
@@ -121,6 +122,7 @@ def get_cross_tower_context():
"""Returns the current DistributionStrategy if in a cross-tower context.
Note that execution:
+
1. starts in the default (single-tower) tower context;
2. switches to cross-tower context when entering a
`with DistributionStrategy.scope():` block;
diff --git a/tensorflow/python/training/evaluation.py b/tensorflow/python/training/evaluation.py
index b36444a14c..2c4eb02d53 100644
--- a/tensorflow/python/training/evaluation.py
+++ b/tensorflow/python/training/evaluation.py
@@ -18,13 +18,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import time
import math
+import time
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
@@ -77,6 +78,59 @@ def _get_latest_eval_step_value(update_ops):
return array_ops.identity(_get_or_create_eval_step().read_value())
+class _MultiStepStopAfterNEvalsHook(session_run_hook.SessionRunHook):
+ """Run hook used by the evaluation routines to run the `eval_ops` N times."""
+
+ def __init__(self, num_evals, steps_per_run=1):
+ """Constructs the run hook.
+
+ Args:
+ num_evals: The number of evaluations to run for. if set to None, will
+ iterate the dataset until all inputs are exhausted.
+ steps_per_run: Number of steps executed per run call.
+ """
+ self._num_evals = num_evals
+ self._evals_completed = None
+ self._steps_per_run_initial_value = steps_per_run
+
+ def _set_evals_completed_tensor(self, updated_eval_step):
+ self._evals_completed = updated_eval_step
+
+ def begin(self):
+ self._steps_per_run_variable = \
+ basic_session_run_hooks.get_or_create_steps_per_run_variable()
+
+ def after_create_session(self, session, coord):
+ # Update number of steps to run in the first run call
+ if self._num_evals is None:
+ steps = self._steps_per_run_initial_value
+ else:
+ steps = min(self._steps_per_run_initial_value, self._num_evals)
+ self._steps_per_run_variable.load(steps, session=session)
+
+ def before_run(self, run_context):
+ return session_run_hook.SessionRunArgs({
+ 'evals_completed': self._evals_completed
+ })
+
+ def after_run(self, run_context, run_values):
+ evals_completed = run_values.results['evals_completed']
+ # Update number of steps to run in the next iteration
+ if self._num_evals is None:
+ steps = self._steps_per_run_initial_value
+ else:
+ steps = min(self._num_evals - evals_completed,
+ self._steps_per_run_initial_value)
+ self._steps_per_run_variable.load(steps, session=run_context.session)
+
+ if self._num_evals is None:
+ logging.info('Evaluation [%d]', evals_completed)
+ else:
+ logging.info('Evaluation [%d/%d]', evals_completed, self._num_evals)
+ if self._num_evals is not None and evals_completed >= self._num_evals:
+ run_context.request_stop()
+
+
class _StopAfterNEvalsHook(session_run_hook.SessionRunHook):
"""Run hook used by the evaluation routines to run the `eval_ops` N times."""
@@ -176,7 +230,15 @@ def _evaluate_once(checkpoint_path,
hooks = list(hooks or [])
if eval_ops is not None:
- update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True)
+ if any([isinstance(h, _MultiStepStopAfterNEvalsHook) for h in hooks]):
+ steps_per_run_variable = \
+ basic_session_run_hooks.get_or_create_steps_per_run_variable()
+ update_eval_step = state_ops.assign_add(
+ eval_step,
+ math_ops.cast(steps_per_run_variable, dtype=eval_step.dtype),
+ use_locking=True)
+ else:
+ update_eval_step = state_ops.assign_add(eval_step, 1, use_locking=True)
if isinstance(eval_ops, dict):
eval_ops['update_eval_step'] = update_eval_step
@@ -188,7 +250,7 @@ def _evaluate_once(checkpoint_path,
eval_step_value = _get_latest_eval_step_value(eval_ops)
for h in hooks:
- if isinstance(h, _StopAfterNEvalsHook):
+ if isinstance(h, (_StopAfterNEvalsHook, _MultiStepStopAfterNEvalsHook)):
h._set_evals_completed_tensor(eval_step_value) # pylint: disable=protected-access
logging.info('Starting evaluation at ' + time.strftime('%Y-%m-%d-%H:%M:%S',
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 9d9db70890..eb131ac9f7 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -56,7 +56,8 @@ _restore_sparse = sparse_ops._take_many_sparse_from_tensors_map
# pylint: enable=protected-access
-@tf_export("train.match_filenames_once")
+@tf_export("io.match_filenames_once", "train.match_filenames_once")
+@deprecation.deprecated_endpoints("train.match_filenames_once")
def match_filenames_once(pattern, name=None):
"""Save the list of files matching pattern, so it is only computed once.
diff --git a/tensorflow/python/training/learning_rate_decay_test.py b/tensorflow/python/training/learning_rate_decay_test.py
index 5a9215730e..03a32f6ca0 100644
--- a/tensorflow/python/training/learning_rate_decay_test.py
+++ b/tensorflow/python/training/learning_rate_decay_test.py
@@ -63,7 +63,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
def testVariables(self):
with self.cached_session():
- step = variables.Variable(1)
+ step = variables.VariableV1(1)
assign_1 = step.assign(1)
assign_2 = step.assign(2)
assign_100 = step.assign(100)
@@ -121,7 +121,7 @@ class LRDecayTest(test_util.TensorFlowTestCase):
# Test that ref types are valid.
if not context.executing_eagerly():
- x = variables.Variable(0.0)
+ x = variables.VariableV1(0.0)
x_ref = x.op.outputs[0] # float32_ref tensor should be accepted
boundaries, values = [1.0, 2.0], [1, 2, 3]
learning_rate_decay.piecewise_constant(x_ref, boundaries, values)
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index 82f0e3be52..a479f38165 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -195,8 +195,12 @@ class Scaffold(object):
default_ready_op)
if self._ready_for_local_init_op is None:
def default_ready_for_local_init_op():
- return variables.report_uninitialized_variables(
- variables.global_variables())
+ return array_ops.concat([
+ variables.report_uninitialized_variables(
+ variables.global_variables()),
+ resources.report_uninitialized_resources(
+ resources.shared_resources())
+ ], 0)
self._ready_for_local_init_op = Scaffold.get_or_default(
'ready_for_local_init_op', ops.GraphKeys.READY_FOR_LOCAL_INIT_OP,
default_ready_for_local_init_op)
diff --git a/tensorflow/python/training/monitored_session_test.py b/tensorflow/python/training/monitored_session_test.py
index 2d7799d66a..c870d99de9 100644
--- a/tensorflow/python/training/monitored_session_test.py
+++ b/tensorflow/python/training/monitored_session_test.py
@@ -69,8 +69,8 @@ class ScaffoldTest(test.TestCase):
def test_defaults_empty_graph(self):
with ops.Graph().as_default():
scaffold = monitored_session.Scaffold()
- variables.Variable(1, name='my_var')
- variables.Variable(
+ variables.VariableV1(1, name='my_var')
+ variables.VariableV1(
2, name='my_local_var', collections=[ops.GraphKeys.LOCAL_VARIABLES])
scaffold.finalize()
self.assertTrue(isinstance(scaffold.init_op, ops.Operation))
@@ -105,7 +105,7 @@ class ScaffoldTest(test.TestCase):
def test_caches_values(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
scaffold1 = monitored_session.Scaffold()
scaffold1.finalize()
scaffold2 = monitored_session.Scaffold()
@@ -119,7 +119,7 @@ class ScaffoldTest(test.TestCase):
def test_raise_error_if_more_than_one_cached_item(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
ops.add_to_collection(ops.GraphKeys.SAVERS, saver_lib.Saver())
with self.assertRaisesRegexp(RuntimeError, 'More than one item'):
@@ -127,7 +127,7 @@ class ScaffoldTest(test.TestCase):
def test_uses_passed_values(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
saver = saver_lib.Saver()
scaffold = monitored_session.Scaffold(
init_op=2,
@@ -148,7 +148,7 @@ class ScaffoldTest(test.TestCase):
def test_graph_is_finalized(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
monitored_session.Scaffold().finalize()
with self.assertRaisesRegexp(RuntimeError,
'Graph is finalized and cannot be modified'):
@@ -157,7 +157,7 @@ class ScaffoldTest(test.TestCase):
def test_new_scaffold_from_default_scaffold(self):
scaffold1 = monitored_session.Scaffold()
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
saver = saver_lib.Saver()
scaffold2 = monitored_session.Scaffold(
init_op=2,
@@ -180,7 +180,7 @@ class ScaffoldTest(test.TestCase):
def test_new_scaffold_from_existing_scaffold(self):
with ops.Graph().as_default():
- variables.Variable([1])
+ variables.VariableV1([1])
saver = saver_lib.Saver()
scaffold1 = monitored_session.Scaffold(
init_op=2,
@@ -1374,7 +1374,7 @@ class MonitoredSessionTest(test.TestCase):
def test_defaults(self):
with ops.Graph().as_default():
- a_var = variables.Variable(0)
+ a_var = variables.VariableV1(0)
with monitored_session.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
@@ -1700,7 +1700,7 @@ class MonitoredSessionTest(test.TestCase):
def test_graph_finalized_during_run_unfinalized_after_exit(self):
with ops.Graph().as_default() as g:
- a_var = variables.Variable(0)
+ a_var = variables.VariableV1(0)
with monitored_session.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
self.assertTrue(g.finalized)
@@ -1708,7 +1708,7 @@ class MonitoredSessionTest(test.TestCase):
def test_keep_finalized_graph_as_finalized(self):
with ops.Graph().as_default() as g:
- a_var = variables.Variable(0)
+ a_var = variables.VariableV1(0)
monitored_session.Scaffold().finalize()
with monitored_session.MonitoredSession() as session:
self.assertEqual(0, session.run(a_var))
@@ -2032,7 +2032,7 @@ class MonitoredSessionTest(test.TestCase):
with ops.Graph().as_default():
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
- graph_state = variables.Variable(0.0)
+ graph_state = variables.VariableV1(0.0)
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
def step_fn(step_context):
@@ -2088,7 +2088,7 @@ class MonitoredSessionTest(test.TestCase):
c = array_ops.placeholder(dtypes.float32)
v = array_ops.identity(c)
vv = constant_op.constant(3.2)
- graph_state = variables.Variable(0.0)
+ graph_state = variables.VariableV1(0.0)
graph_side_effect = state_ops.assign_add(graph_state, 0.31)
class Hook(session_run_hook.SessionRunHook):
@@ -2125,7 +2125,7 @@ class SingularMonitoredSessionTest(test.TestCase):
def test_handles_initialization(self):
with ops.Graph().as_default():
- a_var = variables.Variable(0)
+ a_var = variables.VariableV1(0)
with monitored_session.SingularMonitoredSession() as session:
# If it's not initialized, following statement raises an error.
self.assertEqual(0, session.run(a_var))
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 177a7ddfa5..89bfcaf4ad 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -25,6 +25,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.training import slot_creator
from tensorflow.python.util.tf_export import tf_export
@@ -36,9 +37,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
The moving average of 'variable' updated with 'value' is:
variable * decay + value * (1 - decay)
- The returned Operation sets 'variable' to the newly computed moving average.
-
- The new value of 'variable' can be set with the 'AssignSub' op as:
+ The returned Operation sets 'variable' to the newly computed moving average,
+ by performing this subtraction:
variable -= (1 - decay) * (variable - value)
Since variables that are initialized to a `0` value will be `0` biased,
@@ -50,7 +50,7 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
The names of the debias shadow variables, by default, include both the scope
they were created in and the scope of the variables they debias. They are also
- given a uniqifying-suffix.
+ given a uniquifying-suffix.
E.g.:
@@ -58,8 +58,8 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
with tf.variable_scope('scope1'):
with tf.variable_scope('scope2'):
var = tf.get_variable('foo')
- tf.assign_moving_average(var, 0.0, 1.0)
- tf.assign_moving_average(var, 0.0, 0.9)
+ update_1 = tf.assign_moving_average(var, 0.0, 1.0)
+ update_2 = tf.assign_moving_average(var, 0.0, 0.9)
# var.name: 'scope1/scope2/foo'
# shadow var names: 'scope1/scope2/scope1/scope2/foo/biased'
@@ -76,20 +76,33 @@ def assign_moving_average(variable, value, decay, zero_debias=True, name=None):
name: Optional name of the returned operation.
Returns:
- A reference to the input 'variable' tensor with the newly computed
- moving average.
+ A tensor which if evaluated will compute and return the new moving average.
"""
+ def update_fn(v, value, decay=decay):
+ decay = ops.convert_to_tensor(1.0 - decay, name="decay")
+ if decay.dtype != v.dtype.base_dtype:
+ decay = math_ops.cast(decay, v.dtype.base_dtype)
+ if zero_debias:
+ update_delta = _zero_debias(v, value, decay)
+ else:
+ update_delta = (v - value) * decay
+ return state_ops.assign_sub(v, update_delta, name=scope)
+
with ops.name_scope(name, "AssignMovingAvg",
[variable, value, decay]) as scope:
- with ops.colocate_with(variable):
- decay = ops.convert_to_tensor(1.0 - decay, name="decay")
- if decay.dtype != variable.dtype.base_dtype:
- decay = math_ops.cast(decay, variable.dtype.base_dtype)
- if zero_debias:
- update_delta = _zero_debias(variable, value, decay)
- else:
- update_delta = (variable - value) * decay
- return state_ops.assign_sub(variable, update_delta, name=scope)
+ tower_context = distribution_strategy_context.get_tower_context()
+ if tower_context:
+ # In a tower context, we update variable using the mean of value across
+ # towers.
+ def merge_fn(strategy, v, value):
+ value = strategy.reduce(
+ variable_scope.VariableAggregation.MEAN, value, v)
+ return strategy.update(v, update_fn, value)
+
+ return tower_context.merge_call(merge_fn, variable, value)
+ else:
+ strategy = distribution_strategy_context.get_cross_tower_context()
+ return strategy.update(variable, update_fn, value)
def weighted_moving_average(value,
@@ -372,23 +385,22 @@ class ExponentialMovingAverage(object):
Args:
var_list: A list of Variable or Tensor objects. The variables
- and Tensors must be of types float16, float32, or float64.
+ and Tensors must be of types bfloat16, float16, float32, or float64.
Returns:
An Operation that updates the moving averages.
Raises:
- TypeError: If the arguments are not all float16, float32, or float64.
- ValueError: If the moving average of one of the variables is already
- being computed.
+ TypeError: If the arguments are not an allowed type.
"""
# TODO(touts): op_scope
if var_list is None:
var_list = variables.trainable_variables()
zero_debias_true = set() # set of vars to set `zero_debias=True`
for var in var_list:
- if var.dtype.base_dtype not in [dtypes.float16, dtypes.float32,
- dtypes.float64]:
+ if var.dtype.base_dtype not in [
+ dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64
+ ]:
raise TypeError("The variables must be half, float, or double: %s" %
var.name)
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 93991d0e14..bb2fca66e3 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -110,6 +111,32 @@ class MovingAveragesTest(test.TestCase):
denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay)
self.assertAllClose(numerator_2 / denominator_2, wma_array)
+ def testWeightedMovingAverageBfloat16(self):
+ bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
+ with self.cached_session() as sess:
+ decay = 0.5
+ weight = array_ops.placeholder(dtypes.bfloat16, [])
+ val = array_ops.placeholder(dtypes.bfloat16, [])
+
+ wma = moving_averages.weighted_moving_average(val, decay, weight)
+ variables.global_variables_initializer().run()
+
+ # Get the first weighted moving average.
+ val_1 = 3.0
+ weight_1 = 4.0
+ wma_array = sess.run(wma, feed_dict={val: val_1, weight: weight_1})
+ numerator_1 = val_1 * weight_1 * (1.0 - decay)
+ denominator_1 = weight_1 * (1.0 - decay)
+ self.assertAllClose(numerator_1 / denominator_1, wma_array)
+
+ # Get the second weighted moving average.
+ val_2 = 11.0
+ weight_2 = 22.0
+ wma_array = sess.run(wma, feed_dict={val: val_2, weight: weight_2})
+ numerator_2 = numerator_1 * decay + val_2 * weight_2 * (1.0 - decay)
+ denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay)
+ self.assertAllClose(bfloat16(numerator_2 / denominator_2), wma_array)
+
def _Repeat(value, dim):
if dim == 1:
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index 699162b30c..47034919e1 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -471,7 +471,10 @@ class Optimizer(
if var_list is None:
var_list = tape.watched_variables()
- grads = tape.gradient(loss_value, var_list, grad_loss)
+ # TODO(jhseu): Figure out why GradientTape's gradients don't require loss
+ # to be executed.
+ with ops.control_dependencies([loss_value]):
+ grads = tape.gradient(loss_value, var_list, grad_loss)
return list(zip(grads, var_list))
# Non-callable/Tensor loss case
@@ -585,7 +588,7 @@ class Optimizer(
var_list = [v for g, v, _ in converted_grads_and_vars if g is not None]
if not var_list:
raise ValueError("No gradients provided for any variable: %s." %
- ([str(v) for _, _, v in converted_grads_and_vars],))
+ ([str(v) for _, v, _ in converted_grads_and_vars],))
with ops.init_scope():
self._create_slots(var_list)
update_ops = []
@@ -689,7 +692,7 @@ class Optimizer(
update_ops = [
op
for grad, var in grads_and_vars
- for op in distribution.unwrap(distribution.update(var, update, grad))
+ for op in distribution.update(var, update, grad, grouped=False)
]
def finish(self, update_ops):
@@ -697,13 +700,13 @@ class Optimizer(
non_slot_devices = distribution.non_slot_devices(var_list)
finish_updates = distribution.update_non_slot(
- non_slot_devices, finish, self, update_ops)
+ non_slot_devices, finish, self, update_ops, grouped=False)
if global_step is None:
apply_updates = distribution.group(finish_updates, name=name)
else:
- with ops.control_dependencies(distribution.unwrap(finish_updates)):
- apply_updates = distribution.group(distribution.update(
- global_step, state_ops.assign_add, 1, name=name))
+ with ops.control_dependencies(finish_updates):
+ apply_updates = distribution.update(
+ global_step, state_ops.assign_add, 1, name=name)
if not context.executing_eagerly():
if isinstance(apply_updates, ops.Tensor):
diff --git a/tensorflow/python/training/quantize_training_test.py b/tensorflow/python/training/quantize_training_test.py
index 9754adea85..6edbf7665f 100644
--- a/tensorflow/python/training/quantize_training_test.py
+++ b/tensorflow/python/training/quantize_training_test.py
@@ -58,7 +58,8 @@ class PywrapQuantizeTrainingTest(test.TestCase):
g = ops.Graph()
with session.Session(graph=g) as sess:
a = constant_op.constant(6.0, shape=[1, 1], name='a')
- b = variables.Variable(constant_op.constant(7.0, shape=[1, 1]), name='b')
+ b = variables.VariableV1(
+ constant_op.constant(7.0, shape=[1, 1]), name='b')
c = math_ops.matmul(a, b, name='matmul')
init_op = variables.global_variables_initializer()
diff --git a/tensorflow/python/training/queue_runner_test.py b/tensorflow/python/training/queue_runner_test.py
index 9b9e28af2b..15fe42bbd8 100644
--- a/tensorflow/python/training/queue_runner_test.py
+++ b/tensorflow/python/training/queue_runner_test.py
@@ -44,7 +44,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -64,9 +64,9 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var0 = variables.Variable(zero64)
+ var0 = variables.VariableV1(zero64)
count_up_to_3 = var0.count_up_to(3)
- var1 = variables.Variable(zero64)
+ var1 = variables.VariableV1(zero64)
count_up_to_30 = var1.count_up_to(30)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
qr = queue_runner_impl.QueueRunner(queue, [count_up_to_3, count_up_to_30])
@@ -131,7 +131,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -184,7 +184,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
with session.Session() as other_sess:
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -199,7 +199,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -215,7 +215,7 @@ class QueueRunnerTest(test.TestCase):
with self.cached_session() as sess:
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
variables.global_variables_initializer().run()
@@ -250,7 +250,7 @@ class QueueRunnerTest(test.TestCase):
def testStartQueueRunners(self):
# CountUpTo will raise OUT_OF_RANGE when it reaches the count.
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
@@ -267,7 +267,7 @@ class QueueRunnerTest(test.TestCase):
def testStartQueueRunnersRaisesIfNotASession(self):
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
@@ -280,7 +280,7 @@ class QueueRunnerTest(test.TestCase):
def testStartQueueRunnersIgnoresMonitoredSession(self):
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
@@ -297,7 +297,7 @@ class QueueRunnerTest(test.TestCase):
graph = ops.Graph()
with graph.as_default():
zero64 = constant_op.constant(0, dtype=dtypes.int64)
- var = variables.Variable(zero64)
+ var = variables.VariableV1(zero64)
count_up_to = var.count_up_to(3)
queue = data_flow_ops.FIFOQueue(10, dtypes.float32)
init_op = variables.global_variables_initializer()
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index 69b1055ebe..49e6e6546d 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -311,8 +311,8 @@ class SaverTest(test.TestCase):
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(20.0, name="v1")
+ v0 = variables.VariableV1(10.0, name="v0")
+ v1 = variables.VariableV1(20.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
save = saver_module.Saver(
@@ -350,8 +350,8 @@ class SaverTest(test.TestCase):
# Start a second session. In that session the parameter nodes
# have not been initialized either.
with self.cached_session() as sess:
- v0 = variables.Variable(-1.0, name="v0")
- v1 = variables.Variable(-1.0, name="v1")
+ v0 = variables.VariableV1(-1.0, name="v0")
+ v1 = variables.VariableV1(-1.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
save = saver_module.Saver({"v0": v0, "v1": v1, "v2": v2.saveable})
@@ -370,7 +370,7 @@ class SaverTest(test.TestCase):
self.assertEqual(30.0, v2.values().eval())
def testFilenameTensor(self):
- v0 = variables.Variable(0, name="v0")
+ v0 = variables.VariableV1(0, name="v0")
filename = b"somerandomfilename"
save = saver_module.Saver({"v0": v0}, filename=filename)
with self.cached_session() as sess:
@@ -379,7 +379,7 @@ class SaverTest(test.TestCase):
self.assertEqual(sess.run(tensor), filename)
def testInvalidPath(self):
- v0 = variables.Variable(0, name="v0")
+ v0 = variables.VariableV1(0, name="v0")
for ver in (saver_pb2.SaverDef.V1, saver_pb2.SaverDef.V2):
with self.cached_session() as sess:
save = saver_module.Saver({"v0": v0}, write_version=ver)
@@ -392,7 +392,7 @@ class SaverTest(test.TestCase):
with self.cached_session() as sess:
# Build a graph with 1 node, and save and restore for them.
- v = variables.Variable(np.int64(15), name="v")
+ v = variables.VariableV1(np.int64(15), name="v")
save = saver_module.Saver({"v": v}, restore_sequentially=True)
variables.global_variables_initializer().run()
@@ -402,7 +402,7 @@ class SaverTest(test.TestCase):
self.assertEqual(save_path, val)
with self.cached_session() as sess:
- v = variables.Variable(np.int64(-1), name="v")
+ v = variables.VariableV1(np.int64(-1), name="v")
save = saver_module.Saver({"v": v})
with self.assertRaisesWithPredicateMatch(
@@ -416,9 +416,9 @@ class SaverTest(test.TestCase):
def testSomeErrors(self):
with ops_lib.Graph().as_default():
- v0 = variables.Variable([10.0], name="v0")
- v1 = variables.Variable([20.0], name="v1")
- v2 = variables.Variable([20.0], name="v2")
+ v0 = variables.VariableV1([10.0], name="v0")
+ v1 = variables.VariableV1([20.0], name="v1")
+ v2 = variables.VariableV1([20.0], name="v2")
v2._set_save_slice_info(
variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
@@ -446,7 +446,7 @@ class SaverTest(test.TestCase):
def testSameName(self):
with ops_lib.Graph().as_default():
- v0 = variables.Variable([10.0], name="v0")
+ v0 = variables.VariableV1([10.0], name="v0")
v2 = saver_test_utils.CheckpointedOp(name="v2")
# Saving one variable under two names raises an error.
@@ -468,8 +468,8 @@ class SaverTest(test.TestCase):
with self.session(graph=ops_lib.Graph()) as sess:
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(20.0, name="v1")
+ v0 = variables.VariableV1(10.0, name="v0")
+ v1 = variables.VariableV1(20.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
v2_init = v2.insert("k1", 30.0)
save = saver_module.Saver([v0, v1, v2.saveable])
@@ -490,8 +490,8 @@ class SaverTest(test.TestCase):
# Start a second session. In that session the variables
# have not been initialized either.
with self.session(graph=ops_lib.Graph()) as sess:
- v0 = variables.Variable(-1.0, name="v0")
- v1 = variables.Variable(-1.0, name="v1")
+ v0 = variables.VariableV1(-1.0, name="v0")
+ v1 = variables.VariableV1(-1.0, name="v1")
v2 = saver_test_utils.CheckpointedOp(name="v2")
save = saver_module.Saver([v0, v1, v2.saveable])
@@ -515,8 +515,8 @@ class SaverTest(test.TestCase):
# Build another graph with 2 nodes, initialized
# differently, and a Restore node for them.
with self.session(graph=ops_lib.Graph()) as sess:
- v0_2 = variables.Variable(1000.0, name="v0")
- v1_2 = variables.Variable(2000.0, name="v1")
+ v0_2 = variables.VariableV1(1000.0, name="v0")
+ v1_2 = variables.VariableV1(2000.0, name="v1")
v2_2 = saver_test_utils.CheckpointedOp(name="v2")
save2 = saver_module.Saver([v0_2, v1_2, v2_2.saveable])
v2_2.insert("k1000", 3000.0).run()
@@ -574,14 +574,14 @@ class SaverTest(test.TestCase):
save_path = os.path.join(self.get_temp_dir(), "gpu")
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
- v0_1 = variables.Variable(123.45)
+ v0_1 = variables.VariableV1(123.45)
save = saver_module.Saver({"v0": v0_1})
variables.global_variables_initializer().run()
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
- v0_2 = variables.Variable(543.21)
+ v0_2 = variables.VariableV1(543.21)
save = saver_module.Saver({"v0": v0_2})
variables.global_variables_initializer().run()
@@ -591,22 +591,22 @@ class SaverTest(test.TestCase):
save_path = os.path.join(self.get_temp_dir(), "gpu")
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
- v0_1 = variables.Variable(123.45)
+ v0_1 = variables.VariableV1(123.45)
save = saver_module.Saver({"v0": v0_1}, sharded=True, allow_empty=True)
variables.global_variables_initializer().run()
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
with sess.graph.device(test.gpu_device_name()):
- v0_2 = variables.Variable(543.21)
+ v0_2 = variables.VariableV1(543.21)
save = saver_module.Saver({"v0": v0_2}, sharded=True, allow_empty=True)
variables.global_variables_initializer().run()
def testVariables(self):
save_path = os.path.join(self.get_temp_dir(), "variables")
with session.Session("", graph=ops_lib.Graph()) as sess:
- one = variables.Variable(1.0)
- twos = variables.Variable([2.0, 2.0, 2.0])
+ one = variables.VariableV1(1.0)
+ twos = variables.VariableV1([2.0, 2.0, 2.0])
v2 = saver_test_utils.CheckpointedOp(name="v2")
init = variables.global_variables_initializer()
save = saver_module.Saver()
@@ -615,8 +615,8 @@ class SaverTest(test.TestCase):
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
- one = variables.Variable(0.0)
- twos = variables.Variable([0.0, 0.0, 0.0])
+ one = variables.VariableV1(0.0)
+ twos = variables.VariableV1([0.0, 0.0, 0.0])
v2 = saver_test_utils.CheckpointedOp(name="v2")
# Saver with no arg, defaults to 'all variables'.
save = saver_module.Saver()
@@ -628,14 +628,14 @@ class SaverTest(test.TestCase):
def testVarListShouldBeEmptyInDeferredBuild(self):
with ops_lib.Graph().as_default():
- v = variables.Variable(1.0)
+ v = variables.VariableV1(1.0)
with self.assertRaisesRegexp(ValueError, "defer_build"):
saver_module.Saver([v], defer_build=True)
def testBuildShouldBeCalledBeforeSaveInCaseOfDeferBuild(self):
save_path = os.path.join(self.get_temp_dir(), "error_deferred_build")
with ops_lib.Graph().as_default(), session.Session() as sess:
- variables.Variable(1.0)
+ variables.VariableV1(1.0)
saver = saver_module.Saver(defer_build=True)
with self.assertRaisesRegexp(RuntimeError, "build"):
saver.save(sess, save_path)
@@ -643,18 +643,18 @@ class SaverTest(test.TestCase):
def testDeferredBuild(self):
save_path = os.path.join(self.get_temp_dir(), "deferred_build")
with session.Session("", graph=ops_lib.Graph()) as sess:
- one = variables.Variable(1.0)
+ one = variables.VariableV1(1.0)
save = saver_module.Saver(defer_build=True)
# if build is not deferred, saver cannot save the `twos`.
- twos = variables.Variable([2.0, 2.0, 2.0])
+ twos = variables.VariableV1([2.0, 2.0, 2.0])
init = variables.global_variables_initializer()
save.build()
init.run()
save.save(sess, save_path)
with session.Session("", graph=ops_lib.Graph()) as sess:
- one = variables.Variable(0.0)
- twos = variables.Variable([0.0, 0.0, 0.0])
+ one = variables.VariableV1(0.0)
+ twos = variables.VariableV1([0.0, 0.0, 0.0])
# Saver with no arg, defaults to 'all variables'.
save = saver_module.Saver()
save.restore(sess, save_path)
@@ -664,7 +664,7 @@ class SaverTest(test.TestCase):
def testReshape(self):
save_path = os.path.join(self.get_temp_dir(), "variables_reshape")
with session.Session("", graph=ops_lib.Graph()) as sess:
- var = variables.Variable([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
+ var = variables.VariableV1([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
init = variables.global_variables_initializer()
save = saver_module.Saver()
init.run()
@@ -672,7 +672,7 @@ class SaverTest(test.TestCase):
# Error when restoring with default reshape=False
with session.Session("", graph=ops_lib.Graph()) as sess:
- var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
+ var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
save = saver_module.Saver()
with self.assertRaisesRegexp(
errors_impl.InvalidArgumentError,
@@ -681,7 +681,7 @@ class SaverTest(test.TestCase):
# Restored to new shape with reshape=True
with session.Session("", graph=ops_lib.Graph()) as sess:
- var = variables.Variable([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
+ var = variables.VariableV1([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0]])
save = saver_module.Saver(reshape=True)
save.restore(sess, save_path)
self.assertAllClose([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], var.eval())
@@ -731,8 +731,8 @@ class SaverTest(test.TestCase):
for save_path in paths:
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(20.0, name="v1")
+ v0 = variables.VariableV1(10.0, name="v0")
+ v1 = variables.VariableV1(20.0, name="v1")
save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
init_all_op = variables.global_variables_initializer()
@@ -770,8 +770,8 @@ class SaverTest(test.TestCase):
# Build a graph with 2 parameter nodes, and Save and
# Restore nodes for them.
- v0 = variables.Variable(10.0, name="v0")
- v1 = variables.Variable(20.0, name="v1")
+ v0 = variables.VariableV1(10.0, name="v0")
+ v1 = variables.VariableV1(20.0, name="v1")
save = saver_module.Saver({"v0": v0, "v1": v1}, restore_sequentially=True)
init_all_op = variables.global_variables_initializer()
@@ -859,10 +859,10 @@ class SaveRestoreShardedTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(10, name="v0")
+ v0 = variables.VariableV1(10, name="v0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
with sess.graph.device("/cpu:1"):
- v1 = variables.Variable(20, name="v1")
+ v1 = variables.VariableV1(20, name="v1")
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
@@ -890,7 +890,7 @@ class SaveRestoreShardedTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(111, name="v0")
+ v0 = variables.VariableV1(111, name="v0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
save = saver_module.Saver(
{
@@ -914,7 +914,7 @@ class SaveRestoreShardedTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v1 = variables.Variable(222)
+ v1 = variables.VariableV1(222)
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
@@ -938,10 +938,10 @@ class SaveRestoreShardedTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(111, name="v0")
+ v0 = variables.VariableV1(111, name="v0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
with sess.graph.device("/cpu:1"):
- v1 = variables.Variable(222, name="v1")
+ v1 = variables.VariableV1(222, name="v1")
t1 = saver_test_utils.CheckpointedOp(name="t1")
save = saver_module.Saver(
{
@@ -984,7 +984,7 @@ class SaveRestoreShardedTest(test.TestCase):
def testSaverDef(self):
with self.cached_session():
- v0 = variables.Variable(123, name="v0")
+ v0 = variables.VariableV1(123, name="v0")
save = saver_module.Saver({"v0": v0}, sharded=True)
sd = save.as_saver_def()
self.assertTrue(sd.sharded)
@@ -1023,7 +1023,7 @@ class SaveRestoreShardedTest(test.TestCase):
if use_resource:
vs = [resource_variable_ops.ResourceVariable(rnd, name=var_name)]
else:
- vs = [variables.Variable(rnd, name=var_name)]
+ vs = [variables.VariableV1(rnd, name=var_name)]
variables.global_variables_initializer().run()
if call_saver_with_dict:
@@ -1054,7 +1054,7 @@ class SaveRestoreShardedTest(test.TestCase):
]
else:
new_vs = [
- variables.Variable(
+ variables.VariableV1(
array_ops.zeros(
shape=var_full_shape), # != original contents.
name=var_name)
@@ -1210,7 +1210,7 @@ class MaxToKeepTest(test.TestCase):
save_dir = self._get_test_dir("max_to_keep_non_sharded")
with self.cached_session() as sess:
- v = variables.Variable(10.0, name="v")
+ v = variables.VariableV1(10.0, name="v")
save = saver_module.Saver({"v": v}, max_to_keep=2)
variables.global_variables_initializer().run()
self.assertEqual([], save.last_checkpoints)
@@ -1389,9 +1389,9 @@ class MaxToKeepTest(test.TestCase):
target="",
config=config_pb2.ConfigProto(device_count={"CPU": 2})) as sess:
with sess.graph.device("/cpu:0"):
- v0 = variables.Variable(111, name="v0")
+ v0 = variables.VariableV1(111, name="v0")
with sess.graph.device("/cpu:1"):
- v1 = variables.Variable(222, name="v1")
+ v1 = variables.VariableV1(222, name="v1")
save = saver_module.Saver(
{
"v0": v0,
@@ -1448,7 +1448,7 @@ class MaxToKeepTest(test.TestCase):
save_dir2 = self._get_test_dir("max_to_keep_0")
with self.cached_session() as sess:
- v = variables.Variable(10.0, name="v")
+ v = variables.VariableV1(10.0, name="v")
variables.global_variables_initializer().run()
# Test max_to_keep being None.
@@ -1475,7 +1475,7 @@ class MaxToKeepTest(test.TestCase):
save_dir = self._get_test_dir("no_meta_graph")
with self.cached_session() as sess:
- v = variables.Variable(10.0, name="v")
+ v = variables.VariableV1(10.0, name="v")
save = saver_module.Saver({"v": v})
variables.global_variables_initializer().run()
@@ -1632,13 +1632,13 @@ class MetaGraphTest(test.TestCase):
filename = os.path.join(test_dir, "metafile")
with self.cached_session():
# Creates a graph.
- v0 = variables.Variable(1.0, name="v0")
+ v0 = variables.VariableV1(1.0, name="v0")
control_flow_ops.cond(
math_ops.less(v0, 10), lambda: math_ops.add(v0, 1),
lambda: math_ops.subtract(v0, 1))
control_flow_ops.while_loop(lambda i: math_ops.less(i, 10),
lambda i: math_ops.add(i, 1), [v0])
- var = variables.Variable(constant_op.constant(0, dtype=dtypes.int64))
+ var = variables.VariableV1(constant_op.constant(0, dtype=dtypes.int64))
count_up_to = var.count_up_to(3)
input_queue = data_flow_ops.FIFOQueue(
30, dtypes.float32, shared_name="collection_queue")
@@ -1687,7 +1687,7 @@ class MetaGraphTest(test.TestCase):
def testAddCollectionDefFails(self):
with self.cached_session():
# Creates a graph.
- v0 = variables.Variable(10.0, name="v0")
+ v0 = variables.VariableV1(10.0, name="v0")
# Creates a saver.
save = saver_module.Saver({"v0": v0})
# Generates MetaGraphDef.
@@ -1711,8 +1711,8 @@ class MetaGraphTest(test.TestCase):
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
with self.session(graph=ops_lib.Graph()) as sess:
# Creates a graph.
- v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
- v1 = variables.Variable(11.0, name="v1")
+ v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
+ v1 = variables.VariableV1(11.0, name="v1")
# Creates 2 savers.
saver0 = saver_module.Saver({"v0": v0}, name="saver0")
saver1 = saver_module.Saver({"v1": v1}, name="saver1")
@@ -1788,8 +1788,8 @@ class MetaGraphTest(test.TestCase):
saver1_ckpt = os.path.join(test_dir, "saver1.ckpt")
with self.session(graph=ops_lib.Graph()) as sess:
# Creates a graph.
- v0 = variables.Variable([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
- v1 = variables.Variable(11.0, name="v1")
+ v0 = variables.VariableV1([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], name="v0")
+ v1 = variables.VariableV1(11.0, name="v1")
# Creates 2 savers.
saver0 = saver_module.Saver({"v0": v0}, name="saver0")
@@ -1840,7 +1840,7 @@ class MetaGraphTest(test.TestCase):
filename = os.path.join(test_dir, "metafile")
with self.session(graph=ops_lib.Graph()):
# Creates a graph.
- variables.Variable(10.0, name="v0")
+ variables.VariableV1(10.0, name="v0")
# Exports the graph as binary format.
saver_module.export_meta_graph(filename, as_text=False)
with self.session(graph=ops_lib.Graph()):
@@ -1871,8 +1871,8 @@ class MetaGraphTest(test.TestCase):
test_dir = self._get_test_dir("slice_saver")
filename = os.path.join(test_dir, "metafile")
with self.cached_session():
- v1 = variables.Variable([20.0], name="v1")
- v2 = variables.Variable([20.0], name="v2")
+ v1 = variables.VariableV1([20.0], name="v1")
+ v2 = variables.VariableV1([20.0], name="v2")
v2._set_save_slice_info(
variables.Variable.SaveSliceInfo("v1", [1], [0], [1]))
@@ -1899,7 +1899,7 @@ class MetaGraphTest(test.TestCase):
# Hidden 1
images = constant_op.constant(1.2, dtypes.float32, shape=[100, 28])
with ops_lib.name_scope("hidden1"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[28, 128], stddev=1.0 / math.sqrt(float(28))),
name="weights")
@@ -1907,7 +1907,7 @@ class MetaGraphTest(test.TestCase):
# the save and restore of control flow context (which doesn't make any
# sense here from a machine learning perspective). The typical biases is
# a simple Variable without the conditions.
- biases = variables.Variable(
+ biases = variables.VariableV1(
control_flow_ops.cond(
math_ops.less(random.random(), 0.5),
lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
@@ -1915,7 +1915,7 @@ class MetaGraphTest(test.TestCase):
hidden1 = nn_ops.relu(math_ops.matmul(images, weights) + biases)
# Hidden 2
with ops_lib.name_scope("hidden2"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[128, 32], stddev=1.0 / math.sqrt(float(128))),
name="weights")
@@ -1933,15 +1933,16 @@ class MetaGraphTest(test.TestCase):
_, biases = control_flow_ops.while_loop(
loop_cond, loop_body,
- [constant_op.constant(0), variables.Variable(array_ops.zeros([32]))])
+ [constant_op.constant(0),
+ variables.VariableV1(array_ops.zeros([32]))])
hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
# Linear
with ops_lib.name_scope("softmax_linear"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[32, 10], stddev=1.0 / math.sqrt(float(32))),
name="weights")
- biases = variables.Variable(array_ops.zeros([10]), name="biases")
+ biases = variables.VariableV1(array_ops.zeros([10]), name="biases")
logits = math_ops.matmul(hidden2, weights) + biases
ops_lib.add_to_collection("logits", logits)
init_all_op = variables.global_variables_initializer()
@@ -2028,7 +2029,7 @@ class MetaGraphTest(test.TestCase):
# Create while loop using `outer_body_fn`.
with ops_lib.Graph().as_default():
- var = variables.Variable(0.0)
+ var = variables.VariableV1(0.0)
var_name = var.name
output = graph_fn(var)
output_name = output.name
@@ -2122,8 +2123,8 @@ class MetaGraphTest(test.TestCase):
def testStrippedOpListDef(self):
with self.cached_session():
# Creates a graph.
- v0 = variables.Variable(0.0)
- var = variables.Variable(10.0)
+ v0 = variables.VariableV1(0.0)
+ var = variables.VariableV1(10.0)
math_ops.add(v0, var)
@function.Defun(dtypes.float32)
@@ -2161,8 +2162,8 @@ class MetaGraphTest(test.TestCase):
# With strip_default_attrs enabled, attributes "T" (float32) and "Tout"
# (complex64) in the "Complex" op must be removed.
with self.cached_session():
- real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
- imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
@@ -2178,8 +2179,8 @@ class MetaGraphTest(test.TestCase):
# (complex64) in the "Complex" op must *not* be removed, even if they map
# to their defaults.
with self.session(graph=ops_lib.Graph()):
- real_num = variables.Variable(1.0, dtype=dtypes.float32, name="real")
- imag_num = variables.Variable(2.0, dtype=dtypes.float32, name="imag")
+ real_num = variables.VariableV1(1.0, dtype=dtypes.float32, name="real")
+ imag_num = variables.VariableV1(2.0, dtype=dtypes.float32, name="imag")
math_ops.complex(real_num, imag_num, name="complex")
save = saver_module.Saver({"real_num": real_num, "imag_num": imag_num})
@@ -2198,9 +2199,9 @@ class MetaGraphTest(test.TestCase):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
with session.Session() as sess:
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
- bias = variables.Variable(array_ops.zeros([10]), name="bias")
+ bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
@@ -2243,7 +2244,7 @@ class MetaGraphTest(test.TestCase):
self.assertIsNone(new_saver_1)
# Create a variable in graph_2 under scope "my_scope".
- variables.Variable(array_ops.zeros([10]), name="my_scope/my_var")
+ variables.VariableV1(array_ops.zeros([10]), name="my_scope/my_var")
sess.run(variables.global_variables_initializer())
# Restore the checkpoint into a different scope "subgraph_2".
new_saver_2 = saver_module.import_meta_graph(
@@ -2268,9 +2269,9 @@ class MetaGraphTest(test.TestCase):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
with session.Session() as sess:
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
- bias = variables.Variable(array_ops.zeros([10]), name="bias")
+ bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias, name="logits")
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
@@ -2299,9 +2300,9 @@ class MetaGraphTest(test.TestCase):
with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
- bias = variables.Variable(array_ops.zeros([10]), name="bias")
+ bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
@@ -2332,9 +2333,9 @@ class MetaGraphTest(test.TestCase):
with ops_lib.device("/job:ps/replica:0/task:0/device:GPU:0"):
image = array_ops.placeholder(dtypes.float32, [None, 784], name="image")
label = array_ops.placeholder(dtypes.float32, [None, 10], name="label")
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.random_uniform([784, 10]), name="weights")
- bias = variables.Variable(array_ops.zeros([10]), name="bias")
+ bias = variables.VariableV1(array_ops.zeros([10]), name="bias")
logit = nn_ops.relu(math_ops.matmul(image, weights) + bias)
nn_ops.softmax(logit, name="prediction")
cost = nn_ops.softmax_cross_entropy_with_logits(labels=label,
@@ -2385,9 +2386,9 @@ class CheckpointReaderTest(test.TestCase):
def testDebugString(self):
# Builds a graph.
- v0 = variables.Variable(
+ v0 = variables.VariableV1(
[[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
- v1 = variables.Variable(
+ v1 = variables.VariableV1(
[[[1], [2]], [[3], [4]], [[5], [6]]], dtype=dtypes.float32, name="v1")
init_all_op = variables.global_variables_initializer()
save = saver_module.Saver(
@@ -2444,7 +2445,8 @@ class WriteGraphTest(test.TestCase):
def testWriteGraph(self):
test_dir = self._get_test_dir("write_graph_dir")
- variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
+ variables.VariableV1(
+ [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
path = graph_io.write_graph(ops_lib.get_default_graph(),
os.path.join(test_dir, "l1"), "graph.pbtxt")
truth = os.path.join(test_dir, "l1", "graph.pbtxt")
@@ -2453,7 +2455,8 @@ class WriteGraphTest(test.TestCase):
def testRecursiveCreate(self):
test_dir = self._get_test_dir("deep_dir")
- variables.Variable([[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
+ variables.VariableV1(
+ [[1, 2, 3], [4, 5, 6]], dtype=dtypes.float32, name="v0")
path = graph_io.write_graph(ops_lib.get_default_graph().as_graph_def(),
os.path.join(test_dir, "l1", "l2", "l3"),
"graph.pbtxt")
@@ -2477,7 +2480,7 @@ class ScopedGraphTest(test.TestCase):
images = constant_op.constant(
1.2, dtypes.float32, shape=[100, 28], name="images")
with ops_lib.name_scope("hidden1"):
- weights1 = variables.Variable(
+ weights1 = variables.VariableV1(
random_ops.truncated_normal(
[28, 128], stddev=1.0 / math.sqrt(float(28))),
name="weights")
@@ -2485,7 +2488,7 @@ class ScopedGraphTest(test.TestCase):
# coverage the save and restore of control flow context (which doesn't
# make any sense here from a machine learning perspective). The typical
# biases is a simple Variable without the conditions.
- biases1 = variables.Variable(
+ biases1 = variables.VariableV1(
control_flow_ops.cond(
math_ops.less(random.random(), 0.5),
lambda: array_ops.ones([128]), lambda: array_ops.zeros([128])),
@@ -2494,7 +2497,7 @@ class ScopedGraphTest(test.TestCase):
# Hidden 2
with ops_lib.name_scope("hidden2"):
- weights2 = variables.Variable(
+ weights2 = variables.VariableV1(
random_ops.truncated_normal(
[128, 32], stddev=1.0 / math.sqrt(float(128))),
name="weights")
@@ -2511,16 +2514,16 @@ class ScopedGraphTest(test.TestCase):
return it + 1, biases2
_, biases2 = control_flow_ops.while_loop(loop_cond, loop_body, [
- constant_op.constant(0), variables.Variable(array_ops.zeros([32]))
+ constant_op.constant(0), variables.VariableV1(array_ops.zeros([32]))
])
hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights2) + biases2)
# Linear
with ops_lib.name_scope("softmax_linear"):
- weights3 = variables.Variable(
+ weights3 = variables.VariableV1(
random_ops.truncated_normal(
[32, 10], stddev=1.0 / math.sqrt(float(32))),
name="weights")
- biases3 = variables.Variable(array_ops.zeros([10]), name="biases")
+ biases3 = variables.VariableV1(array_ops.zeros([10]), name="biases")
logits = math_ops.matmul(hidden2, weights3) + biases3
ops_lib.add_to_collection("logits", logits)
@@ -2566,7 +2569,7 @@ class ScopedGraphTest(test.TestCase):
with graph.as_default():
# Hidden 2
with ops_lib.name_scope("hidden2"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[128, 32], stddev=1.0 / math.sqrt(float(128))),
name="weights")
@@ -2583,16 +2586,16 @@ class ScopedGraphTest(test.TestCase):
return it + 1, biases
_, biases = control_flow_ops.while_loop(loop_cond, loop_body, [
- constant_op.constant(0), variables.Variable(array_ops.zeros([32]))
+ constant_op.constant(0), variables.VariableV1(array_ops.zeros([32]))
])
hidden2 = nn_ops.relu(math_ops.matmul(hidden1, weights) + biases)
# Linear
with ops_lib.name_scope("softmax_linear"):
- weights = variables.Variable(
+ weights = variables.VariableV1(
random_ops.truncated_normal(
[32, 10], stddev=1.0 / math.sqrt(float(32))),
name="weights")
- biases = variables.Variable(array_ops.zeros([10]), name="biases")
+ biases = variables.VariableV1(array_ops.zeros([10]), name="biases")
logits = math_ops.matmul(hidden2, weights) + biases
ops_lib.add_to_collection("logits", logits)
@@ -2629,9 +2632,9 @@ class ScopedGraphTest(test.TestCase):
with ops_lib.name_scope("hidden1"):
images = constant_op.constant(
1.0, dtypes.float32, shape=[3, 2], name="images")
- weights1 = variables.Variable(
+ weights1 = variables.VariableV1(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
- biases1 = variables.Variable([0.1] * 3, name="biases")
+ biases1 = variables.VariableV1([0.1] * 3, name="biases")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
# Run the graph and save scoped checkpoint.
@@ -2685,9 +2688,9 @@ class ScopedGraphTest(test.TestCase):
with ops_lib.name_scope("hidden1"):
images = constant_op.constant(
1.0, dtypes.float32, shape=[3, 2], name="images")
- weights1 = variables.Variable(
+ weights1 = variables.VariableV1(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], name="weights")
- biases1 = variables.Variable([0.1] * 3, name="biases")
+ biases1 = variables.VariableV1([0.1] * 3, name="biases")
nn_ops.relu(math_ops.matmul(images, weights1) + biases1, name="relu")
# Run the graph and save scoped checkpoint.
@@ -2720,12 +2723,12 @@ class ScopedGraphTest(test.TestCase):
graph = ops_lib.Graph()
with graph.as_default():
with ops_lib.name_scope("hidden1"):
- variable1 = variables.Variable([1.0], name="variable1")
+ variable1 = variables.VariableV1([1.0], name="variable1")
saver1 = saver_module.Saver(var_list=[variable1])
graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver1)
with ops_lib.name_scope("hidden2"):
- variable2 = variables.Variable([2.0], name="variable2")
+ variable2 = variables.VariableV1([2.0], name="variable2")
saver2 = saver_module.Saver(var_list=[variable2], name="hidden2/")
graph.add_to_collection(ops_lib.GraphKeys.SAVERS, saver2)
@@ -2978,7 +2981,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
with ops_lib.Graph().as_default() as g:
- a = variables.Variable(1., name="a")
+ a = variables.VariableV1(1., name="a")
a_saver = saver_module.Saver([a])
with self.session(graph=g) as sess:
@@ -2986,7 +2989,7 @@ class CheckpointableCompatibilityTests(test.TestCase):
save_path = a_saver.save(sess=sess, save_path=checkpoint_prefix)
with ops_lib.Graph().as_default() as g:
- a = variables.Variable([1.], name="a")
+ a = variables.VariableV1([1.], name="a")
a_saver = saver_module.Saver([a])
with self.session(graph=g) as sess:
with self.assertRaisesRegexp(
diff --git a/tensorflow/python/training/server_lib_same_variables_no_clear_test.py b/tensorflow/python/training/server_lib_same_variables_no_clear_test.py
index c7e84e9ba1..5aa7f45c2b 100644
--- a/tensorflow/python/training/server_lib_same_variables_no_clear_test.py
+++ b/tensorflow/python/training/server_lib_same_variables_no_clear_test.py
@@ -37,8 +37,8 @@ class SameVariablesNoClearTest(test.TestCase):
server = server_lib.Server.create_local_server()
with session.Session(server.target) as sess_1:
- v0 = variables.Variable([[2, 1]], name="v0")
- v1 = variables.Variable([[1], [2]], name="v1")
+ v0 = variables.VariableV1([[2, 1]], name="v0")
+ v1 = variables.VariableV1([[1], [2]], name="v1")
v2 = math_ops.matmul(v0, v1)
sess_1.run([v0.initializer, v1.initializer])
self.assertAllEqual([[4]], sess_1.run(v2))
diff --git a/tensorflow/python/training/server_lib_test.py b/tensorflow/python/training/server_lib_test.py
index 063044f0d0..cf995707fc 100644
--- a/tensorflow/python/training/server_lib_test.py
+++ b/tensorflow/python/training/server_lib_test.py
@@ -76,9 +76,9 @@ class GrpcServerTest(test.TestCase):
def testResetFails(self):
# Creates variable with container name.
with ops.container("test0"):
- v0 = variables.Variable(1.0, name="v0")
+ v0 = variables.VariableV1(1.0, name="v0")
# Creates variable with default container.
- v1 = variables.Variable(2.0, name="v1")
+ v1 = variables.VariableV1(2.0, name="v1")
# Verifies resetting the non-existent target returns error.
with self.assertRaises(errors_impl.NotFoundError):
session.Session.reset("nonexistent", ["test0"])
@@ -234,8 +234,8 @@ class GrpcServerTest(test.TestCase):
[0.], dtype=dtypes.float32))
self.assertIsNotNone(input_queue)
- var = variables.Variable(1., dtype=dtypes.float32, trainable=False,
- name="var")
+ var = variables.VariableV1(1., dtype=dtypes.float32, trainable=False,
+ name="var")
sess.run(variables.global_variables_initializer())
queue_runner_impl.start_queue_runners(sess)
@@ -245,7 +245,7 @@ class GrpcServerTest(test.TestCase):
server = self._cached_server
init_value = array_ops.placeholder(dtypes.int32)
- v = variables.Variable(init_value, validate_shape=False, name="v")
+ v = variables.VariableV1(init_value, validate_shape=False, name="v")
sharing_config = config_pb2.ConfigProto(isolate_session_state=False)
sharing_sess_0 = session.Session(server.target, config=sharing_config)
@@ -302,7 +302,7 @@ class GrpcServerTest(test.TestCase):
isolate_config = config_pb2.ConfigProto(isolate_session_state=True)
with ops.Graph().as_default():
- w_vector = variables.Variable([1, 2, 3], name="w")
+ w_vector = variables.VariableV1([1, 2, 3], name="w")
with session.Session(server.target, config=sharing_config) as sess:
with self.assertRaises(errors_impl.FailedPreconditionError):
sess.run(w_vector)
@@ -310,20 +310,20 @@ class GrpcServerTest(test.TestCase):
self.assertAllEqual([1, 2, 3], sess.run(w_vector))
with ops.Graph().as_default():
- w_vector = variables.Variable([4, 5, 6], name="w")
+ w_vector = variables.VariableV1([4, 5, 6], name="w")
with session.Session(server.target, config=sharing_config) as sess:
self.assertAllEqual([1, 2, 3], sess.run(w_vector))
sess.run(w_vector.initializer)
self.assertAllEqual([4, 5, 6], sess.run(w_vector))
with ops.Graph().as_default():
- w_scalar = variables.Variable(86, name="w")
+ w_scalar = variables.VariableV1(86, name="w")
with session.Session(server.target, config=sharing_config) as sess:
with self.assertRaises(errors_impl.InvalidArgumentError):
sess.run(w_scalar.initializer)
with ops.Graph().as_default():
- w_scalar = variables.Variable(37, name="w")
+ w_scalar = variables.VariableV1(37, name="w")
with session.Session(server.target, config=isolate_config) as sess:
with self.assertRaises(errors_impl.FailedPreconditionError):
sess.run(w_scalar)
diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py
index a2e0645ba8..cd313c2ce0 100644
--- a/tensorflow/python/training/session_manager.py
+++ b/tensorflow/python/training/session_manager.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util.tf_export import tf_export
@@ -182,6 +183,12 @@ class SessionManager(object):
"""
self._target = master
sess = session.Session(self._target, graph=self._graph, config=config)
+ # TODO(jhseu): Delete once tpu.initialize_system() goes away.
+ initialize_ops = (
+ distribution_strategy_context.get_distribution_strategy().initialize()
+ )
+ if initialize_ops:
+ sess.run(initialize_ops)
if checkpoint_dir and checkpoint_filename_with_path:
raise ValueError("Can not provide both checkpoint_dir and "
diff --git a/tensorflow/python/training/session_manager_test.py b/tensorflow/python/training/session_manager_test.py
index f1d18f7704..2b5c3b01de 100644
--- a/tensorflow/python/training/session_manager_test.py
+++ b/tensorflow/python/training/session_manager_test.py
@@ -40,7 +40,7 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionSucceeds(self):
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
sess = sm.prepare_session(
@@ -50,7 +50,7 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionSucceedsWithInitFeedDict(self):
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
- v = variables.Variable(p, name="v")
+ v = variables.VariableV1(p, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
sess = sm.prepare_session(
@@ -61,7 +61,7 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionSucceedsWithInitFn(self):
with ops.Graph().as_default():
- v = variables.Variable([125], name="v")
+ v = variables.VariableV1([125], name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
sess = sm.prepare_session(
@@ -79,7 +79,7 @@ class SessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
@@ -97,7 +97,7 @@ class SessionManagerTest(test.TestCase):
# Renames the checkpoint directory.
os.rename(checkpoint_dir, checkpoint_dir2)
gfile.MakeDirs(checkpoint_dir)
- v = variables.Variable([6.0, 7.0, 8.0], name="v")
+ v = variables.VariableV1([6.0, 7.0, 8.0], name="v")
with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
session_manager.SessionManager(
@@ -134,7 +134,7 @@ class SessionManagerTest(test.TestCase):
checkpoint_filename_with_path=None):
# Create a new Graph and SessionManager and recover from a checkpoint.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
+ v = variables.VariableV1(2, name="v")
with session_lib.Session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
sm2 = session_manager.SessionManager(
@@ -162,7 +162,7 @@ class SessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
+ v = variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
@@ -186,7 +186,7 @@ class SessionManagerTest(test.TestCase):
def testWaitForSessionReturnsNoneAfterTimeout(self):
with ops.Graph().as_default():
- variables.Variable(1, name="v")
+ variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables(),
recovery_wait_secs=1)
@@ -217,7 +217,7 @@ class SessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
+ v = variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
@@ -230,8 +230,8 @@ class SessionManagerTest(test.TestCase):
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(2, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -275,7 +275,7 @@ class SessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
+ v = variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.report_uninitialized_variables())
saver = saver_lib.Saver({"v": v})
@@ -288,8 +288,8 @@ class SessionManagerTest(test.TestCase):
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(2, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -321,7 +321,7 @@ class SessionManagerTest(test.TestCase):
# local_init_op exactly once, regardless of whether the session was
# successfully recovered.
with ops.Graph().as_default():
- w = variables.Variable(
+ w = variables.VariableV1(
1,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -356,8 +356,8 @@ class SessionManagerTest(test.TestCase):
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(2, name="v")
+ w = variables.VariableV1(
1,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -389,8 +389,8 @@ class SessionManagerTest(test.TestCase):
def testWaitForSessionLocalInit(self):
server = server_lib.Server.create_local_server()
with ops.Graph().as_default() as graph:
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -420,8 +420,8 @@ class SessionManagerTest(test.TestCase):
def testWaitForSessionWithReadyForLocalInitOpFailsToReadyLocal(self):
with ops.Graph().as_default() as graph:
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -439,8 +439,8 @@ class SessionManagerTest(test.TestCase):
def testWaitForSessionInsufficientReadyForLocalInitCheck(self):
with ops.Graph().as_default() as graph:
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -456,13 +456,13 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionWithReadyForLocalInitOp(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- x = variables.Variable(
+ x = variables.VariableV1(
3 * v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -495,25 +495,25 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionWithPartialInitOp(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w")
- x = variables.Variable(
+ x = variables.VariableV1(
3 * v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="x")
# TODO(b/70206927): Use ResourceVariables once they are handled properly.
- v_res = variables.Variable(1, name="v_res")
- w_res = variables.Variable(
+ v_res = variables.VariableV1(1, name="v_res")
+ w_res = variables.VariableV1(
v_res,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
name="w_res")
- x_res = variables.Variable(
+ x_res = variables.VariableV1(
3 * v_res,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -565,7 +565,7 @@ class SessionManagerTest(test.TestCase):
# cyclic dependencies.
with ops.Graph().as_default():
i = control_flow_ops.while_loop(lambda i: i < 1, lambda i: i + 1, [0])
- v = variables.Variable(array_ops.identity(i), name="v")
+ v = variables.VariableV1(array_ops.identity(i), name="v")
with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
sm = session_manager.SessionManager(
@@ -579,8 +579,8 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionDidNotInitLocalVariable(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -596,8 +596,8 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionDidNotInitLocalVariableList(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -613,8 +613,8 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionWithReadyNotReadyForLocal(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -634,8 +634,8 @@ class SessionManagerTest(test.TestCase):
def testPrepareSessionWithInsufficientReadyForLocalInitCheck(self):
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
- w = variables.Variable(
+ v = variables.VariableV1(1, name="v")
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -656,7 +656,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
def testPrepareSessionSucceeds(self):
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
sess = sm.prepare_session(
@@ -666,7 +666,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
def testPrepareSessionSucceedsWithInitFeedDict(self):
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
- v = variables.Variable(p, name="v")
+ v = variables.VariableV1(p, name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
sess = sm.prepare_session(
@@ -677,7 +677,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
def testPrepareSessionSucceedsWithInitFn(self):
with ops.Graph().as_default():
- v = variables.Variable([125], name="v")
+ v = variables.VariableV1([125], name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
sess = sm.prepare_session(
@@ -695,7 +695,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
saver = saver_lib.Saver({"v": v})
@@ -713,7 +713,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
# Renames the checkpoint directory.
os.rename(checkpoint_dir, checkpoint_dir2)
gfile.MakeDirs(checkpoint_dir)
- v = variables.Variable([6.0, 7.0, 8.0], name="v")
+ v = variables.VariableV1([6.0, 7.0, 8.0], name="v")
with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
session_manager.SessionManager(
@@ -755,7 +755,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
gfile.MakeDirs(checkpoint_dir)
with ops.Graph().as_default():
- v = variables.Variable(1, name="v")
+ v = variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized())
saver = saver_lib.Saver({"v": v})
@@ -768,7 +768,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
os.path.join(checkpoint_dir, "recover_session_checkpoint"))
# Create a new Graph and SessionManager and recover.
with ops.Graph().as_default():
- v = variables.Variable(2, name="v")
+ v = variables.VariableV1(2, name="v")
with self.cached_session():
self.assertEqual(False, variables.is_variable_initialized(v).eval())
sm2 = session_manager.SessionManager(
@@ -785,7 +785,7 @@ class ObsoleteSessionManagerTest(test.TestCase):
def testWaitForSessionReturnsNoneAfterTimeout(self):
with ops.Graph().as_default():
- variables.Variable(1, name="v")
+ variables.VariableV1(1, name="v")
sm = session_manager.SessionManager(
ready_op=variables.assert_variables_initialized(),
recovery_wait_secs=1)
diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py
index 0755364bbe..a5e626d320 100644
--- a/tensorflow/python/training/supervisor.py
+++ b/tensorflow/python/training/supervisor.py
@@ -242,10 +242,9 @@ class Supervisor(object):
ready_for_local_init_op: 1-D string `Tensor`. This tensor is evaluated by
supervisors in `prepare_or_wait_for_session()` to check if the model is
ready to run the local_init_op.
- The model is considered ready if it returns an empty array. Defaults to
- the tensor returned from
- `tf.report_uninitialized_variables(tf.global_variables())`. If `None`,
- the model is not checked for readiness before running local_init_op.
+ The model is considered ready if it returns an empty array. Defaults to
+ `None`. If `None`, the model is not checked for readiness before running
+ local_init_op.
is_chief: If True, create a chief supervisor in charge of initializing
and restoring the model. If False, create a supervisor that relies
on a chief supervisor for inits and restore.
diff --git a/tensorflow/python/training/supervisor_test.py b/tensorflow/python/training/supervisor_test.py
index caf6eba3e0..7cd99d8680 100644
--- a/tensorflow/python/training/supervisor_test.py
+++ b/tensorflow/python/training/supervisor_test.py
@@ -423,7 +423,7 @@ class SupervisorTest(test.TestCase):
def testLogdirButExplicitlyNoSummaryWriter(self):
logdir = self._test_dir("explicit_no_summary_writer")
with ops.Graph().as_default():
- variables.Variable([1.0], name="foo")
+ variables.VariableV1([1.0], name="foo")
summary.scalar("c1", constant_op.constant(1))
summary.scalar("c2", constant_op.constant(2))
summary.scalar("c3", constant_op.constant(3))
@@ -491,7 +491,7 @@ class SupervisorTest(test.TestCase):
def testNoLogdirSucceeds(self):
with ops.Graph().as_default():
- variables.Variable([1.0, 2.0, 3.0])
+ variables.VariableV1([1.0, 2.0, 3.0])
sv = supervisor.Supervisor(logdir="", summary_op=None)
sess = sv.prepare_or_wait_for_session("")
sess.close()
@@ -499,7 +499,7 @@ class SupervisorTest(test.TestCase):
def testUseSessionManager(self):
with ops.Graph().as_default():
- variables.Variable([1.0, 2.0, 3.0])
+ variables.VariableV1([1.0, 2.0, 3.0])
sm = session_manager_lib.SessionManager()
# Pass in session_manager. The additional init_op is ignored.
sv = supervisor.Supervisor(logdir="", session_manager=sm)
@@ -508,7 +508,7 @@ class SupervisorTest(test.TestCase):
def testInitOp(self):
logdir = self._test_dir("default_init_op")
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0])
+ v = variables.VariableV1([1.0, 2.0, 3.0])
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
self.assertAllClose([1.0, 2.0, 3.0], sess.run(v))
@@ -517,7 +517,7 @@ class SupervisorTest(test.TestCase):
def testInitFn(self):
logdir = self._test_dir("default_init_op")
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0])
+ v = variables.VariableV1([1.0, 2.0, 3.0])
def _init_fn(sess):
sess.run(v.initializer)
@@ -531,7 +531,7 @@ class SupervisorTest(test.TestCase):
logdir = self._test_dir("feed_dict_init_op")
with ops.Graph().as_default():
p = array_ops.placeholder(dtypes.float32, shape=(3,))
- v = variables.Variable(p, name="v")
+ v = variables.VariableV1(p, name="v")
sv = supervisor.Supervisor(
logdir=logdir,
init_op=variables.global_variables_initializer(),
@@ -550,10 +550,10 @@ class SupervisorTest(test.TestCase):
g = ops.Graph()
with g.as_default():
with ops.device("/job:local"):
- v = variables.Variable(
+ v = variables.VariableV1(
1, name="default_ready_for_local_init_op_v_" + str(uid))
vadd = v.assign_add(1)
- w = variables.Variable(
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -590,7 +590,7 @@ class SupervisorTest(test.TestCase):
# Create a checkpoint.
with ops.Graph().as_default():
- v = variables.Variable(
+ v = variables.VariableV1(
10.0, name="ready_for_local_init_op_restore_v_" + str(uid))
summary.scalar("ready_for_local_init_op_restore_v_" + str(uid), v)
sv = supervisor.Supervisor(logdir=logdir)
@@ -607,10 +607,10 @@ class SupervisorTest(test.TestCase):
g = ops.Graph()
with g.as_default():
with ops.device("/job:local"):
- v = variables.Variable(
+ v = variables.VariableV1(
1.0, name="ready_for_local_init_op_restore_v_" + str(uid))
vadd = v.assign_add(1)
- w = variables.Variable(
+ w = variables.VariableV1(
v,
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES],
@@ -642,13 +642,13 @@ class SupervisorTest(test.TestCase):
logdir = self._test_dir("default_local_init_op")
with ops.Graph().as_default():
# A local variable.
- v = variables.Variable(
+ v = variables.VariableV1(
[1.0, 2.0, 3.0],
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES])
# An entity which is initialized through a TABLE_INITIALIZER.
- w = variables.Variable([4, 5, 6], trainable=False, collections=[])
+ w = variables.VariableV1([4, 5, 6], trainable=False, collections=[])
ops.add_to_collection(ops.GraphKeys.TABLE_INITIALIZERS, w.initializer)
# This shouldn't add a variable to the VARIABLES collection responsible
@@ -668,7 +668,7 @@ class SupervisorTest(test.TestCase):
with ops.Graph().as_default():
with ops.device("/job:localhost"):
# A local variable.
- v = variables.Variable(
+ v = variables.VariableV1(
[1.0, 2.0, 3.0],
trainable=False,
collections=[ops.GraphKeys.LOCAL_VARIABLES])
@@ -687,8 +687,8 @@ class SupervisorTest(test.TestCase):
server = server_lib.Server.create_local_server()
logdir = self._test_dir("default_init_op_fails")
with ops.Graph().as_default():
- v = variables.Variable([1.0, 2.0, 3.0], name="v")
- variables.Variable([4.0, 5.0, 6.0], name="w")
+ v = variables.VariableV1([1.0, 2.0, 3.0], name="v")
+ variables.VariableV1([4.0, 5.0, 6.0], name="w")
# w will not be initialized.
sv = supervisor.Supervisor(logdir=logdir, init_op=v.initializer)
with self.assertRaisesRegexp(RuntimeError,
@@ -699,11 +699,11 @@ class SupervisorTest(test.TestCase):
server = server_lib.Server.create_local_server()
logdir = self._test_dir("default_init_op_fails_for_local_variable")
with ops.Graph().as_default():
- v = variables.Variable(
+ v = variables.VariableV1(
[1.0, 2.0, 3.0],
name="v",
collections=[ops.GraphKeys.LOCAL_VARIABLES])
- variables.Variable(
+ variables.VariableV1(
[1.0, 2.0, 3.0],
name="w",
collections=[ops.GraphKeys.LOCAL_VARIABLES])
@@ -716,17 +716,17 @@ class SupervisorTest(test.TestCase):
def testSetupFail(self):
logdir = self._test_dir("setup_fail")
with ops.Graph().as_default():
- variables.Variable([1.0, 2.0, 3.0], name="v")
+ variables.VariableV1([1.0, 2.0, 3.0], name="v")
with self.assertRaisesRegexp(ValueError, "must have their device set"):
supervisor.Supervisor(logdir=logdir, is_chief=False)
with ops.Graph().as_default(), ops.device("/job:ps"):
- variables.Variable([1.0, 2.0, 3.0], name="v")
+ variables.VariableV1([1.0, 2.0, 3.0], name="v")
supervisor.Supervisor(logdir=logdir, is_chief=False)
def testDefaultGlobalStep(self):
logdir = self._test_dir("default_global_step")
with ops.Graph().as_default():
- variables.Variable(287, name="global_step")
+ variables.VariableV1(287, name="global_step")
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
self.assertEquals(287, sess.run(sv.global_step))
@@ -735,7 +735,7 @@ class SupervisorTest(test.TestCase):
def testRestoreFromMetaGraph(self):
logdir = self._test_dir("restore_from_meta_graph")
with ops.Graph().as_default():
- variables.Variable(1, name="v0")
+ variables.VariableV1(1, name="v0")
sv = supervisor.Supervisor(logdir=logdir)
sess = sv.prepare_or_wait_for_session("")
filename = sv.saver.save(sess, sv.save_path)
@@ -757,7 +757,7 @@ class SupervisorTest(test.TestCase):
logdir = self._test_dir("standard_services_without_global_step")
# Create a checkpoint.
with ops.Graph().as_default():
- v = variables.Variable([1.0], name="foo")
+ v = variables.VariableV1([1.0], name="foo")
summary.scalar("v", v[0])
sv = supervisor.Supervisor(logdir=logdir)
meta_graph_def = meta_graph.create_meta_graph_def(
@@ -796,7 +796,7 @@ class SupervisorTest(test.TestCase):
self.assertRaises(StopIteration, lambda: next(rr))
# There should be a checkpoint file with the variable "foo"
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([10.10], name="foo")
+ v = variables.VariableV1([10.10], name="foo")
sav = saver_lib.Saver([v])
sav.restore(sess, save_path)
self.assertEqual(1.0, v.eval()[0])
@@ -807,7 +807,7 @@ class SupervisorTest(test.TestCase):
logdir = self._test_dir("standard_services_with_global_step")
# Create a checkpoint.
with ops.Graph().as_default():
- v = variables.Variable([123], name="global_step")
+ v = variables.VariableV1([123], name="global_step")
sv = supervisor.Supervisor(logdir=logdir)
meta_graph_def = meta_graph.create_meta_graph_def(
saver_def=sv.saver.saver_def)
@@ -860,7 +860,7 @@ class SupervisorTest(test.TestCase):
self.assertRaises(StopIteration, lambda: next(rr))
# There should be a checkpoint file with the variable "foo"
with ops.Graph().as_default(), self.cached_session() as sess:
- v = variables.Variable([-12], name="global_step")
+ v = variables.VariableV1([-12], name="global_step")
sav = saver_lib.Saver([v])
sav.restore(sess, save_path)
self.assertEqual(123, v.eval()[0])
diff --git a/tensorflow/python/training/sync_replicas_optimizer.py b/tensorflow/python/training/sync_replicas_optimizer.py
index 7afaa92699..6a3756fba9 100644
--- a/tensorflow/python/training/sync_replicas_optimizer.py
+++ b/tensorflow/python/training/sync_replicas_optimizer.py
@@ -78,7 +78,11 @@ class SyncReplicasOptimizer(optimizer.Optimizer):
4. Only after all variables have been updated, increment the global step.
5. Only after step 4, pushes `global_step` in the `token_queue`, once for
each worker replica. The workers can now fetch the global step, use it to
- update its local_step variable and start the next batch.
+ update its local_step variable and start the next batch. Please note that
+ some workers can consume multiple minibatches, while some may not consume
+ even one. This is because each worker fetches minibatches as long as
+ a token exists. If one worker is stuck for some reason and does not
+ consume a token, another worker can use it.
For the replicas:
diff --git a/tensorflow/python/training/sync_replicas_optimizer_test.py b/tensorflow/python/training/sync_replicas_optimizer_test.py
index fff17402e2..1ef8756ef6 100644
--- a/tensorflow/python/training/sync_replicas_optimizer_test.py
+++ b/tensorflow/python/training/sync_replicas_optimizer_test.py
@@ -40,11 +40,12 @@ def get_workers(num_workers, replicas_to_aggregate, workers):
is_chief = (worker_id == 0)
with graph.as_default():
with ops.device("/job:ps/task:0"):
- global_step = variables.Variable(0, name="global_step", trainable=False)
- var_0 = variables.Variable(0.0, name="v0")
+ global_step = variables.VariableV1(
+ 0, name="global_step", trainable=False)
+ var_0 = variables.VariableV1(0.0, name="v0")
with ops.device("/job:ps/task:1"):
- var_1 = variables.Variable(1.0, name="v1")
- var_sparse = variables.Variable([[3.0], [4.0]], name="v_sparse")
+ var_1 = variables.VariableV1(1.0, name="v1")
+ var_sparse = variables.VariableV1([[3.0], [4.0]], name="v_sparse")
with ops.device("/job:worker/task:" + str(worker_id)):
grads_0 = constant_op.constant(0.1 + worker_id * 0.2)
@@ -272,8 +273,8 @@ class SyncReplicasOptimizerHookTest(test.TestCase):
replicas_to_aggregate=1,
total_num_replicas=1)
hook = opt.make_session_run_hook(True)
- v = variables.Variable([0.])
- global_step = variables.Variable(0, name="global_step", trainable=False)
+ v = variables.VariableV1([0.])
+ global_step = variables.VariableV1(0, name="global_step", trainable=False)
opt.minimize(v, global_step=global_step)
hook.begin()
@@ -282,8 +283,8 @@ class SyncReplicasOptimizerHookTest(test.TestCase):
opt=adam.AdamOptimizer(0.01),
replicas_to_aggregate=1,
total_num_replicas=1)
- v = variables.Variable([0.], name="fetch_variable_test")
- global_step = variables.Variable(0, name="global_step", trainable=False)
+ v = variables.VariableV1([0.], name="fetch_variable_test")
+ global_step = variables.VariableV1(0, name="global_step", trainable=False)
opt.minimize(v, global_step=global_step)
opt_variables = opt.variables()
beta1_power, beta2_power = opt._opt._get_beta_accumulators()
diff --git a/tensorflow/python/training/training_ops_test.py b/tensorflow/python/training/training_ops_test.py
index d131a11067..f410ceaaff 100644
--- a/tensorflow/python/training/training_ops_test.py
+++ b/tensorflow/python/training/training_ops_test.py
@@ -51,7 +51,7 @@ class TrainingOpsTest(TensorFlowTestCase):
def _testTypes(self, x, alpha, delta, use_gpu=None):
self.setUp()
with self.test_session(use_gpu=use_gpu):
- var = variables.Variable(x)
+ var = variables.VariableV1(x)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
apply_sgd = training_ops.apply_gradient_descent(var, alpha, delta)
@@ -70,8 +70,8 @@ class TrainingOpsTest(TensorFlowTestCase):
def _testTypesForAdagrad(self, x, y, lr, grad, use_gpu=None):
self.setUp()
with self.test_session(use_gpu=use_gpu):
- var = variables.Variable(x)
- accum = variables.Variable(y)
+ var = variables.VariableV1(x)
+ accum = variables.VariableV1(y)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
@@ -94,9 +94,9 @@ class TrainingOpsTest(TensorFlowTestCase):
lr_power=-0.5):
self.setUp()
with self.test_session(use_gpu=use_gpu):
- var = variables.Variable(x)
- accum = variables.Variable(y)
- linear = variables.Variable(z)
+ var = variables.VariableV1(x)
+ accum = variables.VariableV1(y)
+ linear = variables.VariableV1(z)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
@@ -148,8 +148,8 @@ class TrainingOpsTest(TensorFlowTestCase):
def _testTypesForSparseAdagrad(self, x, y, lr, grad, indices):
self.setUp()
with self.test_session(use_gpu=False):
- var = variables.Variable(x)
- accum = variables.Variable(y)
+ var = variables.VariableV1(x)
+ accum = variables.VariableV1(y)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
@@ -178,9 +178,9 @@ class TrainingOpsTest(TensorFlowTestCase):
lr_power=-0.5):
self.setUp()
with self.test_session(use_gpu=False):
- var = variables.Variable(x)
- accum = variables.Variable(y)
- linear = variables.Variable(z)
+ var = variables.VariableV1(x)
+ accum = variables.VariableV1(y)
+ linear = variables.VariableV1(z)
variables.global_variables_initializer().run()
self.assertAllCloseAccordingToType(x, var.eval())
@@ -257,9 +257,9 @@ class TrainingOpsTest(TensorFlowTestCase):
def _testTypesForAdam(self, var, m, v, grad, use_gpu):
self.setUp()
with self.test_session(use_gpu=use_gpu):
- var_t = variables.Variable(var)
- m_t = variables.Variable(m)
- v_t = variables.Variable(v)
+ var_t = variables.VariableV1(var)
+ m_t = variables.VariableV1(m)
+ v_t = variables.VariableV1(v)
t = 1
beta1 = np.array(0.9, dtype=var.dtype)
@@ -270,8 +270,8 @@ class TrainingOpsTest(TensorFlowTestCase):
epsilon = np.array(1e-8, dtype=var.dtype)
beta1_t = constant_op.constant(beta1, self._toType(var.dtype), [])
beta2_t = constant_op.constant(beta2, self._toType(var.dtype), [])
- beta1_power_t = variables.Variable(beta1_power)
- beta2_power_t = variables.Variable(beta2_power)
+ beta1_power_t = variables.VariableV1(beta1_power)
+ beta2_power_t = variables.VariableV1(beta2_power)
lr_t = constant_op.constant(lr, self._toType(var.dtype), [])
epsilon_t = constant_op.constant(epsilon, self._toType(var.dtype), [])
variables.global_variables_initializer().run()
diff --git a/tensorflow/python/training/training_util_test.py b/tensorflow/python/training/training_util_test.py
index 6cc177e0e8..ba64e785ac 100644
--- a/tensorflow/python/training/training_util_test.py
+++ b/tensorflow/python/training/training_util_test.py
@@ -49,7 +49,7 @@ class GlobalStepTest(test.TestCase):
def test_invalid_shape(self):
with ops.Graph().as_default() as g:
self.assertIsNone(training_util.get_global_step())
- variables.Variable(
+ variables.VariableV1(
[0],
trainable=False,
dtype=dtypes.int32,
@@ -73,7 +73,7 @@ class GlobalStepTest(test.TestCase):
def test_get_global_step(self):
with ops.Graph().as_default() as g:
self.assertIsNone(training_util.get_global_step())
- variables.Variable(
+ variables.VariableV1(
0,
trainable=False,
dtype=dtypes.int32,
diff --git a/tensorflow/python/util/function_utils.py b/tensorflow/python/util/function_utils.py
index 4e9b07e20a..a56dfbff8e 100644
--- a/tensorflow/python/util/function_utils.py
+++ b/tensorflow/python/util/function_utils.py
@@ -59,6 +59,29 @@ def fn_args(fn):
return tuple(args)
+def has_kwargs(fn):
+ """Returns whether the passed callable has **kwargs in its signature.
+
+ Args:
+ fn: Function, or function-like object (e.g., result of `functools.partial`).
+
+ Returns:
+ `bool`: if `fn` has **kwargs in its signature.
+
+ Raises:
+ `TypeError`: If fn is not a Function, or function-like object.
+ """
+ if isinstance(fn, functools.partial):
+ fn = fn.func
+ elif _is_callable_object(fn):
+ fn = fn.__call__
+ elif not callable(fn):
+ raise TypeError(
+ 'fn should be a function-like object, but is of type {}.'.format(
+ type(fn)))
+ return tf_inspect.getfullargspec(fn).varkw is not None
+
+
def get_func_name(func):
"""Returns name of passed callable."""
_, func = tf_decorator.unwrap(func)
diff --git a/tensorflow/python/util/function_utils_test.py b/tensorflow/python/util/function_utils_test.py
index 1588328c26..e5b0843e4b 100644
--- a/tensorflow/python/util/function_utils_test.py
+++ b/tensorflow/python/util/function_utils_test.py
@@ -135,6 +135,101 @@ class FnArgsTest(test.TestCase):
self.assertEqual(3, double_wrapped_fn(a=3))
+class HasKwargsTest(test.TestCase):
+
+ def test_simple_function(self):
+
+ fn_has_kwargs = lambda **x: x
+ self.assertTrue(function_utils.has_kwargs(fn_has_kwargs))
+
+ fn_has_no_kwargs = lambda x: x
+ self.assertFalse(function_utils.has_kwargs(fn_has_no_kwargs))
+
+ def test_callable(self):
+
+ class FooHasKwargs(object):
+
+ def __call__(self, **x):
+ del x
+ self.assertTrue(function_utils.has_kwargs(FooHasKwargs()))
+
+ class FooHasNoKwargs(object):
+
+ def __call__(self, x):
+ del x
+ self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs()))
+
+ def test_bounded_method(self):
+
+ class FooHasKwargs(object):
+
+ def fn(self, **x):
+ del x
+ self.assertTrue(function_utils.has_kwargs(FooHasKwargs().fn))
+
+ class FooHasNoKwargs(object):
+
+ def fn(self, x):
+ del x
+ self.assertFalse(function_utils.has_kwargs(FooHasNoKwargs().fn))
+
+ def test_partial_function(self):
+ expected_test_arg = 123
+
+ def fn_has_kwargs(test_arg, **x):
+ if test_arg != expected_test_arg:
+ return ValueError('partial fn does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_kwargs, test_arg=123)
+ self.assertTrue(function_utils.has_kwargs(wrapped_fn))
+ some_kwargs = dict(x=1, y=2, z=3)
+ self.assertEqual(wrapped_fn(**some_kwargs), some_kwargs)
+
+ def fn_has_no_kwargs(x, test_arg):
+ if test_arg != expected_test_arg:
+ return ValueError('partial fn does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg=123)
+ self.assertFalse(function_utils.has_kwargs(wrapped_fn))
+ some_arg = 1
+ self.assertEqual(wrapped_fn(some_arg), some_arg)
+
+ def test_double_partial(self):
+ expected_test_arg1 = 123
+ expected_test_arg2 = 456
+
+ def fn_has_kwargs(test_arg1, test_arg2, **x):
+ if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
+ return ValueError('partial does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_kwargs, test_arg2=456)
+ double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
+
+ self.assertTrue(function_utils.has_kwargs(double_wrapped_fn))
+ some_kwargs = dict(x=1, y=2, z=3)
+ self.assertEqual(double_wrapped_fn(**some_kwargs), some_kwargs)
+
+ def fn_has_no_kwargs(x, test_arg1, test_arg2):
+ if test_arg1 != expected_test_arg1 or test_arg2 != expected_test_arg2:
+ return ValueError('partial does not work correctly')
+ return x
+
+ wrapped_fn = functools.partial(fn_has_no_kwargs, test_arg2=456)
+ double_wrapped_fn = functools.partial(wrapped_fn, test_arg1=123)
+
+ self.assertFalse(function_utils.has_kwargs(double_wrapped_fn))
+ some_arg = 1
+ self.assertEqual(double_wrapped_fn(some_arg), some_arg)
+
+ def test_raises_type_error(self):
+ with self.assertRaisesRegexp(
+ TypeError, 'fn should be a function-like object'):
+ function_utils.has_kwargs('not a function')
+
+
class GetFuncNameTest(test.TestCase):
def testWithSimpleFunction(self):
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 653ca525dc..d67dbde304 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -19,6 +19,9 @@ This module can perform operations on nested structures. A nested structure is a
Python sequence, tuple (including `namedtuple`), or dict that can contain
further sequences, tuples, and dicts.
+attr.s decorated classes (http://www.attrs.org) are also supported, in the
+same way as `namedtuple`.
+
The utilities here assume (and do not check) that the nested structures form a
'tree', i.e., no references in the structure of the input of these functions
should be recursive.
@@ -38,6 +41,12 @@ import six as _six
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
+def _get_attrs_values(obj):
+ """Returns the list of values from an attrs instance."""
+ attrs = getattr(obj.__class__, "__attrs_attrs__")
+ return [getattr(obj, a.name) for a in attrs]
+
+
def _sorted(dict_):
"""Returns a sorted list of the dict keys, with error if keys not sortable."""
try:
@@ -64,6 +73,7 @@ def _is_namedtuple(instance, strict=False):
# See the swig file (util.i) for documentation.
_is_mapping = _pywrap_tensorflow.IsMapping
+_is_attrs = _pywrap_tensorflow.IsAttrs
def _sequence_like(instance, args):
@@ -85,7 +95,7 @@ def _sequence_like(instance, args):
# corresponding `OrderedDict` to pack it back).
result = dict(zip(_sorted(instance), args))
return type(instance)((key, result[key]) for key in _six.iterkeys(instance))
- elif _is_namedtuple(instance):
+ elif _is_namedtuple(instance) or _is_attrs(instance):
return type(instance)(*args)
else:
# Not a namedtuple
@@ -93,6 +103,7 @@ def _sequence_like(instance, args):
def _yield_value(iterable):
+ """Yields the next value from the given iterable."""
if _is_mapping(iterable):
# Iterate through dictionaries in a deterministic order by sorting the
# keys. Notice this means that we ignore the original order of `OrderedDict`
@@ -101,6 +112,9 @@ def _yield_value(iterable):
# corresponding `OrderedDict` to pack it back).
for key in _sorted(iterable):
yield iterable[key]
+ elif _is_attrs(iterable):
+ for value in _get_attrs_values(iterable):
+ yield value
else:
for value in iterable:
yield value
@@ -805,5 +819,5 @@ def flatten_with_joined_string_paths(structure, separator="/"):
return list(zip(flat_string_paths, flatten(structure)))
-_pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence)
-_pywrap_tensorflow.RegisterMappingClass(_collections.Mapping)
+_pywrap_tensorflow.RegisterType("Mapping", _collections.Mapping)
+_pywrap_tensorflow.RegisterType("Sequence", _collections.Sequence)
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index bfb4c6f910..e03a8daaa1 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -33,6 +33,11 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
from tensorflow.python.util import nest
+try:
+ import attr # pylint:disable=g-import-not-at-top
+except ImportError:
+ attr = None
+
class _CustomMapping(collections.Mapping):
@@ -53,6 +58,35 @@ class NestTest(parameterized.TestCase, test.TestCase):
PointXY = collections.namedtuple("Point", ["x", "y"]) # pylint: disable=invalid-name
+ if attr:
+ class BadAttr(object):
+ """Class that has a non-iterable __attrs_attrs__."""
+ __attrs_attrs__ = None
+
+ @attr.s
+ class SampleAttr(object):
+ field1 = attr.ib()
+ field2 = attr.ib()
+
+ @test_util.assert_no_new_pyobjects_executing_eagerly
+ def testAttrsFlattenAndPack(self):
+ if attr is None:
+ self.skipTest("attr module is unavailable.")
+
+ field_values = [1, 2]
+ sample_attr = NestTest.SampleAttr(*field_values)
+ self.assertFalse(nest._is_attrs(field_values))
+ self.assertTrue(nest._is_attrs(sample_attr))
+ flat = nest.flatten(sample_attr)
+ self.assertEqual(field_values, flat)
+ restructured_from_flat = nest.pack_sequence_as(sample_attr, flat)
+ self.assertIsInstance(restructured_from_flat, NestTest.SampleAttr)
+ self.assertEqual(restructured_from_flat, sample_attr)
+
+ # Check that flatten fails if attributes are not iterable
+ with self.assertRaisesRegexp(TypeError, "object is not iterable"):
+ flat = nest.flatten(NestTest.BadAttr())
+
@test_util.assert_no_new_pyobjects_executing_eagerly
def testFlattenAndPack(self):
structure = ((3, 4), 5, (6, 7, (9, 10), 8))
diff --git a/tensorflow/python/util/protobuf/compare.py b/tensorflow/python/util/protobuf/compare.py
index a0e6bf65cf..3a3af4bffa 100644
--- a/tensorflow/python/util/protobuf/compare.py
+++ b/tensorflow/python/util/protobuf/compare.py
@@ -63,6 +63,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import difflib
import six
@@ -101,10 +102,19 @@ def assertProtoEqual(self, a, b, check_initialized=True, # pylint: disable=inva
if normalize_numbers:
NormalizeNumberFields(pb)
- self.assertMultiLineEqual(
- text_format.MessageToString(a, descriptor_pool=pool),
- text_format.MessageToString(b, descriptor_pool=pool),
- msg=msg)
+ a_str = text_format.MessageToString(a, descriptor_pool=pool)
+ b_str = text_format.MessageToString(b, descriptor_pool=pool)
+
+ # Some Python versions would perform regular diff instead of multi-line
+ # diff if string is longer than 2**16. We substitute this behavior
+ # with a call to unified_diff instead to have easier-to-read diffs.
+ # For context, see: https://bugs.python.org/issue11763.
+ if len(a_str) < 2**16 and len(b_str) < 2**16:
+ self.assertMultiLineEqual(a_str, b_str, msg=msg)
+ else:
+ diff = '\n' + ''.join(difflib.unified_diff(a_str.splitlines(True),
+ b_str.splitlines(True)))
+ self.fail('%s : %s' % (msg, diff))
def NormalizeNumberFields(pb):
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index 967c872c2a..444e44eaf1 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -36,6 +36,55 @@ else:
'annotations'
])
+if hasattr(_inspect, 'getfullargspec'):
+ _getfullargspec = _inspect.getfullargspec # pylint: disable=invalid-name
+
+ def _getargspec(target):
+ """A python3 version of getargspec.
+
+ Calls `getfullargspec` and assigns args, varargs,
+ varkw, and defaults to a python 2/3 compatible `ArgSpec`.
+
+ The parameter name 'varkw' is changed to 'keywords' to fit the
+ `ArgSpec` struct.
+
+ Args:
+ target: the target object to inspect.
+
+ Returns:
+ An ArgSpec with args, varargs, keywords, and defaults parameters
+ from FullArgSpec.
+ """
+ fullargspecs = getfullargspec(target)
+ argspecs = ArgSpec(
+ args=fullargspecs.args,
+ varargs=fullargspecs.varargs,
+ keywords=fullargspecs.varkw,
+ defaults=fullargspecs.defaults)
+ return argspecs
+else:
+ _getargspec = _inspect.getargspec
+
+ def _getfullargspec(target):
+ """A python2 version of getfullargspec.
+
+ Args:
+ target: the target object to inspect.
+
+ Returns:
+ A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations.
+ """
+ argspecs = getargspec(target)
+ fullargspecs = FullArgSpec(
+ args=argspecs.args,
+ varargs=argspecs.varargs,
+ varkw=argspecs.keywords,
+ defaults=argspecs.defaults,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+ return fullargspecs
+
def currentframe():
"""TFDecorator-aware replacement for inspect.currentframe."""
@@ -43,16 +92,18 @@ def currentframe():
def getargspec(obj):
- """TFDecorator-aware replacement for inspect.getargspec.
+ """TFDecorator-aware replacement for `inspect.getargspec`.
+
+ Note: `getfullargspec` is recommended as the python 2/3 compatible
+ replacement for this function.
Args:
- obj: A function, partial function, or callable object, possibly
- decorated.
+ obj: A function, partial function, or callable object, possibly decorated.
Returns:
The `ArgSpec` that describes the signature of the outermost decorator that
- changes the callable's signature. If the callable is not decorated,
- `inspect.getargspec()` will be called directly on the object.
+ changes the callable's signature, or the `ArgSpec` that describes
+ the object if not decorated.
Raises:
ValueError: When callable's signature can not be expressed with
@@ -72,24 +123,24 @@ def getargspec(obj):
try:
# Python3 will handle most callables here (not partial).
- return _inspect.getargspec(target)
+ return _getargspec(target)
except TypeError:
pass
if isinstance(target, type):
try:
- return _inspect.getargspec(target.__init__)
+ return _getargspec(target.__init__)
except TypeError:
pass
try:
- return _inspect.getargspec(target.__new__)
+ return _getargspec(target.__new__)
except TypeError:
pass
# The `type(target)` ensures that if a class is received we don't return
# the signature of it's __call__ method.
- return _inspect.getargspec(type(target).__call__)
+ return _getargspec(type(target).__call__)
def _get_argspec_for_partial(obj):
@@ -172,30 +223,6 @@ def _get_argspec_for_partial(obj):
return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:]))
-if hasattr(_inspect, 'getfullargspec'):
- _getfullargspec = _inspect.getfullargspec
-else:
-
- def _getfullargspec(target):
- """A python2 version of getfullargspec.
-
- Args:
- target: the target object to inspect.
- Returns:
- A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations.
- """
- argspecs = getargspec(target)
- fullargspecs = FullArgSpec(
- args=argspecs.args,
- varargs=argspecs.varargs,
- varkw=argspecs.keywords,
- defaults=argspecs.defaults,
- kwonlyargs=[],
- kwonlydefaults=None,
- annotations={})
- return fullargspecs
-
-
def getfullargspec(obj):
"""TFDecorator-aware replacement for `inspect.getfullargspec`.
diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py
index d3b7e4b969..02d075cdff 100644
--- a/tensorflow/python/util/tf_inspect_test.py
+++ b/tensorflow/python/util/tf_inspect_test.py
@@ -122,18 +122,6 @@ class TfInspectTest(test.TestCase):
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
- def testGetFullArgsSpecForPartial(self):
-
- def func(a, b):
- del a, b
-
- partial_function = functools.partial(func, 1)
- argspec = tf_inspect.FullArgSpec(
- args=['b'], varargs=None, varkw=None, defaults=None,
- kwonlyargs=[], kwonlydefaults=None, annotations={})
-
- self.assertEqual(argspec, tf_inspect.getfullargspec(partial_function))
-
def testGetArgSpecOnPartialInvalidArgspec(self):
"""Tests getargspec on partial function that doesn't have valid argspec."""
@@ -303,6 +291,193 @@ class TfInspectTest(test.TestCase):
self.assertEqual(argspec, tf_inspect.getargspec(NewClass))
+ def testGetFullArgSpecOnDecoratorsThatDontProvideFullArgSpec(self):
+ argspec = tf_inspect.getfullargspec(test_decorated_function_with_defaults)
+ self.assertEqual(['a', 'b', 'c'], argspec.args)
+ self.assertEqual((2, 'Hello'), argspec.defaults)
+
+ def testGetFullArgSpecOnDecoratorThatChangesFullArgSpec(self):
+ argspec = tf_inspect.FullArgSpec(
+ args=['a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ decorator = tf_decorator.TFDecorator('', test_undecorated_function, '',
+ argspec)
+ self.assertEqual(argspec, tf_inspect.getfullargspec(decorator))
+
+ def testGetFullArgSpecIgnoresDecoratorsThatDontProvideFullArgSpec(self):
+ argspec = tf_inspect.FullArgSpec(
+ args=['a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function,
+ '', argspec)
+ outer_decorator = tf_decorator.TFDecorator('', inner_decorator)
+ self.assertEqual(argspec, tf_inspect.getfullargspec(outer_decorator))
+
+ def testGetFullArgSpecReturnsOutermostDecoratorThatChangesFullArgSpec(self):
+ outer_argspec = tf_inspect.FullArgSpec(
+ args=['a'],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+ inner_argspec = tf_inspect.FullArgSpec(
+ args=['b'],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function,
+ '', inner_argspec)
+ outer_decorator = tf_decorator.TFDecorator('', inner_decorator, '',
+ outer_argspec)
+ self.assertEqual(outer_argspec, tf_inspect.getfullargspec(outer_decorator))
+
+ def testGetFullArgsSpecForPartial(self):
+
+ def func(a, b):
+ del a, b
+
+ partial_function = functools.partial(func, 1)
+ argspec = tf_inspect.FullArgSpec(
+ args=['b'],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(partial_function))
+
+ def testGetFullArgSpecOnPartialNoArgumentsLeft(self):
+ """Tests getfullargspec on partial function that prunes all arguments."""
+
+ def func(m, n):
+ return 2 * m + n
+
+ partial_func = functools.partial(func, 7, 10)
+ argspec = tf_inspect.FullArgSpec(
+ args=[],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
+
+ def testGetFullArgSpecOnPartialWithVarargs(self):
+ """Tests getfullargspec on partial function with variable arguments."""
+
+ def func(m, *arg):
+ return m + len(arg)
+
+ partial_func = functools.partial(func, 7, 8)
+ argspec = tf_inspect.FullArgSpec(
+ args=[],
+ varargs='arg',
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
+
+ def testGetFullArgSpecOnPartialWithVarkwargs(self):
+ """Tests getfullargspec.
+
+ Tests on partial function with variable keyword arguments.
+ """
+
+ def func(m, n, **kwarg):
+ return m * n + len(kwarg)
+
+ partial_func = functools.partial(func, 7)
+ argspec = tf_inspect.FullArgSpec(
+ args=['n'],
+ varargs=None,
+ varkw='kwarg',
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
+
+ def testGetFullArgSpecOnCallableObject(self):
+
+ class Callable(object):
+
+ def __call__(self, a, b=1, c='hello'):
+ pass
+
+ argspec = tf_inspect.FullArgSpec(
+ args=['self', 'a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ test_obj = Callable()
+ self.assertEqual(argspec, tf_inspect.getfullargspec(test_obj))
+
+ def testGetFullArgSpecOnInitClass(self):
+
+ class InitClass(object):
+
+ def __init__(self, a, b=1, c='hello'):
+ pass
+
+ argspec = tf_inspect.FullArgSpec(
+ args=['self', 'a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(InitClass))
+
+ def testGetFullArgSpecOnNewClass(self):
+
+ class NewClass(object):
+
+ def __new__(cls, a, b=1, c='hello'):
+ pass
+
+ argspec = tf_inspect.FullArgSpec(
+ args=['cls', 'a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(NewClass))
+
def testGetDoc(self):
self.assertEqual('Test Decorated Function With Defaults Docstring.',
tf_inspect.getdoc(test_decorated_function_with_defaults))
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 562bbdcfeb..7b3e618e84 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -15,9 +15,11 @@ limitations under the License.
#include "tensorflow/python/util/util.h"
#include <functional>
+#include <memory>
#include <unordered_map>
#include <vector>
+#include "absl/memory/memory.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
@@ -27,14 +29,51 @@ limitations under the License.
namespace tensorflow {
namespace swig {
-namespace {
+std::unordered_map<string, PyObject*>* PythonTypesMap() {
+ static auto* m = new std::unordered_map<string, PyObject*>();
+ return m;
+}
+
+PyObject* GetRegisteredType(const string& key) {
+ auto* m = PythonTypesMap();
+ auto it = m->find(key);
+ if (it == m->end()) return nullptr;
+ return it->second;
+}
-// Type object for collections.Sequence. This is set by RegisterSequenceClass.
-PyObject* CollectionsSequenceType = nullptr;
-// Type object for collections.Mapping, set by RegisterMappingClass.
-PyObject* CollectionsMappingType = nullptr;
-PyTypeObject* SparseTensorValueType = nullptr;
+PyObject* RegisterType(PyObject* type_name, PyObject* type) {
+ if (!PyType_Check(type)) {
+ PyErr_SetString(PyExc_TypeError,
+ tensorflow::strings::StrCat("Expecting a type, got ",
+ Py_TYPE(type)->tp_name)
+ .c_str());
+ return nullptr;
+ }
+
+ string key;
+ if (PyBytes_Check(type_name)) {
+ key = PyBytes_AsString(type_name);
+ }
+#if PY_MAJOR_VERSION >= 3
+ if (PyUnicode_Check(type_name)) {
+ key = PyUnicode_AsUTF8(type_name);
+ }
+#endif
+ if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) {
+ PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
+ "Type already registered for ", key)
+ .c_str());
+ return nullptr;
+ }
+
+ Py_INCREF(type);
+ PythonTypesMap()->emplace(key, type);
+
+ Py_RETURN_NONE;
+}
+
+namespace {
const int kMaxItemsInCache = 1024;
bool WarnedThatSetIsNotSequence = false;
@@ -175,240 +214,320 @@ class CachedTypeCheck {
// Returns -1 if an error occurred.
int IsMappingHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
- return PyObject_IsInstance(to_check, CollectionsMappingType);
+ PyObject* collections_mapping_type = GetRegisteredType("Mapping");
+ if (TF_PREDICT_FALSE(collections_mapping_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Mapping type has not been set. "
+ "Please register the type with the identifier "
+ "\"Mapping\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ return PyObject_IsInstance(to_check, collections_mapping_type);
});
if (PyDict_Check(o)) return true;
- if (TF_PREDICT_FALSE(CollectionsMappingType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Mapping type has not been set. "
- "Please call RegisterMappingClass before using this module")
- .c_str());
- return -1;
- }
return check_cache->CachedLookup(o);
}
-// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
+// Returns 1 if `o` is an instance of attrs-decorated class.
// Returns 0 otherwise.
-// Returns -1 if an error occurred.
-int IsSequenceHelper(PyObject* o) {
+int IsAttrsHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
- int is_instance = PyObject_IsInstance(to_check, CollectionsSequenceType);
+ Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
+ if (cls) {
+ return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
+ }
- // Don't cache a failed is_instance check.
- if (is_instance == -1) return -1;
+ // PyObject_GetAttrString returns null on error
+ PyErr_Clear();
+ return 0;
+ });
+ return check_cache->CachedLookup(o);
+}
- return static_cast<int>(is_instance != 0 && !IsString(to_check));
+// Returns 1 if `o` is an object of type IndexedSlices.
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsIndexedSlicesHelper(PyObject* o) {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ PyObject* indexed_slices_type = GetRegisteredType("IndexedSlices");
+ if (TF_PREDICT_FALSE(indexed_slices_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "IndexedSlices type has not been set. "
+ "Please register the type with the identifier "
+ "\"IndexedSlices\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ return PyObject_IsInstance(to_check, indexed_slices_type);
+ });
+ return check_cache->CachedLookup(o);
+}
+
+// Returns 1 if `o` is a Tensor.
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsTensorHelper(PyObject* o) {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ PyObject* tensor_type = GetRegisteredType("Tensor");
+ if (TF_PREDICT_FALSE(tensor_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "Tensor type has not been set. "
+ "Please register the type with the identifier "
+ "\"Tensor\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ return PyObject_IsInstance(to_check, tensor_type);
});
+ return check_cache->CachedLookup(o);
+}
+
+// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsSequenceHelper(PyObject* o) {
// We treat dicts and other mappings as special cases of sequences.
if (IsMappingHelper(o)) return true;
+ if (IsAttrsHelper(o)) return true;
if (PySet_Check(o) && !WarnedThatSetIsNotSequence) {
LOG(WARNING) << "Sets are not currently considered sequences, "
"but this may change in the future, "
"so consider avoiding using them.";
WarnedThatSetIsNotSequence = true;
}
- if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Sequence type has not been set. "
- "Please call RegisterSequenceClass before using this module")
- .c_str());
- return -1;
- }
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ PyObject* collections_sequence_type = GetRegisteredType("Sequence");
+ if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Sequence type has not been set. "
+ "Please register the type with the identifier "
+ "\"Sequence\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ int is_instance = PyObject_IsInstance(to_check, collections_sequence_type);
+
+ // Don't cache a failed is_instance check.
+ if (is_instance == -1) return -1;
+
+ return static_cast<int>(is_instance != 0 && !IsString(to_check));
+ });
return check_cache->CachedLookup(o);
}
-// Implements the same idea as tensorflow.util.nest._yield_value
-// During construction we check if the iterable is a dictionary.
-// If so, we construct a sequence from its sorted keys that will be used
-// for iteration.
-// If not, we construct a sequence directly from the iterable.
-// At each step, we get the next element from the sequence and use it
-// either as a key or return it directly.
-//
-// 'iterable' must not be modified while ValIterator is used.
-class ValIterator {
+// ValueIterator interface
+class ValueIterator {
public:
- explicit ValIterator(PyObject* iterable)
- : dict_(nullptr),
- mapping_(nullptr),
- last_mapping_element_(nullptr),
- seq_(nullptr),
- index_(0) {
- if (PyDict_Check(iterable)) {
- dict_ = iterable;
- // PyDict_Keys returns a list, which can be used with
- // PySequence_Fast_GET_ITEM.
- seq_ = PyDict_Keys(iterable);
- // Iterate through dictionaries in a deterministic order by sorting the
- // keys. Notice this means that we ignore the original order of
- // `OrderedDict` instances. This is intentional, to avoid potential
- // bugs caused by mixing ordered and plain dicts (e.g., flattening
- // a dict but using a corresponding `OrderedDict` to pack it back).
- PyList_Sort(seq_);
- } else if (IsMappingHelper(iterable)) {
- mapping_ = iterable;
- seq_ = MappingKeys(iterable);
- PyList_Sort(seq_);
+ virtual ~ValueIterator() {}
+ virtual Safe_PyObjectPtr next() = 0;
+
+ bool valid() const { return is_valid_; }
+
+ protected:
+ void invalidate() { is_valid_ = false; }
+
+ private:
+ bool is_valid_ = true;
+};
+
+using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
+
+// Iterate through dictionaries in a deterministic order by sorting the
+// keys. Notice this means that we ignore the original order of
+// `OrderedDict` instances. This is intentional, to avoid potential
+// bugs caused by mixing ordered and plain dicts (e.g., flattening
+// a dict but using a corresponding `OrderedDict` to pack it back).
+class DictValueIterator : public ValueIterator {
+ public:
+ explicit DictValueIterator(PyObject* dict)
+ : dict_(dict), keys_(PyDict_Keys(dict)) {
+ if (PyList_Sort(keys_.get()) == -1) {
+ invalidate();
} else {
- seq_ = PySequence_Fast(iterable, "");
+ iter_.reset(PyObject_GetIter(keys_.get()));
}
- size_ = PySequence_Fast_GET_SIZE(seq_);
}
- ~ValIterator() { Py_DECREF(seq_); }
-
- // Return a borrowed reference to the next element from iterable.
- // Return nullptr when iteration is over.
- PyObject* next() {
- if (TF_PREDICT_FALSE(seq_ == nullptr)) {
- return nullptr;
- }
- PyObject* element = nullptr;
- if (index_ < size_) {
- // Both PySequence_Fast_GET_ITEM and PyDict_GetItem return borrowed
- // references. For general mappings, ValIterator keeps a reference to the
- // last retrieved element (and decrefs it before producing the next
- // element) to abstract away the borrowed/new difference.
- element = PySequence_Fast_GET_ITEM(seq_, index_);
- ++index_;
- if (dict_ != nullptr) {
- element = PyDict_GetItem(dict_, element);
- if (element == nullptr) {
- PyErr_SetString(PyExc_RuntimeError,
- "Dictionary was modified during iteration over it");
- return nullptr;
- }
- } else if (mapping_ != nullptr) {
- element = PyObject_GetItem(mapping_, element);
- if (element == nullptr) {
- PyErr_SetString(PyExc_RuntimeError,
- "Mapping was modified during iteration over it");
- return nullptr;
- }
- last_mapping_element_.reset(element);
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
+ if (key) {
+ // PyDict_GetItem returns a borrowed reference.
+ PyObject* elem = PyDict_GetItem(dict_, key.get());
+ if (elem) {
+ Py_INCREF(elem);
+ result.reset(elem);
+ } else {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Dictionary was modified during iteration over it");
}
}
- return element;
+ return result;
}
private:
- // Special casing for things that pass PyDict_Check (faster, no Python calls)
PyObject* dict_;
+ Safe_PyObjectPtr keys_;
+ Safe_PyObjectPtr iter_;
+};
+
+// Iterate over mapping objects by sorting the keys first
+class MappingValueIterator : public ValueIterator {
+ public:
+ explicit MappingValueIterator(PyObject* mapping)
+ : mapping_(mapping), keys_(MappingKeys(mapping)) {
+ if (!keys_ || PyList_Sort(keys_.get()) == -1) {
+ invalidate();
+ } else {
+ iter_.reset(PyObject_GetIter(keys_.get()));
+ }
+ }
- // General mappings which have custom Python logic
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
+ if (key) {
+ // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
+ PyObject* elem = PyObject_GetItem(mapping_, key.get());
+ if (elem) {
+ result.reset(elem);
+ } else {
+ PyErr_SetString(PyExc_RuntimeError,
+ "Mapping was modified during iteration over it");
+ }
+ }
+ return result;
+ }
+
+ private:
PyObject* mapping_;
- Safe_PyObjectPtr last_mapping_element_;
+ Safe_PyObjectPtr keys_;
+ Safe_PyObjectPtr iter_;
+};
- PyObject* seq_;
- Py_ssize_t size_;
+// Iterate over a sequence, by index.
+class SequenceValueIterator : public ValueIterator {
+ public:
+ explicit SequenceValueIterator(PyObject* iterable)
+ : seq_(PySequence_Fast(iterable, "")),
+ size_(PySequence_Fast_GET_SIZE(seq_.get())),
+ index_(0) {}
+
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ if (index_ < size_) {
+ // PySequence_Fast_GET_ITEM returns a borrowed reference.
+ PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
+ ++index_;
+ Py_INCREF(elem);
+ result.reset(elem);
+ }
+
+ return result;
+ }
+
+ private:
+ Safe_PyObjectPtr seq_;
+ const Py_ssize_t size_;
Py_ssize_t index_;
};
-bool IsSparseTensorValueType(PyObject* o) {
- if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
- return false;
+// Just return itself as a single item.
+class SparseTensorValueIterator : public ValueIterator {
+ public:
+ explicit SparseTensorValueIterator(PyObject* tensor) : tensor_(tensor) {
+ Py_INCREF(tensor);
}
- return PyObject_TypeCheck(o, SparseTensorValueType) == 1;
-}
+ Safe_PyObjectPtr next() override { return std::move(tensor_); }
-int IsSequenceForDataHelper(PyObject* o) {
- return IsSequenceHelper(o) == 1 && !PyList_Check(o) &&
- !IsSparseTensorValueType(o);
-}
+ private:
+ Safe_PyObjectPtr tensor_;
+};
-bool GetNextValuesForDict(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
- Safe_PyObjectPtr keys(PyDict_Keys(nested));
- if (PyList_Sort(keys.get()) == -1) return false;
- Py_ssize_t size = PyList_Size(keys.get());
- for (Py_ssize_t i = 0; i < size; ++i) {
- // We know that key and item will not be deleted because nested owns
- // a reference to them and callers of flatten must not modify nested
- // while the method is running.
- PyObject* key = PyList_GET_ITEM(keys.get(), i);
- PyObject* item = PyDict_GetItem(nested, key);
- Py_INCREF(item);
- next_values->emplace_back(item);
+class AttrsValueIterator : public ValueIterator {
+ public:
+ explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
+ Py_INCREF(nested);
+ cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
+ if (cls_) {
+ attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
+ if (attrs_) {
+ iter_.reset(PyObject_GetIter(attrs_.get()));
+ }
+ }
+ if (!iter_ || PyErr_Occurred()) invalidate();
}
- return true;
-}
-bool GetNextValuesForMapping(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
- Safe_PyObjectPtr keys(MappingKeys(nested));
- if (keys.get() == nullptr) {
- return false;
- }
- if (PyList_Sort(keys.get()) == -1) return false;
- Py_ssize_t size = PyList_Size(keys.get());
- for (Py_ssize_t i = 0; i < size; ++i) {
- PyObject* key = PyList_GET_ITEM(keys.get(), i);
- // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
- PyObject* item = PyObject_GetItem(nested, key);
- next_values->emplace_back(item);
+ Safe_PyObjectPtr next() override {
+ Safe_PyObjectPtr result;
+ Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
+ if (item) {
+ Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
+ result.reset(PyObject_GetAttr(nested_.get(), name.get()));
+ }
+
+ return result;
}
- return true;
-}
-bool GetNextValuesForIterable(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
- PyObject* item;
- PyObject* iterator = PyObject_GetIter(nested);
- if (iterator == nullptr || PyErr_Occurred()) {
+ private:
+ Safe_PyObjectPtr nested_;
+ Safe_PyObjectPtr cls_;
+ Safe_PyObjectPtr attrs_;
+ Safe_PyObjectPtr iter_;
+};
+
+bool IsSparseTensorValueType(PyObject* o) {
+ PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue");
+ if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
return false;
}
- while ((item = PyIter_Next(iterator)) != nullptr) {
- next_values->emplace_back(item);
- }
- Py_DECREF(iterator);
- return true;
+
+ return PyObject_TypeCheck(
+ o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
+}
+
+int IsSequenceForDataHelper(PyObject* o) {
+ return IsSequenceHelper(o) == 1 && !PyList_Check(o) &&
+ !IsSparseTensorValueType(o);
}
-// GetNextValues returns the values that the FlattenHelper function will recurse
-// over next.
-bool GetNextValues(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
+ValueIteratorPtr GetValueIterator(PyObject* nested) {
if (PyDict_Check(nested)) {
- // if nested is dictionary, sort it by key and recurse on each value
- return GetNextValuesForDict(nested, next_values);
+ return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) {
- // same treatment as dictionaries, but for custom mapping types
- return GetNextValuesForMapping(nested, next_values);
+ return absl::make_unique<MappingValueIterator>(nested);
+ } else if (IsAttrsHelper(nested)) {
+ return absl::make_unique<AttrsValueIterator>(nested);
+ } else {
+ return absl::make_unique<SequenceValueIterator>(nested);
}
- // iterate and recurse
- return GetNextValuesForIterable(nested, next_values);
}
-// Similar to above, just specialized for the functions in the data pacakage.
-bool GetNextValuesForData(PyObject* nested,
- std::vector<Safe_PyObjectPtr>* next_values) {
+// Similar to above, just specialized for the functions in the data package.
+ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
if (PyDict_Check(nested)) {
- // if nested is dictionary, sort it by key and recurse on each value
- return GetNextValuesForDict(nested, next_values);
+ return absl::make_unique<DictValueIterator>(nested);
} else if (IsMappingHelper(nested)) {
- // same treatment as dictionaries, but for custom mapping types
- return GetNextValuesForMapping(nested, next_values);
+ return absl::make_unique<MappingValueIterator>(nested);
+ } else if (IsAttrsHelper(nested)) {
+ return absl::make_unique<AttrsValueIterator>(nested);
} else if (IsSparseTensorValueType(nested)) {
- // if nested is a SparseTensorValue, just return itself as a single item
- Py_INCREF(nested);
- next_values->emplace_back(nested);
- return true;
+ return absl::make_unique<SparseTensorValueIterator>(nested);
+ } else {
+ return absl::make_unique<SequenceValueIterator>(nested);
}
- // iterate and recurse
- return GetNextValuesForIterable(nested, next_values);
}
bool FlattenHelper(
PyObject* nested, PyObject* list,
const std::function<int(PyObject*)>& is_sequence_helper,
- const std::function<bool(PyObject*, std::vector<Safe_PyObjectPtr>*)>&
- next_values_getter) {
+ const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
// if nested is not a sequence, append itself and exit
int is_seq = is_sequence_helper(nested);
if (is_seq == -1) return false;
@@ -416,16 +535,15 @@ bool FlattenHelper(
return PyList_Append(list, nested) != -1;
}
- std::vector<Safe_PyObjectPtr> next_values;
- // Get the next values to recurse over.
- if (!next_values_getter(nested, &next_values)) return false;
+ ValueIteratorPtr iter = value_iterator_getter(nested);
+ if (!iter->valid()) return false;
- for (const auto& item : next_values) {
+ for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
if (Py_EnterRecursiveCall(" in flatten")) {
return false;
}
- const bool success =
- FlattenHelper(item.get(), list, is_sequence_helper, next_values_getter);
+ const bool success = FlattenHelper(item.get(), list, is_sequence_helper,
+ value_iterator_getter);
Py_LeaveRecursiveCall();
if (!success) {
return false;
@@ -579,22 +697,25 @@ bool AssertSameStructureHelper(
}
}
- ValIterator iter1(o1);
- ValIterator iter2(o2);
+ ValueIteratorPtr iter1 = GetValueIterator(o1);
+ ValueIteratorPtr iter2 = GetValueIterator(o2);
+
+ if (!iter1->valid() || !iter2->valid()) return false;
while (true) {
- PyObject* v1 = iter1.next();
- PyObject* v2 = iter2.next();
- if (v1 != nullptr && v2 != nullptr) {
+ Safe_PyObjectPtr v1 = iter1->next();
+ Safe_PyObjectPtr v2 = iter2->next();
+ if (v1 && v2) {
if (Py_EnterRecursiveCall(" in assert_same_structure")) {
return false;
}
- bool no_internal_errors = AssertSameStructureHelper(
- v1, v2, check_types, error_msg, is_type_error, is_sequence_helper);
+ bool no_internal_errors =
+ AssertSameStructureHelper(v1.get(), v2.get(), check_types, error_msg,
+ is_type_error, is_sequence_helper);
Py_LeaveRecursiveCall();
if (!no_internal_errors) return false;
if (!error_msg->empty()) return true;
- } else if (v1 == nullptr && v2 == nullptr) {
+ } else if (!v1 && !v2) {
// Done with all recursive calls. Structure matched.
return true;
} else {
@@ -610,52 +731,15 @@ bool AssertSameStructureHelper(
} // namespace
-void RegisterSequenceClass(PyObject* sequence_class) {
- if (!PyType_Check(sequence_class)) {
- PyErr_SetString(
- PyExc_TypeError,
- tensorflow::strings::StrCat(
- "Expecting a class definition for `collections.Sequence`. Got ",
- Py_TYPE(sequence_class)->tp_name)
- .c_str());
- return;
- }
- CollectionsSequenceType = sequence_class;
-}
-
-void RegisterMappingClass(PyObject* mapping_class) {
- if (!PyType_Check(mapping_class)) {
- PyErr_SetString(
- PyExc_TypeError,
- tensorflow::strings::StrCat(
- "Expecting a class definition for `collections.Mapping`. Got ",
- Py_TYPE(mapping_class)->tp_name)
- .c_str());
- return;
- }
- CollectionsMappingType = mapping_class;
-}
-
-void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
- if (!PyType_Check(sparse_tensor_value_class)) {
- PyErr_SetString(
- PyExc_TypeError,
- tensorflow::strings::StrCat(
- "Expecting a class definition for `SparseTensorValue`. Got ",
- Py_TYPE(sparse_tensor_value_class)->tp_name)
- .c_str());
- return;
- }
- SparseTensorValueType =
- reinterpret_cast<PyTypeObject*>(sparse_tensor_value_class);
-}
-
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
+bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
+bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
+bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0);
- if (FlattenHelper(nested, list, IsSequenceHelper, GetNextValues)) {
+ if (FlattenHelper(nested, list, IsSequenceHelper, GetValueIterator)) {
return list;
} else {
Py_DECREF(list);
@@ -668,7 +752,7 @@ bool IsSequenceForData(PyObject* o) { return IsSequenceForDataHelper(o) == 1; }
PyObject* FlattenForData(PyObject* nested) {
PyObject* list = PyList_New(0);
if (FlattenHelper(nested, list, IsSequenceForDataHelper,
- GetNextValuesForData)) {
+ GetValueIteratorForData)) {
return list;
} else {
Py_DECREF(list);
@@ -699,13 +783,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) {
}
}
- if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Sequence type has not been set. "
- "Please call RegisterSequenceClass before using this module")
- .c_str());
+ PyObject* collections_sequence_type = GetRegisteredType("Sequence");
+
+ if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Sequence type has not been set. "
+ "Please register the type with the identifier "
+ "\"Sequence\" using RegisterType.")
+ .c_str());
return nullptr;
}
@@ -717,7 +803,8 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) {
}
Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
- int is_instance = PyObject_IsInstance(fields.get(), CollectionsSequenceType);
+ int is_instance =
+ PyObject_IsInstance(fields.get(), collections_sequence_type);
if (is_instance == 0) {
Py_RETURN_FALSE;
} else if (is_instance == -1) {
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 343605285e..f37cd527d8 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -56,6 +56,33 @@ PyObject* IsNamedtuple(PyObject* o, bool strict);
// True if the sequence subclasses mapping.
bool IsMapping(PyObject* o);
+// Returns a true if its input is an instance of an attr.s decorated class.
+//
+// Args:
+// o: the input to be checked.
+//
+// Returns:
+// True if the object is an instance of an attr.s decorated class.
+bool IsAttrs(PyObject* o);
+
+// Returns a true if its input is an ops.Tensor.
+//
+// Args:
+// seq: the input to be checked.
+//
+// Returns:
+// True if the object is a tensor.
+bool IsTensor(PyObject* o);
+
+// Returns a true if its input is an ops.IndexesSlices.
+//
+// Args:
+// seq: the input to be checked.
+//
+// Returns:
+// True if the object is an ops.IndexedSlices.
+bool IsIndexedSlices(PyObject* o);
+
// Implements the same interface as tensorflow.util.nest._same_namedtuples
// Returns Py_True iff the two namedtuples have the same name and fields.
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
@@ -121,18 +148,6 @@ PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types);
// TypeError: The nest is or contains a dict with non-sortable keys.
PyObject* Flatten(PyObject* nested);
-// RegisterSequenceClass is used to pass PyTypeObject for collections.Sequence
-// (which is defined in python) into the C++ world.
-// Alternative approach could be to import the collections modules and retrieve
-// the type from the module. This approach also requires some trigger from
-// Python so that we know that Python interpreter had been initialzied.
-void RegisterSequenceClass(PyObject* sequence_class);
-// Like RegisterSequenceClass, but for collections.Mapping.
-void RegisterMappingClass(PyObject* mapping_class);
-// Similar to the above functions, except for the
-// sparse_tensor.SparseTensorValue class.
-void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class);
-
// The tensorflow.python.data package has its own nest utility that follows very
// slightly different semantics for its functions than the tensorflow.python
// nest utility. Returns a true if its input is a collections.Sequence (except
@@ -158,6 +173,10 @@ PyObject* FlattenForData(PyObject* nested);
PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
bool check_types);
+// RegisterType is used to pass PyTypeObject (which is defined in python) for an
+// arbitrary identifier `type_name` into C++.
+PyObject* RegisterType(PyObject* type_name, PyObject* type);
+
} // namespace swig
} // namespace tensorflow
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 104a615636..3c0ec87fa4 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -28,14 +28,8 @@ limitations under the License.
// for functions in this module because they use python methods that need GIL.
// TODO(iga): Find a way not to leak such definitions across files.
-%unignore tensorflow::swig::RegisterSequenceClass;
-%noexception tensorflow::swig::RegisterSequenceClass;
-
-%unignore tensorflow::swig::RegisterMappingClass;
-%noexception tensorflow::swig::RegisterMappingClass;
-
-%unignore tensorflow::swig::RegisterSparseTensorValueClass;
-%noexception tensorflow::swig::RegisterSparseTensorValueClass;
+%unignore tensorflow::swig::RegisterType;
+%noexception tensorflow::swig::RegisterType;
%feature("docstring") tensorflow::swig::IsSequence
"""Returns a true if its input is a collections.Sequence (except strings).
@@ -65,6 +59,18 @@ Returns:
%unignore tensorflow::swig::IsMapping;
%noexception tensorflow::swig::IsMapping;
+%feature("docstring") tensorflow::swig::IsAttrs
+"""Returns True iff `instance` is an instance of an `attr.s` decorated class.
+
+Args:
+ instance: An instance of a Python object.
+
+Returns:
+ True if `instance` is an instance of an `attr.s` decorated class.
+"""
+%unignore tensorflow::swig::IsAttrs;
+%noexception tensorflow::swig::IsAttrs;
+
%feature("docstring") tensorflow::swig::SameNamedtuples
"Returns True if the two namedtuples have the same name and fields."
%unignore tensorflow::swig::SameNamedtuples;