aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--.bazelrc2
-rw-r--r--configure.py174
-rw-r--r--tensorflow/BUILD41
-rwxr-xr-xtensorflow/c/eager/c_api.cc8
-rwxr-xr-xtensorflow/c/eager/c_api.h3
-rw-r--r--tensorflow/c/eager/tape.h7
-rw-r--r--tensorflow/cc/framework/cc_op_gen.cc10
-rw-r--r--tensorflow/cc/framework/scope.cc33
-rw-r--r--tensorflow/cc/framework/scope.h4
-rw-r--r--tensorflow/cc/framework/scope_internal.h5
-rw-r--r--tensorflow/compiler/jit/BUILD11
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass.cc180
-rw-r--r--tensorflow/compiler/jit/build_xla_ops_pass_test.cc32
-rw-r--r--tensorflow/compiler/jit/deadness_analysis.cc137
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_internal.h4
-rw-r--r--tensorflow/compiler/jit/deadness_analysis_test.cc31
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc56
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc12
-rw-r--r--tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc11
-rw-r--r--tensorflow/compiler/jit/kernels/BUILD1
-rw-r--r--tensorflow/compiler/jit/kernels/xla_ops.cc3
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc6
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc11
-rw-r--r--tensorflow/compiler/jit/partially_decluster_pass.cc7
-rw-r--r--tensorflow/compiler/jit/resource_operation_safety_analysis.cc5
-rw-r--r--tensorflow/compiler/jit/xla_compilation_cache.h6
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc15
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h3
-rw-r--r--tensorflow/compiler/tests/BUILD1
-rw-r--r--tensorflow/compiler/tests/randomized_tests.cc14
-rw-r--r--tensorflow/compiler/tests/stateless_random_ops_test.py2
-rw-r--r--tensorflow/compiler/tf2xla/BUILD2
-rw-r--r--tensorflow/compiler/tf2xla/cc/BUILD7
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc84
-rw-r--r--tensorflow/compiler/tf2xla/resource_operation_table.cc14
-rw-r--r--tensorflow/compiler/tf2xla/resource_operation_table_test.cc3
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.cc30
-rw-r--r--tensorflow/compiler/tf2xla/tf2xla_util.h51
-rw-r--r--tensorflow/compiler/tf2xla/type_util.h8
-rw-r--r--tensorflow/compiler/xla/client/BUILD2
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.cc4
-rw-r--r--tensorflow/compiler/xla/client/xla_builder.h8
-rw-r--r--tensorflow/compiler/xla/literal.cc20
-rw-r--r--tensorflow/compiler/xla/literal_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/BUILD80
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.h10
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_expander.cc1
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.cc9
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_propagation.h17
-rw-r--r--tensorflow/compiler/xla/service/bfloat16_support.cc2
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.cc68
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment.h39
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness.h5
-rw-r--r--tensorflow/compiler/xla/service/buffer_value_containers.h4
-rw-r--r--tensorflow/compiler/xla/service/call_graph.cc9
-rw-r--r--tensorflow/compiler/xla/service/call_graph.h16
-rw-r--r--tensorflow/compiler/xla/service/copy_insertion.cc17
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD5
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc3
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_runtime.cc4
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc12
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h10
-rw-r--r--tensorflow/compiler/xla/service/cpu/target_machine_features.cc1
-rw-r--r--tensorflow/compiler/xla/service/cpu/target_machine_features.h5
-rw-r--r--tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc2
-rw-r--r--tensorflow/compiler/xla/service/defuser.cc3
-rw-r--r--tensorflow/compiler/xla/service/dfs_hlo_visitor.h1
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.cc2
-rw-r--r--tensorflow/compiler/xla/service/generic_transfer_manager.h7
-rw-r--r--tensorflow/compiler/xla/service/gpu/BUILD6
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc2
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc3
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.h4
-rw-r--r--tensorflow/compiler/xla/service/gpu/instruction_fusion.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc6
-rw-r--r--tensorflow/compiler/xla/service/gpu/nvptx_compiler.h9
-rw-r--r--tensorflow/compiler/xla/service/gpu/partition_assignment.cc9
-rw-r--r--tensorflow/compiler/xla/service/gpu/stream_assignment.h4
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.cc26
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator.h27
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto6
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_alias_analysis.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_buffer.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_clone_context.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc80
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h12
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_map.h13
-rw-r--r--tensorflow/compiler/xla/service/hlo_domain_metadata.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.cc2
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator.h15
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc72
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h22
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.cc30
-rw-r--r--tensorflow/compiler/xla/service/hlo_instructions.h3
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.cc42
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_metadata.h14
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.cc6
-rw-r--r--tensorflow/compiler/xla/service/hlo_module_group_util.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_opcode.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.cc10
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_ordering_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.cc276
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser.h5
-rw-r--r--tensorflow/compiler/xla/service/hlo_parser_test.cc227
-rw-r--r--tensorflow/compiler/xla/service/hlo_pass_pipeline.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_reachability.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.h7
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo_schedule.h6
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc8
-rw-r--r--tensorflow/compiler/xla/service/hlo_value.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc6
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.cc5
-rw-r--r--tensorflow/compiler/xla/service/indexed_array_analysis.h4
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.cc8
-rw-r--r--tensorflow/compiler/xla/service/instruction_fusion.h5
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc65
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h13
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment_test.cc44
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/BUILD2
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc6
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h11
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.cc6
-rw-r--r--tensorflow/compiler/xla/service/multi_output_fusion.h3
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.h8
-rw-r--r--tensorflow/compiler/xla/service/reduce_precision_insertion.h1
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc8
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc4
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h16
-rw-r--r--tensorflow/compiler/xla/service/tuple_points_to_analysis.h2
-rw-r--r--tensorflow/compiler/xla/service/while_loop_constant_sinking.cc1
-rw-r--r--tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc16
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc9
-rw-r--r--tensorflow/compiler/xla/shape_util.cc3
-rw-r--r--tensorflow/compiler/xla/tests/BUILD3
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc19
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc8
-rw-r--r--tensorflow/compiler/xrt/ops/xrt_execute_op.cc2
-rw-r--r--tensorflow/compiler/xrt/tests/raw_api_test.cc41
-rw-r--r--tensorflow/contrib/BUILD13
-rw-r--r--tensorflow/contrib/batching/BUILD58
-rw-r--r--tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h21
-rw-r--r--tensorflow/contrib/batching/basic_batch_scheduler.h21
-rw-r--r--tensorflow/contrib/batching/batch_scheduler.h21
-rw-r--r--tensorflow/contrib/batching/serial_device_batch_scheduler.h21
-rw-r--r--tensorflow/contrib/batching/shared_batch_scheduler.h21
-rw-r--r--tensorflow/contrib/batching/test_util/BUILD19
-rw-r--r--tensorflow/contrib/batching/test_util/fake_clock_env.h21
-rw-r--r--tensorflow/contrib/batching/util/BUILD28
-rw-r--r--tensorflow/contrib/batching/util/periodic_function.h20
-rw-r--r--tensorflow/contrib/bigtable/README.md4
-rw-r--r--tensorflow/contrib/bigtable/python/ops/bigtable_api.py4
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py5
-rw-r--r--tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py4
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py18
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py24
-rw-r--r--tensorflow/contrib/cmake/CMakeLists.txt11
-rw-r--r--tensorflow/contrib/cmake/external/jemalloc.cmake50
-rw-r--r--tensorflow/contrib/cmake/external/protobuf.cmake2
-rw-r--r--tensorflow/contrib/cmake/make.bat38
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt4
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake23
-rw-r--r--tensorflow/contrib/crf/python/ops/crf.py10
-rw-r--r--tensorflow/contrib/data/BUILD38
-rw-r--r--tensorflow/contrib/data/README.md18
-rw-r--r--tensorflow/contrib/data/__init__.py11
-rw-r--r--tensorflow/contrib/data/ops/indexed_dataset_ops.cc80
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD552
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py226
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py3
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py62
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py8
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/test_utils.py73
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py526
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD213
-rw-r--r--tensorflow/contrib/data/python/ops/batching.py549
-rw-r--r--tensorflow/contrib/data/python/ops/counter.py13
-rw-r--r--tensorflow/contrib/data/python/ops/enumerate_ops.py15
-rw-r--r--tensorflow/contrib/data/python/ops/error_ops.py38
-rw-r--r--tensorflow/contrib/data/python/ops/get_single_element.py29
-rw-r--r--tensorflow/contrib/data/python/ops/grouping.py441
-rw-r--r--tensorflow/contrib/data/python/ops/interleave_ops.py148
-rw-r--r--tensorflow/contrib/data/python/ops/iterator_ops.py167
-rw-r--r--tensorflow/contrib/data/python/ops/parsing_ops.py107
-rw-r--r--tensorflow/contrib/data/python/ops/prefetching_ops.py485
-rw-r--r--tensorflow/contrib/data/python/ops/random_ops.py34
-rw-r--r--tensorflow/contrib/data/python/ops/readers.py678
-rw-r--r--tensorflow/contrib/data/python/ops/resampling.py260
-rw-r--r--tensorflow/contrib/data/python/ops/scan_ops.py137
-rw-r--r--tensorflow/contrib/data/python/ops/shuffle_ops.py56
-rw-r--r--tensorflow/contrib/data/python/ops/threadpool.py89
-rw-r--r--tensorflow/contrib/data/python/ops/unique.py44
-rw-r--r--tensorflow/contrib/data/python/ops/writers.py40
-rw-r--r--tensorflow/contrib/decision_trees/proto/BUILD1
-rw-r--r--tensorflow/contrib/distribute/README.md3
-rw-r--r--tensorflow/contrib/distribute/python/BUILD30
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py2
-rw-r--r--tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/keras_test.py88
-rw-r--r--tensorflow/contrib/distribute/python/metrics_v1_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/minimize_loss_test.py26
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy.py23
-rw-r--r--tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py14
-rw-r--r--tensorflow/contrib/distribute/python/monitor.py1
-rw-r--r--tensorflow/contrib/distribute/python/one_device_strategy.py17
-rw-r--r--tensorflow/contrib/distribute/python/optimizer_v2_test.py8
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy.py22
-rw-r--r--tensorflow/contrib/distribute/python/parameter_server_strategy_test.py3
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2.py232
-rw-r--r--tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py90
-rw-r--r--tensorflow/contrib/distribute/python/step_fn.py7
-rw-r--r--tensorflow/contrib/distribute/python/step_fn_test.py1
-rw-r--r--tensorflow/contrib/distribute/python/strategy_test_lib.py6
-rw-r--r--tensorflow/contrib/distribute/python/tpu_strategy.py191
-rw-r--r--tensorflow/contrib/distribute/python/values.py477
-rw-r--r--tensorflow/contrib/distribute/python/values_test.py31
-rw-r--r--tensorflow/contrib/eager/python/datasets.py4
-rw-r--r--tensorflow/contrib/eager/python/datasets_test.py6
-rw-r--r--tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py12
-rw-r--r--tensorflow/contrib/eager/python/remote_test.py5
-rw-r--r--tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py2
-rw-r--r--tensorflow/contrib/estimator/python/estimator/rnn_test.py2
-rw-r--r--tensorflow/contrib/factorization/python/ops/gmm_ops.py14
-rw-r--r--tensorflow/contrib/factorization/python/ops/wals_test.py16
-rw-r--r--tensorflow/contrib/fused_conv/BUILD2
-rw-r--r--tensorflow/contrib/ignite/BUILD139
-rw-r--r--tensorflow/contrib/ignite/README.md167
-rw-r--r--tensorflow/contrib/ignite/__init__.py42
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc334
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h81
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h126
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_client.h84
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset.cc81
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset.h63
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc422
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h99
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc198
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_plain_client.h43
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc123
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc142
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc151
-rw-r--r--tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h51
-rw-r--r--tensorflow/contrib/ignite/ops/dataset_ops.cc56
-rw-r--r--tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py772
-rw-r--r--tensorflow/contrib/ignite/python/ops/ignite_op_loader.py (renamed from tensorflow/contrib/data/python/ops/contrib_op_loader.py)2
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/bin/start-plain.sh24
-rw-r--r--tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml39
-rw-r--r--tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py118
-rw-r--r--tensorflow/contrib/ignite/python/tests/sql/init.sql20
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/start_ignite.sh22
-rwxr-xr-xtensorflow/contrib/ignite/python/tests/stop_ignite.sh19
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.cc2
-rw-r--r--tensorflow/contrib/image/kernels/image_ops.h7
-rw-r--r--tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc1
-rw-r--r--tensorflow/contrib/image/python/kernel_tests/image_ops_test.py9
-rw-r--r--tensorflow/contrib/lite/BUILD19
-rw-r--r--tensorflow/contrib/lite/delegates/flex/BUILD6
-rw-r--r--tensorflow/contrib/lite/examples/android/BUILD1
-rw-r--r--tensorflow/contrib/lite/g3doc/performance.md10
-rw-r--r--tensorflow/contrib/lite/java/aar_with_jni.bzl53
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/BUILD1
-rw-r--r--tensorflow/contrib/lite/java/ovic/BUILD61
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/BUILD6
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java77
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml27
-rw-r--r--tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml3
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/BoundingBox.java68
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java152
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassificationResult.java (renamed from tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java)12
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java10
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java142
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectionResult.java91
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetector.java184
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java160
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java2
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java6
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java149
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/testdata/BUILD5
-rw-r--r--tensorflow/contrib/lite/java/ovic/src/testdata/coco_labels.txt91
-rw-r--r--tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD1
-rw-r--r--tensorflow/contrib/lite/python/convert.py8
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/quantize.cc30
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc6
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc14
-rw-r--r--tensorflow/contrib/lite/toco/model_cmdline_flags.cc18
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc32
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py2
-rw-r--r--tensorflow/contrib/makefile/Makefile3
-rw-r--r--tensorflow/contrib/model_pruning/README.md1
-rw-r--r--tensorflow/contrib/opt/python/training/shampoo_test.py40
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py32
-rw-r--r--tensorflow/contrib/quantize/BUILD1
-rw-r--r--tensorflow/contrib/quantize/python/quant_ops.py28
-rw-r--r--tensorflow/contrib/stateless/BUILD8
-rw-r--r--tensorflow/contrib/stateless/__init__.py5
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/head_test.py2
-rw-r--r--tensorflow/contrib/tpu/BUILD6
-rw-r--r--tensorflow/contrib/tpu/__init__.py3
-rw-r--r--tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc138
-rw-r--r--tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc6
-rw-r--r--tensorflow/contrib/tpu/profiler/op_profile.proto4
-rw-r--r--tensorflow/contrib/tpu/profiler/pip_package/setup.py2
-rw-r--r--tensorflow/contrib/tpu/profiler/version.h2
-rw-r--r--tensorflow/contrib/tpu/proto/optimization_parameters.proto37
-rw-r--r--tensorflow/contrib/tpu/python/ops/tpu_ops.py175
-rw-r--r--tensorflow/contrib/tpu/python/tpu/async_checkpoint.py12
-rw-r--r--tensorflow/contrib/tpu/python/tpu/datasets.py4
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py289
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py55
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py11
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_context.py14
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py17
-rw-r--r--tensorflow/contrib/tpu/tpu_estimator.md2
-rw-r--r--tensorflow/contrib/training/BUILD3
-rw-r--r--tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py2
-rw-r--r--tensorflow/core/BUILD27
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt21
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt58
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt25
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt4
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt13
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt35
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt8
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt2
-rw-r--r--tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt8
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt4
-rw-r--r--tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt4
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc49
-rw-r--r--tensorflow/core/common_runtime/direct_session.h3
-rw-r--r--tensorflow/core/common_runtime/direct_session_test.cc28
-rw-r--r--tensorflow/core/common_runtime/eager/BUILD7
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.cc43
-rw-r--r--tensorflow/core/common_runtime/lower_if_op.h5
-rw-r--r--tensorflow/core/common_runtime/lower_if_op_test.cc4
-rw-r--r--tensorflow/core/framework/dataset.h12
-rw-r--r--tensorflow/core/framework/function.cc8
-rw-r--r--tensorflow/core/framework/function.h5
-rw-r--r--tensorflow/core/framework/model.cc83
-rw-r--r--tensorflow/core/framework/model.h42
-rw-r--r--tensorflow/core/framework/node_def_util.h1
-rw-r--r--tensorflow/core/framework/op.h20
-rw-r--r--tensorflow/core/framework/op_def_builder.cc24
-rw-r--r--tensorflow/core/framework/op_def_builder.h14
-rw-r--r--tensorflow/core/framework/resource_mgr.h11
-rw-r--r--tensorflow/core/framework/run_handler.cc249
-rw-r--r--tensorflow/core/framework/run_handler.h95
-rw-r--r--tensorflow/core/framework/run_handler_util.cc57
-rw-r--r--tensorflow/core/framework/run_handler_util.h43
-rw-r--r--tensorflow/core/framework/run_handler_util_test.cc93
-rw-r--r--tensorflow/core/graph/graph.cc4
-rw-r--r--tensorflow/core/graph/graph.h8
-rw-r--r--tensorflow/core/graph/node_builder.cc7
-rw-r--r--tensorflow/core/graph/node_builder.h4
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt7
-rw-r--r--tensorflow/core/grappler/grappler_item.cc1
-rw-r--r--tensorflow/core/grappler/grappler_item.h9
-rw-r--r--tensorflow/core/grappler/op_types.cc4
-rw-r--r--tensorflow/core/grappler/op_types.h1
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD2
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/data/BUILD7
-rw-r--r--tensorflow/core/grappler/optimizers/data/graph_utils.h4
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization.cc28
-rw-r--r--tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc112
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/BUILD3
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc29
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc36
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h23
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc16
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.cc451
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils.h35
-rw-r--r--tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc205
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc29
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc126
-rw-r--r--tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc8
-rw-r--r--tensorflow/core/grappler/utils/functions.cc55
-rw-r--r--tensorflow/core/grappler/utils/functions.h5
-rw-r--r--tensorflow/core/kernels/BUILD47
-rw-r--r--tensorflow/core/kernels/batching_util/BUILD15
-rw-r--r--tensorflow/core/kernels/collective_ops.cc2
-rw-r--r--tensorflow/core/kernels/data/BUILD1
-rw-r--r--tensorflow/core/kernels/data/experimental/BUILD (renamed from tensorflow/contrib/data/kernels/BUILD)90
-rw-r--r--tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/assert_next_dataset_op.cc)5
-rw-r--r--tensorflow/core/kernels/data/experimental/csv_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/csv_dataset_op.cc)3
-rw-r--r--tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc)5
-rw-r--r--tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc (renamed from tensorflow/contrib/data/kernels/identity_indexed_dataset.cc)7
-rw-r--r--tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc)6
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.cc (renamed from tensorflow/contrib/data/kernels/indexed_dataset.cc)14
-rw-r--r--tensorflow/core/kernels/data/experimental/indexed_dataset.h (renamed from tensorflow/contrib/data/kernels/indexed_dataset.h)6
-rw-r--r--tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/lmdb_dataset_op.cc)3
-rw-r--r--tensorflow/core/kernels/data/experimental/prefetching_kernels.cc (renamed from tensorflow/contrib/data/kernels/prefetching_kernels.cc)23
-rw-r--r--tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/threadpool_dataset_op.cc)7
-rw-r--r--tensorflow/core/kernels/data/experimental/unique_dataset_op.cc (renamed from tensorflow/contrib/data/kernels/unique_dataset_op.cc)7
-rw-r--r--tensorflow/core/kernels/data/map_and_batch_dataset_op.cc92
-rw-r--r--tensorflow/core/kernels/data/multi_device_iterator_ops.cc34
-rw-r--r--tensorflow/core/kernels/data/optimize_dataset_op.cc16
-rw-r--r--tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc89
-rw-r--r--tensorflow/core/kernels/data/parallel_map_iterator.cc119
-rw-r--r--tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc10
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_grad_input_ops.cc2
-rw-r--r--tensorflow/core/kernels/mkl_conv_ops.cc2
-rw-r--r--tensorflow/core/kernels/resource_variable_ops.cc18
-rw-r--r--tensorflow/core/kernels/slice_op.cc195
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc2
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt448
-rw-r--r--tensorflow/core/ops/experimental_dataset_ops.cc (renamed from tensorflow/contrib/data/ops/dataset_ops.cc)161
-rw-r--r--tensorflow/core/ops/functional_ops.cc23
-rw-r--r--tensorflow/core/ops/ops.pbtxt423
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_metadata_client.cc15
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_metadata_client.h10
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc6
-rw-r--r--tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc8
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc25
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.h7
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system_test.cc1280
-rw-r--r--tensorflow/core/platform/cloud/google_auth_provider_test.cc20
-rw-r--r--tensorflow/core/platform/cloud/retrying_file_system.h67
-rw-r--r--tensorflow/core/platform/cloud/retrying_file_system_test.cc102
-rw-r--r--tensorflow/core/platform/cloud/retrying_utils.cc35
-rw-r--r--tensorflow/core/platform/cloud/retrying_utils.h29
-rw-r--r--tensorflow/core/platform/cloud/retrying_utils_test.cc32
-rw-r--r--tensorflow/core/platform/default/build_config.bzl20
-rw-r--r--tensorflow/core/platform/posix/port.cc36
-rw-r--r--tensorflow/core/platform/windows/port.cc51
-rw-r--r--tensorflow/core/profiler/BUILD1
-rw-r--r--tensorflow/core/protobuf/config.proto5
-rw-r--r--tensorflow/core/protobuf/rewriter_config.proto4
-rw-r--r--tensorflow/core/util/mkl_util.h12
-rw-r--r--tensorflow/core/util/tensor_bundle/BUILD4
-rw-r--r--tensorflow/docs_src/BUILD14
-rw-r--r--tensorflow/docs_src/__init__.py0
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md2426
-rw-r--r--tensorflow/examples/get_started/regression/test.py2
-rw-r--r--tensorflow/go/op/wrappers.go1782
-rwxr-xr-xtensorflow/go/test.sh3
-rw-r--r--tensorflow/python/BUILD44
-rw-r--r--tensorflow/python/autograph/CONTRIBUTING.md2
-rw-r--r--tensorflow/python/autograph/converters/BUILD6
-rw-r--r--tensorflow/python/autograph/converters/function_scopes.py (renamed from tensorflow/python/autograph/converters/name_scopes.py)32
-rw-r--r--tensorflow/python/autograph/converters/function_scopes_test.py (renamed from tensorflow/python/autograph/converters/name_scopes_test.py)40
-rw-r--r--tensorflow/python/autograph/core/BUILD51
-rw-r--r--tensorflow/python/autograph/core/converter.py53
-rw-r--r--tensorflow/python/autograph/core/converter_test.py124
-rw-r--r--tensorflow/python/autograph/core/converter_testing.py2
-rw-r--r--tensorflow/python/autograph/core/function_wrapping.py30
-rw-r--r--tensorflow/python/autograph/core/function_wrapping_test.py34
-rw-r--r--tensorflow/python/autograph/impl/conversion.py6
-rw-r--r--tensorflow/python/autograph/lang/special_functions.py24
-rw-r--r--tensorflow/python/autograph/lang/special_functions_test.py37
-rw-r--r--tensorflow/python/autograph/operators/data_structures.py17
-rw-r--r--tensorflow/python/autograph/operators/data_structures_test.py31
-rw-r--r--tensorflow/python/compat/compat.py2
-rw-r--r--tensorflow/python/data/BUILD1
-rw-r--r--tensorflow/python/data/__init__.py1
-rw-r--r--tensorflow/python/data/experimental/BUILD16
-rw-r--r--tensorflow/python/data/experimental/__init__.py109
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/BUILD662
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py)324
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/bucketing_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/bucketing_test.py)11
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py)47
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py)5
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py)7
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py)33
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py)17
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py)5
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py)5
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py)9
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py)6
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/BUILD (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/BUILD)82
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py)5
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py)16
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py)13
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py)32
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py)11
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py)30
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py)27
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py)9
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py)50
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py)6
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py)9
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py)7
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py)7
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py (renamed from tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py)12
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/resample_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/resample_test.py)5
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py)5
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/BUILD (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/BUILD)242
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py692
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py)6
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py)15
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py85
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py)5
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py)2
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py (renamed from tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py)7
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py (renamed from tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py)4
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py)8
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py)5
-rw-r--r--tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py (renamed from tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py)5
-rw-r--r--tensorflow/python/data/experimental/ops/BUILD377
-rw-r--r--tensorflow/python/data/experimental/ops/batching.py669
-rw-r--r--tensorflow/python/data/experimental/ops/counter.py55
-rw-r--r--tensorflow/python/data/experimental/ops/enumerate_ops.py60
-rw-r--r--tensorflow/python/data/experimental/ops/error_ops.py78
-rw-r--r--tensorflow/python/data/experimental/ops/get_single_element.py72
-rw-r--r--tensorflow/python/data/experimental/ops/grouping.py551
-rw-r--r--tensorflow/python/data/experimental/ops/indexed_dataset_ops.py (renamed from tensorflow/contrib/data/python/ops/indexed_dataset_ops.py)25
-rw-r--r--tensorflow/python/data/experimental/ops/interleave_ops.py262
-rw-r--r--tensorflow/python/data/experimental/ops/iterator_ops.py268
-rw-r--r--tensorflow/python/data/experimental/ops/map_defun.py (renamed from tensorflow/contrib/data/python/ops/map_defun.py)0
-rw-r--r--tensorflow/python/data/experimental/ops/optimization.py (renamed from tensorflow/contrib/data/python/ops/optimization.py)68
-rw-r--r--tensorflow/python/data/experimental/ops/parsing_ops.py152
-rw-r--r--tensorflow/python/data/experimental/ops/prefetching_ops.py531
-rw-r--r--tensorflow/python/data/experimental/ops/random_ops.py54
-rw-r--r--tensorflow/python/data/experimental/ops/readers.py904
-rw-r--r--tensorflow/python/data/experimental/ops/resampling.py296
-rw-r--r--tensorflow/python/data/experimental/ops/scan_ops.py177
-rw-r--r--tensorflow/python/data/experimental/ops/shuffle_ops.py102
-rw-r--r--tensorflow/python/data/experimental/ops/stats_ops.py (renamed from tensorflow/contrib/data/python/ops/stats_ops.py)14
-rw-r--r--tensorflow/python/data/experimental/ops/threadpool.py104
-rw-r--r--tensorflow/python/data/experimental/ops/unique.py79
-rw-r--r--tensorflow/python/data/experimental/ops/writers.py60
-rw-r--r--tensorflow/python/data/kernel_tests/BUILD21
-rw-r--r--tensorflow/python/data/kernel_tests/dataset_ops_test.py158
-rw-r--r--tensorflow/python/data/kernel_tests/test_base.py80
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py272
-rw-r--r--tensorflow/python/data/ops/optional_ops.py4
-rw-r--r--tensorflow/python/data/ops/readers.py4
-rw-r--r--tensorflow/python/debug/examples/debug_tflearn_iris.py14
-rwxr-xr-xtensorflow/python/debug/examples/examples_test.sh2
-rw-r--r--tensorflow/python/distribute/distribute_coordinator.py4
-rw-r--r--tensorflow/python/distribute/estimator_training.py2
-rw-r--r--tensorflow/python/eager/BUILD1
-rw-r--r--tensorflow/python/eager/backprop.py2
-rw-r--r--tensorflow/python/eager/function.py148
-rw-r--r--tensorflow/python/eager/function_test.py41
-rwxr-xr-xtensorflow/python/eager/pywrap_tfe.h4
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc226
-rw-r--r--tensorflow/python/estimator/BUILD1
-rw-r--r--tensorflow/python/estimator/canned/dnn.py3
-rw-r--r--tensorflow/python/estimator/canned/dnn_linear_combined_test.py268
-rw-r--r--tensorflow/python/estimator/canned/linear.py83
-rw-r--r--tensorflow/python/estimator/canned/linear_test.py138
-rw-r--r--tensorflow/python/estimator/canned/linear_testing_utils.py184
-rw-r--r--tensorflow/python/estimator/estimator.py48
-rw-r--r--tensorflow/python/estimator/estimator_test.py94
-rw-r--r--tensorflow/python/estimator/keras.py48
-rw-r--r--tensorflow/python/estimator/keras_test.py6
-rw-r--r--tensorflow/python/estimator/util.py8
-rw-r--r--tensorflow/python/feature_column/BUILD2
-rw-r--r--tensorflow/python/feature_column/feature_column_v2.py100
-rw-r--r--tensorflow/python/feature_column/feature_column_v2_test.py4
-rw-r--r--tensorflow/python/framework/device.py12
-rw-r--r--tensorflow/python/framework/dtypes.py4
-rw-r--r--tensorflow/python/framework/errors_impl.py6
-rw-r--r--tensorflow/python/framework/function.py2
-rw-r--r--tensorflow/python/framework/graph_io.py2
-rw-r--r--tensorflow/python/framework/importer.py2
-rw-r--r--tensorflow/python/framework/random_seed.py6
-rw-r--r--tensorflow/python/framework/sparse_tensor.py4
-rw-r--r--tensorflow/python/framework/test_util.py84
-rw-r--r--tensorflow/python/keras/backend.py33
-rw-r--r--tensorflow/python/keras/engine/base_layer.py157
-rw-r--r--tensorflow/python/keras/engine/distributed_training_utils.py134
-rw-r--r--tensorflow/python/keras/engine/training.py48
-rw-r--r--tensorflow/python/keras/engine/training_distributed.py30
-rw-r--r--tensorflow/python/keras/engine/training_eager_test.py14
-rw-r--r--tensorflow/python/keras/engine/training_test.py12
-rw-r--r--tensorflow/python/keras/layers/cudnn_recurrent.py24
-rw-r--r--tensorflow/python/keras/layers/cudnn_recurrent_test.py27
-rw-r--r--tensorflow/python/keras/layers/embeddings.py10
-rw-r--r--tensorflow/python/keras/layers/recurrent.py65
-rw-r--r--tensorflow/python/keras/layers/recurrent_test.py90
-rw-r--r--tensorflow/python/keras/models.py5
-rw-r--r--tensorflow/python/keras/preprocessing/image_test.py37
-rw-r--r--tensorflow/python/kernel_tests/BUILD25
-rw-r--r--tensorflow/python/kernel_tests/array_ops_test.py8
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py201
-rw-r--r--tensorflow/python/kernel_tests/depthwise_conv_op_test.py4
-rw-r--r--tensorflow/python/kernel_tests/distributions/beta_test.py5
-rw-r--r--tensorflow/python/kernel_tests/distributions/dirichlet_test.py17
-rw-r--r--tensorflow/python/kernel_tests/distributions/exponential_test.py23
-rw-r--r--tensorflow/python/kernel_tests/distributions/gamma_test.py8
-rw-r--r--tensorflow/python/kernel_tests/rnn_test.py39
-rw-r--r--tensorflow/python/kernel_tests/variables_test.py42
-rw-r--r--tensorflow/python/layers/base.py16
-rw-r--r--tensorflow/python/layers/convolutional_test.py36
-rw-r--r--tensorflow/python/layers/core.py16
-rw-r--r--tensorflow/python/layers/core_test.py40
-rw-r--r--tensorflow/python/lib/io/tf_record.py13
-rw-r--r--tensorflow/python/ops/array_ops.py44
-rw-r--r--tensorflow/python/ops/candidate_sampling_ops.py8
-rw-r--r--tensorflow/python/ops/check_ops.py63
-rw-r--r--tensorflow/python/ops/clip_ops.py8
-rw-r--r--tensorflow/python/ops/confusion_matrix.py4
-rw-r--r--tensorflow/python/ops/control_flow_ops.py19
-rw-r--r--tensorflow/python/ops/conv2d_benchmark.py4
-rw-r--r--tensorflow/python/ops/data_flow_ops.py17
-rw-r--r--tensorflow/python/ops/distributions/BUILD7
-rw-r--r--tensorflow/python/ops/distributions/bernoulli.py9
-rw-r--r--tensorflow/python/ops/distributions/beta.py18
-rw-r--r--tensorflow/python/ops/distributions/categorical.py9
-rw-r--r--tensorflow/python/ops/distributions/dirichlet.py11
-rw-r--r--tensorflow/python/ops/distributions/dirichlet_multinomial.py9
-rw-r--r--tensorflow/python/ops/distributions/distribution.py17
-rw-r--r--tensorflow/python/ops/distributions/exponential.py16
-rw-r--r--tensorflow/python/ops/distributions/gamma.py16
-rw-r--r--tensorflow/python/ops/distributions/identity_bijector.py9
-rw-r--r--tensorflow/python/ops/distributions/kullback_leibler.py25
-rw-r--r--tensorflow/python/ops/distributions/laplace.py14
-rw-r--r--tensorflow/python/ops/distributions/multinomial.py9
-rw-r--r--tensorflow/python/ops/distributions/normal.py14
-rw-r--r--tensorflow/python/ops/distributions/special_math.py61
-rw-r--r--tensorflow/python/ops/distributions/student_t.py14
-rw-r--r--tensorflow/python/ops/distributions/transformed_distribution.py9
-rw-r--r--tensorflow/python/ops/distributions/uniform.py9
-rw-r--r--tensorflow/python/ops/init_ops.py5
-rw-r--r--tensorflow/python/ops/linalg_ops.py15
-rw-r--r--tensorflow/python/ops/lookup_ops.py2
-rw-r--r--tensorflow/python/ops/manip_ops.py4
-rw-r--r--tensorflow/python/ops/math_ops.py145
-rw-r--r--tensorflow/python/ops/nn_impl.py6
-rw-r--r--tensorflow/python/ops/nn_ops.py18
-rw-r--r--tensorflow/python/ops/numerics.py4
-rw-r--r--tensorflow/python/ops/parallel_for/pfor.py6
-rw-r--r--tensorflow/python/ops/parsing_ops.py18
-rw-r--r--tensorflow/python/ops/random_ops.py19
-rw-r--r--tensorflow/python/ops/rnn_cell_impl.py6
-rw-r--r--tensorflow/python/ops/sparse_ops.py107
-rw-r--r--tensorflow/python/ops/special_math_ops.py4
-rw-r--r--tensorflow/python/ops/string_ops.py32
-rw-r--r--tensorflow/python/ops/variable_scope.py3
-rw-r--r--tensorflow/python/ops/variables.py48
-rw-r--r--tensorflow/python/ops/while_v2.py63
-rwxr-xr-xtensorflow/python/pywrap_tfe.i1
-rw-r--r--tensorflow/python/saved_model/builder_impl.py7
-rw-r--r--tensorflow/python/saved_model/loader_impl.py8
-rw-r--r--tensorflow/python/saved_model/main_op_impl.py5
-rw-r--r--tensorflow/python/saved_model/signature_def_utils_impl.py27
-rw-r--r--tensorflow/python/saved_model/utils_impl.py10
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files.bzl2
-rw-r--r--tensorflow/python/tools/api/generator/api_init_files_v1.bzl2
-rw-r--r--tensorflow/python/tools/saved_model_cli.py19
-rw-r--r--tensorflow/python/training/distribute.py51
-rw-r--r--tensorflow/python/training/distribution_strategy_context.py2
-rw-r--r--tensorflow/python/training/input.py3
-rw-r--r--tensorflow/python/training/moving_averages.py9
-rw-r--r--tensorflow/python/training/moving_averages_test.py27
-rw-r--r--tensorflow/python/training/optimizer.py15
-rw-r--r--tensorflow/python/training/session_manager.py7
-rw-r--r--tensorflow/python/util/nest.py4
-rw-r--r--tensorflow/python/util/tf_inspect.py93
-rw-r--r--tensorflow/python/util/tf_inspect_test.py199
-rw-r--r--tensorflow/python/util/util.cc223
-rw-r--r--tensorflow/python/util/util.h34
-rw-r--r--tensorflow/python/util/util.i10
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.cc222
-rw-r--r--tensorflow/stream_executor/cuda/cuda_gpu_executor.h11
-rw-r--r--tensorflow/stream_executor/device_description.cc76
-rw-r--r--tensorflow/stream_executor/device_description.h64
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt24
-rw-r--r--tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt148
-rw-r--r--tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt46
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt58
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt58
-rw-r--r--tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt105
-rw-r--r--tensorflow/tools/api/golden/tensorflow.image.pbtxt251
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt268
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt289
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt55
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt268
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt289
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt57
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt30
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt135
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optional.pbtxt28
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt135
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-reducer.pbtxt21
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt135
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-stats-aggregator.pbtxt13
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-t-f-record-writer.pbtxt13
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt139
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.data.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.debugging.pbtxt96
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.dtypes.-d-type.pbtxt77
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.dtypes.pbtxt20
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.graph_util.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.-fixed-len-feature.pbtxt27
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.-fixed-len-sequence-feature.pbtxt31
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.-padding-f-i-f-o-queue.pbtxt66
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.-priority-queue.pbtxt66
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.-queue-base.pbtxt65
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.-random-shuffle-queue.pbtxt66
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.-sparse-feature.pbtxt35
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-compression-type.pbtxt20
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-options.pbtxt17
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-writer.pbtxt21
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.-var-len-feature.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt84
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt188
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.random.pbtxt47
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.saved_model.-builder.pbtxt21
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt44
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-conditional-accumulator.pbtxt46
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt54
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt112
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt6
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt57
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt30
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt135
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optional.pbtxt28
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt135
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-reducer.pbtxt21
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt135
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-stats-aggregator.pbtxt13
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-t-f-record-writer.pbtxt13
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt139
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.debugging.pbtxt96
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.dtypes.-d-type.pbtxt77
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.dtypes.pbtxt20
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.graph_util.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.-fixed-len-feature.pbtxt27
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.-fixed-len-sequence-feature.pbtxt31
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.-padding-f-i-f-o-queue.pbtxt66
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.-priority-queue.pbtxt66
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.-queue-base.pbtxt65
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.-random-shuffle-queue.pbtxt66
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.-sparse-feature.pbtxt35
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-compression-type.pbtxt20
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-options.pbtxt17
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-writer.pbtxt21
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.-var-len-feature.pbtxt19
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt84
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt188
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt12
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt202
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.pbtxt8
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt47
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.saved_model.-builder.pbtxt21
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt44
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-conditional-accumulator.pbtxt46
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt54
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt112
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt4
-rw-r--r--tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt4
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.cmake4
-rw-r--r--tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.042
-rwxr-xr-xtensorflow/tools/ci_build/builds/android.sh1
-rwxr-xr-xtensorflow/tools/ci_build/builds/run_pip_tests.sh1
-rwxr-xr-xtensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh2
-rwxr-xr-xtensorflow/tools/ci_build/install/install_pip_packages.sh8
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh4
-rwxr-xr-xtensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh4
-rw-r--r--tensorflow/tools/docker/Dockerfile4
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel4
-rw-r--r--tensorflow/tools/docker/Dockerfile.devel-gpu4
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.devel-mkl4
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.devel-mkl-horovod4
-rw-r--r--tensorflow/tools/docker/Dockerfile.gpu4
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.mkl4
-rwxr-xr-xtensorflow/tools/docker/Dockerfile.mkl-horovod4
-rw-r--r--tensorflow/tools/docs/BUILD1
-rw-r--r--tensorflow/tools/docs/build_docs_test.py6
-rw-r--r--tensorflow/tools/lib_package/BUILD16
-rw-r--r--tensorflow/tools/pip_package/BUILD14
-rw-r--r--tensorflow/tools/pip_package/setup.py7
-rwxr-xr-xtensorflow/workspace.bzl20
-rw-r--r--third_party/gpus/crosstool/BUILD.tpl14
-rw-r--r--third_party/gpus/cuda_configure.bzl33
-rw-r--r--third_party/jemalloc.BUILD356
-rw-r--r--third_party/nccl/nccl_configure.bzl13
-rw-r--r--third_party/nccl/system.BUILD.tpl4
-rw-r--r--third_party/py/python_configure.bzl4
-rw-r--r--third_party/systemlibs/jemalloc.BUILD30
-rw-r--r--third_party/systemlibs/syslibs_configure.bzl1
-rw-r--r--third_party/toolchains/BUILD4
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD2
-rwxr-xr-xthird_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD14
872 files changed, 31825 insertions, 14492 deletions
diff --git a/.bazelrc b/.bazelrc
index 9f09fdff97..d5d20309df 100644
--- a/.bazelrc
+++ b/.bazelrc
@@ -29,7 +29,7 @@ build:mkl -c opt
# This config option is used to enable MKL-DNN open source library only,
# without depending on MKL binary version.
-build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
+build:mkl_open_source_only --define=build_with_mkl_dnn_only=true
build:mkl_open_source_only --define=build_with_mkl=true --define=enable_mkl=true
build:download_clang --crosstool_top=@local_config_download_clang//:toolchain
diff --git a/configure.py b/configure.py
index 129d9c5fe7..2d2da11700 100644
--- a/configure.py
+++ b/configure.py
@@ -48,10 +48,13 @@ _SUPPORTED_ANDROID_NDK_VERSIONS = [10, 11, 12, 13, 14, 15, 16]
_DEFAULT_PROMPT_ASK_ATTEMPTS = 10
-_TF_WORKSPACE_ROOT = os.path.abspath(os.path.dirname(__file__))
_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc'
-_TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME)
-_TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE')
+_TF_WORKSPACE_ROOT = ''
+_TF_BAZELRC = ''
+
+NCCL_LIB_PATHS = [
+ 'lib64/', 'lib/powerpc64le-linux-gnu/', 'lib/x86_64-linux-gnu/', ''
+]
if platform.machine() == 'ppc64le':
_DEFAULT_TENSORRT_PATH_LINUX = '/usr/lib/powerpc64le-linux-gnu/'
@@ -224,7 +227,7 @@ def setup_python(environ_cp):
python_lib_path = default_python_lib_path
environ_cp['PYTHON_LIB_PATH'] = python_lib_path
- python_major_version = get_python_major_version(python_bin_path)
+ _ = get_python_major_version(python_bin_path)
# Convert python path to Windows style before writing into bazel.rc
if is_windows() or is_cygwin():
@@ -243,10 +246,10 @@ def setup_python(environ_cp):
f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
-def reset_tf_configure_bazelrc(workspace_path):
+def reset_tf_configure_bazelrc():
"""Reset file that contains customized config settings."""
open(_TF_BAZELRC, 'w').close()
- bazelrc_path = os.path.join(workspace_path, '.bazelrc')
+ bazelrc_path = os.path.join(_TF_WORKSPACE_ROOT, '.bazelrc')
data = []
if os.path.exists(bazelrc_path):
@@ -259,7 +262,6 @@ def reset_tf_configure_bazelrc(workspace_path):
f.write('%s\n' % l)
f.write('import %%workspace%%/%s\n' % _TF_BAZELRC_FILENAME)
-
def cleanup_makefile():
"""Delete any leftover BUILD files from the Makefile build.
@@ -881,7 +883,7 @@ def set_tf_cudnn_version(environ_cp):
"""Set CUDNN_INSTALL_PATH and TF_CUDNN_VERSION."""
ask_cudnn_version = (
'Please specify the cuDNN version you want to use. '
- '[Leave empty to default to cuDNN %s.0]: ') % _DEFAULT_CUDNN_VERSION
+ '[Leave empty to default to cuDNN %s]: ') % _DEFAULT_CUDNN_VERSION
for _ in range(_DEFAULT_PROMPT_ASK_ATTEMPTS):
tf_cudnn_version = get_from_env_or_user_or_default(
@@ -1038,7 +1040,7 @@ def set_tf_tensorrt_install_path(environ_cp):
for lib_file in possible_files:
if is_cuda_compatible(lib_file, cuda_ver, cudnn_ver):
matches = nvinfer_pattern.search(lib_file)
- if len(matches.groups()) == 0:
+ if not matches.groups():
continue
ver_str = matches.group(1)
ver = convert_version_to_int(ver_str) if len(ver_str) else 0
@@ -1094,7 +1096,7 @@ def set_tf_tensorrt_install_path(environ_cp):
def set_tf_nccl_install_path(environ_cp):
- """Set NCCL_INSTALL_PATH and TF_NCCL_VERSION.
+ """Set NCCL_INSTALL_PATH, NCCL_HDR_PATH and TF_NCCL_VERSION.
Args:
environ_cp: copy of the os.environ.
@@ -1120,46 +1122,107 @@ def set_tf_nccl_install_path(environ_cp):
if tf_nccl_version == '1':
break # No need to get install path, NCCL 1 is a GitHub repo.
- # TODO(csigg): Look with ldconfig first if we can find the library in paths
+ # Look with ldconfig first if we can find the library in paths
# like /usr/lib/x86_64-linux-gnu and the header file in the corresponding
# include directory. This is where the NCCL .deb packages install them.
- # Then ask the user if we should use that. Instead of a single
- # NCCL_INSTALL_PATH, pass separate NCCL_LIB_PATH and NCCL_HDR_PATH to
- # nccl_configure.bzl
- default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH')
- ask_nccl_path = (r'Please specify the location where NCCL %s library is '
- 'installed. Refer to README.md for more details. [Default '
- 'is %s]:') % (tf_nccl_version, default_nccl_path)
- nccl_install_path = get_from_env_or_user_or_default(
- environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path)
-
- # Result returned from "read" will be used unexpanded. That make "~"
- # unusable. Going through one more level of expansion to handle that.
- nccl_install_path = os.path.realpath(os.path.expanduser(nccl_install_path))
- if is_windows() or is_cygwin():
- nccl_install_path = cygpath(nccl_install_path)
- if is_windows():
- nccl_lib_path = 'lib/x64/nccl.lib'
- elif is_linux():
- nccl_lib_path = 'lib/libnccl.so.%s' % tf_nccl_version
- elif is_macos():
- nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version
-
- nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path)
- nccl_hdr_path = os.path.join(nccl_install_path, 'include/nccl.h')
- if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path):
- # Set NCCL_INSTALL_PATH
- environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path
- write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path)
- break
-
- # Reset and Retry
- print('Invalid path to NCCL %s toolkit, %s or %s not found. Please use the '
+ # First check to see if NCCL is in the ldconfig.
+ # If its found, use that location.
+ if is_linux():
+ ldconfig_bin = which('ldconfig') or '/sbin/ldconfig'
+ nccl2_path_from_ldconfig = run_shell([ldconfig_bin, '-p'])
+ nccl2_path_from_ldconfig = re.search('.*libnccl.so .* => (.*)',
+ nccl2_path_from_ldconfig)
+ if nccl2_path_from_ldconfig:
+ nccl2_path_from_ldconfig = nccl2_path_from_ldconfig.group(1)
+ if os.path.exists('%s.%s' % (nccl2_path_from_ldconfig, tf_nccl_version)):
+ nccl_install_path = os.path.dirname(nccl2_path_from_ldconfig)
+ print('NCCL libraries found in ' + nccl2_path_from_ldconfig)
+
+ # Check if this is the main system lib location
+ if re.search('.*linux-gnu', nccl_install_path):
+ trunc_nccl_install_path = '/usr'
+ print('This looks like a system path.')
+ else:
+ trunc_nccl_install_path = nccl_install_path + '/..'
+
+ # Look for header
+ nccl_hdr_path = trunc_nccl_install_path + '/include'
+ print('Assuming NCCL header path is ' + nccl_hdr_path)
+ if os.path.exists(nccl_hdr_path + '/nccl.h'):
+ # Set NCCL_INSTALL_PATH
+ environ_cp['NCCL_INSTALL_PATH'] = nccl_install_path
+ write_action_env_to_bazelrc('NCCL_INSTALL_PATH', nccl_install_path)
+
+ # Set NCCL_HDR_PATH
+ environ_cp['NCCL_HDR_PATH'] = nccl_hdr_path
+ write_action_env_to_bazelrc('NCCL_HDR_PATH', nccl_hdr_path)
+ break
+ else:
+ print(
+ 'The header for NCCL2 cannot be found. Please install the libnccl-dev package.'
+ )
+ else:
+ print('NCCL2 is listed by ldconfig but the library is not found. '
+ 'Your ldconfig is out of date. Please run sudo ldconfig.')
+ else:
+ # NCCL is not found in ldconfig. Ask the user for the location.
+ default_nccl_path = environ_cp.get('CUDA_TOOLKIT_PATH')
+ ask_nccl_path = (
+ r'Please specify the location where NCCL %s library is '
+ 'installed. Refer to README.md for more details. [Default '
+ 'is %s]:') % (tf_nccl_version, default_nccl_path)
+ nccl_install_path = get_from_env_or_user_or_default(
+ environ_cp, 'NCCL_INSTALL_PATH', ask_nccl_path, default_nccl_path)
+
+ # Result returned from "read" will be used unexpanded. That make "~"
+ # unusable. Going through one more level of expansion to handle that.
+ nccl_install_path = os.path.realpath(
+ os.path.expanduser(nccl_install_path))
+ if is_windows() or is_cygwin():
+ nccl_install_path = cygpath(nccl_install_path)
+
+ if is_windows():
+ nccl_lib_path = 'lib/x64/nccl.lib'
+ elif is_linux():
+ nccl_lib_filename = 'libnccl.so.%s' % tf_nccl_version
+ nccl_lpath = '%s/lib/%s' % (nccl_install_path, nccl_lib_filename)
+ if not os.path.exists(nccl_lpath):
+ for relative_path in NCCL_LIB_PATHS:
+ path = '%s/%s%s' % (nccl_install_path, relative_path,
+ nccl_lib_filename)
+ if os.path.exists(path):
+ print('NCCL found at ' + path)
+ nccl_lib_path = path
+ break
+ else:
+ nccl_lib_path = nccl_lpath
+ elif is_macos():
+ nccl_lib_path = 'lib/libnccl.%s.dylib' % tf_nccl_version
+
+ nccl_lib_path = os.path.join(nccl_install_path, nccl_lib_path)
+ nccl_hdr_path = os.path.join(
+ os.path.dirname(nccl_lib_path), '../include/nccl.h')
+ print('Assuming NCCL header path is ' + nccl_hdr_path)
+ if os.path.exists(nccl_lib_path) and os.path.exists(nccl_hdr_path):
+ # Set NCCL_INSTALL_PATH
+ environ_cp['NCCL_INSTALL_PATH'] = os.path.dirname(nccl_lib_path)
+ write_action_env_to_bazelrc('NCCL_INSTALL_PATH',
+ os.path.dirname(nccl_lib_path))
+
+ # Set NCCL_HDR_PATH
+ environ_cp['NCCL_HDR_PATH'] = os.path.dirname(nccl_hdr_path)
+ write_action_env_to_bazelrc('NCCL_HDR_PATH',
+ os.path.dirname(nccl_hdr_path))
+ break
+
+ # Reset and Retry
+ print(
+ 'Invalid path to NCCL %s toolkit, %s or %s not found. Please use the '
'O/S agnostic package of NCCL 2' % (tf_nccl_version, nccl_lib_path,
nccl_hdr_path))
- environ_cp['TF_NCCL_VERSION'] = ''
+ environ_cp['TF_NCCL_VERSION'] = ''
else:
raise UserInputError('Invalid TF_NCCL setting was provided %d '
'times in a row. Assuming to be a scripting mistake.' %
@@ -1406,7 +1469,7 @@ def set_other_mpi_vars(environ_cp):
def set_system_libs_flag(environ_cp):
syslibs = environ_cp.get('TF_SYSTEM_LIBS', '')
- if syslibs and syslibs != '':
+ if syslibs:
if ',' in syslibs:
syslibs = ','.join(sorted(syslibs.split(',')))
else:
@@ -1465,26 +1528,31 @@ def config_info_line(name, help_text):
def main():
+ global _TF_WORKSPACE_ROOT
+ global _TF_BAZELRC
+
parser = argparse.ArgumentParser()
parser.add_argument(
'--workspace',
type=str,
- default=_TF_WORKSPACE_ROOT,
+ default=os.path.abspath(os.path.dirname(__file__)),
help='The absolute path to your active Bazel workspace.')
args = parser.parse_args()
+ _TF_WORKSPACE_ROOT = args.workspace
+ _TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME)
+
# Make a copy of os.environ to be clear when functions and getting and setting
# environment variables.
environ_cp = dict(os.environ)
check_bazel_version('0.15.0')
- reset_tf_configure_bazelrc(args.workspace)
+ reset_tf_configure_bazelrc()
cleanup_makefile()
setup_python(environ_cp)
if is_windows():
- environ_cp['TF_NEED_JEMALLOC'] = '0'
environ_cp['TF_NEED_OPENCL_SYCL'] = '0'
environ_cp['TF_NEED_COMPUTECPP'] = '0'
environ_cp['TF_NEED_OPENCL'] = '0'
@@ -1498,8 +1566,8 @@ def main():
environ_cp['TF_SET_ANDROID_WORKSPACE'] = '0'
if is_macos():
- environ_cp['TF_NEED_JEMALLOC'] = '0'
environ_cp['TF_NEED_TENSORRT'] = '0'
+ environ_cp['TF_ENABLE_XLA'] = '0'
# The numpy package on ppc64le uses OpenBLAS which has multi-threading
# issues that lead to incorrect answers. Set OMP_NUM_THREADS=1 at
@@ -1508,11 +1576,10 @@ def main():
if is_ppc64le():
write_action_env_to_bazelrc('OMP_NUM_THREADS', 1)
- set_build_var(environ_cp, 'TF_NEED_JEMALLOC', 'jemalloc as malloc',
- 'with_jemalloc', True)
+ set_build_var(environ_cp, 'TF_NEED_IGNITE', 'Apache Ignite',
+ 'with_ignite_support', True, 'ignite')
set_build_var(environ_cp, 'TF_ENABLE_XLA', 'XLA JIT', 'with_xla_support',
- False, 'xla')
-
+ True, 'xla')
set_action_env_var(environ_cp, 'TF_NEED_OPENCL_SYCL', 'OpenCL SYCL', False)
if environ_cp.get('TF_NEED_OPENCL_SYCL') == '1':
@@ -1620,4 +1687,3 @@ def main():
if __name__ == '__main__':
main()
-
diff --git a/tensorflow/BUILD b/tensorflow/BUILD
index 5f73da68a2..9b62a50452 100644
--- a/tensorflow/BUILD
+++ b/tensorflow/BUILD
@@ -203,24 +203,15 @@ config_setting(
visibility = ["//visibility:public"],
)
-# TODO(jhseu): Enable on other platforms other than Linux.
config_setting(
- name = "with_jemalloc_linux_x86_64",
- define_values = {"with_jemalloc": "true"},
- values = {"cpu": "k8"},
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_jemalloc_linux_ppc64le",
- define_values = {"with_jemalloc": "true"},
- values = {"cpu": "ppc"},
+ name = "with_default_optimizations",
+ define_values = {"with_default_optimizations": "true"},
visibility = ["//visibility:public"],
)
config_setting(
- name = "with_default_optimizations",
- define_values = {"with_default_optimizations": "true"},
+ name = "with_ignite_support",
+ define_values = {"with_ignite_support": "true"},
visibility = ["//visibility:public"],
)
@@ -260,30 +251,6 @@ config_setting(
)
config_setting(
- name = "with_jemalloc_linux_x86_64_dynamic",
- define_values = {
- "with_jemalloc": "true",
- "framework_shared_object": "true",
- },
- values = {
- "cpu": "k8",
- },
- visibility = ["//visibility:public"],
-)
-
-config_setting(
- name = "with_jemalloc_linux_ppc64le_dynamic",
- define_values = {
- "with_jemalloc": "true",
- "framework_shared_object": "true",
- },
- values = {
- "cpu": "ppc",
- },
- visibility = ["//visibility:public"],
-)
-
-config_setting(
name = "using_cuda_clang",
define_values = {
"using_cuda_clang": "true",
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 0bf3d9542b..3554ec0bf3 100755
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -578,6 +578,14 @@ void TFE_OpSetAttrFunction(TFE_Op* op, const char* attr_name,
op->operation.MutableAttrs()->Set(attr_name, attr_value);
}
+void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
+ const char* data, size_t length) {
+ tensorflow::AttrValue attr_value;
+ tensorflow::NameAttrList* func = attr_value.mutable_func();
+ func->set_name(data, length);
+ op->operation.MutableAttrs()->Set(attr_name, attr_value);
+}
+
void TFE_OpSetAttrTensor(TFE_Op* op, const char* attr_name, TF_Tensor* tensor,
TF_Status* status) {
tensorflow::Tensor t;
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index 6323f8a053..b2454d8722 100755
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -313,6 +313,9 @@ TF_CAPI_EXPORT extern void TFE_OpSetAttrFunction(TFE_Op* op,
const char* attr_name,
const TFE_Op* value);
+TF_CAPI_EXPORT void TFE_OpSetAttrFunctionName(TFE_Op* op, const char* attr_name,
+ const char* data, size_t length);
+
TF_CAPI_EXPORT extern void TFE_OpSetAttrTensor(TFE_Op* op,
const char* attr_name,
TF_Tensor* tensor,
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index 41b5b8ff36..5ba55a203f 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -130,7 +130,7 @@ class GradientTape {
const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
- BackwardFunction* backward_function,
+ const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter);
void DeleteTrace(int64 tensor_id);
@@ -206,10 +206,9 @@ void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
const string& op_type, std::vector<TapeTensor>& output_tensors,
gtl::ArraySlice<int64> input_tensor_id,
gtl::ArraySlice<tensorflow::DataType> input_dtypes,
- BackwardFunction* backward_function,
+ const std::function<BackwardFunction*()>& backward_function_getter,
const std::function<void(BackwardFunction*)>& backward_function_deleter) {
if (!ShouldRecord(input_tensor_id, input_dtypes)) {
- backward_function_deleter(backward_function);
return;
}
std::vector<int64> ids;
@@ -229,7 +228,7 @@ void GradientTape<Gradient, BackwardFunction, TapeTensor>::RecordOperation(
tensors.push_back(o);
}
op_tape_[op_id] = OpTapeEntry<BackwardFunction, TapeTensor>{
- op_type, std::move(tensors), ids, backward_function,
+ op_type, std::move(tensors), std::move(ids), backward_function_getter(),
backward_function_deleter};
}
diff --git a/tensorflow/cc/framework/cc_op_gen.cc b/tensorflow/cc/framework/cc_op_gen.cc
index a32d1b1eb5..39593370d1 100644
--- a/tensorflow/cc/framework/cc_op_gen.cc
+++ b/tensorflow/cc/framework/cc_op_gen.cc
@@ -853,11 +853,7 @@ void OpInfo::WriteClassDecl(WritableFile* h) const {
}
}
- strings::StrAppend(&class_decl, "\n");
-
- if (output_types.empty()) {
- strings::StrAppend(&class_decl, " Operation operation;\n");
- }
+ strings::StrAppend(&class_decl, "\n Operation operation;\n");
for (int i = 0; i < output_types.size(); ++i) {
strings::StrAppend(&class_decl, " ", output_types[i], " ", output_names[i],
";\n");
@@ -878,9 +874,11 @@ void OpInfo::GetOutput(string* out) const {
string return_on_error =
strings::StrCat("if (!", scope_str, ".ok()) return;");
+ strings::StrAppend(out, " this->operation = Operation(ret);\n");
+
// No outputs.
if (graph_op_def.output_arg_size() == 0) {
- strings::StrAppend(out, " this->operation = Operation(ret);\n return;\n");
+ strings::StrAppend(out, " return;\n");
return;
}
if (graph_op_def.output_arg_size() == 1) {
diff --git a/tensorflow/cc/framework/scope.cc b/tensorflow/cc/framework/scope.cc
index 7f6ac4cae7..6abc9e268e 100644
--- a/tensorflow/cc/framework/scope.cc
+++ b/tensorflow/cc/framework/scope.cc
@@ -62,7 +62,7 @@ Scope::Impl::Impl(const std::shared_ptr<Graph>& graph,
refiner_(refiner),
scope_used_(nullptr),
colocation_constraints_(),
- disable_shape_inference_(false) {}
+ disable_shape_inference_(refiner_ == nullptr) {}
Scope Scope::NewRootScope() {
Graph* graph = new Graph(OpRegistry::Global());
@@ -94,6 +94,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ScopeName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -110,6 +111,7 @@ Scope::Impl::Impl(const Scope& other, Tags::OpName, const string& name,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -132,6 +134,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ControlDeps,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -163,6 +166,7 @@ Scope::Impl::Impl(const Scope& other, Tags::SingleUseScope,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -178,6 +182,7 @@ Scope::Impl::Impl(const Scope& other, Tags::ExitOnError)
exit_on_error_(true),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -194,6 +199,7 @@ Scope::Impl::Impl(const Scope& other, Tags::KernelLabel,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(kernel_label),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(other.impl()->colocation_constraints_),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
@@ -210,12 +216,30 @@ Scope::Impl::Impl(const Scope& other, Tags::Colocate,
exit_on_error_(other.impl()->exit_on_error_),
kernel_label_(other.impl()->kernel_label_),
device_(other.impl()->device_),
+ assigned_device_(other.impl()->assigned_device_),
colocation_constraints_(
clear_colocations
? std::unordered_set<string>()
: other.impl()->GetColocationConstraints(colocate_with_op)),
disable_shape_inference_(other.impl()->disable_shape_inference_) {}
+Scope::Impl::Impl(const Scope& other, Tags::AssignedDevice,
+ const string& assigned_device)
+ : graph_(other.impl()->graph_),
+ status_(other.impl()->status_),
+ name_map_(other.impl()->name_map_),
+ refiner_(other.impl()->refiner_),
+ scope_used_(other.impl()->scope_used_),
+ control_deps_(other.impl()->control_deps_),
+ name_(other.impl()->name_),
+ op_name_(other.impl()->op_name_),
+ exit_on_error_(other.impl()->exit_on_error_),
+ kernel_label_(other.impl()->kernel_label_),
+ device_(other.impl()->device_),
+ assigned_device_(assigned_device),
+ colocation_constraints_(other.impl()->colocation_constraints_),
+ disable_shape_inference_(other.impl()->disable_shape_inference_) {}
+
std::unordered_set<string> Scope::Impl::GetColocationConstraints(
const Operation& colocate_with_op) const {
std::unordered_set<string> current_constraints(colocation_constraints_);
@@ -299,6 +323,9 @@ void Scope::UpdateBuilder(NodeBuilder* builder) const {
if (!impl()->device_.empty()) {
builder->Device(impl()->device_);
}
+ if (!impl()->assigned_device_.empty()) {
+ builder->AssignedDevice(impl()->assigned_device_);
+ }
}
string Scope::Impl::GetUniqueName(const string& prefix,
@@ -394,6 +421,10 @@ Scope Scope::WithDevice(const string& device) const {
return Scope(new Impl(*this, Impl::Tags::Device(), device));
}
+Scope Scope::WithAssignedDevice(const string& assigned_device) const {
+ return Scope(new Impl(*this, Impl::Tags::AssignedDevice(), assigned_device));
+}
+
Scope Scope::ColocateWith(const Operation& op) const {
return Scope(new Impl(*this, Impl::Tags::Colocate(), op,
/* clear_colocations */ false));
diff --git a/tensorflow/cc/framework/scope.h b/tensorflow/cc/framework/scope.h
index 30c32bd44b..e307d8989b 100644
--- a/tensorflow/cc/framework/scope.h
+++ b/tensorflow/cc/framework/scope.h
@@ -133,6 +133,10 @@ class Scope {
/// the device field set to 'device'.
Scope WithDevice(const string& device) const;
+ /// Returns a new scope. All ops created within the returned scope will have
+ /// their assigned device set to `assigned_device`.
+ Scope WithAssignedDevice(const string& assigned_device) const;
+
/// Return a new scope. All ops created within the returned scope will be
/// co-located on the device where op is placed.
/// NOTE: This function is intended to be use internal libraries only for
diff --git a/tensorflow/cc/framework/scope_internal.h b/tensorflow/cc/framework/scope_internal.h
index 58adaef2e9..514e02e841 100644
--- a/tensorflow/cc/framework/scope_internal.h
+++ b/tensorflow/cc/framework/scope_internal.h
@@ -26,6 +26,8 @@ class ShapeRefiner;
// graph, status, name_map, and refiner.
// This is intended to enable the C API (which are used by other language
// bindings) to create a Scope and access C++ functionality (i.e. gradients).
+//
+// Shape inference is disabled if `refiner` is nullptr.
Scope NewInternalScope(Graph* graph, Status* status, ShapeRefiner* refiner);
class Scope::Impl {
@@ -58,6 +60,7 @@ class Scope::Impl {
enum class ExitOnError;
enum class KernelLabel;
enum class Colocate;
+ enum class AssignedDevice;
};
Impl(Graph* graph, Status* status, NameMap* name_map, ShapeRefiner* refiner,
@@ -74,6 +77,7 @@ class Scope::Impl {
Impl(const Scope& other, Tags::KernelLabel, const string& kernel_label);
Impl(const Scope& other, Tags::Colocate, const Operation& colocate_with_op,
bool clear_colocations);
+ Impl(const Scope& other, Tags::AssignedDevice, const string& assigned_device);
std::unordered_set<string> GetColocationConstraints(
const Operation& colocate_with_op) const;
@@ -107,6 +111,7 @@ class Scope::Impl {
const bool exit_on_error_ = false;
const string kernel_label_ = "";
const string device_ = "";
+ const string assigned_device_ = "";
const std::unordered_set<string> colocation_constraints_;
// If true, Scope::DoShapeInference() always returns Status:OK().
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 5bf4af1014..661b444a42 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -258,6 +258,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -323,6 +324,8 @@ cc_library(
"//tensorflow/core:graph",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -383,12 +386,16 @@ cc_library(
":shape_inference_helpers",
":union_find",
":xla_cluster_util",
+ "//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:ops",
+ "//tensorflow/cc:scope_internal",
"//tensorflow/compiler/jit/graphcycles",
"//tensorflow/compiler/jit/legacy_flags:mark_for_compilation_pass_flags",
"//tensorflow/compiler/jit/ops:xla_ops",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla:resource_operation_table",
"//tensorflow/compiler/tf2xla:xla_compiler",
+ "//tensorflow/compiler/tf2xla/cc:xla_jit_ops",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:core_cpu",
@@ -400,6 +407,8 @@ cc_library(
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:bounds_check",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
@@ -471,6 +480,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -509,6 +519,7 @@ tf_cc_test(
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/grappler/optimizers/data:graph_utils",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass.cc b/tensorflow/compiler/jit/build_xla_ops_pass.cc
index 9e3fd93cda..5974696b77 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass.cc
@@ -14,8 +14,12 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/build_xla_ops_pass.h"
+#include "absl/algorithm/container.h"
+#include "tensorflow/cc/framework/ops.h"
+#include "tensorflow/cc/framework/scope_internal.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
@@ -31,132 +35,108 @@ limitations under the License.
#include "tensorflow/core/public/version.h"
namespace tensorflow {
-
-static Status BuildXlaCompileNode(
- const string& nodename, const string& function_name,
- const AttrValueMap& function_attr, const string& device_name,
- const DataTypeVector& constant_dtypes, int num_resources,
- const DataTypeVector& arg_dtypes, Graph* graph, Node** node) {
- NodeDef def;
- def.set_name(graph->NewName(nodename));
- def.set_op("_XlaCompile");
- def.set_device(device_name);
- AddNodeAttr("Tconstants", constant_dtypes, &def);
- AddNodeAttr("Targs", arg_dtypes, &def);
- AddNodeAttr("Nresources", num_resources, &def);
- NameAttrList function;
- function.set_name(function_name);
- *function.mutable_attr() = function_attr;
- AddNodeAttr("function", function, &def);
-
- Status status;
- *node = graph->AddNode(def, &status);
- return status;
+namespace {
+void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
+ std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
+ old_node->out_edges().end());
+ for (const Edge* edge : out_edges) {
+ // TODO(sanjoy): This does not update NodeDef inputs. To be able to update
+ // NodeDef inputs we first need to fix encapsulate_subgraphs_pass to fix up
+ // the NodeDef inputs to the function call nodes.
+ g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input());
+ g->RemoveEdge(edge);
+ }
}
-static Status BuildXlaRunNode(const string& nodename, const string& device_name,
- const DataTypeVector& arg_dtypes,
- const DataTypeVector& result_dtypes, Graph* graph,
- Node** node) {
- NodeDef def;
- def.set_name(graph->NewName(nodename));
- def.set_op("_XlaRun");
- def.set_device(device_name);
- AddNodeAttr("Targs", arg_dtypes, &def);
- AddNodeAttr("Tresults", result_dtypes, &def);
+struct XlaClusterInfo {
+ std::vector<Output> constant_inputs;
+ std::vector<Output> non_constant_inputs;
+ std::vector<Output> resource_inputs;
+ NameAttrList function;
+};
- Status status;
- *node = graph->AddNode(def, &status);
- return status;
+Output IncomingEdgeAsOutput(const Edge* e) {
+ return Output(e->src(), e->src_output());
}
-static Status GetXlaAttrs(Node* node, int* num_constant_args,
- int* num_resource_args, DataTypeVector* const_dtypes,
- DataTypeVector* arg_dtypes) {
+Status GetXlaClusterInfo(Node* n, XlaClusterInfo* result) {
+ int num_constant_inputs, num_resource_inputs;
TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumConstantArgsAttr, num_constant_args));
+ GetNodeAttr(n->attrs(), kXlaNumConstantArgsAttr, &num_constant_inputs));
TF_RETURN_IF_ERROR(
- GetNodeAttr(node->attrs(), kXlaNumResourceArgsAttr, num_resource_args));
+ GetNodeAttr(n->attrs(), kXlaNumResourceArgsAttr, &num_resource_inputs));
- if (*num_constant_args < 0 || *num_resource_args < 0 ||
- *num_constant_args + *num_resource_args > node->num_inputs()) {
+ if (num_constant_inputs < 0 || num_resource_inputs < 0 ||
+ num_constant_inputs + num_resource_inputs > n->num_inputs()) {
return errors::InvalidArgument(
"Invalid number of constant/resource arguments to XLA kernel.");
}
- const int num_nonconst_args =
- node->num_inputs() - *num_constant_args - *num_resource_args;
-
- const DataTypeVector& input_types = node->input_types();
- std::copy(input_types.begin(), input_types.begin() + *num_constant_args,
- std::back_inserter(*const_dtypes));
- std::copy(input_types.begin() + *num_constant_args,
- input_types.begin() + *num_constant_args + num_nonconst_args,
- std::back_inserter(*arg_dtypes));
- return Status::OK();
-}
-
-static void CopyIncomingEdges(Graph* g, Node* old_node, Node* new_node,
- int prefix_to_ignore) {
- for (const Edge* edge : old_node->in_edges()) {
- if (edge->IsControlEdge()) {
- g->AddControlEdge(edge->src(), new_node);
- } else if (edge->dst_input() >= prefix_to_ignore) {
- g->AddEdge(edge->src(), edge->src_output(), new_node,
- edge->dst_input() - prefix_to_ignore);
- }
- }
-}
+ int num_non_constant_inputs =
+ n->num_inputs() - num_constant_inputs - num_resource_inputs;
-static void MoveOutgoingEdges(Graph* g, Node* old_node, Node* new_node) {
- std::vector<const Edge*> out_edges(old_node->out_edges().begin(),
- old_node->out_edges().end());
- for (const Edge* edge : out_edges) {
- // TODO(sanjoy): This does not update NodeDef inputs.
- g->AddEdge(new_node, edge->src_output(), edge->dst(), edge->dst_input());
- g->RemoveEdge(edge);
- }
-}
+ std::vector<const Edge*> input_edges_vector;
+ TF_RETURN_IF_ERROR(n->input_edges(&input_edges_vector));
+ absl::Span<const Edge*> input_edges(input_edges_vector);
-static Status ReplaceNodeWithXlaCompileAndRun(Graph* g, Node* n) {
- int num_constant_args, num_resource_args;
- DataTypeVector const_dtypes;
- DataTypeVector arg_dtypes;
+ absl::c_transform(input_edges.subspan(0, num_constant_inputs),
+ std::back_inserter(result->constant_inputs),
+ IncomingEdgeAsOutput);
- TF_RETURN_IF_ERROR(GetXlaAttrs(n, &num_constant_args, &num_resource_args,
- &const_dtypes, &arg_dtypes));
+ absl::c_transform(
+ input_edges.subspan(num_constant_inputs, num_non_constant_inputs),
+ std::back_inserter(result->non_constant_inputs), IncomingEdgeAsOutput);
- Node *compile_node, *run_node;
+ absl::c_transform(
+ input_edges.subspan(num_constant_inputs + num_non_constant_inputs,
+ num_resource_inputs),
+ std::back_inserter(result->resource_inputs), IncomingEdgeAsOutput);
- TF_RETURN_IF_ERROR(BuildXlaCompileNode(
- n->name(), n->type_string(), n->def().attr(), n->requested_device(),
- const_dtypes, num_resource_args, arg_dtypes, g, &compile_node));
+ result->function.set_name(n->type_string());
+ *result->function.mutable_attr() = n->def().attr();
+ return Status::OK();
+}
- DataTypeVector arg_dtypes_with_resources = arg_dtypes;
- for (int i = 0; i < num_resource_args; i++) {
- arg_dtypes_with_resources.push_back(DT_RESOURCE);
+Status CopyIncomingControlEdges(Graph* g, Node* from, Node* to) {
+ for (const Edge* e : from->in_edges()) {
+ if (e->IsControlEdge()) {
+ g->AddControlEdge(e->src(), to);
+ }
}
- TF_RETURN_IF_ERROR(BuildXlaRunNode(n->name(), n->requested_device(),
- arg_dtypes_with_resources,
- n->output_types(), g, &run_node));
-
- compile_node->set_assigned_device_name(n->assigned_device_name());
- run_node->set_assigned_device_name(n->assigned_device_name());
+ return Status::OK();
+}
- CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/compile_node,
- /*prefix_to_ignore=*/0);
- CopyIncomingEdges(g, /*old_node=*/n, /*new_node=*/run_node,
- /*prefix_to_ignore=*/num_constant_args);
+Status ReplaceNodeWithXlaCompileAndXlaRun(Graph* g, Node* n) {
+ Status status;
+ Scope root = NewInternalScope(g, &status, /*refiner=*/nullptr)
+ .NewSubScope(n->name())
+ .WithDevice(n->requested_device())
+ .WithAssignedDevice(n->assigned_device_name());
+
+ XlaClusterInfo cluster_info;
+ TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
+
+ ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
+ /*constants=*/cluster_info.constant_inputs,
+ /*args=*/cluster_info.non_constant_inputs,
+ /*resources=*/cluster_info.resource_inputs,
+ cluster_info.function);
+ TF_RETURN_IF_ERROR(
+ CopyIncomingControlEdges(g, /*from=*/n, /*to=*/xla_compile.key.node()));
- // The compilation_key output.
- g->AddEdge(compile_node, 0, run_node, n->num_inputs() - num_constant_args);
+ std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
+ absl::c_copy(cluster_info.resource_inputs, std::back_inserter(xla_run_args));
+ ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
+ xla_compile.key, n->output_types());
- MoveOutgoingEdges(g, /*old_node=*/n, /*new_node=*/run_node);
+ MoveOutgoingEdges(g, /*old_node=*/n,
+ /*new_node=*/xla_run.operation.node());
g->RemoveNode(n);
return Status::OK();
}
+} // namespace
Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
Graph* graph = options.graph->get();
@@ -170,7 +150,7 @@ Status BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
// Only compile nodes that are marked for compilation by the
// compilation-marking pass (via 'attr_name').
if (IsXlaCompiledKernel(*n)) {
- TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndRun(graph, n));
+ TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(graph, n));
}
}
diff --git a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
index b7cb4506b9..9d56db7b6b 100644
--- a/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
+++ b/tensorflow/compiler/jit/build_xla_ops_pass_test.cc
@@ -56,18 +56,26 @@ Status BuildXlaOps(const Scope& s, std::unique_ptr<Graph>* result) {
}
Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
- const string& node_name, Node** result) {
+ const string& node_name, int num_constant_args,
+ int num_resource_args, Node** result) {
NodeDef call_node;
call_node.set_name(node_name);
call_node.set_op(callee_name);
AddNodeAttr(kXlaCompiledKernelAttr, true, &call_node);
- AddNodeAttr(kXlaNumConstantArgsAttr, 0, &call_node);
- AddNodeAttr(kXlaNumResourceArgsAttr, 0, &call_node);
+ AddNodeAttr(kXlaNumConstantArgsAttr, num_constant_args, &call_node);
+ AddNodeAttr(kXlaNumResourceArgsAttr, num_resource_args, &call_node);
Status s;
*result = graph->AddNode(call_node, &s);
return s;
}
+Status MakeXlaCompiledKernel(Graph* graph, const string& callee_name,
+ const string& node_name, Node** result) {
+ return MakeXlaCompiledKernel(graph, callee_name, node_name,
+ /*num_constant_args=*/0, /*num_resource_args=*/0,
+ result);
+}
+
Node* MakeWrite(const Scope& scope, const string& id) {
Output var_handle =
ops::VarHandleOp(scope.WithOpName("Var" + id), DT_FLOAT, TensorShape({}));
@@ -108,5 +116,23 @@ TEST(BuildXlaOps, ControlDepsPreserved) {
EXPECT_THAT(write_op_new, NodeWith(CtrlDeps(NodeWith(Op("_XlaRun")))));
}
+TEST(BuildXlaOps, CleanFailureOnBogusAttr) {
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ FunctionDefLibrary flib_def =
+ CreateFunctionDefLibWithConstFunction("cluster_0");
+ TF_ASSERT_OK(root.graph()->AddFunctionLibrary(flib_def));
+ Node* call;
+ TF_ASSERT_OK(
+ MakeXlaCompiledKernel(root.graph(), "cluster_0", "C", 100, 100, &call));
+ Node* write_op = MakeWrite(root, "write");
+ root.graph()->AddControlEdge(call, write_op);
+
+ std::unique_ptr<Graph> graph;
+ Status failure_status = BuildXlaOps(root, &graph);
+ ASSERT_FALSE(failure_status.ok());
+ EXPECT_EQ(failure_status.code(), error::INVALID_ARGUMENT);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/jit/deadness_analysis.cc b/tensorflow/compiler/jit/deadness_analysis.cc
index 9128b48da3..e0b9932d80 100644
--- a/tensorflow/compiler/jit/deadness_analysis.cc
+++ b/tensorflow/compiler/jit/deadness_analysis.cc
@@ -14,11 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/deadness_analysis.h"
+#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/jit/deadness_analysis_internal.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
// ALGORITHM OVERVIEW
@@ -296,7 +298,7 @@ class SymbolPredicate : public Predicate {
template <typename FunctionTy>
/*static*/ void Predicate::Visit(Predicate* p, const FunctionTy& func) {
- gtl::FlatSet<Predicate*> visited;
+ absl::flat_hash_set<Predicate*> visited;
std::vector<Predicate*> stack;
stack.push_back(p);
@@ -383,6 +385,8 @@ class PredicateFactory {
}
Predicate* MakeAndOrImpl(absl::Span<Predicate* const> operands, bool is_and);
+ Predicate* MakeInternedAndOr(std::vector<Predicate*> simplified_ops,
+ Predicate::Kind pred_kind);
// Predicate instances are interned, meaning that there is only a single
// instance of a Predicate object with a given content. This makes checking
@@ -417,24 +421,53 @@ class PredicateFactory {
}
};
- gtl::FlatMap<SignatureForAndOr, std::unique_ptr<Predicate>,
- HashSignatureForAndOr>
+ absl::flat_hash_map<SignatureForAndOr, std::unique_ptr<Predicate>,
+ HashSignatureForAndOr>
interned_and_or_instances_;
- gtl::FlatMap<SignatureForNot, std::unique_ptr<Predicate>>
+ absl::flat_hash_map<SignatureForNot, std::unique_ptr<Predicate>>
interned_not_instances_;
- gtl::FlatMap<SignatureForAndRec, std::unique_ptr<Predicate>>
+ absl::flat_hash_map<SignatureForAndRec, std::unique_ptr<Predicate>>
interned_and_rec_instances_;
- gtl::FlatMap<SignatureForSymbol, std::unique_ptr<Predicate>,
- HashSignatureForSymbol>
+ absl::flat_hash_map<SignatureForSymbol, std::unique_ptr<Predicate>,
+ HashSignatureForSymbol>
interned_symbol_instances_;
};
+Predicate* PredicateFactory::MakeInternedAndOr(
+ std::vector<Predicate*> simplified_ops, Predicate::Kind pred_kind) {
+ std::stable_sort(
+ simplified_ops.begin(), simplified_ops.end(),
+ [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
+
+ auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
+ if (it != interned_and_or_instances_.end()) {
+ return it->second.get();
+ }
+
+ simplified_ops.shrink_to_fit();
+ // NB! Because we'll use a non-owning reference to simplified_ops in the
+ // key for interned_and_or_instances_ we need to be careful to std::move()
+ // it all the way through.
+ absl::Span<Predicate* const> operands_slice = simplified_ops;
+ std::unique_ptr<Predicate> new_pred =
+ pred_kind == Predicate::Kind::kAnd
+ ? Make<AndPredicate>(std::move(simplified_ops))
+ : Make<OrPredicate>(std::move(simplified_ops));
+
+ Predicate* new_pred_ptr = new_pred.get();
+ interned_and_or_instances_.emplace(
+ SignatureForAndOr(pred_kind, operands_slice), std::move(new_pred));
+ return new_pred_ptr;
+}
+
// Common code to create AndPredicate or OrPredicate instances.
Predicate* PredicateFactory::MakeAndOrImpl(
absl::Span<Predicate* const> operands, bool is_and) {
Predicate::Kind pred_kind =
is_and ? Predicate::Kind::kAnd : Predicate::Kind::kOr;
- gtl::FlatSet<Predicate*> simplified_ops_set;
+ Predicate::Kind other_pred_kind =
+ is_and ? Predicate::Kind::kOr : Predicate::Kind::kAnd;
+ absl::flat_hash_set<Predicate*> simplified_ops_set;
std::vector<Predicate*> simplified_ops;
for (Predicate* op : operands) {
// Simplify A&A => A and A|A => A.
@@ -459,7 +492,7 @@ Predicate* PredicateFactory::MakeAndOrImpl(
}
// Simplify "A&~A=>False" and "A|~A=>True".
- gtl::FlatSet<Predicate*> negated_ops;
+ absl::flat_hash_set<Predicate*> negated_ops;
for (Predicate* op : simplified_ops) {
if (op->kind() == Predicate::Kind::kNot) {
negated_ops.insert(dynamic_cast<NotPredicate&>(*op).operand());
@@ -472,30 +505,63 @@ Predicate* PredicateFactory::MakeAndOrImpl(
}
}
- std::stable_sort(
- simplified_ops.begin(), simplified_ops.end(),
- [](Predicate* a, Predicate* b) { return a->hash() < b->hash(); });
+ // If all ops contain the same subop, then factor it out thanks to the
+ // distributive property. Such as:
+ // - (A & B) | (A & C) | (A & D) => A & (B | C | D)
+ // - (A | B) & (A | C) & (A | D) => A | (B & C & D)
+ //
+ // First find any predicates contained in all subops.
+ std::vector<Predicate*> common_inner_operands;
+ absl::flat_hash_set<Predicate*> common_inner_operands_set;
+ for (Predicate* op : simplified_ops) {
+ if (op->kind() != other_pred_kind) {
+ common_inner_operands.clear();
+ break;
+ }
- auto it = interned_and_or_instances_.find({pred_kind, simplified_ops});
- if (it == interned_and_or_instances_.end()) {
- simplified_ops.shrink_to_fit();
- // NB! Because we'll use a non-owning reference to simplified_ops in the
- // key for interned_and_or_instances_ we need to be careful to std::move()
- // it all the way through.
- absl::Span<Predicate* const> operands_slice = simplified_ops;
- std::unique_ptr<Predicate> new_pred =
- is_and ? Make<AndPredicate>(std::move(simplified_ops))
- : Make<OrPredicate>(std::move(simplified_ops));
+ if (common_inner_operands.empty()) {
+ common_inner_operands.insert(common_inner_operands.end(),
+ op->GetOperands().begin(),
+ op->GetOperands().end());
+ } else {
+ std::vector<Predicate*> sub_ops_intersection;
+ common_inner_operands.clear();
+ absl::c_copy_if(op->GetOperands(),
+ std::back_inserter(common_inner_operands),
+ [&](Predicate* sub_op) {
+ return common_inner_operands_set.count(sub_op) == 1;
+ });
+ }
+ if (common_inner_operands.empty()) break;
+ common_inner_operands_set.clear();
+ common_inner_operands_set.insert(common_inner_operands.begin(),
+ common_inner_operands.end());
+ }
- Predicate* new_pred_ptr = new_pred.get();
- CHECK(interned_and_or_instances_
- .emplace(SignatureForAndOr(pred_kind, operands_slice),
- std::move(new_pred))
- .second);
- return new_pred_ptr;
- } else {
- return it->second.get();
+ if (common_inner_operands.empty()) {
+ return MakeInternedAndOr(std::move(simplified_ops), pred_kind);
}
+
+ // For all predicates that can be factored out, remove them and recreate the
+ // subops.
+ std::vector<Predicate*> factored_ops;
+ for (Predicate* op : simplified_ops) {
+ std::vector<Predicate*> new_sub_op_ops;
+ absl::c_copy_if(op->GetOperands(), std::back_inserter(new_sub_op_ops),
+ [&](Predicate* sub_op) {
+ return std::find(common_inner_operands.begin(),
+ common_inner_operands.end(),
+ sub_op) == common_inner_operands.end();
+ });
+ factored_ops.push_back(MakeAndOrImpl(new_sub_op_ops, !is_and));
+ }
+
+ Predicate* new_inner_op = MakeAndOrImpl(factored_ops, is_and);
+ std::vector<Predicate*> outer_ops;
+ outer_ops.push_back(new_inner_op);
+ outer_ops.insert(outer_ops.end(), common_inner_operands.begin(),
+ common_inner_operands.end());
+ return MakeAndOrImpl(outer_ops, !is_and);
}
class DeadnessAnalysisImpl : public DeadnessAnalysis {
@@ -507,7 +573,8 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
Status PopulateWithReversePostOrder(absl::Span<Node* const> rpo);
bool HasInputsWithMismatchingDeadness(const Node& node) override;
void Print() const override;
- gtl::FlatMap<TensorId, string, TensorId::Hasher> PredicateMapAsString() const;
+ absl::flat_hash_map<TensorId, string, TensorId::Hasher> PredicateMapAsString()
+ const;
private:
enum class EdgeKind { kDataAndControl, kDataOnly, kControlOnly };
@@ -549,7 +616,7 @@ class DeadnessAnalysisImpl : public DeadnessAnalysis {
Status HandleNode(Node* n, std::vector<bool>* should_revisit);
const Graph& graph_;
- gtl::FlatMap<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
+ absl::flat_hash_map<TensorId, Predicate*, TensorId::Hasher> predicate_map_;
PredicateFactory predicate_factory_;
bool vlog_;
};
@@ -912,9 +979,9 @@ DeadnessAnalysis::~DeadnessAnalysis() {}
return Status::OK();
}
-gtl::FlatMap<TensorId, string, TensorId::Hasher>
+absl::flat_hash_map<TensorId, string, TensorId::Hasher>
DeadnessAnalysisImpl::PredicateMapAsString() const {
- gtl::FlatMap<TensorId, string, TensorId::Hasher> result;
+ absl::flat_hash_map<TensorId, string, TensorId::Hasher> result;
std::vector<TensorId> tensor_ids;
for (const auto& kv_pair : predicate_map_) {
CHECK(result.insert({kv_pair.first, kv_pair.second->ToString()}).second);
diff --git a/tensorflow/compiler/jit/deadness_analysis_internal.h b/tensorflow/compiler/jit/deadness_analysis_internal.h
index 3df2679c62..354782374a 100644
--- a/tensorflow/compiler/jit/deadness_analysis_internal.h
+++ b/tensorflow/compiler/jit/deadness_analysis_internal.h
@@ -16,15 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
#define TENSORFLOW_COMPILER_JIT_DEADNESS_ANALYSIS_INTERNAL_H_
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace tensorflow {
namespace deadness_analysis_internal {
// Returns a map describing the predicate each Tensor was mapped to. For
// testing purposes only.
-using PredicateMapTy = gtl::FlatMap<TensorId, string, TensorId::Hasher>;
+using PredicateMapTy = absl::flat_hash_map<TensorId, string, TensorId::Hasher>;
Status ComputePredicates(const Graph& graph, PredicateMapTy* out_predicate_map);
// Returns a map describing the predicate each Tensor was mapped to. For
diff --git a/tensorflow/compiler/jit/deadness_analysis_test.cc b/tensorflow/compiler/jit/deadness_analysis_test.cc
index 28a56044d5..617e31488c 100644
--- a/tensorflow/compiler/jit/deadness_analysis_test.cc
+++ b/tensorflow/compiler/jit/deadness_analysis_test.cc
@@ -384,10 +384,31 @@ TEST(DeadnessAnalysisTest, OrOfAnd) {
EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add2.node()));
}
-TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) {
- // This demonstrates one of the weaknesses in the current approach -- since we
- // only do some basic simplifications we can't see that "(A|B)&C" ==
- // "(A&C)|(B&C)".
+TEST(DeadnessAnalysisTest, AndOrDistributiveSimplified) {
+ // (*A | (~*A & ((~*B & ~*A) | (~*A & *B)))) == #true
+ Scope root = Scope::NewRootScope().ExitOnError();
+
+ ops::Switch sw_0 = CreateSwitch(root, "A");
+ ops::Switch sw_1 = CreateSwitch(root, "B");
+ Output add0 =
+ ops::Add(root.WithOpName("and0"), sw_0.output_false, sw_1.output_true);
+ Output add1 =
+ ops::Add(root.WithOpName("and1"), sw_0.output_false, sw_1.output_false);
+ ops::Merge or2(root.WithOpName("or2"), {add0, add1});
+ Output add3 =
+ ops::Add(root.WithOpName("and3"), or2.output, sw_0.output_false);
+ ops::Merge or4(root.WithOpName("or4"), {add3, sw_0.output_true});
+
+ std::unique_ptr<DeadnessAnalysis> result;
+ TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
+
+ PredicateMapTy predicate_map;
+ TF_ASSERT_OK(ComputePredicates(*root.graph(), &predicate_map));
+ EXPECT_EQ(predicate_map[ControlOutputFor(or4.output)], "#true");
+}
+
+TEST(DeadnessAnalysisTest, AndOrDistributive) {
+ // (A|B)&C == (A&C)|(B&C)
Scope root = Scope::NewRootScope().ExitOnError();
ops::Switch sw_0 = CreateSwitch(root, "0");
@@ -408,7 +429,7 @@ TEST(DeadnessAnalysisTest, NEGATIVE_AndOrDistributive) {
std::unique_ptr<DeadnessAnalysis> result;
TF_ASSERT_OK(AnalyzeDeadness(root.graph(), &result));
- EXPECT_TRUE(result->HasInputsWithMismatchingDeadness(*add2.node()));
+ EXPECT_FALSE(result->HasInputsWithMismatchingDeadness(*add3.node()));
}
TEST(DeadnessAnalysisTest, Ternary) {
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index e0632ff7e4..da27f837e8 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
@@ -44,7 +45,6 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/public/session_options.h"
@@ -78,7 +78,8 @@ void SortControlInputs(GraphDef* gdef) {
namespace {
bool AreAllParentsGuaranteedConst(
- const Node& n, const gtl::FlatSet<const Node*>& runtime_const_nodes) {
+ const Node& n,
+ const absl::flat_hash_set<const Node*>& runtime_const_nodes) {
if (n.type_string() == "GuaranteeConst") {
// If the current node is itself a cast-to-const, no need
// to look at the incoming edges.
@@ -101,7 +102,7 @@ bool AreAllParentsGuaranteedConst(
void MarkGuaranteedConstants(
const Graph& graph,
const std::vector<std::pair<const Node*, Node*>>& src_arg_pairs) {
- gtl::FlatSet<const Node*> guaranteed_const_nodes;
+ absl::flat_hash_set<const Node*> guaranteed_const_nodes;
std::vector<const Node*> srcs;
srcs.reserve(src_arg_pairs.size());
for (const auto& src_arg : src_arg_pairs) {
@@ -748,6 +749,12 @@ Node* Encapsulator::Subgraph::MakeNodeImage(const Graph* graph_in, Node* node) {
graph_->set_versions(graph_in->versions());
}
+ // TODO(b/116981129): Enhance how the device for the encapsulated subgraph is
+ // determined. In case of hard placement, ensure all the encapsulated nodes
+ // have the same requested device, which in turn will be the requested device
+ // for the entire encapsulated subgraph. In case of soft placement, use a
+ // deterministic approach to fill in the requested device. Handle co-location
+ // constraints similarly if they exist.
if (device_.empty()) {
device_ = node->assigned_device_name().empty()
? node->requested_device()
@@ -1357,28 +1364,31 @@ void Encapsulator::Subgraph::GetOutsideCompilationSubgraphNames(
Status Encapsulator::GetFunctionNameAttr(
Node const* node, string* attr, string* outside_compilation_attr) const {
- Status s = GetNodeAttr(node->attrs(), group_attribute_, attr);
- if (s.code() == error::Code::NOT_FOUND) {
- // Return empty attr if there's no group_attribute.
- attr->clear();
- } else {
- TF_RETURN_IF_ERROR(s);
- }
- bool has_group_attr = s.ok();
- s = GetNodeAttr(node->attrs(), outside_compilation_attribute_,
- outside_compilation_attr);
- if (s.code() == error::Code::NOT_FOUND) {
- // Return empty attr if there's no outside_compilation attribute.
- outside_compilation_attr->clear();
- } else {
- TF_RETURN_IF_ERROR(s);
- if (!has_group_attr) {
- return errors::InvalidArgument(
- "Node ", node->name(), " has ", outside_compilation_attribute_,
- " attribute but no ", group_attribute_, " attribute.");
+ AttrSlice attrs = node->attrs();
+ attr->clear();
+ outside_compilation_attr->clear();
+ bool found_group_attribute = false;
+ bool found_outside_compilation_attribute = false;
+ for (const auto& node_attr : attrs) {
+ if (node_attr.first == group_attribute_) {
+ TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
+ *attr = node_attr.second.s();
+ found_group_attribute = true;
+ } else if (node_attr.first == outside_compilation_attribute_) {
+ TF_RETURN_IF_ERROR(AttrValueHasType(node_attr.second, "string"));
+ *outside_compilation_attr = node_attr.second.s();
+ found_outside_compilation_attribute = true;
}
+ if (found_group_attribute && found_outside_compilation_attribute) break;
+ }
+
+ if (found_outside_compilation_attribute && !found_group_attribute) {
+ return errors::InvalidArgument(
+ "Node ", node->name(), " has ", outside_compilation_attribute_,
+ " attribute but no ", group_attribute_, " attribute.");
+ } else {
+ return Status::OK();
}
- return Status::OK();
}
bool IsInSubgraph(const string& func_id, const string& outside_compilation_id) {
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
index 97ef8cd3cb..2ce6fa73fc 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass.cc
@@ -15,13 +15,13 @@ limitations under the License.
#include "tensorflow/compiler/jit/encapsulate_xla_computations_pass.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/lib/strings/proto_serialization.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -62,7 +62,7 @@ DataType EdgeType(const Edge* edge) {
}
// Adds the control inputs of `node` to `*deps`.
-void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+void AddControlInputs(const Node& node, absl::flat_hash_set<Node*>* deps) {
for (const Edge* edge : node.in_edges()) {
if (edge->IsControlEdge()) {
deps->insert(edge->src());
@@ -71,7 +71,7 @@ void AddControlInputs(const Node& node, gtl::FlatSet<Node*>* deps) {
}
// Adds the control outputs of `node` to `*deps`.
-void AddControlOutputs(const Node& node, gtl::FlatSet<Node*>* deps) {
+void AddControlOutputs(const Node& node, absl::flat_hash_set<Node*>* deps) {
for (const Edge* edge : node.out_edges()) {
if (edge->IsControlEdge()) {
deps->insert(edge->dst());
@@ -246,7 +246,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// Data and control inputs to the new XlaLaunch node.
std::vector<std::pair<Node*, int>> data_inputs(num_inputs);
- gtl::FlatSet<Node*> control_inputs;
+ absl::flat_hash_set<Node*> control_inputs;
DataTypeVector arg_types(num_args);
AddControlInputs(*launch, &control_inputs);
@@ -266,7 +266,7 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// Outputs.
const int num_outputs = launch->output_types().size();
- gtl::FlatSet<Node*> control_outputs;
+ absl::flat_hash_set<Node*> control_outputs;
std::vector<std::vector<std::pair<Node*, int>>> data_outputs(num_outputs);
DataTypeVector output_types(num_outputs);
@@ -297,7 +297,9 @@ Status RewriteSubgraph(const std::vector<OutputTensor>& arg_source_tensors,
// Target the XLA CPU/GPU backends.
VLOG(2) << "Replacing with XlaLaunch";
+ VLOG(2) << "Device is " << launch->requested_device();
def.set_op("XlaLaunch");
+ def.set_device(launch->requested_device());
AddNodeAttr("Tconstants", DataTypeVector{}, &def);
AddNodeAttr("Targs", arg_types, &def);
AddNodeAttr("Nresources", num_variables, &def);
diff --git a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
index f643fb0cfe..22531a4ace 100644
--- a/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_xla_computations_pass_test.cc
@@ -19,7 +19,7 @@ limitations under the License.
#include "tensorflow/cc/ops/resource_variable_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/compiler/jit/encapsulate_subgraphs_pass.h"
-#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_op.h"
+#include "tensorflow/compiler/tf2xla/cc/ops/xla_jit_ops.h"
#include "tensorflow/compiler/tf2xla/test_util.h"
#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/graph_constructor.h"
@@ -55,6 +55,7 @@ static std::unique_ptr<Graph> MakeOuterGraph(
.Input(u.node()->name(), 0, DT_RESOURCE)
.Input(v.node()->name(), 0, DT_RESOURCE)
.Input(w.node()->name(), 0, DT_RESOURCE)
+ .Device("/gpu:0")
.Attr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0")
.Attr("_variable_start_index", 4)
.Finalize(&def));
@@ -107,10 +108,11 @@ static std::unique_ptr<Graph> MakeBodyGraph() {
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ node->set_requested_device("/gpu:0");
};
auto b_identity = ops::Identity(scope.WithOpName("B_identity"), arg1);
-
+ add_attrs(b_identity.node());
auto read_u = ops::ReadVariableOp(scope.WithOpName("ReadU"), arg4, DT_FLOAT);
add_attrs(read_u.node());
auto read_v = ops::ReadVariableOp(scope.WithOpName("ReadV"), arg5, DT_FLOAT);
@@ -215,6 +217,7 @@ TEST(EncapsulateXlaComputations, Encapsulate) {
auto add_attrs = [](Node* node) {
node->AddAttr(EncapsulateXlaComputationsPass::kXlaClusterAttr, "launch0");
+ node->set_requested_device("/gpu:0");
};
auto b_identity = ops::Identity(scope.WithOpName("B_identity"), b);
@@ -317,8 +320,8 @@ TEST(EncapsulateXlaComputations, BuildXlaLaunchOp) {
NameAttrList function;
function.set_name("launch0");
auto launch = ops::XlaLaunch(
- scope.WithOpName("launch0"), std::initializer_list<Input>{},
- std::initializer_list<Input>{a, b, c, d},
+ scope.WithOpName("launch0").WithDevice("/gpu:0"),
+ std::initializer_list<Input>{}, std::initializer_list<Input>{a, b, c, d},
std::initializer_list<Input>{u, v, w},
DataTypeVector{DT_FLOAT, DT_INT32, DT_FLOAT, DT_FLOAT}, function);
diff --git a/tensorflow/compiler/jit/kernels/BUILD b/tensorflow/compiler/jit/kernels/BUILD
index 0839f1cb3d..26cb3af9d6 100644
--- a/tensorflow/compiler/jit/kernels/BUILD
+++ b/tensorflow/compiler/jit/kernels/BUILD
@@ -26,6 +26,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
],
alwayslink = 1,
diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc
index a85006eb03..cfd27a6510 100644
--- a/tensorflow/compiler/jit/kernels/xla_ops.cc
+++ b/tensorflow/compiler/jit/kernels/xla_ops.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/tf2xla/shape_util.h"
@@ -163,7 +164,7 @@ class XlaExecutableClosureStore {
private:
mutex mutex_;
int64 key_counter_ GUARDED_BY(mutex_);
- gtl::FlatMap<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
+ absl::flat_hash_map<KeyT, XlaExecutableClosure> closures_ GUARDED_BY(mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaExecutableClosureStore);
};
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 133d982360..4f0c370e65 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/jit/deadness_analysis.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
@@ -42,7 +43,6 @@ limitations under the License.
#include "tensorflow/core/graph/control_flow.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/public/version.h"
@@ -371,7 +371,7 @@ bool IsXlaFusable(const NodeDef& node) {
Status FindCompilationCandidates(
const Graph& graph, FunctionLibraryDefinition* flib_def, Env* env,
const std::function<bool(const Node*, const DeviceType&)>& is_compilable_fn,
- OrderedNodeSet* candidates, gtl::FlatSet<Node*>* isolated_nodes) {
+ OrderedNodeSet* candidates, absl::flat_hash_set<Node*>* isolated_nodes) {
OptimizerOptions opts;
std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
new ProcessFunctionLibraryRuntime(nullptr, env, TF_GRAPH_DEF_VERSION,
@@ -849,7 +849,7 @@ Status MarkForCompilationPass::RunImpl(
Graph* graph = options.graph->get();
OrderedNodeSet compilation_candidates;
- gtl::FlatSet<Node*> isolated_nodes;
+ absl::flat_hash_set<Node*> isolated_nodes;
TF_RETURN_IF_ERROR(FindCompilationCandidates(
*graph, options.flib_def,
(options.session_options != nullptr) ? options.session_options->env
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index 4f9145b479..2a80c745e3 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/mark_for_compilation_pass_test_helper.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "tensorflow/cc/framework/ops.h"
@@ -61,10 +62,10 @@ std::unordered_map<string, string> GetClusters(const Graph& graph) {
return ids;
}
-gtl::FlatMap<string, std::vector<string>> GetClusterSets(
+absl::flat_hash_map<string, std::vector<string>> GetClusterSets(
const Graph& g, std::vector<string>* cluster_names = nullptr) {
CHECK(cluster_names == nullptr || cluster_names->empty());
- gtl::FlatMap<string, std::vector<string>> cluster_sets;
+ absl::flat_hash_map<string, std::vector<string>> cluster_sets;
for (const auto& p : GetClusters(g)) {
cluster_sets[p.second].push_back(p.first);
}
@@ -566,7 +567,7 @@ TEST(XlaCompilationTest, ResourcesClusteringAllowed) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
- gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ absl::flat_hash_map<string, std::vector<string>> cluster_sets =
GetClusterSets(*graph);
ASSERT_EQ(cluster_sets.size(), 1);
std::vector<string> expected_clustered_nodes = {"AssignmentW", "ReadR",
@@ -586,7 +587,7 @@ TEST(XlaCompilationTest, ResourcesClusteringDisallowed) {
std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
TF_EXPECT_OK(root.ToGraph(graph.get()));
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
- gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ absl::flat_hash_map<string, std::vector<string>> cluster_sets =
GetClusterSets(*graph);
ASSERT_EQ(cluster_sets.size(), 1);
std::vector<string> expected_clustered_nodes = {"AssignmentW",
@@ -616,7 +617,7 @@ TEST(XlaCompilationTest, ChainOfOps) {
TF_ASSERT_OK(MarkForCompilationPassTestHelper::MarkForCompilation(&graph));
std::vector<string> cluster_names;
- gtl::FlatMap<string, std::vector<string>> cluster_sets =
+ absl::flat_hash_map<string, std::vector<string>> cluster_sets =
GetClusterSets(*graph, &cluster_names);
ASSERT_EQ(cluster_sets.size(), 2);
diff --git a/tensorflow/compiler/jit/partially_decluster_pass.cc b/tensorflow/compiler/jit/partially_decluster_pass.cc
index 10fc9e85d9..b1f9e9088f 100644
--- a/tensorflow/compiler/jit/partially_decluster_pass.cc
+++ b/tensorflow/compiler/jit/partially_decluster_pass.cc
@@ -15,17 +15,18 @@ limitations under the License.
#include "tensorflow/compiler/jit/partially_decluster_pass.h"
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/jit/xla_cluster_util.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/framework/memory_types.h"
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace tensorflow {
namespace {
-Status FindNodesToDecluster(const Graph& graph, gtl::FlatSet<Node*>* result,
+Status FindNodesToDecluster(const Graph& graph,
+ absl::flat_hash_set<Node*>* result,
absl::Span<Node* const> post_order) {
// Find nodes that have at least one user outside their cluster that expects
// hostmem output. These nodes should be cloned to outside the cluster to
@@ -171,7 +172,7 @@ Status PartiallyDeclusterToRemoveDeviceToHostCopies(Graph* graph) {
GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
/*edge_filter=*/NotBackedge);
- gtl::FlatSet<Node*> nodes_to_partially_decluster;
+ absl::flat_hash_set<Node*> nodes_to_partially_decluster;
TF_RETURN_IF_ERROR(
FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
diff --git a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
index 56e35c0059..e039d46ec8 100644
--- a/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
+++ b/tensorflow/compiler/jit/resource_operation_safety_analysis.cc
@@ -82,6 +82,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/resource_operation_safety_analysis.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
@@ -89,8 +90,6 @@ limitations under the License.
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/tensor_id.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/core/util/ptr_util.h"
@@ -177,7 +176,7 @@ string ResourceOpToString(const ResourceOp& resource_op) {
// point.
class ResourceOpSet {
private:
- using Impl = gtl::FlatSet<ResourceOp>;
+ using Impl = absl::flat_hash_set<ResourceOp>;
public:
ResourceOpSet() = default;
diff --git a/tensorflow/compiler/jit/xla_compilation_cache.h b/tensorflow/compiler/jit/xla_compilation_cache.h
index 10ad87e38c..17c0321c1e 100644
--- a/tensorflow/compiler/jit/xla_compilation_cache.h
+++ b/tensorflow/compiler/jit/xla_compilation_cache.h
@@ -16,6 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
#define TENSORFLOW_COMPILER_JIT_XLA_COMPILATION_CACHE_H_
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/tf2xla/xla_compiler.h"
#include "tensorflow/compiler/tf2xla/xla_context.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -24,7 +25,6 @@ limitations under the License.
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
@@ -152,7 +152,7 @@ class XlaCompilationCache : public ResourceBase {
};
mutex compile_cache_mu_;
- gtl::FlatMap<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
+ absl::flat_hash_map<Signature, std::unique_ptr<Entry>, Signature::Hash> cache_
GUARDED_BY(compile_cache_mu_);
struct CompileStats {
@@ -165,7 +165,7 @@ class XlaCompilationCache : public ResourceBase {
mutex compile_stats_mu_;
// Maps cluster names to compilation statistics for said cluster.
- gtl::FlatMap<string, CompileStats> compile_stats_
+ absl::flat_hash_map<string, CompileStats> compile_stats_
GUARDED_BY(compile_stats_mu_);
TF_DISALLOW_COPY_AND_ASSIGN(XlaCompilationCache);
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index af83c792e5..e083652978 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -75,8 +75,9 @@ XlaTransferManager::XlaTransferManager(
}
}
-Status XlaTransferManager::TransferLiteralToDevice(
- const Tensor& host_tensor, Tensor* device_tensor) const {
+Status XlaTransferManager::TransferLiteralToDevice(const Tensor& host_tensor,
+ Tensor* device_tensor,
+ bool buffer_is_fresh) const {
xla::Shape xla_shape;
TF_RETURN_IF_ERROR(TensorShapeToXLAShape(host_tensor.dtype(),
host_tensor.shape(), &xla_shape));
@@ -97,8 +98,11 @@ Status XlaTransferManager::TransferLiteralToDevice(
// synchronized.
host_to_device_stream_->ThenWaitFor(stream_.get());
}
+ xla::TransferManager::TransferToDeviceHint hint =
+ buffer_is_fresh ? xla::TransferManager::kBufferUndefined
+ : xla::TransferManager::kNoHint;
TF_RETURN_IF_ERROR(transfer_manager_->TransferLiteralToDeviceAsync(
- host_to_device_stream_.get(), *literal, shaped_buffer));
+ host_to_device_stream_.get(), *literal, shaped_buffer, hint));
if (UseMultipleStreams()) {
auto event = std::make_shared<se::Event>(stream_->parent());
TF_RET_CHECK(event->Init()) << "Event failed to initialize!";
@@ -165,6 +169,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
return;
}
TensorShape shape = shape_or_status.ValueOrDie();
+ bool buffer_is_fresh = false;
if (!xla_tensor->has_shaped_buffer()) {
Status s =
xla_tensor->AllocateShapedBuffer(device_tensor->dtype(), shape, client_,
@@ -173,6 +178,7 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
done(s);
return;
}
+ buffer_is_fresh = true;
}
Status status;
@@ -183,7 +189,8 @@ void XlaTransferManager::CopyCPUTensorToDevice(const Tensor* cpu_tensor,
"Tensor::CopyFrom failed when copying from CPU to XLA device"));
return;
}
- status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor);
+ status = TransferLiteralToDevice(reshaped_cpu_tensor, device_tensor,
+ buffer_is_fresh);
} else {
se::DeviceMemoryBase dev_dst_ptr =
XlaTensor::DeviceMemoryFromTensor(*device_tensor);
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index df82421294..a4c0c296fc 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -67,7 +67,8 @@ class XlaTransferManager {
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
- Tensor* device_tensor) const;
+ Tensor* device_tensor,
+ bool buffer_is_fresh) const;
void TransferLiteralFromDevice(Tensor* host_tensor,
const Tensor& device_tensor,
const StatusCallback& done) const;
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 3cf74fa788..822fedf121 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -1105,6 +1105,7 @@ cc_library(
"//tensorflow/core:test",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:ops_util",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/tests/randomized_tests.cc b/tensorflow/compiler/tests/randomized_tests.cc
index bddda6f302..7a96f4c25c 100644
--- a/tensorflow/compiler/tests/randomized_tests.cc
+++ b/tensorflow/compiler/tests/randomized_tests.cc
@@ -45,6 +45,7 @@ limitations under the License.
#include <random>
#include <unordered_map>
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/defs.h"
@@ -63,7 +64,6 @@ limitations under the License.
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
@@ -457,7 +457,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
Tensor tensor(dtype, TensorShape(shape));
switch (dtype) {
case DT_FLOAT: {
- gtl::FlatSet<float> already_generated;
+ absl::flat_hash_set<float> already_generated;
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
test::FillFn<float>(&tensor, [&](int i) -> float {
float generated;
@@ -470,7 +470,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_DOUBLE: {
- gtl::FlatSet<double> already_generated;
+ absl::flat_hash_set<double> already_generated;
std::uniform_real_distribution<double> distribution(-1.0, 1.0);
test::FillFn<double>(&tensor, [&](int i) -> double {
double generated;
@@ -483,7 +483,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_COMPLEX64: {
- gtl::FlatSet<std::pair<float, float>> already_generated;
+ absl::flat_hash_set<std::pair<float, float>> already_generated;
std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
test::FillFn<complex64>(&tensor, [&](int i) {
complex64 generated;
@@ -500,7 +500,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_INT32: {
- gtl::FlatSet<int32> already_generated;
+ absl::flat_hash_set<int32> already_generated;
std::uniform_int_distribution<int32> distribution(-(1 << 20), 1 << 20);
test::FillFn<int32>(&tensor, [&](int i) -> int32 {
int32 generated;
@@ -513,7 +513,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_INT64: {
- gtl::FlatSet<int64> already_generated;
+ absl::flat_hash_set<int64> already_generated;
std::uniform_int_distribution<int64> distribution(-(1LL << 40),
1LL << 40);
test::FillFn<int64>(&tensor, [&](int i) -> int64 {
@@ -527,7 +527,7 @@ Tensor OpTest::RandomTensor(DataType dtype, bool needs_unique_values,
break;
}
case DT_BOOL: {
- gtl::FlatSet<bool> already_generated;
+ absl::flat_hash_set<bool> already_generated;
std::bernoulli_distribution distribution;
test::FillFn<bool>(&tensor, [&](int i) -> bool {
bool generated;
diff --git a/tensorflow/compiler/tests/stateless_random_ops_test.py b/tensorflow/compiler/tests/stateless_random_ops_test.py
index f3861043b2..e8741bc468 100644
--- a/tensorflow/compiler/tests/stateless_random_ops_test.py
+++ b/tensorflow/compiler/tests/stateless_random_ops_test.py
@@ -91,7 +91,7 @@ class StatelessRandomOpsTest(xla_test.XLATestCase):
with self.cached_session() as sess, self.test_scope():
for dtype in self._random_types():
seed_t = array_ops.placeholder(dtypes.int32, shape=[2])
- x = stateless.stateless_random_uniform(
+ x = stateless.stateless_random_normal(
shape=[10000], seed=seed_t, dtype=dtype)
y = sess.run(x, {seed_t: [0x12345678, 0xabcdef12]})
self.assertTrue(np.all(np.isfinite(y)))
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index ba1e3b2b4f..3f631f91ec 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -635,6 +635,7 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ops",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
@@ -649,6 +650,7 @@ tf_cc_test(
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
diff --git a/tensorflow/compiler/tf2xla/cc/BUILD b/tensorflow/compiler/tf2xla/cc/BUILD
index ea8d1b3d14..adcdb6c8f7 100644
--- a/tensorflow/compiler/tf2xla/cc/BUILD
+++ b/tensorflow/compiler/tf2xla/cc/BUILD
@@ -30,14 +30,15 @@ cc_library(
tf_gen_op_wrapper_cc(
name = "xla_jit_op_gen",
- out_ops_file = "ops/xla_jit_op",
+ include_internal_ops = 1,
+ out_ops_file = "ops/xla_jit_ops",
deps = ["//tensorflow/compiler/jit/ops:xla_ops"],
)
cc_library(
name = "xla_jit_ops",
- srcs = ["ops/xla_jit_op.cc"],
- hdrs = ["ops/xla_jit_op.h"],
+ srcs = ["ops/xla_jit_ops.cc"],
+ hdrs = ["ops/xla_jit_ops.h"],
deps = [
"//tensorflow/cc:const_op",
"//tensorflow/cc:ops",
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 2d45507796..36c6f5d316 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -92,13 +92,51 @@ Status FunctionalizeControlFlowForFunction(
});
const FunctionBody* body = flr->GetFunctionBody(handle);
+ // Call graph optimizer. The most important optimization we need is constant
+ // folding, which will replace ops like Shape/BroadcastGradientArgs with
+ // constant shape input. Without this optimization, those ops might become
+ // dynamic input for then/else body function and XLA will complain that input
+ // is not compile time constant. We enable function inlining as well, because
+ // otherwise we won't be able to infer shape for any node depending on
+ // function call nodes.
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_before_opt_", func_name),
+ *body->graph, fld);
+ }
+ // Optimizer accepts std::unique_ptr<Graph>* as input and might change
+ // underlying pointer, thus we create a new Graph and copy from body->graph.
+ std::unique_ptr<Graph> optimized_graph(new Graph(fld));
+ CopyGraph(*body->graph, optimized_graph.get());
+ OptimizerOptions opts;
+ opts.set_opt_level(OptimizerOptions::L0);
+ opts.set_do_function_inlining(true);
+ opts.set_do_constant_folding(true);
+ GraphOptimizer optimizer(opts);
+ auto cf_consider_fn = [](const Node* n) {
+ // Skip SymbolicGradient op when doing constant folding.
+ // Enabling SymbolicGradient op in constant folding requires
+ // flr->device() to be non-null, and here we have not constructed
+ // proper Device object yet (it will be constructed in XlaCompiler).
+ return n->type_string() != FunctionLibraryDefinition::kGradientOp;
+ };
+ optimizer.Optimize(flr, flr->env(),
+ /*device=*/nullptr, &optimized_graph,
+ /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
+ cf_consider_fn);
+ if (VLOG_IS_ON(4)) {
+ dump_graph::DumpGraphToFile(
+ absl::StrCat("functionalize_control_flow_after_opt_", func_name),
+ *optimized_graph, fld);
+ }
+
// If any node has associated functions, functionalize them first.
// Gather nodes with associated functions first, because rewriting those nodes
// might involve node deletion/addition. Avoid modifying nodes while iterating
// it.
std::vector<std::pair<Node*, std::vector<AssociatedFunctionInfo>>>
nodes_to_associated_functions;
- for (auto* n : body->graph->nodes()) {
+ for (auto* n : optimized_graph->nodes()) {
auto associated_functions = GetAssociatedFunctions(*n, flr);
if (!associated_functions.empty()) {
nodes_to_associated_functions.push_back({n, associated_functions});
@@ -118,7 +156,14 @@ Status FunctionalizeControlFlowForFunction(
// but still rewrite the node.
new_name = iter->second;
} else {
- new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+ if (associated_function.type() ==
+ AssociatedFunctionInfo::AssociatedFunctionType::kSymbolicGradient) {
+ // For SymbolicGradient, `name` is always "SymbolicGradient",
+ // which is not very informative. Use node name instead.
+ new_name = fld->UniqueFunctionName(absl::StrCat(n->name(), "_f15n_"));
+ } else {
+ new_name = fld->UniqueFunctionName(absl::StrCat(name, "_f15n_"));
+ }
TF_RETURN_IF_ERROR(FunctionalizeControlFlowForFunction(
name, new_name, associated_function.attrs(), fld, flr,
canonicalized_name_to_new_name));
@@ -129,43 +174,10 @@ Status FunctionalizeControlFlowForFunction(
// That's fine because in that case, associated_functions will only have
// one member and the loop will only run once.
TF_RETURN_IF_ERROR(RewriteAssociatedFunction(
- body->graph, n, fld, associated_function, new_name));
+ optimized_graph.get(), n, fld, associated_function, new_name));
}
}
- // Call graph optimizer. The most important optimization we need is constant
- // folding, which will replace ops like Shape/BroadcastGradientArgs with
- // constant shape input. Without this optimization, those ops might become
- // dynamic input for then/else body function and XLA will complain that input
- // is not compile time constant. We enable function inlining as well, because
- // otherwise we won't be able to infer shape for any node depending on
- // function call nodes.
- if (VLOG_IS_ON(4)) {
- dump_graph::DumpGraphToFile(
- absl::StrCat("functionalize_control_flow_before_opt_", func_name),
- *body->graph, fld);
- }
- // Optimizer accepts std::unique_ptr<Graph>* as input and might change
- // underlying pointer, thus we create a new Graph and copy from body->graph.
- std::unique_ptr<Graph> optimized_graph(new Graph(fld));
- CopyGraph(*body->graph, optimized_graph.get());
- OptimizerOptions opts;
- opts.set_opt_level(OptimizerOptions::L0);
- opts.set_do_function_inlining(true);
- opts.set_do_constant_folding(true);
- GraphOptimizer optimizer(opts);
- auto cf_consider_fn = [](const Node* n) {
- // Skip SymbolicGradient op when doing constant folding.
- // Enabling SymbolicGradient op in constant folding requires
- // flr->device() to be non-null, and here we have not constructed
- // proper Device object yet (it will be constructed in XlaCompiler).
- return n->type_string() != FunctionLibraryDefinition::kGradientOp;
- };
- optimizer.Optimize(flr, flr->env(),
- /*device=*/nullptr, &optimized_graph,
- /*shape_map=*/nullptr, /*cse_consider_fn=*/nullptr,
- cf_consider_fn);
-
// Functionalize the function body.
if (VLOG_IS_ON(4)) {
dump_graph::DumpGraphToFile(
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table.cc b/tensorflow/compiler/tf2xla/resource_operation_table.cc
index 20f2ce2919..72b240996f 100644
--- a/tensorflow/compiler/tf2xla/resource_operation_table.cc
+++ b/tensorflow/compiler/tf2xla/resource_operation_table.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "absl/algorithm/container.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
+#include "absl/container/flat_hash_map.h"
namespace tensorflow {
/*static*/ absl::string_view XlaResourceOpInfo::XlaResourceOpKindToString(
@@ -30,9 +30,9 @@ namespace tensorflow {
}
}
-static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>*
+static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
CreateResourceOpInfoMap() {
- auto* result = new gtl::FlatMap<absl::string_view, XlaResourceOpInfo>;
+ auto* result = new absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>;
auto add = [&](absl::string_view op, XlaResourceOpKind op_kind,
XlaResourceKind resource_kind) {
@@ -103,15 +103,15 @@ CreateResourceOpInfoMap() {
return result;
}
-static const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>&
+static const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>&
GetStaticResourceOpInfoMap() {
- static gtl::FlatMap<absl::string_view, XlaResourceOpInfo>* op_info_map =
- CreateResourceOpInfoMap();
+ static absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>*
+ op_info_map = CreateResourceOpInfoMap();
return *op_info_map;
}
const XlaResourceOpInfo* GetResourceOpInfoForOp(absl::string_view op) {
- const gtl::FlatMap<absl::string_view, XlaResourceOpInfo>& op_infos =
+ const absl::flat_hash_map<absl::string_view, XlaResourceOpInfo>& op_infos =
GetStaticResourceOpInfoMap();
auto it = op_infos.find(op);
return it == op_infos.end() ? nullptr : &it->second;
diff --git a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc
index a85ef040a7..956f597301 100644
--- a/tensorflow/compiler/tf2xla/resource_operation_table_test.cc
+++ b/tensorflow/compiler/tf2xla/resource_operation_table_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/tf2xla/resource_operation_table.h"
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/lib/core/status_test_util.h"
@@ -33,7 +34,7 @@ bool HasResourceInputOrOutput(const OpDef& op_def) {
}
TEST(ResourceOperationTableTest, HaveAllResourceOps) {
- gtl::FlatMap<string, bool> known_resource_ops;
+ absl::flat_hash_map<string, bool> known_resource_ops;
for (absl::string_view known_resource_op :
resource_op_table_internal::GetKnownResourceOps()) {
ASSERT_TRUE(
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.cc b/tensorflow/compiler/tf2xla/tf2xla_util.cc
index d6f42bac86..01dd3ba10f 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.cc
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.cc
@@ -336,9 +336,9 @@ bool HasAssociatedFunction(const NodeDef& node_def,
}
if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
- // Skip gradient op. Gradient op has "f" attr, which is set to the function
- // we are getting gradient for. That function is not associated with the op.
- return false;
+ // Gradient op has "f" attr, which is set to the function we are getting
+ // gradient for. We need to functionalize the gradient function.
+ return true;
}
for (const auto& iter : node_def.attr()) {
@@ -357,17 +357,18 @@ std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
if (flr->GetFunctionLibraryDefinition()->Contains(op)) {
// This is a function call node.
AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
- results.emplace_back(AssociatedFunctionInfo(op, attrs));
+ results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
} else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
- // Skip gradient op. Gradient op has "f" attr, which is set to the function
- // we are getting gradient for. That function is not associated with the op.
+ // This is a SymbolicGradient op.
+ AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
+ results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
} else {
// Collect all function attrs for the node.
for (auto& iter : node.attrs()) {
if (iter.second.has_func()) {
VLOG(2) << "Found function attr for node " << node.name() << ": "
<< iter.first << " = " << iter.second.func().name();
- results.emplace_back(AssociatedFunctionInfo(
+ results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
iter.second.func().name(), iter.second.func().attr(), iter.first));
}
}
@@ -410,6 +411,21 @@ Status RewriteAssociatedFunction(
graph->RemoveNode(node);
break;
}
+ case AssociatedFunctionInfo::kSymbolicGradient: {
+ NameAttrList func;
+ TF_RETURN_IF_ERROR(GetNodeAttr(
+ node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
+ GradientDef gradient_def;
+ gradient_def.set_function_name(func.name());
+ gradient_def.set_gradient_func(rewritten_function_name);
+ string original_grad_func = fld->FindGradient(func.name());
+ if (original_grad_func.empty()) {
+ TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
+ } else if (original_grad_func != rewritten_function_name) {
+ TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
+ }
+ break;
+ }
case AssociatedFunctionInfo::kFunctionAttr: {
// Change function attr to rewritten functions.
NameAttrList func;
diff --git a/tensorflow/compiler/tf2xla/tf2xla_util.h b/tensorflow/compiler/tf2xla/tf2xla_util.h
index 6065d0bb9a..53eab8b63e 100644
--- a/tensorflow/compiler/tf2xla/tf2xla_util.h
+++ b/tensorflow/compiler/tf2xla/tf2xla_util.h
@@ -65,21 +65,33 @@ uint32 GetXLARandomSeed();
class AssociatedFunctionInfo {
public:
enum AssociatedFunctionType {
- kFunctionCallNode = 0,
- kFunctionAttr = 1,
+ kFunctionAttr = 0,
+ kFunctionCallNode = 1,
+ kSymbolicGradient = 2,
};
- // The node is a function call.
- AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs)
- : type_(kFunctionCallNode), func_name_(func_name), attrs_(attrs) {}
-
// The function is an attr of the node.
- AssociatedFunctionInfo(const string& func_name, const AttrValueMap& attrs,
- const string& attr_name)
- : type_(kFunctionAttr),
- func_name_(func_name),
- attrs_(attrs),
- attr_name_(attr_name) {}
+ static AssociatedFunctionInfo FunctionAttr(const string& func_name,
+ const AttrValueMap& attrs,
+ const string& attr_name) {
+ return AssociatedFunctionInfo(kFunctionAttr, func_name, attrs, attr_name);
+ }
+
+ // The node is a function call.
+ static AssociatedFunctionInfo FunctionCall(const string& func_name,
+ const AttrValueMap& attrs) {
+ // attr_name will not be used in this case.
+ return AssociatedFunctionInfo(kFunctionCallNode, func_name, attrs,
+ /*attr_name=*/"");
+ }
+
+ // The node is a SymbolicGradient op.
+ static AssociatedFunctionInfo SymbolicGradient(const string& func_name,
+ const AttrValueMap& attrs) {
+ // attr_name will not be used in this case.
+ return AssociatedFunctionInfo(kSymbolicGradient, func_name, attrs,
+ /*attr_name=*/"");
+ }
AssociatedFunctionType type() const { return type_; }
@@ -90,6 +102,13 @@ class AssociatedFunctionInfo {
const AttrValueMap& attrs() const { return attrs_; }
private:
+ AssociatedFunctionInfo(AssociatedFunctionType type, const string& func_name,
+ const AttrValueMap& attrs, const string& attr_name)
+ : type_(type),
+ func_name_(func_name),
+ attrs_(attrs),
+ attr_name_(attr_name) {}
+
// Available for all instances.
AssociatedFunctionType type_;
string func_name_;
@@ -105,14 +124,18 @@ bool HasAssociatedFunction(const NodeDef& node_def,
// Gets functions associated with the node. Current cases:
// 1. For function call node, its function name;
-// 2. For nodes like XlaWhile/XlaIf, all their function attributes.
+// 2. For SymbolicGradient op, returned func_name will be "SymbolicGradient",
+// and returned attrs will be this node's attributes;
+// 3. For nodes like XlaWhile/XlaIf, all their function attributes.
std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
const Node& node, FunctionLibraryRuntime* flr);
// Changes associated functions for the node. Current cases:
// 1. For function call node, creates a new node with the new function name and
// remove the old node;
-// 2. For nodes like XlaWhile/XlaIf, modify their function attributes.
+// 2. For SymbolicGradient op, add or replace GradientDef in
+// FunctionLibraryDefinition;
+// 3. For nodes like XlaWhile/XlaIf, modify their function attributes.
Status RewriteAssociatedFunction(
Graph* graph, Node* node, FunctionLibraryDefinition* fld,
const AssociatedFunctionInfo& associated_function,
diff --git a/tensorflow/compiler/tf2xla/type_util.h b/tensorflow/compiler/tf2xla/type_util.h
index bda667eb1f..6354216eee 100644
--- a/tensorflow/compiler/tf2xla/type_util.h
+++ b/tensorflow/compiler/tf2xla/type_util.h
@@ -25,6 +25,14 @@ namespace tensorflow {
// Converts a Tensorflow DataType to an XLA PrimitiveType.
Status DataTypeToPrimitiveType(DataType data_type, xla::PrimitiveType* type);
+// N.B.: there is intentionally no function to convert an XLA PrimitiveType to
+// a TensorFlow DataType. The mapping from TF types to XLA types is not
+// one-to-one: for example, both DT_INT8 and DT_QINT8 map to xla::S8. So the
+// inverse would not be a well-defined function. If you find that you want the
+// inverse mapping, then most likely you should be preserving the original
+// TensorFlow type, rather than trying to convert an XLA type into a TensorFlow
+// type.
+
} // namespace tensorflow
#endif // TENSORFLOW_COMPILER_TF2XLA_TYPE_UTIL_H_
diff --git a/tensorflow/compiler/xla/client/BUILD b/tensorflow/compiler/xla/client/BUILD
index f825f67b44..dc097f3696 100644
--- a/tensorflow/compiler/xla/client/BUILD
+++ b/tensorflow/compiler/xla/client/BUILD
@@ -220,6 +220,8 @@ cc_library(
"//tensorflow/compiler/xla/service:shape_inference",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
diff --git a/tensorflow/compiler/xla/client/xla_builder.cc b/tensorflow/compiler/xla/client/xla_builder.cc
index 5277de6a85..e0ec91dba1 100644
--- a/tensorflow/compiler/xla/client/xla_builder.cc
+++ b/tensorflow/compiler/xla/client/xla_builder.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla {
@@ -2290,7 +2290,7 @@ StatusOr<XlaComputation> XlaBuilder::BuildConstantSubGraph(
// also a valid dependency order). The related ops will be added to the
// subgraph in the same order.
std::set<int64> related_ops;
- tensorflow::gtl::FlatSet<int64> related_calls; // Related computations.
+ absl::flat_hash_set<int64> related_calls; // Related computations.
std::queue<int64> worklist;
worklist.push(root->id());
related_ops.insert(root->id());
diff --git a/tensorflow/compiler/xla/client/xla_builder.h b/tensorflow/compiler/xla/client/xla_builder.h
index 1da6ddd318..cd0d5ca5d3 100644
--- a/tensorflow/compiler/xla/client/xla_builder.h
+++ b/tensorflow/compiler/xla/client/xla_builder.h
@@ -21,6 +21,8 @@ limitations under the License.
#include <type_traits>
#include <utility>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/client/padding.h"
@@ -34,8 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stacktrace.h"
#include "tensorflow/core/platform/types.h"
@@ -1027,7 +1027,7 @@ class XlaBuilder {
// A map from XlaOp::Handle to the index in the instructions_ vector where the
// instruction is held.
- tensorflow::gtl::FlatMap<int64, int64> handle_to_index_;
+ absl::flat_hash_map<int64, int64> handle_to_index_;
// The embedded computations used by this computation. Each computation was
// the entry computation of some XlaComputation, the key is the unique id of
@@ -1035,7 +1035,7 @@ class XlaBuilder {
std::map<int64, HloComputationProto> embedded_;
// The unique parameter numbers.
- tensorflow::gtl::FlatSet<int64> parameter_numbers_;
+ absl::flat_hash_set<int64> parameter_numbers_;
// The metadata to attach to each op. This is structured as a "modal"-like
// operation, in order to simplify client code (and not sprinkle this metadata
diff --git a/tensorflow/compiler/xla/literal.cc b/tensorflow/compiler/xla/literal.cc
index 5035f41988..deeb140b8f 100644
--- a/tensorflow/compiler/xla/literal.cc
+++ b/tensorflow/compiler/xla/literal.cc
@@ -287,6 +287,8 @@ Status MutableLiteralBase::CopyElementFrom(const LiteralSlice& src_literal,
return InvalidArgument("LiteralProto has no layout");
}
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape()));
+
Literal literal(proto.shape());
TF_RETURN_IF_ERROR(literal.root_piece_->ForEachMutableSubpieceWithStatus(
@@ -1850,6 +1852,24 @@ Status LiteralBase::Piece::CopyFromProto(const LiteralProto& proto) {
TF_RET_CHECK(LayoutUtil::HasLayout(proto.shape()));
TF_RET_CHECK(ShapeUtil::Equal(proto.shape(), subshape()));
+ if (LayoutUtil::IsSparseArray(subshape())) {
+ // Compute the number of elements (indices) in the sparse shape and reserve
+ // the necessary space in spare_indices.
+ TF_RET_CHECK(ShapeUtil::Rank(subshape()) != 0)
+ << "Scalar shapes cannot be sparse";
+ TF_RET_CHECK(proto.sparse_indices_size() % ShapeUtil::Rank(subshape()) == 0)
+ << "Unexpected number of indices in proto ("
+ << proto.sparse_indices_size() << ") for shape of rank "
+ << ShapeUtil::Rank(subshape());
+ const int64 index_count =
+ proto.sparse_indices_size() / ShapeUtil::Rank(subshape());
+ sparse_indices()->Resize(index_count);
+
+ // Copy the indices from the proto into the SparseIndexArray object.
+ TF_RETURN_IF_ERROR(CopyFromRepeatedField(sparse_indices()->mutable_data(),
+ proto.sparse_indices()));
+ }
+
switch (subshape().element_type()) {
case PRED:
TF_RETURN_IF_ERROR(CopyFromRepeatedField(data<bool>(), proto.preds()));
diff --git a/tensorflow/compiler/xla/literal_test.cc b/tensorflow/compiler/xla/literal_test.cc
index 7ad287c897..dd5b54e4c9 100644
--- a/tensorflow/compiler/xla/literal_test.cc
+++ b/tensorflow/compiler/xla/literal_test.cc
@@ -224,6 +224,16 @@ TEST_F(LiteralUtilTest, CreateSparse) {
absl::Span<const int64>(expected_indices.data(),
expected_indices.num_elements()));
EXPECT_EQ(literal.data<int64>(), absl::Span<const int64>(expected_values));
+
+ // Serialize then deserialize and verify the resulting literal.
+ TF_ASSERT_OK_AND_ASSIGN(Literal literal_from_proto,
+ Literal::CreateFromProto(literal.ToProto()));
+
+ EXPECT_EQ(literal_from_proto.sparse_indices()->data(),
+ absl::Span<const int64>(expected_indices.data(),
+ expected_indices.num_elements()));
+ EXPECT_EQ(literal_from_proto.data<int64>(),
+ absl::Span<const int64>(expected_values));
}
TEST_F(LiteralUtilTest, LiteralR4F32ProjectedStringifies) {
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index e800cf470c..f329a27e14 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -146,6 +146,8 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -182,6 +184,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
@@ -251,6 +254,7 @@ cc_library(
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/container:inlined_vector",
+ "@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -296,6 +300,7 @@ cc_library(
"hlo_opcode.cc",
"hlo_schedule.cc",
"hlo_sharding.cc",
+ "hlo_sharding_metadata.cc",
],
hdrs = [
"dfs_hlo_visitor.h",
@@ -309,6 +314,7 @@ cc_library(
"hlo_opcode.h",
"hlo_schedule.h",
"hlo_sharding.h",
+ "hlo_sharding_metadata.h",
],
deps = [
":hlo_casting_utils",
@@ -333,6 +339,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -395,6 +403,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:span",
],
)
@@ -485,6 +494,8 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -776,6 +787,7 @@ cc_library(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -903,6 +915,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
@@ -952,6 +965,8 @@ cc_library(
deps = [
"//tensorflow/compiler/xla:types",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@@ -987,6 +1002,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
@@ -1034,6 +1051,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -1087,6 +1106,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
],
@@ -1125,6 +1145,8 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
],
)
@@ -1146,6 +1168,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
],
)
@@ -1196,6 +1219,7 @@ cc_library(
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/types:optional",
],
@@ -1216,6 +1240,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@@ -1260,6 +1286,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -1280,6 +1308,7 @@ tf_cc_test(
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -1304,6 +1333,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
],
)
@@ -1330,6 +1360,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@@ -1385,6 +1417,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",
],
@@ -1640,6 +1673,8 @@ cc_library(
":while_loop_analysis",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
],
@@ -1671,6 +1706,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -2043,6 +2079,7 @@ cc_library(
":logical_buffer",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -2078,6 +2115,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@@ -2099,6 +2137,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -2182,6 +2221,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@@ -2203,6 +2243,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
@@ -2263,6 +2305,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
@@ -2319,6 +2363,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -2345,6 +2391,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
],
)
@@ -2416,6 +2464,7 @@ cc_library(
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
],
@@ -2460,6 +2509,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -2588,6 +2639,8 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -2627,6 +2680,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
],
)
@@ -2701,27 +2755,13 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
],
)
cc_library(
- name = "hlo_sharding_metadata",
- srcs = ["hlo_sharding_metadata.cc"],
- hdrs = [
- "hlo_sharding_metadata.h",
- ],
- deps = [
- ":hlo",
- "//tensorflow/compiler/xla:shape_tree",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/core:lib",
- "@com_google_absl//absl/memory",
- "@com_google_absl//absl/types:span",
- ],
-)
-
-cc_library(
name = "hlo_domain_verifier",
srcs = ["hlo_domain_verifier.cc"],
hdrs = ["hlo_domain_verifier.h"],
@@ -2771,7 +2811,6 @@ tf_cc_test(
":hlo_domain_isolator",
":hlo_domain_remover",
":hlo_parser",
- ":hlo_sharding_metadata",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/legacy_flags:debug_options_flags",
"//tensorflow/compiler/xla/tests:hlo_test_base",
@@ -3147,6 +3186,7 @@ cc_library(
":hlo_pass_pipeline",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
@@ -3269,6 +3309,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
],
)
@@ -3298,6 +3340,7 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:inlined_vector",
],
)
@@ -3354,6 +3397,8 @@ cc_library(
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/container:inlined_vector",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
@@ -3381,7 +3426,6 @@ cc_library(
deps = [
":hlo",
":hlo_lexer",
- ":hlo_sharding_metadata",
"//tensorflow/compiler/xla:literal",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h
index a7d8927cf7..43feccee3c 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.h
+++ b/tensorflow/compiler/xla/service/allocation_tracker.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/backend.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
@@ -110,7 +111,7 @@ class AllocationTracker {
// A map from device memory opaque value to allocation. One such map is
// maintained per device ordinal.
- using AllocationMap = tensorflow::gtl::FlatMap<const void*, Allocation>;
+ using AllocationMap = absl::flat_hash_map<const void*, Allocation>;
tensorflow::mutex mutex_;
@@ -123,10 +124,7 @@ class AllocationTracker {
int64 next_handle_ GUARDED_BY(mutex_);
// A map from device ordinal to AllocationMap.
- //
- // This is not a TF FlatMap because (currently) FlatMap (and therefore
- // AllocationMap) is not movable.
- std::unordered_map<int, AllocationMap> opaque_to_allocation_map_
+ absl::flat_hash_map<int, AllocationMap> opaque_to_allocation_map_
GUARDED_BY(mutex_);
// A map from data handle to a vector of shaped buffers that represent the
@@ -146,7 +144,7 @@ class AllocationTracker {
// non-owning "view" into a tuple's sub-buffers. The sub-buffers are then
// free'd when both the view *and* the original tuple are Unregistered. This
// refcounting is managed in opaque_to_allocation_map_.
- tensorflow::gtl::FlatMap<int64, std::vector<std::unique_ptr<ShapedBuffer>>>
+ absl::flat_hash_map<int64, std::vector<std::unique_ptr<ShapedBuffer>>>
handle_to_shaped_buffers_ GUARDED_BY(mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(AllocationTracker);
diff --git a/tensorflow/compiler/xla/service/batchnorm_expander.cc b/tensorflow/compiler/xla/service/batchnorm_expander.cc
index 30d33e0d35..f70f6ddfec 100644
--- a/tensorflow/compiler/xla/service/batchnorm_expander.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_expander.cc
@@ -35,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.cc b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
index 58f78f8e24..002be9c970 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/bfloat16_propagation.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -81,7 +82,7 @@ void BFloat16Propagation::RevertIfFusionInternalBF16Changes(
};
auto root = fusion->fused_instructions_computation()->root_instruction();
- tensorflow::gtl::FlatSet<const HloValue*> changed_root_buffers;
+ absl::flat_hash_set<const HloValue*> changed_root_buffers;
auto root_changes_it = changes_to_bf16_.find(root);
if (root_changes_it != changes_to_bf16_.end()) {
@@ -500,7 +501,7 @@ void BFloat16Propagation::AdjustCalledComputationRoot(HloInstruction* hlo) {
bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
HloComputation* computation,
- tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations) {
+ absl::flat_hash_set<const HloComputation*>* visited_computations) {
bool parameter_changed = false;
auto insts = computation->MakeInstructionPostOrder();
// Do the adjustment on each instruction in the computation in reverse
@@ -560,7 +561,7 @@ bool BFloat16Propagation::ResolveInconsistencyOfAliasingBuffersHelper(
// another input parameter. A fixed point will be reached because the
// parameters can only be changed from BF16 to F32, not the other way
// around.
- tensorflow::gtl::FlatSet<const HloComputation*> visited_in_while;
+ absl::flat_hash_set<const HloComputation*> visited_in_while;
while (ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_condition(),
&visited_in_while) ||
ResolveInconsistencyOfAliasingBuffersHelper(hlo->while_body(),
@@ -587,7 +588,7 @@ void BFloat16Propagation::ResolveInconsistencyOfAliasingBuffers(
HloModule* module) {
const auto& computations_topological_order =
module->MakeComputationPostOrder();
- tensorflow::gtl::FlatSet<const HloComputation*> resolved;
+ absl::flat_hash_set<const HloComputation*> resolved;
for (auto comp_it = computations_topological_order.rbegin();
comp_it != computations_topological_order.rend(); ++comp_it) {
if (ContainsKey(resolved, *comp_it)) {
diff --git a/tensorflow/compiler/xla/service/bfloat16_propagation.h b/tensorflow/compiler/xla/service/bfloat16_propagation.h
index 6a62439f88..5fcaa15c83 100644
--- a/tensorflow/compiler/xla/service/bfloat16_propagation.h
+++ b/tensorflow/compiler/xla/service/bfloat16_propagation.h
@@ -21,6 +21,8 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/bfloat16_support.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -81,7 +83,7 @@ class BFloat16Propagation : public HloModulePass {
// The set of instructions to consider using bfloat16, computed in the forward
// pass.
- tensorflow::gtl::FlatSet<const HloInstruction*> consider_using_bfloat16_;
+ absl::flat_hash_set<const HloInstruction*> consider_using_bfloat16_;
// ***************************
// Functions called and state produced by the backward pass (from root to
@@ -110,12 +112,12 @@ class BFloat16Propagation : public HloModulePass {
// The set of HloInstructions that have been visited in the
// opportunity-finding pass.
- tensorflow::gtl::FlatSet<const HloInstruction*>
+ absl::flat_hash_set<const HloInstruction*>
instructions_visited_in_backward_pass_;
// The set of HloComputations that have been visited in the
// opportunity-finding pass.
- tensorflow::gtl::FlatSet<const HloComputation*>
+ absl::flat_hash_set<const HloComputation*>
computations_visited_in_backward_pass_;
// ***************************
@@ -131,7 +133,7 @@ class BFloat16Propagation : public HloModulePass {
// point is reached.
bool ResolveInconsistencyOfAliasingBuffersHelper(
HloComputation* computation,
- tensorflow::gtl::FlatSet<const HloComputation*>* visited_computations);
+ absl::flat_hash_set<const HloComputation*>* visited_computations);
// Makes the parameters of called computations match how they are called by
// the given HLO.
@@ -182,11 +184,11 @@ class BFloat16Propagation : public HloModulePass {
PrimitiveType target_type);
// The set of F32 HLO values that must be kept in F32.
- tensorflow::gtl::FlatSet<const HloValue*> values_that_must_be_kept_as_f32_;
+ absl::flat_hash_set<const HloValue*> values_that_must_be_kept_as_f32_;
// Mapping from each HloComputation to the number of callers to it in the
// module. Populated at the beginning of this pass.
- tensorflow::gtl::FlatMap<const HloComputation*, int64> caller_counts_;
+ absl::flat_hash_map<const HloComputation*, int64> caller_counts_;
// We first store the potential F32-to-BF16 changes to changes_to_bf16_, which
// are subject to further adjustment, then finally applied to the HLOs. This
@@ -195,8 +197,7 @@ class BFloat16Propagation : public HloModulePass {
//
// For each HloInstruction, changes_to_bf16_ stores the affected buffers in
// the output as a map from in-place pointers to subshapes to shape indices.
- tensorflow::gtl::FlatMap<HloInstruction*,
- tensorflow::gtl::FlatMap<Shape*, ShapeIndex>>
+ absl::flat_hash_map<HloInstruction*, absl::flat_hash_map<Shape*, ShapeIndex>>
changes_to_bf16_;
// Whether the last processed HLO module has been changed by this pass.
diff --git a/tensorflow/compiler/xla/service/bfloat16_support.cc b/tensorflow/compiler/xla/service/bfloat16_support.cc
index 23645346e6..5b48f10505 100644
--- a/tensorflow/compiler/xla/service/bfloat16_support.cc
+++ b/tensorflow/compiler/xla/service/bfloat16_support.cc
@@ -78,8 +78,10 @@ bool BFloat16Support::EffectiveOperandPrecisionIsOutputPrecision(
const HloInstruction& hlo, int64 operand_index) {
switch (hlo.opcode()) {
case HloOpcode::kAbs:
+ case HloOpcode::kAllToAll:
case HloOpcode::kBroadcast:
case HloOpcode::kClamp:
+ case HloOpcode::kCollectivePermute:
case HloOpcode::kConcatenate:
case HloOpcode::kConvert:
case HloOpcode::kCopy:
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.cc b/tensorflow/compiler/xla/service/buffer_assignment.cc
index 34a7be0e9c..2c2d1626c2 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include <ostream>
#include <utility>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@@ -41,10 +43,10 @@ limitations under the License.
namespace xla {
namespace {
+using absl::flat_hash_map;
+using absl::flat_hash_set;
using absl::StrAppend;
using absl::StrAppendFormat;
-using ::tensorflow::gtl::FlatMap;
-using ::tensorflow::gtl::FlatSet;
using ::tensorflow::strings::HumanReadableNumBytes;
template <typename T>
@@ -128,8 +130,8 @@ Status GatherComputationsByAllocationType(
// Sets for quickly checking membership. Computations are returned in vectors
// for stable iteration.
- FlatSet<const HloComputation*> thread_local_set;
- FlatSet<const HloComputation*> global_set;
+ flat_hash_set<const HloComputation*> thread_local_set;
+ flat_hash_set<const HloComputation*> global_set;
while (!worklist.empty()) {
auto worklist_front = worklist.front();
@@ -444,7 +446,7 @@ bool BufferAssignment::SharesSliceAtIndex(
bool BufferAssignment::HaveDisjointSlices(const HloInstruction* hlo_a,
const HloInstruction* hlo_b) const {
using SliceSet =
- FlatSet<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>;
+ flat_hash_set<BufferAllocation::Slice, BufferAllocation::Slice::Hasher>;
// Gets the slices all of instr's subshapes. If any subshape doesn't have an
// assigned slice, returns the empty set.
auto collect_slices = [&](const HloInstruction* instr) -> SliceSet {
@@ -519,7 +521,8 @@ void BufferAssignment::AddAssignment(BufferAllocation* allocation,
// BufferAllocation.
void BufferAssignment::CombineTempAllocations() {
VLOG(1) << "CombineTempAllocations()";
- FlatMap<LogicalBuffer::Color, BufferAllocation, LogicalBuffer::Color::Hasher>
+ flat_hash_map<LogicalBuffer::Color, BufferAllocation,
+ LogicalBuffer::Color::Hasher>
combined_allocation_map;
// Move all temp allocations into a single run at the end of the allocations
@@ -582,7 +585,8 @@ void BufferAssignment::CombineTempAllocations() {
}
// Update allocation indices to their new positions.
- allocation_index_for_buffer_.clear_no_resize();
+ allocation_index_for_buffer_.erase(allocation_index_for_buffer_.begin(),
+ allocation_index_for_buffer_.end());
for (size_t index = 0; index < allocations_.size(); ++index) {
BufferAllocation* allocation = &allocations_[index];
allocation->set_index(index);
@@ -812,9 +816,9 @@ bool BufferAssigner::MaybeAssignBuffer(BufferAllocation* allocation,
Status BufferAssigner::AssignBuffersForComputation(
const HloComputation* computation, bool is_thread_local,
- const FlatSet<const LogicalBuffer*>& colocated_buffers,
- const FlatSet<BufferAllocation::Index>& colocated_allocations,
- FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>*
+ const flat_hash_set<const LogicalBuffer*>& colocated_buffers,
+ const flat_hash_set<BufferAllocation::Index>& colocated_allocations,
+ flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>*
buffers_to_assign_sequentially,
BufferAssignment* assignment) {
// Buffers are sorted and assigned to BufferAllocations in decreasing order of
@@ -833,7 +837,7 @@ Status BufferAssigner::AssignBuffersForComputation(
// Generate a post order sort of instructions for sorting of the
// LogicalBuffers.
- FlatMap<const HloInstruction*, int> post_order_position;
+ flat_hash_map<const HloInstruction*, int> post_order_position;
int position = 0;
for (auto* instruction : computation->MakeInstructionPostOrder()) {
post_order_position.emplace(instruction, position);
@@ -850,8 +854,8 @@ Status BufferAssigner::AssignBuffersForComputation(
// buffers_to_assign_sequentially map, even if we end up with an empty set
// of buffers. This ensures we can correctly determine whether to run
// whole-module heap simulation.
- buffers_to_assign_sequentially->emplace(computation,
- FlatSet<const LogicalBuffer*>());
+ buffers_to_assign_sequentially->emplace(
+ computation, flat_hash_set<const LogicalBuffer*>());
}
// Sort the LogicalBuffers first by size. We assign the larger LogicalBuffers
@@ -1043,12 +1047,12 @@ Status BufferAssigner::AssignBuffersForComputation(
return Status::OK();
}
-FlatMap<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
- LogicalBuffer::Color::Hasher>
+flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
+ LogicalBuffer::Color::Hasher>
BufferAssigner::SplitBuffersByColor(
- const FlatSet<const LogicalBuffer*>& buffers) {
- FlatMap<LogicalBuffer::Color, FlatSet<const LogicalBuffer*>,
- LogicalBuffer::Color::Hasher>
+ const flat_hash_set<const LogicalBuffer*>& buffers) {
+ flat_hash_map<LogicalBuffer::Color, flat_hash_set<const LogicalBuffer*>,
+ LogicalBuffer::Color::Hasher>
color_map;
for (auto buffer : buffers) {
color_map[buffer->color()].insert(buffer);
@@ -1057,7 +1061,8 @@ BufferAssigner::SplitBuffersByColor(
}
Status BufferAssigner::AssignBuffersWithSequentialOrdering(
- const FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>&
+ const flat_hash_map<const HloComputation*,
+ flat_hash_set<const LogicalBuffer*>>&
buffers_to_assign_sequentially,
bool run_whole_module_heap_simulation, BufferAssignment* assignment) {
// Run the sequence of instructions through the heap simulator. The heuristic
@@ -1083,10 +1088,11 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
// only live for the duration of their calling instructions.
VLOG(1) << "Running whole-module heap simulation";
HloSchedule schedule(&assignment->module());
- FlatSet<const LogicalBuffer*> all_buffers_to_assign;
+ flat_hash_set<const LogicalBuffer*> all_buffers_to_assign;
for (const auto& pair : buffers_to_assign_sequentially) {
const HloComputation* computation = pair.first;
- const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
+ const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
+ pair.second;
const std::vector<const HloInstruction*>* instruction_sequence =
hlo_ordering.SequentialOrder(*computation);
CHECK(instruction_sequence != nullptr) << computation->name();
@@ -1120,7 +1126,8 @@ Status BufferAssigner::AssignBuffersWithSequentialOrdering(
VLOG(1) << "Running per-computation heap simulation";
for (const auto& pair : buffers_to_assign_sequentially) {
const HloComputation* computation = pair.first;
- const FlatSet<const LogicalBuffer*>& buffers_to_assign = pair.second;
+ const flat_hash_set<const LogicalBuffer*>& buffers_to_assign =
+ pair.second;
const std::vector<const HloInstruction*>* instruction_sequence =
hlo_ordering.SequentialOrder(*computation);
CHECK(instruction_sequence != nullptr) << computation->name();
@@ -1155,9 +1162,8 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
const BufferAllocation& allocation, const HeapSimulatorTrace& heap_trace) {
// Create a map from LogicalBuffer::Id to LogicalBuffer* for the logical
// buffers in this allocation.
- tensorflow::gtl::FlatMap<LogicalBuffer::Id, const LogicalBuffer*>
- id_to_buffer;
- tensorflow::gtl::FlatMap<const LogicalBuffer*, int64> buffer_sizes;
+ absl::flat_hash_map<LogicalBuffer::Id, const LogicalBuffer*> id_to_buffer;
+ absl::flat_hash_map<const LogicalBuffer*, int64> buffer_sizes;
for (const auto& pair : allocation.assigned_buffers()) {
const LogicalBuffer* buffer = pair.first;
const BufferAllocation::OffsetSize& offset_size = pair.second;
@@ -1196,7 +1202,7 @@ std::vector<const LogicalBuffer*> ComputePeakMemoryLogicalBuffers(
// Next gather the set of logical buffers live at the earliest point of
// maximal live set size.
- tensorflow::gtl::FlatSet<const LogicalBuffer*> live_buffers;
+ absl::flat_hash_set<const LogicalBuffer*> live_buffers;
live_size = 0;
for (const auto& event : heap_trace.events()) {
const LogicalBuffer* buffer = id_to_buffer.at(event.buffer_id());
@@ -1586,8 +1592,8 @@ void BufferAssigner::BuildColocatedBufferSets(
void BufferAssigner::AssignColocatedBufferSets(
const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
BufferAssignment* assignment,
- FlatSet<const LogicalBuffer*>* colocated_buffers,
- FlatSet<BufferAllocation::Index>* colocated_allocations) {
+ flat_hash_set<const LogicalBuffer*>* colocated_buffers,
+ flat_hash_set<BufferAllocation::Index>* colocated_allocations) {
for (const ColocatedBufferSet& colocated_buffer_set : colocated_buffer_sets) {
BufferAllocation* allocation = nullptr;
// Set 'entry_parameter_number' and 'entry_parameter_shape_idx' if entry
@@ -1660,8 +1666,8 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
// Once b/32491382 enables module-level liveness analysis, we may be able
// to assign colocated buffers (or at least reuse their allocation for
// buffers outside of the set) in AssignBuffersForComputation.
- FlatSet<const LogicalBuffer*> colocated_buffers;
- FlatSet<BufferAllocation::Index> colocated_allocations;
+ flat_hash_set<const LogicalBuffer*> colocated_buffers;
+ flat_hash_set<BufferAllocation::Index> colocated_allocations;
std::vector<ColocatedBufferSet> colocated_buffer_sets;
BuildColocatedBufferSets(module, assignment->liveness(),
assignment->buffer_size_, &colocated_buffer_sets);
@@ -1679,7 +1685,7 @@ StatusOr<std::unique_ptr<BufferAssignment>> BufferAssigner::CreateAssignment(
// First assign buffers for global computatations. Temporary buffers for
// sequential computations are collected in 'buffers_to_assign_sequentially'.
- FlatMap<const HloComputation*, FlatSet<const LogicalBuffer*>>
+ flat_hash_map<const HloComputation*, flat_hash_set<const LogicalBuffer*>>
buffers_to_assign_sequentially;
for (auto* computation : global_computations) {
TF_RETURN_IF_ERROR(AssignBuffersForComputation(
diff --git a/tensorflow/compiler/xla/service/buffer_assignment.h b/tensorflow/compiler/xla/service/buffer_assignment.h
index 24ba7c16f5..899cd36e1f 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment.h
+++ b/tensorflow/compiler/xla/service/buffer_assignment.h
@@ -22,6 +22,8 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
@@ -33,8 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -148,7 +148,7 @@ class BufferAllocation {
// Access to the logical buffers assigned to this allocation, and their
// associated logical offsets and sizes.
- const tensorflow::gtl::FlatMap<const LogicalBuffer*, OffsetSize>&
+ const absl::flat_hash_map<const LogicalBuffer*, OffsetSize>&
assigned_buffers() const {
return assigned_buffers_;
}
@@ -323,7 +323,7 @@ class BufferAllocation {
// Mapping from the set of buffers assigned to this allocation to their
// logical offsets and sizes.
- tensorflow::gtl::FlatMap<const LogicalBuffer*, OffsetSize> assigned_buffers_;
+ absl::flat_hash_map<const LogicalBuffer*, OffsetSize> assigned_buffers_;
int64 fragmentation_bytes_ = 0;
std::vector<HeapSimulatorTrace> heap_traces_;
@@ -500,7 +500,7 @@ class BufferAssignment {
int64 temp_allocation_total_size_ = 0;
// Maps Buffers to the index of the BufferAllocation which holds the buffer.
- tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferAllocation::Index>
+ absl::flat_hash_map<const LogicalBuffer*, BufferAllocation::Index>
allocation_index_for_buffer_;
const HloModule* module_;
@@ -554,11 +554,10 @@ class BufferAssigner {
// true.
Status AssignBuffersForComputation(
const HloComputation* computation, bool is_thread_local,
- const tensorflow::gtl::FlatSet<const LogicalBuffer*>& colocated_buffers,
- const tensorflow::gtl::FlatSet<BufferAllocation::Index>&
- colocated_allocations,
- tensorflow::gtl::FlatMap<const HloComputation*,
- tensorflow::gtl::FlatSet<const LogicalBuffer*>>*
+ const absl::flat_hash_set<const LogicalBuffer*>& colocated_buffers,
+ const absl::flat_hash_set<BufferAllocation::Index>& colocated_allocations,
+ absl::flat_hash_map<const HloComputation*,
+ absl::flat_hash_set<const LogicalBuffer*>>*
buffers_to_assign_sequentially,
BufferAssignment* assignment);
@@ -568,9 +567,8 @@ class BufferAssigner {
// 'run_whole_module_heap_simulation' is true, the heap simulation will be run
// assuming all global computations are sequentially ordered.
Status AssignBuffersWithSequentialOrdering(
- const tensorflow::gtl::FlatMap<
- const HloComputation*,
- tensorflow::gtl::FlatSet<const LogicalBuffer*>>&
+ const absl::flat_hash_map<const HloComputation*,
+ absl::flat_hash_set<const LogicalBuffer*>>&
buffers_to_assign_sequentially,
bool run_whole_module_heap_simulation, BufferAssignment* assignment);
@@ -590,7 +588,7 @@ class BufferAssigner {
// alias. Explicitly handling these colocated buffers is necessary because
// points-to analysis is computation level scope and does not recognize
// aliasing across computations (b/32491382).
- using ColocatedBufferSet = tensorflow::gtl::FlatSet<const LogicalBuffer*>;
+ using ColocatedBufferSet = absl::flat_hash_set<const LogicalBuffer*>;
// Returns a vector of ColocatedBufferSet objects, where each
// ColocatedBufferSet aggregates a set of related LogicalBuffers from 'module'
@@ -605,8 +603,8 @@ class BufferAssigner {
void AssignColocatedBufferSets(
const std::vector<ColocatedBufferSet>& colocated_buffer_sets,
BufferAssignment* assignment,
- tensorflow::gtl::FlatSet<const LogicalBuffer*>* colocated_buffers,
- tensorflow::gtl::FlatSet<BufferAllocation::Index>* colocated_allocations);
+ absl::flat_hash_set<const LogicalBuffer*>* colocated_buffers,
+ absl::flat_hash_set<BufferAllocation::Index>* colocated_allocations);
// Adds the 'colocated_set' of buffers to 'colocated_buffer_sets', maintaining
// the invariant that all sets in 'colocated_buffer_sets' are disjoint.
@@ -624,11 +622,10 @@ class BufferAssigner {
// Split a set of buffers into several sets, each of which contains buffers
// colored with the same color.
- tensorflow::gtl::FlatMap<LogicalBuffer::Color,
- tensorflow::gtl::FlatSet<const LogicalBuffer*>,
- LogicalBuffer::Color::Hasher>
- SplitBuffersByColor(
- const tensorflow::gtl::FlatSet<const LogicalBuffer*>& buffers);
+ absl::flat_hash_map<LogicalBuffer::Color,
+ absl::flat_hash_set<const LogicalBuffer*>,
+ LogicalBuffer::Color::Hasher>
+ SplitBuffersByColor(const absl::flat_hash_set<const LogicalBuffer*>& buffers);
// If true, buffer assignments assumes that input parameter buffers and output
// buffers can be shared if their sizes match.
diff --git a/tensorflow/compiler/xla/service/buffer_liveness.h b/tensorflow/compiler/xla/service/buffer_liveness.h
index cdd3cf4032..f939a426ea 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness.h
+++ b/tensorflow/compiler/xla/service/buffer_liveness.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
@@ -27,8 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -102,7 +101,7 @@ class BufferLiveness {
// Set of LogicalBuffers which are aliased in the output of other
// instructions. For example, a LogicalBuffer which is inserted into a tuple
// is considered to be aliased and will be in this set.
- tensorflow::gtl::FlatSet<const LogicalBuffer*> aliased_buffers_;
+ absl::flat_hash_set<const LogicalBuffer*> aliased_buffers_;
// LogicalBuffers that may be live out of the entry computation.
PointsToSet::BufferSet maybe_live_out_buffers_;
diff --git a/tensorflow/compiler/xla/service/buffer_value_containers.h b/tensorflow/compiler/xla/service/buffer_value_containers.h
index 305914fca8..cc46af5eee 100644
--- a/tensorflow/compiler/xla/service/buffer_value_containers.h
+++ b/tensorflow/compiler/xla/service/buffer_value_containers.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_BUFFER_VALUE_CONTAINERS_H_
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
#include "tensorflow/core/lib/gtl/compactptrset.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -38,7 +38,7 @@ BufferValueCompactPointerSet ToBufferValueCompactPointerSet(
return output;
}
-using BufferValueFlatSet = tensorflow::gtl::FlatSet<const BufferValue*>;
+using BufferValueFlatSet = absl::flat_hash_set<const BufferValue*>;
template <class LogicalBufferContainerT>
BufferValueFlatSet ToBufferValueFlatSet(
const LogicalBufferContainerT& logical_buffer_container) {
diff --git a/tensorflow/compiler/xla/service/call_graph.cc b/tensorflow/compiler/xla/service/call_graph.cc
index 23b2a32709..bdd5069632 100644
--- a/tensorflow/compiler/xla/service/call_graph.cc
+++ b/tensorflow/compiler/xla/service/call_graph.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <queue>
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@@ -138,7 +139,7 @@ CallGraphNode& CallGraph::GetNode(const HloComputation* computation) {
bool CallGraph::DominatesHelper(
const HloComputation* a, const HloComputation* b,
- tensorflow::gtl::FlatSet<const HloComputation*>* visited) const {
+ absl::flat_hash_set<const HloComputation*>* visited) const {
if (a == b || ContainsKey(*visited, b)) {
// The call graph is guaranteed to be acyclic so any previously visited node
// we encounter was already determined to be dominated.
@@ -163,7 +164,7 @@ bool CallGraph::DominatesHelper(
bool CallGraph::Dominates(const HloComputation* a,
const HloComputation* b) const {
- tensorflow::gtl::FlatSet<const HloComputation*> visited;
+ absl::flat_hash_set<const HloComputation*> visited;
return DominatesHelper(a, b, &visited);
}
@@ -277,7 +278,7 @@ std::unique_ptr<CallGraph> CallGraph::Build(const HloModule* module) {
Status CallGraph::VisitNodesInternal(
const VisitorFunction& visitor_func, const CallGraphNode& node,
- tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const {
+ absl::flat_hash_set<const CallGraphNode*>* visited) const {
auto pair = visited->insert(&node);
if (!pair.second) {
// Node was not inserted. Node has already been visited.
@@ -294,7 +295,7 @@ Status CallGraph::VisitNodesInternal(
Status CallGraph::VisitNodes(const VisitorFunction& visitor_func,
bool visit_unreachable_nodes) const {
- tensorflow::gtl::FlatSet<const CallGraphNode*> visited;
+ absl::flat_hash_set<const CallGraphNode*> visited;
if (visit_unreachable_nodes) {
// Traverse from all roots in the call graph.
for (const CallGraphNode& node : nodes()) {
diff --git a/tensorflow/compiler/xla/service/call_graph.h b/tensorflow/compiler/xla/service/call_graph.h
index 3af2ab5edf..cb56f4789d 100644
--- a/tensorflow/compiler/xla/service/call_graph.h
+++ b/tensorflow/compiler/xla/service/call_graph.h
@@ -20,11 +20,11 @@ limitations under the License.
#include <ostream>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -145,19 +145,19 @@ class CallGraphNode {
// The computations called by this computation. The vector is used for a
// stable ordering and the set enables fast membership testing.
std::vector<HloComputation*> callees_;
- tensorflow::gtl::FlatSet<HloComputation*> callee_set_;
+ absl::flat_hash_set<HloComputation*> callee_set_;
// The computations which call this computation. The vector is used for a
// stable ordering and the set enables fast membership testing.
std::vector<HloComputation*> callers_;
- tensorflow::gtl::FlatSet<HloComputation*> caller_set_;
+ absl::flat_hash_set<HloComputation*> caller_set_;
// The call sites in this computation
std::vector<CallSite> callsites_;
// The map from instruction to index in callsites_ for looking up the callsite
// (if any) associated with a particular instruction in this computation.
- tensorflow::gtl::FlatMap<const HloInstruction*, int64> callsite_instructions_;
+ absl::flat_hash_map<const HloInstruction*, int64> callsite_instructions_;
// The call sites in other computations which call this computation.
std::vector<CallSite> caller_callsites_;
@@ -250,14 +250,14 @@ class CallGraph {
// 'visited'.
Status VisitNodesInternal(
const VisitorFunction& visitor_func, const CallGraphNode& node,
- tensorflow::gtl::FlatSet<const CallGraphNode*>* visited) const;
+ absl::flat_hash_set<const CallGraphNode*>* visited) const;
// Recursive helper for computing whether 'a' dominates 'b' in the call
// graph. 'b_ancestor' is the currently visited node (which starts at 'b'),
// and 'visited' is the set of computations which have been visited.
bool DominatesHelper(
const HloComputation* a, const HloComputation* b,
- tensorflow::gtl::FlatSet<const HloComputation*>* visited) const;
+ absl::flat_hash_set<const HloComputation*>* visited) const;
// The HLO module represented by this call graph.
const HloModule* module_ = nullptr;
@@ -267,7 +267,7 @@ class CallGraph {
// Map from HLO computation to the index of the corresponding call graph node
// in nodes_.
- tensorflow::gtl::FlatMap<const HloComputation*, int64> node_indices_;
+ absl::flat_hash_map<const HloComputation*, int64> node_indices_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/copy_insertion.cc b/tensorflow/compiler/xla/service/copy_insertion.cc
index b65dfef9c9..f35324aa35 100644
--- a/tensorflow/compiler/xla/service/copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/copy_insertion.cc
@@ -15,6 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/copy_insertion.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_alias_analysis.h"
@@ -31,8 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -432,7 +432,7 @@ class CopyRemover {
// Construct a list for each HLO buffer in the alias analysis. Maintain a
// map from HloValue to the respective list element representing that
// value. The map is used to construct the copy info map below.
- tensorflow::gtl::FlatMap<const HloValue*, ValueNode*> value_to_node;
+ absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
for (const HloBuffer& buffer : alias_analysis.buffers()) {
// Verify values contained in the buffer are strictly ordered. This
// should always be the case after adding copies to eliminate
@@ -480,7 +480,7 @@ class CopyRemover {
// respective ValueNode representing that value.
void AddValueList(
absl::Span<const HloValue* const> values,
- tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>* value_to_node) {
+ absl::flat_hash_map<const HloValue*, ValueNode*>* value_to_node) {
ValueNode* tail = nullptr;
ValueNode* head = nullptr;
for (const HloValue* value : values) {
@@ -516,8 +516,7 @@ class CopyRemover {
// respective ValueNode.
void CreateCopyMap(
const HloModule& module,
- const tensorflow::gtl::FlatMap<const HloValue*, ValueNode*>&
- value_to_node) {
+ const absl::flat_hash_map<const HloValue*, ValueNode*>& value_to_node) {
for (HloComputation* computation : module.computations()) {
for (HloInstruction* instruction : computation->instructions()) {
// Add copies with unambiguous source values to the map. Copies with
@@ -905,7 +904,7 @@ class CopyRemover {
// The heads of all the value lists. Each value list represents the HLO
// values contained in a particular HLO buffer. The values in the list are
// in dependency order.
- tensorflow::gtl::FlatSet<const ValueNode*> value_lists_;
+ absl::flat_hash_set<const ValueNode*> value_lists_;
// Copy removal requires fast access to the value list elements
// corresponding to the source and destination values of the kCopy
@@ -916,7 +915,7 @@ class CopyRemover {
ValueNode* src = nullptr;
ValueNode* dest = nullptr;
};
- tensorflow::gtl::FlatMap<const HloInstruction*, CopyNodes> copy_map_;
+ absl::flat_hash_map<const HloInstruction*, CopyNodes> copy_map_;
};
HloModule* module_;
@@ -1010,7 +1009,7 @@ Status CopyInsertion::AddSpecialCaseCopies(const CallGraph& call_graph,
HloInstruction* root = computation->root_instruction();
// Mark nondistinct/ambiguous indices.
- tensorflow::gtl::FlatSet<const HloBuffer*> seen;
+ absl::flat_hash_set<const HloBuffer*> seen;
ShapeUtil::ForEachSubshape(
root->shape(), [&](const Shape& /*subshape*/, const ShapeIndex& index) {
std::vector<const HloBuffer*> buffers_at_index =
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index b7103118ac..ae4c6e962d 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -290,6 +290,8 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
"//tensorflow/compiler/xla/service/llvm_ir:tuple_ops",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/types:span",
@@ -309,6 +311,7 @@ cc_library(
deps = [
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@llvm//:analysis",
"@llvm//:target",
],
@@ -471,6 +474,7 @@ cc_library(
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
],
@@ -762,6 +766,7 @@ cc_library(
"//tensorflow/compiler/xla/service:computation_layout",
"//tensorflow/compiler/xla/service:layout_assignment",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
],
)
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
index bfecbd6e01..c291bf2d1b 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <numeric>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
@@ -38,7 +39,7 @@ using absl::nullopt;
using absl::optional;
using ShouldMakeOperandColMajorCache =
- tensorflow::gtl::FlatMap<const HloInstruction*, bool>;
+ absl::flat_hash_map<const HloInstruction*, bool>;
} // namespace
static bool ShouldMakeAllUsersColMajor(const HloInstruction* instruction) {
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
index 20cf855735..a9febe891b 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_runtime.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <functional>
+#include "absl/container/flat_hash_map.h"
#include "absl/synchronization/mutex.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/core/platform/dynamic_annotations.h"
@@ -30,8 +31,7 @@ namespace cpu {
namespace runtime {
XfeedManager* GetXfeedManager(int device_ordinal) {
- static tensorflow::gtl::FlatMap<int, XfeedManager*>* managers =
- new tensorflow::gtl::FlatMap<int, XfeedManager*>();
+ static auto* managers = new absl::flat_hash_map<int, XfeedManager*>();
static absl::Mutex* mutex = new absl::Mutex();
absl::MutexLock lock(mutex);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index c3e8020783..a70abb117a 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -24,6 +24,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
@@ -67,8 +69,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -1398,10 +1398,10 @@ static bool ReductionPreservesLayout(const HloInstruction& reduce) {
//
// So if we reduce f32[A,B,C,D] on dimensions 1 and 2, this map contains
// [0->0, 3->1].
- gtl::FlatMap<int64, int64> unreduced_dim_map;
+ absl::flat_hash_map<int64, int64> unreduced_dim_map;
- gtl::FlatSet<int64> reduced_dims(reduce.dimensions().begin(),
- reduce.dimensions().end());
+ absl::flat_hash_set<int64> reduced_dims(reduce.dimensions().begin(),
+ reduce.dimensions().end());
const Shape& operand_shape = reduce.operand(0)->shape();
const Shape& result_shape = reduce.shape();
@@ -1977,7 +1977,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
//
// * Implement the memcpy within the innermost loop.
- gtl::FlatSet<int64> inner_dims;
+ absl::flat_hash_set<int64> inner_dims;
for (int64 dim : LayoutUtil::MinorToMajor(layout)) {
if (operand->shape().dimensions(dim) != slice->shape().dimensions(dim)) {
break;
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index daafef4eb3..586f27b104 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -23,6 +23,7 @@ limitations under the License.
#include <unordered_map>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/ADT/Triple.h"
@@ -47,7 +48,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -427,7 +427,7 @@ class IrEmitter : public DfsHloVisitorWithDefault,
// Maps the buffer allocation slices for the parameters to the computation
// being compiled to their parameter numbers. Only relevant for thread local
// computations.
- tensorflow::gtl::FlatMap<BufferAllocation::Index, int64>
+ absl::flat_hash_map<BufferAllocation::Index, int64>
computation_parameter_allocations_;
// Maps HLO instructions to their index into the profile counter array.
@@ -567,11 +567,11 @@ class IrEmitter : public DfsHloVisitorWithDefault,
}
};
- tensorflow::gtl::FlatMap<const Literal*, llvm::Constant*,
- LiteralPtrHashFunctor, LiteralPtrEqualityFunctor>
+ absl::flat_hash_map<const Literal*, llvm::Constant*, LiteralPtrHashFunctor,
+ LiteralPtrEqualityFunctor>
emitted_literals_;
- tensorflow::gtl::FlatMap<BufferAllocation::Index, llvm::Constant*>
+ absl::flat_hash_map<BufferAllocation::Index, llvm::Constant*>
constant_buffer_to_global_;
std::vector<const HloComputation*> thread_local_computations_;
diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc
index a0cd8ee2d2..5cdac203af 100644
--- a/tensorflow/compiler/xla/service/cpu/target_machine_features.cc
+++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.cc
@@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/target_machine_features.h"
+#include "tensorflow/core/platform/logging.h"
namespace xla {
namespace cpu {
diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h
index 8b00ae9e47..a383b4a4a0 100644
--- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h
+++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h
@@ -16,10 +16,10 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_TARGET_MACHINE_FEATURES_H_
+#include "absl/container/flat_hash_map.h"
#include "llvm/Analysis/TargetTransformInfo.h"
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/xla/primitive_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace cpu {
@@ -97,8 +97,7 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures {
// This is mutated from within `GetTargetTransformInfoFor` which is
// semantically a getter (and thus `const`); and is therefore declared
// mutable. Making this mutable is okay because it has cache semantics.
- mutable tensorflow::gtl::FlatMap<const llvm::Function*,
- llvm::TargetTransformInfo>
+ mutable absl::flat_hash_map<const llvm::Function*, llvm::TargetTransformInfo>
target_transform_info_cache_;
llvm::TargetMachine* target_machine_;
};
diff --git a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
index 7af51db55a..b35fd9dad8 100644
--- a/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/tests/cpu_noalias_test.cc
@@ -121,7 +121,7 @@ TEST_F(CpuNoAliasTest, Concat) {
CHECK: %read_concat2_array = load {{.*}} !alias.scope [[concat1_noalias]], !noalias [[concat1_scope]]
CHECK-DAG: [[buf_size32:![0-9]+]] = !{!"buffer:{{.*}} size:32
CHECK-DAG: [[buf_size48:![0-9]+]] = !{!"buffer:{{.*}} size:48
- CHECK-DAG: [[param_x_noalias]] = !{[[buf_size32]], [[buf_size48]]}
+ CHECK-DAG: [[param_x_noalias]] = !{[[buf_size48]], [[buf_size32]]}
CHECK-DAG: [[concat1_scope]] = !{[[buf_size32]]}
CHECK-DAG: [[concat1_noalias]] = !{[[buf_size48]]}
)";
diff --git a/tensorflow/compiler/xla/service/defuser.cc b/tensorflow/compiler/xla/service/defuser.cc
index d124f74d19..661539cccb 100644
--- a/tensorflow/compiler/xla/service/defuser.cc
+++ b/tensorflow/compiler/xla/service/defuser.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -48,7 +49,7 @@ Status Defuse(HloInstruction* fusion_instruction) {
fusion_instruction->fused_instructions_computation();
// A map from fused instruction to its defused clone.
- tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*>
+ absl::flat_hash_map<const HloInstruction*, HloInstruction*>
defused_instructions;
// Initialize map to contain the fusion instruction parameters mapping
// to the operands of the fusion instruction.
diff --git a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
index 5761573791..68d01d75a2 100644
--- a/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
+++ b/tensorflow/compiler/xla/service/dfs_hlo_visitor.h
@@ -27,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.cc b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
index bec02e14f9..f92fde7f46 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.cc
@@ -98,7 +98,7 @@ Status GenericTransferManager::TransferLiteralFromDeviceInternal(
Status GenericTransferManager::TransferLiteralToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal,
- const ShapedBuffer& device_buffer) {
+ const ShapedBuffer& device_buffer, TransferToDeviceHint /*hint*/) {
const Shape& shape = literal.shape();
VLOG(2) << "transferring literal shape to device: "
<< ShapeUtil::HumanString(shape)
diff --git a/tensorflow/compiler/xla/service/generic_transfer_manager.h b/tensorflow/compiler/xla/service/generic_transfer_manager.h
index 86c8b1c145..b1cba82b9f 100644
--- a/tensorflow/compiler/xla/service/generic_transfer_manager.h
+++ b/tensorflow/compiler/xla/service/generic_transfer_manager.h
@@ -45,9 +45,10 @@ class GenericTransferManager : public TransferManager {
MutableBorrowingLiteral literal,
std::function<void(Status)> done) override;
- Status TransferLiteralToDeviceAsync(
- se::Stream* stream, const LiteralSlice& literal,
- const ShapedBuffer& device_buffer) override;
+ Status TransferLiteralToDeviceAsync(se::Stream* stream,
+ const LiteralSlice& literal,
+ const ShapedBuffer& device_buffer,
+ TransferToDeviceHint hint) override;
Status TransferLiteralToInfeed(se::StreamExecutor* executor,
const LiteralSlice& literal) override;
diff --git a/tensorflow/compiler/xla/service/gpu/BUILD b/tensorflow/compiler/xla/service/gpu/BUILD
index 51968d13d4..522e9f5948 100644
--- a/tensorflow/compiler/xla/service/gpu/BUILD
+++ b/tensorflow/compiler/xla/service/gpu/BUILD
@@ -91,6 +91,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_reachability",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
],
)
@@ -357,6 +358,7 @@ cc_library(
"//tensorflow/core/platform/default/build_config:cufft_plugin",
"//tensorflow/core/platform/default/build_config:stream_executor_cuda", # build_cleaner: keep
"//tensorflow/stream_executor",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
@@ -474,6 +476,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:instruction_fusion",
"//tensorflow/compiler/xla/service:pattern_matcher",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -506,6 +509,7 @@ cc_library(
"//tensorflow/compiler/xla/service:multi_output_fusion",
"//tensorflow/core:lib",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -539,6 +543,7 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo_dataflow_analysis",
"//tensorflow/compiler/xla/service:hlo_pass",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
@@ -713,6 +718,7 @@ cc_library(
"//tensorflow/core:lib_internal",
"//tensorflow/core:regexp_internal",
"//tensorflow/core:stream_executor_no_cuda",
+ "@com_google_absl//absl/container:node_hash_map",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
index 79c74e7e8b..e2ab00ce41 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_copy_insertion.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <set>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/copy_insertion.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
@@ -27,7 +28,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 31a9f9b1be..5742632782 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
@@ -197,7 +198,7 @@ GpuExecutable::ResolveConstantGlobals(se::StreamExecutor* executor) {
}
module_spec.AddCudaPtxInMemory(ptx().c_str());
- tensorflow::gtl::FlatMap<int64, se::DeviceMemoryBase> globals;
+ absl::flat_hash_map<int64, se::DeviceMemoryBase> globals;
se::ModuleHandle module_handle;
executor->LoadModule(module_spec, &module_handle);
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 38b0f8f15b..0e276282e4 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
@@ -35,7 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
@@ -101,7 +101,7 @@ class GpuExecutable : public Executable {
const PointsToSet& GetRootPointsToSet() const;
using BufferAllocToDeviceMemoryMap =
- tensorflow::gtl::FlatMap<BufferAllocation::Index, se::DeviceMemoryBase>;
+ absl::flat_hash_map<BufferAllocation::Index, se::DeviceMemoryBase>;
// Loads the PTX or CUBIN for this executable into `executor` and resolves the
// globals corresponding to constant buffers. Returns a map mapping buffer
diff --git a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
index 4d5d8e99f8..b61f038739 100644
--- a/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/instruction_fusion.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -125,8 +126,8 @@ bool IsIEEEFloatingPointScalarConstant(const HloInstruction* constant) {
}
// Compute the precise number of operands to the new fusion.
- tensorflow::gtl::FlatSet<const HloInstruction*> operands(
- a->operands().begin(), a->operands().end());
+ absl::flat_hash_set<const HloInstruction*> operands(a->operands().begin(),
+ a->operands().end());
operands.insert(b->operands().begin(), b->operands().end());
// If there's an edge between `a` and `b`, don't count it: We're fusing that
// producer -> consumer relationship.
diff --git a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
index c21f76f6eb..835924024b 100644
--- a/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/gpu/multi_output_fusion.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include <utility>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/service/gpu/gpu_fusible.h"
#include "tensorflow/compiler/xla/service/gpu/instruction_fusion.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -101,7 +101,7 @@ bool GpuMultiOutputFusion::IsFusible(HloInstruction* instr) {
int64 GpuMultiOutputFusion::GetProfit(HloInstruction* instr1,
HloInstruction* instr2) {
- tensorflow::gtl::FlatSet<HloInstruction*> in_list;
+ absl::flat_hash_set<HloInstruction*> in_list;
for (auto instr : instr1->operands()) {
if (!IsProfitableOperand(instr)) {
continue;
@@ -148,7 +148,7 @@ bool GpuMultiOutputFusion::DoProducerConsumerMultiOutputFusion() {
bool changed = false;
RecomputeReachability();
- tensorflow::gtl::FlatSet<HloInstruction*> to_fuse;
+ absl::flat_hash_set<HloInstruction*> to_fuse;
// Keep a list of the instructions to fuse after making all the fusion
// decisions. We first aggressively add instructions to potential_fusion_list,
// then filter out instructions that will be no longer fusible because of
diff --git a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
index 8e97774750..c4a0b727cd 100644
--- a/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
+++ b/tensorflow/compiler/xla/service/gpu/nvptx_compiler.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/container/node_hash_map.h"
#include "absl/types/optional.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/executable.h"
@@ -140,10 +141,10 @@ class NVPTXCompiler : public LLVMCompiler {
tensorflow::condition_variable compilation_done_cv_;
};
- // Don't even think about switching this to FlatMap; iterator stability is
- // critical here.
- std::unordered_map<CompilationCacheKey, CompilationCacheValue,
- CompilationCacheHash, CompilationCacheEq>
+ // Don't even think about switching this to flat_hash_map; iterator stability
+ // is critical here.
+ absl::node_hash_map<CompilationCacheKey, CompilationCacheValue,
+ CompilationCacheHash, CompilationCacheEq>
compilation_cache_ GUARDED_BY(mutex_);
TF_DISALLOW_COPY_AND_ASSIGN(NVPTXCompiler);
diff --git a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
index cf9f102d31..375f68a159 100644
--- a/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
+++ b/tensorflow/compiler/xla/service/gpu/partition_assignment.cc
@@ -62,13 +62,8 @@ LaunchDimensions CalculateLaunchDimensions(
//
// <num threads per block> * <max blocks per core> = <max threads per core>
- auto threads_per_core = device_desc.threads_per_core_limit();
- auto blocks_per_core = device_desc.blocks_per_core_limit();
- int64 threads_per_block;
- if (threads_per_core != 0 && blocks_per_core != 0) {
- threads_per_block = device_desc.threads_per_core_limit() /
- device_desc.blocks_per_core_limit();
- } else {
+ int64 threads_per_block = device_desc.threads_per_block_limit();
+ if (threads_per_block == 0) {
static std::atomic<int64> log_count{0};
if (log_count.fetch_add(1) < 8) {
LOG(WARNING) << "Attempting to calculate launch dimensions for GPU "
diff --git a/tensorflow/compiler/xla/service/gpu/stream_assignment.h b/tensorflow/compiler/xla/service/gpu/stream_assignment.h
index c2df83aaa4..52d38b6f20 100644
--- a/tensorflow/compiler/xla/service/gpu/stream_assignment.h
+++ b/tensorflow/compiler/xla/service/gpu/stream_assignment.h
@@ -16,9 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_STREAM_ASSIGNMENT_H_
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace gpu {
@@ -34,7 +34,7 @@ class StreamAssignment {
private:
int stream_count_ = 1; // At least the main stream.
- tensorflow::gtl::FlatMap<const HloInstruction*, int> hlo_to_stream_number_;
+ absl::flat_hash_map<const HloInstruction*, int> hlo_to_stream_number_;
};
// Assigns GPU streams to instructions in `module`.
diff --git a/tensorflow/compiler/xla/service/heap_simulator.cc b/tensorflow/compiler/xla/service/heap_simulator.cc
index 2bd04259c0..b343305554 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator.cc
@@ -18,14 +18,16 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/util.h"
namespace xla {
-using tensorflow::gtl::FlatMap;
-using tensorflow::gtl::FlatSet;
+using absl::flat_hash_map;
+using absl::flat_hash_set;
/*static*/
StatusOr<int64> HeapSimulator::MinimumMemoryForModule(
@@ -56,7 +58,7 @@ StatusOr<int64> HeapSimulator::MinimumMemoryForComputation(
const HloComputation& computation, const HloInstructionSequence& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation) {
TF_ASSIGN_OR_RETURN(
HeapSimulator::Result result,
@@ -88,7 +90,7 @@ StatusOr<HeapSimulator::Result> HeapSimulator::Run(
const HloInstructionSequence& instruction_sequence,
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn, const Options& options,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation) {
HeapSimulator heap(std::move(algorithm), size_fn, options,
/*schedule=*/nullptr, memory_by_computation);
@@ -115,8 +117,10 @@ Status HeapSimulator::RunComputation(
// 'used_buffers' is the reverse map - it tracks which buffers were used by an
// instruction, so that we can remove the instructions from a buffer's live
// set after they are visited.
- FlatMap<const BufferValue*, FlatSet<const HloInstruction*>> live_buffers;
- FlatMap<const HloInstruction*, FlatSet<const BufferValue*>> used_buffers;
+ flat_hash_map<const BufferValue*, flat_hash_set<const HloInstruction*>>
+ live_buffers;
+ flat_hash_map<const HloInstruction*, flat_hash_set<const BufferValue*>>
+ used_buffers;
auto add_user_to_buffer = [this, &live_buffers, &used_buffers](
const HloInstruction* user,
const BufferValue* buffer) {
@@ -213,7 +217,7 @@ Status HeapSimulator::RunComputation(
VLOG(4) << " Removing user " << instruction->name() << " from buffer "
<< operand_buffer->ToString();
auto it = live_buffers.find(operand_buffer);
- FlatSet<const HloInstruction*>* live_set = &it->second;
+ flat_hash_set<const HloInstruction*>* live_set = &it->second;
live_set->erase(instruction);
if (live_set->empty()) {
live_buffers.erase(it);
@@ -235,7 +239,7 @@ Status HeapSimulator::RunComputation(
// that we should assign.
// Make sure each buffer get reused at most once.
- FlatSet<const BufferValue*> reused_buffers;
+ flat_hash_set<const BufferValue*> reused_buffers;
for (const BufferValue* buffer : buffers_defined_by_instruction) {
if (IgnoreBuffer(buffer)) {
continue;
@@ -323,7 +327,7 @@ Status HeapSimulator::RunComputation(
to_free.reserve(live_buffers.size());
for (const auto& buffer_pending : live_buffers) {
const BufferValue* buffer = buffer_pending.first;
- const FlatSet<const HloInstruction*>& pending = buffer_pending.second;
+ const flat_hash_set<const HloInstruction*>& pending = buffer_pending.second;
CHECK_EQ(pending.size(), 1) << *buffer;
CHECK(*pending.begin() == nullptr) << *buffer;
to_free.push_back(buffer);
@@ -345,7 +349,7 @@ HeapSimulator::HeapSimulator(
std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn, const Options& options,
const HloSchedule* schedule,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation)
: no_fragmentation_stats_(absl::make_unique<NoFragmentationStatsHeap>()),
algorithm_(std::move(algorithm)),
@@ -536,7 +540,7 @@ void NoFragmentationStatsHeap::Alloc(const BufferValue* buffer, int64 size,
void NoFragmentationStatsHeap::AccountForSubcomputationMemory(
const HloInstruction* instruction,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
// We only count the memory usage of the largest subcomputation, instead of
// adding them all, because subcomputations won't execute in parallel.
diff --git a/tensorflow/compiler/xla/service/heap_simulator.h b/tensorflow/compiler/xla/service/heap_simulator.h
index 7d6dcc0dc9..b0295a6163 100644
--- a/tensorflow/compiler/xla/service/heap_simulator.h
+++ b/tensorflow/compiler/xla/service/heap_simulator.h
@@ -21,6 +21,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
#include "tensorflow/compiler/xla/service/buffer_value_containers.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
@@ -30,8 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -58,7 +58,7 @@ class HeapSimulator {
// Result represents the result of the heap simulation.
struct Result {
// The assignment of buffers to chunks.
- tensorflow::gtl::FlatMap<const BufferValue*, Chunk> chunk_map;
+ absl::flat_hash_map<const BufferValue*, Chunk> chunk_map;
// The total size in bytes of the heap, containing all assigned chunks.
int64 heap_size = 0;
@@ -100,7 +100,7 @@ class HeapSimulator {
const HloComputation& computation, const HloInstructionSequence& sequence,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation = nullptr);
// Run the heap simulation with the given algorithm, assuming the given
@@ -130,7 +130,7 @@ class HeapSimulator {
const TuplePointsToAnalysis& points_to_analysis,
const BufferValue::SizeFunction& size_fn,
const Options& options = Options(),
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation = nullptr);
private:
@@ -140,7 +140,7 @@ class HeapSimulator {
HeapSimulator(std::unique_ptr<HeapAlgorithm> algorithm,
const BufferValue::SizeFunction& size_fn,
const Options& options, const HloSchedule* schedule = nullptr,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation = nullptr);
~HeapSimulator();
@@ -172,7 +172,7 @@ class HeapSimulator {
// handle subcomputations. It would be good to unify the handling of
// subcomputations, but it's not clear how.
const HloSchedule* schedule_;
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>*
+ const absl::flat_hash_map<const HloComputation*, int64>*
memory_by_computation_;
// In addition to Alloc and Free, the heap simulator exposes a concept of
@@ -193,12 +193,12 @@ class HeapSimulator {
const BufferValue* canonical = nullptr;
int64 refcount = 0;
};
- tensorflow::gtl::FlatMap<const BufferValue*, std::shared_ptr<SharedGroup>>
+ absl::flat_hash_map<const BufferValue*, std::shared_ptr<SharedGroup>>
shared_buffers_;
// Hold some sets for error-checking the sequence of Alloc and Free calls.
- tensorflow::gtl::FlatSet<const BufferValue*> allocated_buffers_;
- tensorflow::gtl::FlatSet<const BufferValue*> freed_buffers_;
+ absl::flat_hash_set<const BufferValue*> allocated_buffers_;
+ absl::flat_hash_set<const BufferValue*> freed_buffers_;
// Debugging information filled in while the heap simulator runs.
HeapSimulatorTrace debug_trace_;
@@ -235,7 +235,7 @@ class HeapAlgorithm {
// analysis, it's not worth making major changes to HeapSimulator now.
virtual void AccountForSubcomputationMemory(
const HloInstruction* instruction,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {}
// Free de-allocates a previously allocated buffer.
@@ -262,7 +262,7 @@ class NoFragmentationStatsHeap : public HeapAlgorithm {
void AccountForSubcomputationMemory(
const HloInstruction* instruction,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) override;
void Free(const BufferValue* buffer, int64 size) override;
@@ -382,8 +382,7 @@ class GlobalDecreasingSizeBestFitHeap : public HeapAlgorithm {
// Free time of the buffer.
int64 end;
};
- tensorflow::gtl::FlatMap<const BufferValue*, BufferInterval>
- buffer_intervals_;
+ absl::flat_hash_map<const BufferValue*, BufferInterval> buffer_intervals_;
};
// A heap algorithm that chooses the best results from other algorithms added to
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 191fbf8194..ea0bced923 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/literal.h"
#include "tensorflow/compiler/xla/service/buffer_value.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/tests/hlo_verified_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace {
@@ -174,7 +174,7 @@ class HeapSimulatorTracker {
// Construct the module sequence grouped by computation.
HloSchedule schedule(module_.get());
- tensorflow::gtl::FlatMap<const HloInstruction*, int> reverse_position;
+ absl::flat_hash_map<const HloInstruction*, int> reverse_position;
for (int i = 0; i < full_module_sequence.size(); ++i) {
const HloInstruction* instruction = full_module_sequence[i];
schedule.GetOrCreateSequence(instruction->parent())
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index caaca16f71..1ea26ddd5b 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -34,7 +34,7 @@ import "tensorflow/compiler/xla/xla_data.proto";
option cc_enable_arenas = true;
// Serialization of HloInstruction.
-// Next ID: 54
+// Next ID: 56
message HloInstructionProto {
reserved 10;
reserved "parameter_name";
@@ -180,6 +180,10 @@ message HloInstructionProto {
// Collective permute field.
repeated SourceTarget source_target_pairs = 52;
+
+ // Sharding for kDomain instructions.
+ xla.OpSharding domain_entry_sharding = 54;
+ xla.OpSharding domain_exit_sharding = 55;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
index 0986da65cb..c3da12e273 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -119,7 +121,7 @@ class BufferValueMap {
}
// Return a set of all the values in the given buffer.
- const tensorflow::gtl::FlatSet<const HloValue*>& GetValuesInBuffer(
+ const absl::flat_hash_set<const HloValue*>& GetValuesInBuffer(
BufferNumber buffer_number) const {
return buffers_.at(buffer_number);
}
@@ -142,7 +144,7 @@ class BufferValueMap {
// Move the given value into the given buffer.
void MoveValueToBuffer(const HloValue& value, BufferNumber buffer_number) {
BufferNumber old_buffer_number = value_to_buffer_number_.at(&value);
- tensorflow::gtl::FlatSet<const HloValue*>& old_value_set =
+ absl::flat_hash_set<const HloValue*>& old_value_set =
buffers_.at(old_buffer_number);
old_value_set.erase(&value);
if (old_value_set.empty()) {
@@ -290,13 +292,11 @@ class BufferValueMap {
const HloDataflowAnalysis& dataflow_;
// A map containing the set of values contained in each buffer.
- tensorflow::gtl::FlatMap<BufferNumber,
- tensorflow::gtl::FlatSet<const HloValue*>>
+ absl::flat_hash_map<BufferNumber, absl::flat_hash_set<const HloValue*>>
buffers_;
// A map indicating which buffer each value is contained in.
- tensorflow::gtl::FlatMap<const HloValue*, BufferNumber>
- value_to_buffer_number_;
+ absl::flat_hash_map<const HloValue*, BufferNumber> value_to_buffer_number_;
// The buffer number of the next buffer to be created.
BufferNumber next_buffer_number_ = 0;
@@ -352,7 +352,7 @@ bool HloAliasAnalysis::InstructionBuffersAreAmbiguous(
bool HloAliasAnalysis::InstructionBuffersAreDistinct(
const HloInstruction* instruction) const {
- tensorflow::gtl::FlatSet<const HloBuffer*> buffers_seen;
+ absl::flat_hash_set<const HloBuffer*> buffers_seen;
for (const auto& pair :
dataflow_analysis_->GetInstructionValueSet(instruction)) {
const HloValueSet& value_set = pair.second;
diff --git a/tensorflow/compiler/xla/service/hlo_alias_analysis.h b/tensorflow/compiler/xla/service/hlo_alias_analysis.h
index e345804537..372f99ff01 100644
--- a/tensorflow/compiler/xla/service/hlo_alias_analysis.h
+++ b/tensorflow/compiler/xla/service/hlo_alias_analysis.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_buffer.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
@@ -110,7 +111,7 @@ class HloAliasAnalysis {
std::unique_ptr<HloDataflowAnalysis> dataflow_analysis_;
// A map indicating which buffer a value is contained in.
- tensorflow::gtl::FlatMap<const HloValue*, HloBuffer*> value_to_buffer_;
+ absl::flat_hash_map<const HloValue*, HloBuffer*> value_to_buffer_;
// A lazily constructed vector containing all HloBuffers sorted by
// HloBuffer::Id.
diff --git a/tensorflow/compiler/xla/service/hlo_buffer.cc b/tensorflow/compiler/xla/service/hlo_buffer.cc
index 6c11a073b7..9c3aa0e64d 100644
--- a/tensorflow/compiler/xla/service/hlo_buffer.cc
+++ b/tensorflow/compiler/xla/service/hlo_buffer.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/hlo_clone_context.h b/tensorflow/compiler/xla/service/hlo_clone_context.h
index 658643b427..24910ca07b 100644
--- a/tensorflow/compiler/xla/service/hlo_clone_context.h
+++ b/tensorflow/compiler/xla/service/hlo_clone_context.h
@@ -18,8 +18,8 @@ limitations under the License.
#include <string>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -73,12 +73,12 @@ class HloCloneContext {
return FindOrDie(computations_, old_computation);
}
- const tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*>&
+ const absl::flat_hash_map<const HloInstruction*, HloInstruction*>&
cloned_instructions() const {
return instructions_;
}
- const tensorflow::gtl::FlatMap<const HloComputation*, HloComputation*>&
+ const absl::flat_hash_map<const HloComputation*, HloComputation*>&
cloned_computations() const {
return computations_;
}
@@ -86,10 +86,8 @@ class HloCloneContext {
private:
HloModule* module_;
string suffix_;
- tensorflow::gtl::FlatMap<const HloInstruction*, HloInstruction*>
- instructions_;
- tensorflow::gtl::FlatMap<const HloComputation*, HloComputation*>
- computations_;
+ absl::flat_hash_map<const HloInstruction*, HloInstruction*> instructions_;
+ absl::flat_hash_map<const HloComputation*, HloComputation*> computations_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index 0e5920af7a..c2041c4667 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -24,6 +24,8 @@ limitations under the License.
#include <sstream>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
@@ -39,7 +41,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -122,30 +123,6 @@ HloInstruction* HloComputation::AddParameter(
return instructions_.back().get();
}
-namespace {
-
-// Returns the new name for a fusion parameter when we change its number.
-//
-// Fusion parameters are named foo.param_1, bar.param_2, etc. We are
-// renumbering the parameters, so replace the final number in the name with
-// the updated value.
-string RenameFusionParameter(const string& original_name, int64 new_param_no) {
- const string param_underscore = ".param_";
- size_t index = original_name.rfind(param_underscore);
- if (index == string::npos) {
- return original_name;
- }
- string after_param = original_name.substr(index + param_underscore.size());
- int64 numeric_suffix;
- if (absl::SimpleAtoi(after_param, &numeric_suffix)) {
- return StrCat(original_name.substr(0, index + param_underscore.size()),
- new_param_no);
- }
- return original_name;
-}
-
-} // namespace
-
Status HloComputation::RemoveParameter(int64 param_no) {
CHECK_GE(param_no, 0);
CHECK_LT(param_no, param_instructions_.size());
@@ -158,11 +135,9 @@ Status HloComputation::RemoveParameter(int64 param_no) {
while (param_no < param_instructions_.size()) {
param_instruction = param_instructions_[param_no];
- string param_name =
- RenameFusionParameter(param_instruction->name(), param_no);
HloInstruction* new_instr =
AddInstructionInternal(HloInstruction::CreateParameter(
- param_no, param_instruction->shape(), param_name));
+ param_no, param_instruction->shape(), StrCat("param_", param_no)));
TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
param_instructions_[param_no] = new_instr;
TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
@@ -186,11 +161,9 @@ Status HloComputation::RemoveUnusedParameters() {
if (removed > 0) {
const int64 param_no = i - removed;
- string param_name =
- RenameFusionParameter(param_instruction->name(), param_no);
- HloInstruction* new_instr =
- AddInstructionInternal(HloInstruction::CreateParameter(
- param_no, param_instruction->shape(), param_name));
+ HloInstruction* new_instr = AddInstructionInternal(
+ HloInstruction::CreateParameter(param_no, param_instruction->shape(),
+ StrCat("param_", param_no)));
TF_RETURN_IF_ERROR(param_instruction->ReplaceAllUsesWith(new_instr));
param_instructions_[param_no] = new_instr;
TF_RETURN_IF_ERROR(RemoveInstruction(param_instruction));
@@ -305,10 +278,9 @@ void HloComputation::set_root_instruction(HloInstruction* new_root_instruction,
namespace {
// Helper which builds a post order of the HLO call graph.
-void ComputeComputationPostOrder(
- HloComputation* computation,
- tensorflow::gtl::FlatSet<HloComputation*>* visited,
- std::vector<HloComputation*>* post_order) {
+void ComputeComputationPostOrder(HloComputation* computation,
+ absl::flat_hash_set<HloComputation*>* visited,
+ std::vector<HloComputation*>* post_order) {
if (visited->insert(computation).second) {
for (auto* instruction : computation->instructions()) {
for (HloComputation* called_computation :
@@ -325,7 +297,7 @@ void ComputeComputationPostOrder(
void HloComputation::ComputeInstructionPostOrder(
const HloComputation::ChannelDependencyMap& channel_dependency_map,
std::vector<HloInstruction*>* post_order, HloInstruction* root,
- tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const {
+ absl::flat_hash_map<HloInstruction*, VisitState>* visited) const {
std::vector<HloInstruction*> dfs_stack;
dfs_stack.push_back(root);
while (!dfs_stack.empty()) {
@@ -422,7 +394,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
std::vector<HloInstruction*> post_order;
post_order.reserve(instruction_count());
std::vector<HloInstruction*> trace_instructions;
- tensorflow::gtl::FlatMap<HloInstruction*, VisitState> visited;
+ absl::flat_hash_map<HloInstruction*, VisitState> visited;
for (auto& instruction : instructions_) {
if (instruction->opcode() == HloOpcode::kTrace) {
// Trace instructions aren't handled by the DFS visitor. Add trace
@@ -443,7 +415,7 @@ std::vector<HloInstruction*> HloComputation::MakeInstructionPostOrder() const {
std::vector<HloComputation*> HloComputation::MakeEmbeddedComputationsList()
const {
- tensorflow::gtl::FlatSet<HloComputation*> visited;
+ absl::flat_hash_set<HloComputation*> visited;
std::vector<HloComputation*> post_order;
// To avoid special handling of this computation, cast away const of
@@ -533,9 +505,9 @@ HloComputationProto HloComputation::ToProto() const {
/* static */ StatusOr<std::unique_ptr<HloComputation>>
HloComputation::CreateFromProto(
const HloComputationProto& proto,
- const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
- tensorflow::gtl::FlatMap<int64, HloInstruction*> instruction_map;
- tensorflow::gtl::FlatMap<HloInstruction*, int64> to_proto_id;
+ const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
+ absl::flat_hash_map<int64, HloInstruction*> instruction_map;
+ absl::flat_hash_map<HloInstruction*, int64> to_proto_id;
std::vector<std::unique_ptr<HloInstruction>> instructions;
int64 parameter_count = 0;
for (const HloInstructionProto& instruction_proto : proto.instructions()) {
@@ -563,6 +535,28 @@ HloComputation::CreateFromProto(
return to_proto_id[a.get()] < to_proto_id[b.get()];
});
+ TF_RETURN_IF_ERROR([&]() -> Status {
+ std::vector<bool> parameters_seen(parameter_count);
+ int parameters_seen_count = 0;
+ for (auto& instruction : instructions) {
+ if (instruction->opcode() == HloOpcode::kParameter) {
+ int64 param_no = instruction->parameter_number();
+ TF_RET_CHECK(param_no >= 0 && param_no < parameter_count)
+ << "Invalid parameter number. Expected [0, " << parameter_count
+ << "), got " << param_no;
+ TF_RET_CHECK(!parameters_seen[param_no])
+ << "Parameter number " << param_no
+ << " already allocated in this computation";
+ parameters_seen[param_no] = true;
+ parameters_seen_count++;
+ }
+ }
+ TF_RET_CHECK(parameters_seen_count == parameter_count)
+ << "Not all parameters in range [0, " << parameter_count
+ << ") were referenced";
+ return Status::OK();
+ }());
+
auto computation = absl::WrapUnique(
new HloComputation(proto.name(), parameter_count, &instructions, root,
/*fusion_instruction=*/nullptr));
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 936a53bd7e..d87ab4bda1 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -25,6 +25,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -40,8 +42,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
@@ -188,7 +188,7 @@ class HloComputation {
// calls.
static StatusOr<std::unique_ptr<HloComputation>> CreateFromProto(
const HloComputationProto& proto,
- const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
+ const absl::flat_hash_map<int64, HloComputation*>& computation_map);
// Gets the instructions in this computation.
//
@@ -414,14 +414,14 @@ class HloComputation {
// cross-replica-sum the union of the dependencies for all participating
// instructions.
using ChannelDependencyMap =
- tensorflow::gtl::FlatMap<int64, absl::InlinedVector<HloInstruction*, 1>>;
+ absl::flat_hash_map<int64, absl::InlinedVector<HloInstruction*, 1>>;
ChannelDependencyMap ComputeChannelDependencies() const;
enum VisitState { kVisiting, kVisited };
void ComputeInstructionPostOrder(
const HloComputation::ChannelDependencyMap& channel_dependency_map,
std::vector<HloInstruction*>* post_order, HloInstruction* root,
- tensorflow::gtl::FlatMap<HloInstruction*, VisitState>* visited) const;
+ absl::flat_hash_map<HloInstruction*, VisitState>* visited) const;
string name_;
int64 unique_id_;
@@ -439,7 +439,7 @@ class HloComputation {
// instruction pointer to location in the list for fast lookup.
using InstructionList = std::list<std::unique_ptr<HloInstruction>>;
InstructionList instructions_;
- tensorflow::gtl::FlatMap<const HloInstruction*, InstructionList::iterator>
+ absl::flat_hash_map<const HloInstruction*, InstructionList::iterator>
instruction_iterators_;
std::vector<HloInstruction*> param_instructions_;
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index b59c9ba3ed..e602107cbe 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal.h"
@@ -34,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/hash/hash.h"
namespace xla {
@@ -137,8 +137,8 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
// HLO instructions are grouped into equivalency classes by using the
// cse_equal predicate defined above. This set holds a representative
// instruction for each class.
- tensorflow::gtl::FlatSet<HloInstruction*, decltype(&CseHash),
- decltype(cse_equal)>
+ absl::flat_hash_set<HloInstruction*, decltype(&CseHash),
+ decltype(cse_equal)>
representatives(/*N=*/computation->instruction_count() + 1, &CseHash,
cse_equal);
for (auto instruction : computation->MakeInstructionPostOrder()) {
diff --git a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
index 6a63681996..44cde4a3d2 100644
--- a/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_dataflow_analysis.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <queue>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
@@ -91,7 +92,7 @@ HloDataflowAnalysis::HloDataflowAnalysis(
bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
const HloInstruction* inst) {
- tensorflow::gtl::FlatSet<const HloInstruction*> visited;
+ absl::flat_hash_set<const HloInstruction*> visited;
absl::InlinedVector<const HloInstruction*, 4> stack;
stack.push_back(inst);
while (!stack.empty()) {
@@ -159,8 +160,8 @@ void HloDataflowAnalysis::MarkValueForDeletion(HloValue::Id value_id) {
void HloDataflowAnalysis::DeleteMarkedValues() {
#ifndef NDEBUG
// Verify that no marked-for-deletion values are in any of the value sets.
- tensorflow::gtl::FlatSet<HloValue::Id> id_set(value_ids_to_delete_.begin(),
- value_ids_to_delete_.end());
+ absl::flat_hash_set<HloValue::Id> id_set(value_ids_to_delete_.begin(),
+ value_ids_to_delete_.end());
for (const auto& pair : value_sets_) {
const HloInstruction* instruction = pair.first;
const InstructionValueSet& instruction_value_set = pair.second;
@@ -673,7 +674,7 @@ bool HloDataflowAnalysis::UpdateInstructionValueSet(
void HloDataflowAnalysis::Propagate() {
std::queue<HloInstruction*> worklist;
- tensorflow::gtl::FlatSet<HloInstruction*> workset;
+ absl::flat_hash_set<HloInstruction*> workset;
auto add_to_worklist = [&worklist, &workset](HloInstruction* instruction) {
if (workset.insert(instruction).second) {
worklist.push(instruction);
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.cc b/tensorflow/compiler/xla/service/hlo_domain_map.cc
index 113fd18eae..6ca1255ede 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.cc
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include <algorithm>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
@@ -106,8 +108,8 @@ Status HloDomainMap::PopulateDomainMetadataMap() {
auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
return a->Matches(*b);
};
- tensorflow::gtl::FlatMap<const DomainMetadata*, int64, decltype(hash),
- decltype(equal)>
+ absl::flat_hash_map<const DomainMetadata*, int64, decltype(hash),
+ decltype(equal)>
domain_metadata(1024, hash, equal);
for (auto& domain : instruction_domains_) {
@@ -216,7 +218,7 @@ bool HloDomainMap::IsDomainInstruction(HloInstruction* instruction) const {
/* static */ std::vector<HloInstruction*>
HloDomainMap::MakeNonDomainInstructions(
- const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
+ const absl::flat_hash_set<HloInstruction*>& instruction_set,
const InstructionOrderMap& instructions_order) {
std::vector<HloInstruction*> instructions;
instructions.reserve(instruction_set.size());
diff --git a/tensorflow/compiler/xla/service/hlo_domain_map.h b/tensorflow/compiler/xla/service/hlo_domain_map.h
index 56b557d7ce..c8d581b746 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_map.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_map.h
@@ -19,14 +19,14 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -77,8 +77,7 @@ class HloDomainMap {
private:
// Map used for representing instruction ordering, i.e.
// order_map[a] < order_map[b] means a must be ordered before b.
- using InstructionOrderMap =
- tensorflow::gtl::FlatMap<const HloInstruction*, int64>;
+ using InstructionOrderMap = absl::flat_hash_map<const HloInstruction*, int64>;
HloDomainMap(string domain_kind) : domain_kind_(std::move(domain_kind)) {}
@@ -111,7 +110,7 @@ class HloDomainMap {
// Out of an instruction set, returns a vector of all the ones which are not
// a kDomain kind.
static std::vector<HloInstruction*> MakeNonDomainInstructions(
- const tensorflow::gtl::FlatSet<HloInstruction*>& instruction_set,
+ const absl::flat_hash_set<HloInstruction*>& instruction_set,
const InstructionOrderMap& instructions_order);
// Populates domain_metadata_id_ that maps each HloInstruction to the unique
@@ -120,8 +119,8 @@ class HloDomainMap {
string domain_kind_;
std::vector<std::unique_ptr<DomainMetadata::Domain>> instruction_domains_;
- tensorflow::gtl::FlatMap<HloInstruction*, int64> instruction_to_domain_;
- tensorflow::gtl::FlatMap<HloInstruction*, int64> domain_metadata_id_;
+ absl::flat_hash_map<HloInstruction*, int64> instruction_to_domain_;
+ absl::flat_hash_map<HloInstruction*, int64> domain_metadata_id_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_domain_metadata.h b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
index 302807f816..d3c83c15ae 100644
--- a/tensorflow/compiler/xla/service/hlo_domain_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_domain_metadata.h
@@ -20,11 +20,11 @@ limitations under the License.
#include <string>
#include <vector>
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
@@ -42,7 +42,7 @@ class DomainMetadata {
// operand/user pathways, without crossing a kDomain instruction of a given
// kind. The reach_set can contain kDomain instructions of other kinds, if
// two domains of different kind intersect each other.
- tensorflow::gtl::FlatSet<HloInstruction*> reach_set;
+ absl::flat_hash_set<HloInstruction*> reach_set;
// The same instructions in reach_set, but purged from kDomain instructions
// and ordered according to their computation graph post-order, i.e.
@@ -55,8 +55,8 @@ class DomainMetadata {
// whose dataflow enters the reach set (domain), while the exit_domains
// contains the set of kDomain instructions whose dataflow exit the reach
// set.
- tensorflow::gtl::FlatSet<HloInstruction*> enter_domains;
- tensorflow::gtl::FlatSet<HloInstruction*> exit_domains;
+ absl::flat_hash_set<HloInstruction*> enter_domains;
+ absl::flat_hash_set<HloInstruction*> exit_domains;
};
virtual ~DomainMetadata() = default;
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.cc b/tensorflow/compiler/xla/service/hlo_evaluator.cc
index d7c39b2778..eec8d242fa 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.cc
@@ -1378,7 +1378,7 @@ Status HloEvaluator::HandleReduce(HloInstruction* reduce) {
"unsupported");
}
}
- return reduce->Visit(typed_visitors_.at(first_element_type).get());
+ return reduce->Visit(typed_visitors_[first_element_type].get());
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator.h b/tensorflow/compiler/xla/service/hlo_evaluator.h
index 6c2662ebae..07f8d0aad4 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator.h
+++ b/tensorflow/compiler/xla/service/hlo_evaluator.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
+#include "absl/container/node_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -134,7 +134,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// Wraps around instruction handling to infer types before dispatching to
// the corresponding typed Visitor.
Status DefaultAction(HloInstruction* hlo) override {
- return hlo->Visit(typed_visitors_.at(hlo->shape().element_type()).get());
+ return hlo->Visit(typed_visitors_[hlo->shape().element_type()].get());
}
Status Preprocess(HloInstruction* hlo) override;
@@ -210,8 +210,8 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
// post-orderring.
// Must be cleared for each evaluation.
// Storing Literal in place require the container to have pointer stability so
- // we cannot use FlatMap any more.
- std::unordered_map<const HloInstruction*, Literal> evaluated_;
+ // we cannot use flat_hash_map any more.
+ absl::node_hash_map<const HloInstruction*, Literal> evaluated_;
private:
template <typename ReturnT, typename NativeT>
@@ -241,12 +241,7 @@ class HloEvaluator : public DfsHloVisitorWithDefault {
}
// Map from a primitive type to its associated (templated) DfsHloVisitor.
- // Note: the hash function here is only needed because current gcc std::hash
- // does not specialize for enum types. This should however be fixed in the
- // future: https://gcc.gnu.org/bugzilla/show_bug.cgi?id=60970#c5
- tensorflow::gtl::FlatMap<PrimitiveType, std::unique_ptr<DfsHloVisitor>,
- std::hash<int>>
- typed_visitors_;
+ std::unique_ptr<DfsHloVisitor> typed_visitors_[PrimitiveType_ARRAYSIZE];
// Caches pointers to input literals, assuming they are in post-order.
// Literals are not owned by this class, and they must outlive the lifetime of
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
index de3d7a1677..ce4cad4235 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
@@ -90,8 +90,9 @@ std::unique_ptr<HloProfilePrinterData> CreateHloProfilePrinterData(
HloInstructionInfo* instruction_info =
computation_info->add_instruction_infos();
instruction_info->set_long_name(hlo->ToString());
- instruction_info->set_short_name(
- hlo->ToString(HloPrintOptions().set_compact_operands(true)));
+ instruction_info->set_short_name(hlo->ToString(
+ HloPrintOptions().set_compact_operands(true).set_print_operand_names(
+ false)));
instruction_info->set_category(hlo->ToCategory());
instruction_info->set_flop_count(cost_analysis.flop_count(*hlo));
instruction_info->set_transcendental_count(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index 23787dbc8a..8bddaa8c96 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -22,6 +22,8 @@ limitations under the License.
#include <utility>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/ascii.h"
@@ -37,14 +39,13 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/human_readable_json.h"
#include "tensorflow/core/platform/logging.h"
@@ -59,8 +60,8 @@ using absl::StrJoin;
/* static */
StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
const HloInstructionProto& proto,
- const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
- const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map) {
+ const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
+ const absl::flat_hash_map<int64, HloComputation*>& computation_map) {
TF_RET_CHECK(!proto.opcode().empty());
TF_ASSIGN_OR_RETURN(HloOpcode opcode, StringToHloOpcode(proto.opcode()));
TF_RET_CHECK(proto.has_shape());
@@ -80,6 +81,20 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
const auto computations = [&computation_map, &proto](int index) {
return computation_map.at(proto.called_computation_ids(index));
};
+
+ TF_RET_CHECK(std::all_of(
+ proto.operand_ids().begin(), proto.operand_ids().end(),
+ [&instruction_map](int64 id) { return instruction_map.contains(id); }))
+ << proto.name() << " instruction contains invalid operand id(s)";
+
+ TF_RET_CHECK(std::all_of(
+ proto.called_computation_ids().begin(),
+ proto.called_computation_ids().end(),
+ [&computation_map](int64 id) { return computation_map.contains(id); }))
+ << proto.name() << " instruction references invalid computation id(s)";
+
+ TF_RETURN_IF_ERROR(ShapeUtil::ValidateShapeWithOptionalLayout(proto.shape()));
+
switch (opcode) {
// Ops migrated to subclasses.
case HloOpcode::kBatchNormTraining:
@@ -266,7 +281,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
<< "Expect 1 called computation for fusion instruction but sees "
<< proto.called_computation_ids_size();
const int64 fusion_id = proto.called_computation_ids(0);
- auto* fused_computation = FindPtrOrNull(computation_map, fusion_id);
+ auto* fused_computation =
+ tensorflow::gtl::FindPtrOrNull(computation_map, fusion_id);
TF_RET_CHECK(fused_computation != nullptr)
<< "No fusion computation with id " << fusion_id;
instruction = CreateFusion(proto.shape(), fusion_kind, all_operands(),
@@ -302,6 +318,8 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
} break;
case HloOpcode::kOutfeed:
TF_RET_CHECK(proto.operand_ids_size() == 2);
+ TF_RETURN_IF_ERROR(
+ ShapeUtil::ValidateShapeWithOptionalLayout(proto.outfeed_shape()));
instruction = CreateOutfeed(proto.outfeed_shape(), operands(0),
operands(1), proto.outfeed_config());
break;
@@ -466,31 +484,34 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
proto.dot_dimension_numbers(), precision_config);
break;
}
- case HloOpcode::kDomain:
+ case HloOpcode::kDomain: {
TF_RET_CHECK(proto.operand_ids_size() == 1)
<< "Domain instruction should have 1 operands but sees "
<< proto.operand_ids_size();
+ TF_RET_CHECK(proto.has_domain_entry_sharding())
+ << "Domain instruction must domain_entry_sharding";
+ TF_RET_CHECK(proto.has_domain_exit_sharding())
+ << "Domain instruction must domain_exit_sharding";
+ TF_ASSIGN_OR_RETURN(
+ HloSharding entry_hlo_sharding,
+ HloSharding::FromProto(proto.domain_entry_sharding()));
+ TF_ASSIGN_OR_RETURN(HloSharding exit_hlo_sharding,
+ HloSharding::FromProto(proto.domain_exit_sharding()));
instruction = absl::make_unique<HloDomainInstruction>(
- proto.shape(), operands(0), /*operand_side_metadata=*/nullptr,
- /*user_side_metadata=*/nullptr);
+ proto.shape(), operands(0),
+ absl::make_unique<ShardingMetadata>(
+ std::make_shared<const HloSharding>(entry_hlo_sharding)),
+ absl::make_unique<ShardingMetadata>(
+ std::make_shared<const HloSharding>(exit_hlo_sharding)));
break;
+ }
default: {
instruction = absl::WrapUnique(new HloInstruction(opcode, proto.shape()));
for (const int64 operand_id : proto.operand_ids()) {
- TF_RET_CHECK(ContainsKey(instruction_map, operand_id))
- << "No instruction with id " << operand_id;
instruction->AppendOperand(instruction_map.at(operand_id));
}
- for (const int64 predecessor_id : proto.control_predecessor_ids()) {
- TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
- << "No instruction with id " << predecessor_id;
- TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
- ->AddControlDependencyTo(instruction.get()));
- }
if (instruction->opcode() != HloOpcode::kFusion) {
for (const int64 computation_id : proto.called_computation_ids()) {
- TF_RET_CHECK(ContainsKey(computation_map, computation_id))
- << "No computation with id " << computation_id;
instruction->called_computations_.push_back(
computation_map.at(computation_id));
}
@@ -502,6 +523,13 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
}
}
+ for (const int64 predecessor_id : proto.control_predecessor_ids()) {
+ TF_RET_CHECK(ContainsKey(instruction_map, predecessor_id))
+ << "No instruction with id " << predecessor_id;
+ TF_RETURN_IF_ERROR(instruction_map.at(predecessor_id)
+ ->AddControlDependencyTo(instruction.get()));
+ }
+
TF_RET_CHECK(!proto.name().empty());
instruction->SetAndSanitizeName(proto.name());
instruction->metadata_ = proto.metadata();
@@ -1432,7 +1460,7 @@ int64 HloInstruction::operand_index(const HloInstruction* target) const {
HloInstruction::InstructionVector HloInstruction::unique_operands() const {
InstructionVector unique;
- tensorflow::gtl::FlatSet<const HloInstruction*> seen;
+ absl::flat_hash_set<const HloInstruction*> seen;
for (HloInstruction* operand : operands()) {
if (seen.insert(operand).second) {
unique.push_back(operand);
@@ -2006,7 +2034,7 @@ string HloInstruction::OperandsToStringWithCanonicalNameMap(
options.is_in_nested_computation()) {
str.push_back(PrintName(
canonical_name_map->LookupOrInsert(operand->name()), options));
- } else if (!options.compact_operands()) {
+ } else if (options.print_operand_names()) {
str.push_back(PrintName(operand->name(), options));
}
StrAppend(out, StrJoin(str, " "));
@@ -2661,14 +2689,14 @@ class HloInstruction::FusionReusesParamElements {
// the value of this parameter, which would save stack space but not allow us
// to finish early if we find a reuse.
static UseKind Compute(int64 i, const HloInstruction& hlo) {
- tensorflow::gtl::FlatMap<const HloInstruction*, UseKind> memoization_cache;
+ absl::flat_hash_map<const HloInstruction*, UseKind> memoization_cache;
return ComputeInternal(i, hlo, &memoization_cache);
}
private:
static UseKind ComputeInternal(
int64 i, const HloInstruction& hlo,
- tensorflow::gtl::FlatMap<const HloInstruction*, UseKind>* cache) {
+ absl::flat_hash_map<const HloInstruction*, UseKind>* cache) {
if (auto hlo_param = DynCast<HloParameterInstruction>(&hlo)) {
if (hlo_param->parameter_number() == i) {
return UseKind::kUse;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index 009bd3bab3..9deed20e5d 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -32,6 +32,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/container/inlined_vector.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
@@ -50,7 +51,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -80,6 +80,7 @@ class HloPrintOptions {
print_backend_config_(true),
compact_operands_(false),
print_operand_shape_(true),
+ print_operand_names_(true),
print_program_shape_(true),
print_percent_(true),
print_control_dependencies_(true),
@@ -107,6 +108,7 @@ class HloPrintOptions {
.set_print_metadata(false)
.set_print_backend_config(false)
.set_compact_operands(true)
+ .set_print_operand_names(false)
.set_print_operand_shape(true)
.set_print_program_shape(false)
.set_print_percent(false)
@@ -144,6 +146,12 @@ class HloPrintOptions {
return *this;
}
+ // If true, the operand names will be printed.
+ HloPrintOptions& set_print_operand_names(bool value) {
+ print_operand_names_ = value;
+ return *this;
+ }
+
// If true, program shape of hlo computations will be printed.
HloPrintOptions& set_print_program_shape(bool value) {
print_program_shape_ = value;
@@ -162,8 +170,8 @@ class HloPrintOptions {
return *this;
}
- // If true, only a part of operands will be printed out, and their names will
- // be omitted (note that in this case the text will not be parsable).
+ // If true, only a part of operands will be printed out (note that in this
+ // case the text will not be parsable).
HloPrintOptions& set_compact_operands(bool value) {
compact_operands_ = value;
return *this;
@@ -197,6 +205,7 @@ class HloPrintOptions {
bool print_backend_config() const { return print_backend_config_; }
bool compact_operands() const { return compact_operands_; }
bool print_operand_shape() const { return print_operand_shape_; }
+ bool print_operand_names() const { return print_operand_names_; }
bool print_program_shape() const { return print_program_shape_; }
bool print_percent() const { return print_percent_; }
bool print_control_dependencies() const {
@@ -215,6 +224,7 @@ class HloPrintOptions {
bool print_backend_config_;
bool compact_operands_;
bool print_operand_shape_;
+ bool print_operand_names_;
bool print_program_shape_;
bool print_percent_;
bool print_control_dependencies_;
@@ -247,7 +257,7 @@ class CanonicalNameMap {
private:
int64 index;
- tensorflow::gtl::FlatMap<string, string> canonical_name_map;
+ absl::flat_hash_map<string, string> canonical_name_map;
};
// HLO instructions are the atomic unit of the high-level compiler's IR.
@@ -350,8 +360,8 @@ class HloInstruction {
// calls.
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
const HloInstructionProto& proto,
- const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
- const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
+ const absl::flat_hash_map<int64, HloInstruction*>& instruction_map,
+ const absl::flat_hash_map<int64, HloComputation*>& computation_map);
// Creates a parameter-retrieving instruction.
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.cc b/tensorflow/compiler/xla/service/hlo_instructions.cc
index cd71bc3323..68d0979f5c 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.cc
+++ b/tensorflow/compiler/xla/service/hlo_instructions.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <deque>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "absl/strings/escaping.h"
#include "absl/strings/str_cat.h"
@@ -27,8 +28,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
+#include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
#include "tensorflow/compiler/xla/window_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
namespace {
@@ -213,6 +214,7 @@ HloSendRecvInstruction::HloSendRecvInstruction(HloOpcode opcode,
HloInstructionProto HloSendRecvInstruction::ToProto() const {
HloInstructionProto proto = HloInstruction::ToProto();
proto.set_channel_id(channel_id_);
+ proto.set_is_host_transfer(is_host_transfer_);
return proto;
}
@@ -1042,7 +1044,8 @@ HloInstruction* HloFusionInstruction::AddFusionOperand(
const int64 param_no = operand_count();
// Name the parameter after the instruction it represents in the outer
// (non-fusion) computation.
- string param_name = StrCat(new_operand->name(), ".param_", param_no);
+ // string param_name = StrCat(new_operand->name(), ".param_", param_no);
+ string param_name = StrCat("param_", param_no);
HloInstruction* fused_parameter =
fused_instructions_computation()->AddParameter(
HloInstruction::CreateParameter(param_no, new_operand->shape(),
@@ -1098,7 +1101,7 @@ void HloFusionInstruction::MergeFusionInstructionIntoMultiOutput(
// Note that we add the unfused instructions to this->parent_ computation.
// This is necessary because the unique_id needs for an instruction and
// it's only added when inserting to the computation.
- tensorflow::gtl::FlatMap<HloInstruction*, HloInstruction*> old_to_new;
+ absl::flat_hash_map<HloInstruction*, HloInstruction*> old_to_new;
std::vector<HloInstruction*> unfused_instructions;
auto computation_to_merge =
instruction_to_merge->fused_instructions_computation();
@@ -1391,7 +1394,7 @@ std::unique_ptr<HloInstruction> HloFusionInstruction::CloneWithNewOperandsImpl(
}
Status HloFusionInstruction::DeduplicateFusionOperands() {
- tensorflow::gtl::FlatMap<const HloInstruction*, int> operand_indices;
+ absl::flat_hash_map<const HloInstruction*, int> operand_indices;
std::vector<int> operands_to_remove;
for (int i = 0; i < operand_count(); ++i) {
auto emplace_result = operand_indices.emplace(operand(i), i);
@@ -2309,4 +2312,23 @@ std::unique_ptr<HloInstruction> HloDomainInstruction::CloneWithNewOperandsImpl(
shape, new_operands[0], operand_side_metadata_->Clone(),
user_side_metadata_->Clone());
}
+
+HloInstructionProto HloDomainInstruction::ToProto() const {
+ HloInstructionProto proto = HloInstruction::ToProto();
+ auto operand_side_sharding =
+ dynamic_cast<const ShardingMetadata*>(operand_side_metadata_.get());
+ if (operand_side_sharding) {
+ *proto.mutable_domain_entry_sharding() =
+ operand_side_sharding->sharding()->ToProto();
+ }
+
+ auto user_side_sharding =
+ dynamic_cast<const ShardingMetadata*>(user_side_metadata_.get());
+ if (user_side_sharding) {
+ *proto.mutable_domain_exit_sharding() =
+ user_side_sharding->sharding()->ToProto();
+ }
+
+ return proto;
+}
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instructions.h b/tensorflow/compiler/xla/service/hlo_instructions.h
index 9c22f5db7e..c929867bb9 100644
--- a/tensorflow/compiler/xla/service/hlo_instructions.h
+++ b/tensorflow/compiler/xla/service/hlo_instructions.h
@@ -1341,6 +1341,9 @@ class HloDomainInstruction : public HloInstruction {
std::unique_ptr<DomainMetadata> operand_side_metadata,
std::unique_ptr<DomainMetadata> user_side_metadata);
+ // Returns a serialized representation of this instruction.
+ HloInstructionProto ToProto() const override;
+
// Retrieves the operand side metadata of a kDomain instruction.
const DomainMetadata& operand_side_metadata() const {
return *operand_side_metadata_;
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
index 6a4e766788..55314d0ae9 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/tuple_points_to_analysis.h"
@@ -74,7 +76,7 @@ class ListScheduler {
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
ListScheduler scheduler(computation, points_to_analysis, size_function,
memory_by_computation);
@@ -99,7 +101,7 @@ class ListScheduler {
ListScheduler(const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation)
: computation_(computation),
points_to_analysis_(points_to_analysis),
@@ -110,7 +112,7 @@ class ListScheduler {
// LogicalBuffer is in an operand of the instruction as indicated by
// points-to analysis.
for (auto* instruction : computation.instructions()) {
- tensorflow::gtl::FlatSet<const LogicalBuffer*> instr_uses;
+ absl::flat_hash_set<const LogicalBuffer*> instr_uses;
for (auto* operand : instruction->operands()) {
points_to_analysis.GetPointsToSet(operand).ForEachElement(
[&](const ShapeIndex& /*index*/,
@@ -234,8 +236,7 @@ class ListScheduler {
// Populate the ready list with instructions which have no operands or
// control predecessors.
- tensorflow::gtl::FlatMap<const HloInstruction*, int64>
- unscheduled_pred_count;
+ absl::flat_hash_map<const HloInstruction*, int64> unscheduled_pred_count;
for (auto* instruction : computation_.instructions()) {
// TODO(b/34466113): Replace this and above with successors() or
// predecessors() when these methods are added to HloInstruction.
@@ -251,8 +252,8 @@ class ListScheduler {
std::multimap<Priority, ReadyListEntry> ready_queue;
// Map of ready instructions to their iterators in ready_queue.
- tensorflow::gtl::FlatMap<const HloInstruction*,
- std::multimap<Priority, ReadyListEntry>::iterator>
+ absl::flat_hash_map<const HloInstruction*,
+ std::multimap<Priority, ReadyListEntry>::iterator>
ready_instructions;
auto add_to_ready_queue = [&](HloInstruction* inst) {
@@ -347,12 +348,11 @@ class ListScheduler {
// Computations are analyzed in post-order. When scheduling an instruction
// that includes subcomputations, such as a while loop, we use this map to
// look up the memory needed by subcomputations.
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation_;
// A map containing the LogicalBuffers that each instruction uses.
- tensorflow::gtl::FlatMap<const HloInstruction*,
- std::vector<const LogicalBuffer*>>
+ absl::flat_hash_map<const HloInstruction*, std::vector<const LogicalBuffer*>>
buffer_uses_;
// A map containing the count of unscheduled HLOs which using a particular
@@ -361,7 +361,7 @@ class ListScheduler {
std::unordered_map<const LogicalBuffer*, int64> unscheduled_use_count_;
// Set of instructions which have been scheduled.
- tensorflow::gtl::FlatSet<const HloInstruction*> scheduled_instructions_;
+ absl::flat_hash_set<const HloInstruction*> scheduled_instructions_;
};
int64 SumLogicalBufferSizes(
@@ -379,7 +379,7 @@ StatusOr<HloInstructionSequence> ScheduleComputationHelper(
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
const MemorySchedulerAlgorithm& algorithm,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
VLOG(2) << "Computation: " << computation.name();
if (algorithm) {
@@ -396,13 +396,13 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
// These variables are a hack to prevent overflows.
int64 cumulative_total_size = 0;
int64 total_hlos = computation.parent()->instruction_count();
- tensorflow::gtl::FlatMap<const HloInstruction*, int64> extra_users;
- tensorflow::gtl::FlatMap<const HloInstruction*, int64> total_sizes;
+ absl::flat_hash_map<const HloInstruction*, int64> extra_users;
+ absl::flat_hash_map<const HloInstruction*, int64> total_sizes;
for (const HloInstruction* hlo : computation.MakeInstructionPostOrder()) {
if (ListScheduler::IgnoreInstruction(*hlo)) {
extra_users[hlo] = 0;
@@ -419,7 +419,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
points_to_analysis.GetBuffersDefinedByInstruction(hlo), size_function);
total_sizes[hlo] = logical_buffer_size;
cumulative_total_size += logical_buffer_size;
- tensorflow::gtl::FlatSet<const HloInstruction*> unique_operands(
+ absl::flat_hash_set<const HloInstruction*> unique_operands(
hlo->operands().begin(), hlo->operands().end());
for (const HloInstruction* operand : unique_operands) {
extra_users[hlo] += extra_users[operand];
@@ -467,7 +467,7 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
return ListScheduler::Run(computation, points_to_analysis, size_function,
memory_by_computation);
@@ -477,7 +477,7 @@ StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
return HloInstructionSequence(computation.MakeInstructionPostOrder());
}
@@ -486,7 +486,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation) {
// We try a few schedulers and choose whichever returns a lower min-memory,
// not accounting for fragmentation.
@@ -549,7 +549,7 @@ StatusOr<HloSchedule> ScheduleModule(
HloSchedule schedule(&module);
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(&module));
- tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
+ absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
for (const auto* computation : module.MakeComputationPostOrder()) {
if (!computation->IsFusionComputation()) {
TF_ASSIGN_OR_RETURN(HloInstructionSequence computation_sequence,
@@ -577,7 +577,7 @@ StatusOr<HloInstructionSequence> ScheduleComputation(
CHECK(!computation.IsFusionComputation());
TF_ASSIGN_OR_RETURN(std::unique_ptr<TuplePointsToAnalysis> points_to_analysis,
TuplePointsToAnalysis::Run(computation.parent()));
- tensorflow::gtl::FlatMap<const HloComputation*, int64> empty_map;
+ absl::flat_hash_map<const HloComputation*, int64> empty_map;
return ScheduleComputationHelper(computation, *points_to_analysis,
size_function, nullptr, empty_map);
}
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
index 9964c6fdd7..a4c1d3db81 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_ordering.h"
@@ -37,7 +38,7 @@ namespace xla {
typedef std::function<StatusOr<HloInstructionSequence>(
const HloComputation&, const TuplePointsToAnalysis&,
const LogicalBuffer::SizeFunction&,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&)>
+ const absl::flat_hash_map<const HloComputation*, int64>&)>
MemorySchedulerAlgorithm;
// List scheduler
@@ -45,7 +46,7 @@ StatusOr<HloInstructionSequence> ListMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation);
// DFS-order scheduler
@@ -53,7 +54,7 @@ StatusOr<HloInstructionSequence> DFSMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation);
// Naive Post Order scheduler
@@ -61,7 +62,7 @@ StatusOr<HloInstructionSequence> PostOrderMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation);
// The default scheduling algorithm. Runs both the list scheduler
@@ -71,7 +72,7 @@ StatusOr<HloInstructionSequence> DefaultMemoryScheduler(
const HloComputation& computation,
const TuplePointsToAnalysis& points_to_analysis,
const LogicalBuffer::SizeFunction& size_function,
- const tensorflow::gtl::FlatMap<const HloComputation*, int64>&
+ const absl::flat_hash_map<const HloComputation*, int64>&
memory_by_computation);
// Returns an HloSchedule which seeks to minimize the memory required for
diff --git a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
index 1b9e9bfc77..5a9fccc7dd 100644
--- a/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_memory_scheduler_test.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include <string>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/heap_simulator.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_dce.h"
@@ -247,7 +248,7 @@ TEST_F(HloSchedulingTest, ListAccountsForSubcomputations) {
EXPECT_TRUE(ordering.ExecutesBefore(bcast, add));
EXPECT_TRUE(ordering.ExecutesBefore(transpose, add));
- tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
+ absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
memory_by_computation[cond_computation] = 17;
memory_by_computation[body_computation] = 16;
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
@@ -409,7 +410,7 @@ TEST_F(HloSchedulingTest, HeapSimulatorAccountsForSubcomputations) {
EXPECT_EQ(module->entry_computation()->instruction_count(),
schedule.sequence(module->entry_computation()).size());
- tensorflow::gtl::FlatMap<const HloComputation*, int64> memory_by_computation;
+ absl::flat_hash_map<const HloComputation*, int64> memory_by_computation;
memory_by_computation[cond_computation] = 17;
memory_by_computation[body_computation] = 16;
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis =
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index b3949f3a6d..7527e35c95 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -23,6 +23,8 @@ limitations under the License.
#include <utility>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -285,8 +287,8 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
<< ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
<< ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
- tensorflow::gtl::FlatMap<int64, HloComputation*> computation_map;
- tensorflow::gtl::FlatMap<HloComputation*, int64> to_proto_id;
+ absl::flat_hash_map<int64, HloComputation*> computation_map;
+ absl::flat_hash_map<HloComputation*, int64> to_proto_id;
std::vector<std::unique_ptr<HloComputation>> computations;
HloComputation* entry = nullptr;
for (const HloComputationProto& computation_proto : proto.computations()) {
@@ -327,10 +329,10 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
// Because we didn't uniquify the names or the ids, double-check that the
// instruction and computation names and ids are unique from the proto.
- tensorflow::gtl::FlatSet<string> computation_names;
- tensorflow::gtl::FlatSet<string> instruction_names;
- tensorflow::gtl::FlatSet<int> computation_ids;
- tensorflow::gtl::FlatSet<int> instruction_ids;
+ absl::flat_hash_set<string> computation_names;
+ absl::flat_hash_set<string> instruction_names;
+ absl::flat_hash_set<int> computation_ids;
+ absl::flat_hash_set<int> instruction_ids;
for (HloComputation* computation : module->computations()) {
TF_RET_CHECK(!ContainsKey(computation_names, computation->name()))
<< "Computation name is not unique: " << computation->name();
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
index 278d94cdd3..0311b73207 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_metadata.h
@@ -22,6 +22,7 @@ limitations under the License.
#include <unordered_set>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -30,7 +31,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -250,25 +250,25 @@ class HloModuleGroupMetadata {
std::vector<std::unique_ptr<std::vector<HloInstruction*>>> companion_sets_;
// Map from each companion while instruction to the index into companion_set_.
- tensorflow::gtl::FlatMap<const HloInstruction*, int64> companion_set_index_;
+ absl::flat_hash_map<const HloInstruction*, int64> companion_set_index_;
// Map from computation to the instruction using it (a kWhile, kConditional).
- tensorflow::gtl::FlatMap<const HloComputation*, TrackedInstruction>
+ absl::flat_hash_map<const HloComputation*, TrackedInstruction>
tracked_instructions_;
// Maps tracked instructions (kWhile, kConditional, kCall, ...) to the set of
// communicating instructions within the proper called computation(s).
- tensorflow::gtl::FlatMap<HloInstruction*, std::vector<HloInstruction*>>
+ absl::flat_hash_map<HloInstruction*, std::vector<HloInstruction*>>
tracked_instructions_comms_;
// All channels in the module.
std::vector<Channel> channels_;
// Map from channel ids to the index in channels_.
- tensorflow::gtl::FlatMap<int64, int64> channel_id_map_;
+ absl::flat_hash_map<int64, int64> channel_id_map_;
// Map from all-reduce ids to the all reduce instructions.
- tensorflow::gtl::FlatMap<int64, std::vector<HloInstruction*>> all_reduce_map_;
+ absl::flat_hash_map<int64, std::vector<HloInstruction*>> all_reduce_map_;
// The maximum channel id used in the module group.
int64 max_channel_id_ = -1;
@@ -276,7 +276,7 @@ class HloModuleGroupMetadata {
// The modules that this metadata was built from.
const std::vector<HloModule*>& modules_;
- tensorflow::gtl::FlatMap<HloModule*, std::unique_ptr<TuplePointsToAnalysis>>
+ absl::flat_hash_map<HloModule*, std::unique_ptr<TuplePointsToAnalysis>>
points_to_analyses_;
};
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.cc b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
index d83ee71490..fddeb5f0a2 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.cc
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -42,7 +42,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalPredecessors(
HloInstruction* instruction) {
std::vector<HloInstruction*>
predecessors; // Use a vector to avoid non-determinism.
- tensorflow::gtl::FlatSet<HloInstruction*> unique;
+ absl::flat_hash_set<HloInstruction*> unique;
// Adds to the unique predecessors list; if the predecessors is a companion
// instruction, also add companion instructions; if the predecessors is a
@@ -119,7 +119,7 @@ std::vector<HloInstruction*> HloModuleGroupUtil::GlobalSuccessors(
HloInstruction* instruction) {
std::vector<HloInstruction*>
successors; // Use a vector to avoid non-determinism.
- tensorflow::gtl::FlatSet<HloInstruction*> unique;
+ absl::flat_hash_set<HloInstruction*> unique;
// Adds to the unique successors list; if the successor is a companion
// instruction, also add companion instructions; if the successor is a
diff --git a/tensorflow/compiler/xla/service/hlo_module_group_util.h b/tensorflow/compiler/xla/service/hlo_module_group_util.h
index 309c23045d..f21b44bcd9 100644
--- a/tensorflow/compiler/xla/service/hlo_module_group_util.h
+++ b/tensorflow/compiler/xla/service/hlo_module_group_util.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <memory>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -87,7 +87,7 @@ class HloModuleGroupUtil {
// * visit_state: map from each instruction to its visit state.
// * visit_function: function called when each instruction group.
// * root: the root instruction of the traversal.
- using VisitStates = tensorflow::gtl::FlatMap<HloInstruction*, VisitState>;
+ using VisitStates = absl::flat_hash_map<HloInstruction*, VisitState>;
Status VisitTopologicalOrder(VisitStates* visit_state,
const VisitFunction& visit_function,
HloInstruction* root);
diff --git a/tensorflow/compiler/xla/service/hlo_opcode.cc b/tensorflow/compiler/xla/service/hlo_opcode.cc
index 2d4e38589f..4551a1c2e2 100644
--- a/tensorflow/compiler/xla/service/hlo_opcode.cc
+++ b/tensorflow/compiler/xla/service/hlo_opcode.cc
@@ -14,9 +14,9 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -31,7 +31,7 @@ string HloOpcodeString(HloOpcode opcode) {
}
StatusOr<HloOpcode> StringToHloOpcode(const string& opcode_name) {
- static auto* opcode_map = new tensorflow::gtl::FlatMap<string, HloOpcode>({
+ static auto* opcode_map = new absl::flat_hash_map<string, HloOpcode>({
#define STRING_TO_OPCODE_ENTRY(enum_name, opcode_name, ...) \
{opcode_name, HloOpcode::enum_name},
HLO_OPCODE_LIST(STRING_TO_OPCODE_ENTRY)
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.cc b/tensorflow/compiler/xla/service/hlo_ordering.cc
index f1dc08bafa..23d41d91d6 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering.cc
@@ -92,14 +92,18 @@ bool HloOrdering::ExecutesBefore(const HloInstruction* a,
}
bool HloOrdering::IsDefinedBefore(const HloValue& a, const HloValue& b) const {
- // If 'b' is an entry param then 'a' cannot be defined before 'b' because 'b'
- // is live into the module.
+ // Entry parameter should always be defined before other instructions.
const HloModule* module = b.defining_instruction()->parent()->parent();
if (b.defining_instruction()->parent() == module->entry_computation() &&
b.defining_instruction()->opcode() == HloOpcode::kParameter) {
return false;
}
+ if (a.defining_instruction()->parent() == module->entry_computation() &&
+ a.defining_instruction()->opcode() == HloOpcode::kParameter) {
+ return true;
+ }
+
// Phi values require special handling. Because XLA does not have a phi
// instruction, the definition instruction of the phis values are
// placeholders: either the subcomputation parameter (body or condition) or
@@ -316,7 +320,7 @@ string PredecessorHloOrdering::ToStringHelper(const string& name) const {
for (auto predecessor : all) {
if (predecessors_.at(computation)
->IsReachable(predecessor, instruction)) {
- pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
+ pieces.push_back(absl::StrFormat(" %s", predecessor->name()));
}
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_ordering.h b/tensorflow/compiler/xla/service/hlo_ordering.h
index b0361c3f02..66313492eb 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering.h
+++ b/tensorflow/compiler/xla/service/hlo_ordering.h
@@ -20,6 +20,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_dataflow_analysis.h"
@@ -28,7 +29,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_schedule.h"
#include "tensorflow/compiler/xla/service/hlo_value.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -120,8 +120,8 @@ class PredecessorHloOrdering : public HloOrdering {
// predecessors. An instruction is an element of its own predecessor set.
//
// Subclasses should fill this in to define the desired ordering.
- tensorflow::gtl::FlatMap<const HloComputation*,
- std::unique_ptr<HloReachabilityMap>>
+ absl::flat_hash_map<const HloComputation*,
+ std::unique_ptr<HloReachabilityMap>>
predecessors_;
};
@@ -204,7 +204,7 @@ class SequentialHloOrdering : public HloOrdering {
// this map so more than one instruction may have the same position
// value. This is not a problem because ExecutesBefore also verifies
// instructions are in the same computation.
- tensorflow::gtl::FlatMap<const HloInstruction*, int> order_position_;
+ absl::flat_hash_map<const HloInstruction*, int> order_position_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_ordering_test.cc b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
index 00970bcda3..b045adc964 100644
--- a/tensorflow/compiler/xla/service/hlo_ordering_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_ordering_test.cc
@@ -174,6 +174,26 @@ TEST_F(HloOrderingTest, InstructionsInWhileComputations) {
EXPECT_FALSE(ordering.ExecutesBefore(body_param, cond_param));
}
+TEST_F(HloOrderingTest, ParametersDefinedBeforeOthers) {
+ // Entry parameter should always be defined before other instruction.
+ auto module = CreateNewModule();
+ const Shape scalar_shape = ShapeUtil::MakeShape(xla::F32, {});
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(LiteralUtil::CreateR0<float>(1.0)));
+ auto param = builder.AddInstruction(
+ HloInstruction::CreateParameter(0, scalar_shape, "param"));
+ module->AddEntryComputation(builder.Build());
+ TF_ASSERT_OK_AND_ASSIGN(auto dataflow,
+ HloDataflowAnalysis::Run(*module, /*ssa_form=*/true));
+
+ DependencyHloOrdering ordering(module.get());
+ EXPECT_TRUE(ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(param),
+ dataflow->GetValueDefinedAt(constant)));
+ EXPECT_TRUE(!ordering.IsDefinedBefore(dataflow->GetValueDefinedAt(constant),
+ dataflow->GetValueDefinedAt(param)));
+}
+
TEST_F(HloOrderingTest, ValuesInWhileComputations) {
// Tests the ordering of values (defined by dataflow analysis) in the body and
// condition of a while instruction. HLO code:
diff --git a/tensorflow/compiler/xla/service/hlo_parser.cc b/tensorflow/compiler/xla/service/hlo_parser.cc
index 25b70740e3..5a125b4c08 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser.cc
@@ -80,17 +80,23 @@ class HloParser {
StatusOr<PaddingConfig> ParsePaddingConfigOnly();
// Stand-alone parsing utility for a single instruction worth of text.
- Status ParseSingleInstruction(HloComputation::Builder* builder,
- string* root_name);
+ Status ParseSingleInstruction(HloModule* module);
private:
- // Locates an instruction with the given name in the instruction_pool_ or
+ using InstrNameTable =
+ std::unordered_map<string, std::pair<HloInstruction*, LocTy>>;
+
+ // Returns the map from the instruction name to the instruction itself and its
+ // location in the current scope.
+ InstrNameTable& current_name_table() { return scoped_name_tables_.back(); }
+
+ // Locates an instruction with the given name in the current_name_table() or
// returns nullptr.
//
- // If the missing_instruction_hook_ is registered and a "shape" is provided,
- // the hook will be called and may satisfy the request for the given
- // instruction. This is useful when we reify parameters as they're resolved;
- // i.e. for ParseSingleInstruction.
+ // When the name is not found or name is empty, if create_missing_instruction_
+ // hook is registered and a "shape" is provided, the hook will be called to
+ // create an instruction. This is useful when we reify parameters as they're
+ // resolved; i.e. for ParseSingleInstruction.
std::pair<HloInstruction*, LocTy>* FindInstruction(
const string& name, const optional<Shape>& shape = nullopt);
@@ -98,9 +104,11 @@ class HloParser {
bool ParseHloModule(HloModule* module);
bool ParseComputations(HloModule* module);
bool ParseComputation(HloComputation** entry_computation);
- bool ParseInstructionList(HloComputation::Builder* builder,
- string* root_name);
+ bool ParseInstructionList(HloComputation** computation,
+ const string& computation_name);
bool ParseInstruction(HloComputation::Builder* builder, string* root_name);
+ bool ParseInstruciontRhs(HloComputation::Builder* builder, const string& name,
+ LocTy name_loc);
bool ParseControlPredecessors(HloInstruction* instruction);
bool ParseLiteral(Literal* literal, const Shape& shape);
bool ParseTupleLiteral(Literal* literal, const Shape& shape);
@@ -281,23 +289,47 @@ class HloParser {
bool AddComputation(const string& name, HloComputation* computation,
LocTy name_loc);
- // The map from the instruction/computation name to the
- // instruction/computation itself and it's location. This does not own the
- // pointers.
- std::unordered_map<string, std::pair<HloInstruction*, LocTy>>
- instruction_pool_;
+ HloLexer lexer_;
+
+ // A stack for the instruction names. The top of the stack stores the
+ // instruction name table for the current scope.
+ //
+ // A instruction's name is unique among its scope (i.e. its parent
+ // computation), but it's not necessarily unique among all computations in the
+ // module. When there are multiple levels of nested computations, the same
+ // name could appear in both an outer computation and an inner computation. So
+ // we need a stack to make sure a name is only visible within its scope,
+ std::vector<InstrNameTable> scoped_name_tables_;
+
+ // A helper class which pushes and pops to an InstrNameTable stack via RAII.
+ class Scope {
+ public:
+ explicit Scope(std::vector<InstrNameTable>* scoped_name_tables)
+ : scoped_name_tables_(scoped_name_tables) {
+ scoped_name_tables_->emplace_back();
+ }
+ ~Scope() { scoped_name_tables_->pop_back(); }
+
+ private:
+ std::vector<InstrNameTable>* scoped_name_tables_;
+ };
+
+ // Map from the computation name to the computation itself and its location.
std::unordered_map<string, std::pair<HloComputation*, LocTy>>
computation_pool_;
- HloLexer lexer_;
std::vector<std::unique_ptr<HloComputation>> computations_;
std::vector<string> error_;
- // Function that gets invoked when we try to resolve an instruction
- // instruction_pool_ but fail to do so.
- std::function<std::pair<HloInstruction*, LocTy>*(string,
- const optional<Shape>&)>
- missing_instruction_hook_;
+ // When an operand name cannot be resolved, this function is called to create
+ // a parameter instruction with the given name and shape. It registers the
+ // name, instruction, and a placeholder location in the name table. It returns
+ // the newly-created instruction and the placeholder location. If `name` is
+ // empty, this should create the parameter with a generated name. This is
+ // supposed to be set and used only in ParseSingleInstruction.
+ std::function<std::pair<HloInstruction*, LocTy>*(const string& name,
+ const Shape& shape)>
+ create_missing_instruction_;
};
bool SplitToInt64s(absl::string_view s, char delim, std::vector<int64>* out) {
@@ -351,11 +383,21 @@ bool HloParser::Run(HloModule* module) {
std::pair<HloInstruction*, HloParser::LocTy>* HloParser::FindInstruction(
const string& name, const optional<Shape>& shape) {
- std::pair<HloInstruction*, LocTy>* instr =
- tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ std::pair<HloInstruction*, LocTy>* instr = nullptr;
+ if (!name.empty()) {
+ instr = tensorflow::gtl::FindOrNull(current_name_table(), name);
+ }
+
// Potentially call the missing instruction hook.
- if (instr == nullptr && missing_instruction_hook_ != nullptr) {
- return missing_instruction_hook_(name, shape);
+ if (instr == nullptr && create_missing_instruction_ != nullptr &&
+ scoped_name_tables_.size() == 1) {
+ if (!shape.has_value()) {
+ Error(lexer_.GetLoc(),
+ "Operand had no shape in HLO text; cannot create parameter for "
+ "single-instruction module.");
+ return nullptr;
+ }
+ return create_missing_instruction_(name, *shape);
}
return instr;
}
@@ -439,7 +481,6 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
if (!ParseName(&name)) {
return false;
}
- auto builder = absl::make_unique<HloComputation::Builder>(name);
LocTy shape_loc = nullptr;
Shape shape;
@@ -447,40 +488,21 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
return false;
}
- string root_name;
- if (!ParseInstructionList(builder.get(), &root_name)) {
+ HloComputation* computation = nullptr;
+ if (!ParseInstructionList(&computation, name)) {
return false;
}
- std::pair<HloInstruction*, LocTy>* root_node = FindInstruction(root_name);
- // This means some instruction was marked as ROOT but we didn't find it in the
- // pool, which should not happen.
- if (!root_name.empty() && root_node == nullptr) {
- LOG(FATAL) << "instruction " << root_name
- << " was marked as ROOT but the parser has not seen it before";
- }
-
- HloInstruction* root = root_node == nullptr ? nullptr : root_node->first;
- // Now root can be either an existing instruction or a nullptr. If it's a
- // nullptr, the implementation of Builder will set the last instruction as
- // root instruction.
- computations_.emplace_back(builder->Build(root));
- HloComputation* computation = computations_.back().get();
-
- if (!root) {
- root = computation->root_instruction();
- } else {
- CHECK_EQ(root, computation->root_instruction());
- }
-
// If param_list_to_shape was present, check compatibility.
- if (shape_loc != nullptr && !ShapeUtil::Compatible(root->shape(), shape)) {
+ if (shape_loc != nullptr &&
+ !ShapeUtil::Compatible(computation->root_instruction()->shape(), shape)) {
return Error(
shape_loc,
- StrCat("Shape of computation ", name, ", ",
- ShapeUtil::HumanString(shape),
- ", is not compatible with that of its root instruction ",
- root_name, ", ", ShapeUtil::HumanString(root->shape())));
+ StrCat(
+ "Shape of computation ", name, ", ", ShapeUtil::HumanString(shape),
+ ", is not compatible with that of its root instruction ",
+ computation->root_instruction()->name(), ", ",
+ ShapeUtil::HumanString(computation->root_instruction()->shape())));
}
if (is_entry_computation) {
@@ -489,43 +511,62 @@ bool HloParser::ParseComputation(HloComputation** entry_computation) {
}
*entry_computation = computation;
}
- instruction_pool_.clear();
return AddComputation(name, computation, name_loc);
}
// instruction_list ::= '{' instruction_list1 '}'
// instruction_list1 ::= (instruction)+
-bool HloParser::ParseInstructionList(HloComputation::Builder* builder,
- string* root_name) {
+bool HloParser::ParseInstructionList(HloComputation** computation,
+ const string& computation_name) {
+ Scope scope(&scoped_name_tables_);
+ HloComputation::Builder builder(computation_name);
if (!ParseToken(TokKind::kLbrace,
"expects '{' at the beginning of instruction list.")) {
return false;
}
+ string root_name;
do {
- if (!ParseInstruction(builder, root_name)) {
+ if (!ParseInstruction(&builder, &root_name)) {
return false;
}
} while (lexer_.GetKind() != TokKind::kRbrace);
- return ParseToken(TokKind::kRbrace,
- "expects '}' at the end of instruction list.");
+ if (!ParseToken(TokKind::kRbrace,
+ "expects '}' at the end of instruction list.")) {
+ return false;
+ }
+ HloInstruction* root = nullptr;
+ if (!root_name.empty()) {
+ std::pair<HloInstruction*, LocTy>* root_node =
+ tensorflow::gtl::FindOrNull(current_name_table(), root_name);
+
+ // This means some instruction was marked as ROOT but we didn't find it in
+ // the pool, which should not happen.
+ if (root_node == nullptr) {
+ LOG(FATAL) << "instruction " << root_name
+ << " was marked as ROOT but the parser has not seen it before";
+ }
+ root = root_node->first;
+ }
+
+ // Now root can be either an existing instruction or a nullptr. If it's a
+ // nullptr, the implementation of Builder will set the last instruction as
+ // the root instruction.
+ computations_.emplace_back(builder.Build(root));
+ *computation = computations_.back().get();
+ return true;
}
// instruction ::= ('ROOT')? name '=' shape opcode operands (attribute)*
bool HloParser::ParseInstruction(HloComputation::Builder* builder,
string* root_name) {
string name;
- Shape shape;
- HloOpcode opcode;
- std::vector<HloInstruction*> operands;
-
LocTy maybe_root_loc = lexer_.GetLoc();
bool is_root = EatIfPresent(TokKind::kw_ROOT);
const LocTy name_loc = lexer_.GetLoc();
if (!ParseName(&name) ||
- !ParseToken(TokKind::kEqual, "expects '=' in instruction") ||
- !ParseShape(&shape) || !ParseOpcode(&opcode)) {
+ !ParseToken(TokKind::kEqual, "expects '=' in instruction")) {
return false;
}
@@ -536,6 +577,19 @@ bool HloParser::ParseInstruction(HloComputation::Builder* builder,
*root_name = name;
}
+ return ParseInstruciontRhs(builder, name, name_loc);
+}
+
+bool HloParser::ParseInstruciontRhs(HloComputation::Builder* builder,
+ const string& name, LocTy name_loc) {
+ Shape shape;
+ HloOpcode opcode;
+ std::vector<HloInstruction*> operands;
+
+ if (!ParseShape(&shape) || !ParseOpcode(&opcode)) {
+ return false;
+ }
+
// Add optional attributes.
std::unordered_map<string, AttrConfig> attrs;
optional<OpSharding> sharding;
@@ -2146,7 +2200,20 @@ bool HloParser::ParseOperands(std::vector<HloInstruction*>* operands) {
}
}
if (!ParseName(&name)) {
- return false;
+ // When parsing a single instruction (as opposed to a whole module), an
+ // HLO may have one or more operands with a shape but no name:
+ //
+ // foo = add(f32[10], f32[10])
+ //
+ // create_missing_instruction_ is always non-null when parsing a single
+ // instruction, and is responsible for creating kParameter instructions
+ // for these operands.
+ if (shape.has_value() && create_missing_instruction_ != nullptr &&
+ scoped_name_tables_.size() == 1) {
+ name = "";
+ } else {
+ return false;
+ }
}
std::pair<HloInstruction*, LocTy>* instruction =
FindInstruction(name, shape);
@@ -2299,9 +2366,17 @@ bool HloParser::ParseAttributeHelper(
return true;
}
case AttrTy::kHloComputation: {
- HloComputation* result;
- if (!ParseComputationName(&result)) {
- return false;
+ HloComputation* result = nullptr;
+ if (lexer_.GetKind() == TokKind::kLbrace) {
+ // This means it is a nested computation.
+ if (!ParseInstructionList(&result, /*computation_name=*/"_")) {
+ return false;
+ }
+ } else {
+ // This means it is a computation name.
+ if (!ParseComputationName(&result)) {
+ return false;
+ }
}
static_cast<optional<HloComputation*>*>(attr_out_ptr)->emplace(result);
return true;
@@ -3134,7 +3209,7 @@ bool HloParser::EatIfPresent(TokKind kind) {
bool HloParser::AddInstruction(const string& name, HloInstruction* instruction,
LocTy name_loc) {
- auto result = instruction_pool_.insert({name, {instruction, name_loc}});
+ auto result = current_name_table().insert({name, {instruction, name_loc}});
if (!result.second) {
Error(name_loc, StrCat("instruction already exists: ", name));
return Error(/*loc=*/result.first->second.second,
@@ -3204,36 +3279,51 @@ StatusOr<PaddingConfig> HloParser::ParsePaddingConfigOnly() {
return padding_config;
}
-Status HloParser::ParseSingleInstruction(HloComputation::Builder* builder,
- string* root_name) {
- TF_RET_CHECK(missing_instruction_hook_ == nullptr);
+Status HloParser::ParseSingleInstruction(HloModule* module) {
+ TF_RET_CHECK(create_missing_instruction_ == nullptr);
+ TF_RET_CHECK(scoped_name_tables_.empty());
+ HloComputation::Builder builder(module->name());
// The missing instruction hook we register creates the shaped instruction on
// the fly as a parameter and returns it.
int64 parameter_count = 0;
- missing_instruction_hook_ =
- [this, builder, &parameter_count](
- string name,
- const optional<Shape>& shape) -> std::pair<HloInstruction*, LocTy>* {
- if (!shape.has_value()) {
- Error(lexer_.GetLoc(),
- StrCat("Operand ", name,
- " had no shape in HLO text; cannot create parameter for "
- "single-instruction module."));
- return nullptr;
- }
- HloInstruction* parameter = builder->AddInstruction(
- HloInstruction::CreateParameter(parameter_count++, *shape, name));
- instruction_pool_[name] = {parameter, lexer_.GetLoc()};
- return tensorflow::gtl::FindOrNull(instruction_pool_, name);
+ create_missing_instruction_ =
+ [this, &builder, &parameter_count](
+ const string& name,
+ const Shape& shape) -> std::pair<HloInstruction*, LocTy>* {
+ string new_name = name.empty() ? StrCat("_", parameter_count) : name;
+ HloInstruction* parameter = builder.AddInstruction(
+ HloInstruction::CreateParameter(parameter_count++, shape, new_name));
+ current_name_table()[new_name] = {parameter, lexer_.GetLoc()};
+ return tensorflow::gtl::FindOrNull(current_name_table(), new_name);
};
// Prime the lexer.
lexer_.Lex();
// Parse the instruction with the registered hook.
- if (!ParseInstruction(builder, root_name)) {
- return InvalidArgument("Syntax error:\n%s", GetError());
+ Scope scope(&scoped_name_tables_);
+ if (CanBeShape()) {
+ // This means that the instruction's left-hand side is probably omitted,
+ // e.g.
+ //
+ // f32[10] fusion(...), calls={...}
+ if (!ParseInstruciontRhs(&builder, module->name(), lexer_.GetLoc())) {
+ return InvalidArgument("Syntax error:\n%s", GetError());
+ }
+ } else {
+ // This means that the instruction's left-hand side might exist, e.g.
+ //
+ // foo = f32[10] fusion(...), calls={...}
+ string root_name;
+ if (!ParseInstruction(&builder, &root_name)) {
+ return InvalidArgument("Syntax error:\n%s", GetError());
+ }
+ }
+
+ module->AddEntryComputation(builder.Build());
+ for (auto& comp : computations_) {
+ module->AddEmbeddedComputation(std::move(comp));
}
return Status::OK();
}
@@ -3271,12 +3361,8 @@ Status ParseHloString(absl::string_view str, HloModule* module) {
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name) {
HloParser parser(str);
- auto builder = absl::make_unique<HloComputation::Builder>(string(name));
- string root_name;
- TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(builder.get(), &root_name));
- std::unique_ptr<HloComputation> computation = builder->Build();
auto module = absl::make_unique<HloModule>(string(name), HloModuleConfig());
- module->AddEntryComputation(std::move(computation));
+ TF_RETURN_IF_ERROR(parser.ParseSingleInstruction(module.get()));
return std::move(module);
}
diff --git a/tensorflow/compiler/xla/service/hlo_parser.h b/tensorflow/compiler/xla/service/hlo_parser.h
index 3696035514..97d6f0117e 100644
--- a/tensorflow/compiler/xla/service/hlo_parser.h
+++ b/tensorflow/compiler/xla/service/hlo_parser.h
@@ -40,8 +40,9 @@ StatusOr<std::unique_ptr<HloModule>> ParseHloString(
// point to an empty module (no computations).
Status ParseHloString(absl::string_view str, HloModule* module);
-// Parses the text for a single HLO operation into an HLO module with a function
-// that runs that operation (with the same parameters) as its entry computation.
+// Parses the text for a single HLO instruction into an HLO module with an
+// entry computation that runs that instruction (with the same parameters) as
+// its root instruction.
StatusOr<std::unique_ptr<HloModule>> ParseHloOpToModule(
absl::string_view str, absl::string_view name = "single_op");
diff --git a/tensorflow/compiler/xla/service/hlo_parser_test.cc b/tensorflow/compiler/xla/service/hlo_parser_test.cc
index 96db96bdb9..d10acf3814 100644
--- a/tensorflow/compiler/xla/service/hlo_parser_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_parser_test.cc
@@ -1163,49 +1163,80 @@ ENTRY Sort {
// clang-format on
}
-class HloParserTest : public ::testing::Test,
- public ::testing::WithParamInterface<TestData> {
+// The test class for those tests defined above which round-trip through the
+// parser and ToString is templatized on two bool parameters:
+//
+// short_form : used for the "short" test cases which use the ShortParsable
+// output form.
+// proto_round_trip : whether the module should also be round-tripped through
+// HloProto form. This provides much better coverage for the proto
+// serialization/deserialization.
+//
+// The proto_round_trip=true case also technically covers the Parser->ToString
+// roundtrip as well, but separating out the Parser->ToString roundtrip as its
+// own test provides better isolation and could conceivably catch weirdo bugs
+// which are hidden by interaction between the textual and proto roundtripping.
+template <bool short_form, bool proto_round_trip>
+class HloParameterizedParserTest
+ : public ::testing::Test,
+ public ::testing::WithParamInterface<TestData> {
protected:
- static void ExpectHasSubstr(string_view s, string_view expected) {
- EXPECT_TRUE(absl::StrContains(s, expected))
- << "'" << s << "' does not contain '" << expected << "'";
- }
-
// Expects "ToString(ParseHloString(string)) == string", that is, parses the
// string, asserts that it succeeded, stringifies the parsed module, and
// checks that the it equals the original string.
void ExpectEqual() {
const string& original = GetParam().module_string;
- auto result = ParseHloString(original);
- TF_ASSERT_OK(result.status());
- EXPECT_EQ(original, result.ValueOrDie()->ToString(
- HloPrintOptions().set_print_large_constants(true)));
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseHloString(original));
+ if (proto_round_trip) {
+ TF_ASSERT_OK_AND_ASSIGN(module, HloModule::CreateFromProto(
+ module->ToProto(), module->config()));
+ }
+ if (short_form) {
+ EXPECT_EQ(original, module->ToString(HloPrintOptions::ShortParsable()));
+ } else {
+ EXPECT_EQ(
+ original,
+ module->ToString(HloPrintOptions().set_print_large_constants(true)));
+ }
}
};
-class HloParserShortTest : public HloParserTest {
- protected:
- void ExpectEqualShort() {
- const string& original = GetParam().module_string;
- auto result = ParseHloString(original);
- TF_ASSERT_OK(result.status());
- EXPECT_EQ(original,
- result.ValueOrDie()->ToString(HloPrintOptions::ShortParsable()));
- }
-};
+// These using shenanigans are required because the TEST_P macro doesn't like
+// template instantiations which contain commas.
+using HloParserTestLong = HloParameterizedParserTest<false, false>;
+using HloParserTestLongProto = HloParameterizedParserTest<false, true>;
+using HloParserTestShort = HloParameterizedParserTest<true, false>;
+using HloParserTestShortProto = HloParameterizedParserTest<true, true>;
-TEST_P(HloParserTest, Run) { ExpectEqual(); }
+TEST_P(HloParserTestLong, Run) { ExpectEqual(); }
+TEST_P(HloParserTestLongProto, Run) { ExpectEqual(); }
+TEST_P(HloParserTestShort, Run) { ExpectEqual(); }
+TEST_P(HloParserTestShortProto, Run) { ExpectEqual(); }
-TEST_P(HloParserShortTest, Run) { ExpectEqualShort(); }
-
-INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTest,
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestLong,
::testing::ValuesIn(CreateTestCases()),
TestDataToString);
-
-INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserShortTest,
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation,
+ HloParserTestLongProto,
+ ::testing::ValuesIn(CreateTestCases()),
+ TestDataToString);
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation, HloParserTestShort,
+ ::testing::ValuesIn(CreateShortTestCases()),
+ TestDataToString);
+INSTANTIATE_TEST_CASE_P(HloParserTestSuccessInstantiation,
+ HloParserTestShortProto,
::testing::ValuesIn(CreateShortTestCases()),
TestDataToString);
+class HloParserTest : public ::testing::Test {
+ protected:
+ static void ExpectHasSubstr(string_view s, string_view expected) {
+ EXPECT_TRUE(absl::StrContains(s, expected))
+ << "'" << s << "' does not contain '" << expected << "'";
+ }
+};
+
TEST_F(HloParserTest, Empty) {
const string original = "";
auto result = ParseHloString(original);
@@ -1732,6 +1763,25 @@ ENTRY entry {
"was parsing 8:39: error: instruction does not exist: aparam");
}
+TEST_F(HloParserTest, SameNameDiffComputations) {
+ const string original = R"(HloModule same_names:
+add {
+ p0 = f32[] parameter(0)
+ p1 = f32[] parameter(1)
+ ROOT result = f32[] add(p0, p1)
+}
+
+ENTRY ReduceR3ToR2 {
+ p0 = f32[8,16,256]{2,1,0} parameter(0)
+ p1 = f32[] constant(0)
+ ROOT result = f32[8,16]{1,0} reduce(p0, p1), dimensions={2}, to_apply=add
+}
+)";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(original));
+ ASSERT_NE(module->entry_computation(), nullptr);
+ EXPECT_THAT(module->entry_computation()->root_instruction(), op::Reduce());
+}
+
TEST_F(HloParserTest, ParseSharding) {
const string original = "{maximal device=42}";
TF_ASSERT_OK_AND_ASSIGN(HloSharding sharding, ParseSharding(original));
@@ -1792,14 +1842,129 @@ TEST(HloParserSingleOpTest, SingleOp) {
op::Multiply(op::Parameter(0), op::Parameter(1)));
}
-TEST(HloParserSingleOpTest, SingleOpNoShapesProducesError) {
+TEST(HloParserSingleOpTest, SingleOpNoShapeProducesError) {
+ const string text = "multiply(f32[2,4]{1,0} %broadcast, f32[2,4]{1,0} %x)";
+ StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text);
+ ASSERT_TRUE(!module.status().ok());
+ LOG(INFO) << "Status: " << module.status();
+ EXPECT_THAT(module.status().ToString(),
+ ::testing::HasSubstr("expects '=' in instruction"));
+}
+
+TEST(HloParserSingleOpTest, SingleOpNoOperandShapesProducesError) {
const string text = "%multiply = f32[2,4]{1,0} multiply(%broadcast, %x)";
StatusOr<std::unique_ptr<HloModule>> module = ParseHloOpToModule(text);
ASSERT_TRUE(!module.status().ok());
LOG(INFO) << "Status: " << module.status();
- EXPECT_THAT(
- module.status().ToString(),
- ::testing::HasSubstr("Operand broadcast had no shape in HLO text"));
+ EXPECT_THAT(module.status().ToString(),
+ ::testing::HasSubstr("Operand had no shape in HLO text"));
+}
+
+TEST(HloParserSingleOpTest, SingleOpNoNames) {
+ const string text =
+ "%multiply = f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Multiply(op::Parameter(0), op::Parameter(1)));
+}
+
+TEST(HloParserSingleOpTest, CanonicalOp) {
+ const string text = "f32[2,4]{1,0} multiply(f32[2,4]{1,0}, f32[2,4]{1,0})";
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Multiply(op::Parameter(0), op::Parameter(1)));
+ EXPECT_EQ(
+ computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
+ text);
+}
+
+TEST(HloParserSingleOpTest, CanonicalOpWithNested) {
+ const string text =
+ R"(f32[5,20]{1,0} while(f32[5,10]{1,0}), condition=
+{
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+ {
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+ ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+}, body=
+{
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ ROOT tmp_2 = f32[5,20]{1,0} fusion(f32[5,10]{1,0} tmp_0, f32[20,10]{1,0} tmp_1), kind=kLoop, calls=
+ {
+ tmp_0 = f32[5,10]{1,0} parameter(0)
+ tmp_1 = f32[20,10]{1,0} parameter(1)
+ tmp_2 = f32[10,20]{1,0} transpose(f32[20,10]{1,0} tmp_1), dimensions={1,0}
+ ROOT tmp_3 = f32[5,20]{1,0} dot(f32[5,10]{1,0} tmp_0, f32[10,20]{1,0} tmp_2), lhs_contracting_dims={1}, rhs_contracting_dims={0}
+ }
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_EQ(
+ computation->root_instruction()->ToString(HloPrintOptions::Canonical()),
+ text);
+}
+
+TEST(HloParserSingleOpTest, SingleOpWithNested) {
+ const string text =
+ R"(%fusion = f32[3,2,1,1]{3,2,1,0} fusion(f32[3,2,1,1]{3,2,1,0} %p0, f32[2]{0} %p1), kind=kLoop, calls=
+{
+ %param_0 = f32[3,2,1,1]{3,2,1,0} parameter(0)
+ %param_1 = f32[2]{0} parameter(1)
+ %broadcast = f32[3,2,1,1]{3,2,1,0} broadcast(f32[2]{0} %param_1), dimensions={1}
+ ROOT %subtract = f32[3,2,1,1]{3,2,1,0} subtract(f32[3,2,1,1]{3,2,1,0} %param_0, f32[3,2,1,1]{3,2,1,0} %broadcast)
+})";
+
+ TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloOpToModule(text));
+ const HloComputation* computation = module->entry_computation();
+ ASSERT_NE(computation, nullptr);
+ EXPECT_THAT(computation->root_instruction(),
+ op::Fusion(op::Parameter(0), op::Parameter(1)));
+}
+
+TEST(HloParserSingleOpTest, SingleOpWithNested_DoesNotExist) {
+ const string text =
+ R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
+{
+ result = f32[] add(f32[] x, f32[] y)
+})";
+ auto status = ParseHloOpToModule(text).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(),
+ ::testing::HasSubstr("does not exist: x"));
+}
+
+TEST(HloParserSingleOpTest, SingleOpWithNested_NoLhs) {
+ const string text =
+ R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
+{
+ f32[] add(f32[] x, f32[] y)
+})";
+ auto status = ParseHloOpToModule(text).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
+}
+
+TEST(HloParserSingleOpTest, SingleOpWithNested_NoOperandName) {
+ const string text =
+ R"(reduce = f32[] reduce(f32[10], f32[]), dimensions={1}, to_apply=
+{
+ result = f32[] add(f32[], f32[])
+})";
+ auto status = ParseHloOpToModule(text).status();
+ ASSERT_FALSE(status.ok());
+ EXPECT_THAT(status.error_message(), ::testing::HasSubstr("expects name"));
}
TEST(HloParserSingleOpTest, ConvolutionTrivialFeatureGroupCount) {
diff --git a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
index 8c2f928ca1..5e004ce78a 100644
--- a/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
+++ b/tensorflow/compiler/xla/service/hlo_pass_pipeline.cc
@@ -17,6 +17,8 @@ limitations under the License.
#include <functional>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
@@ -24,7 +26,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -74,8 +75,8 @@ StatusOr<bool> HloPassPipeline::RunPassesInternal(
std::vector<HloPassInterface*> HloPassPipeline::GetEnabledPasses(
const DebugOptions& debug_options) {
auto repeated_field = debug_options.xla_disable_hlo_passes();
- tensorflow::gtl::FlatSet<string> disabled_pass_names(repeated_field.begin(),
- repeated_field.end());
+ absl::flat_hash_set<string> disabled_pass_names(repeated_field.begin(),
+ repeated_field.end());
if (!disabled_pass_names.empty()) {
VLOG(1) << "Passes disabled by --xla_disable_hlo_passes: "
<< absl::StrJoin(disabled_pass_names, ", ");
@@ -98,7 +99,7 @@ void HloPassPipeline::MaybeDumpHlo(const HloModule& module,
if (!proto_dump_path.empty()) {
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static auto* const module_id_to_pass_number =
- new tensorflow::gtl::FlatMap<int64, int64>();
+ new absl::flat_hash_map<int64, int64>();
tensorflow::mutex_lock lock(mu);
const int64 pass_number = (*module_id_to_pass_number)[module.unique_id()]++;
diff --git a/tensorflow/compiler/xla/service/hlo_reachability.h b/tensorflow/compiler/xla/service/hlo_reachability.h
index b66a2aa4bd..5a5f01f8fd 100644
--- a/tensorflow/compiler/xla/service/hlo_reachability.h
+++ b/tensorflow/compiler/xla/service/hlo_reachability.h
@@ -19,11 +19,11 @@ limitations under the License.
#include <list>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -154,7 +154,7 @@ class HloReachabilityMap {
// Dense assignment from HloInstruction* to number. These numbers index
// into the bit_vectors_ vector and into the bits within a BitVector.
- tensorflow::gtl::FlatMap<const HloInstruction*, int> indices_;
+ absl::flat_hash_map<const HloInstruction*, int> indices_;
// Bitvectors holding the reachability to each instruction. The bit vector for
// instruction X includes ones for each instruction which X is reachable from.
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index a438671936..5ac43808ee 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -20,6 +20,8 @@ limitations under the License.
#include <set>
#include <string>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@@ -75,7 +77,7 @@ bool IsRematerializable(const HloInstruction* instruction) {
// cache before, and eventually calling the IsRematerializable() API.
bool CanBeRematerialized(
const HloInstruction* instruction,
- tensorflow::gtl::FlatMap<const HloInstruction*, bool>* remat_able) {
+ absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
auto it = remat_able->find(instruction);
if (it != remat_able->end()) {
return it->second;
@@ -268,7 +270,7 @@ class InstructionList {
Item* first_;
// Item for each instruction.
- tensorflow::gtl::FlatMap<const HloInstruction*, Item*> item_map_;
+ absl::flat_hash_map<const HloInstruction*, Item*> item_map_;
};
// Return the items which use the given LogicalBuffer. Sets
@@ -503,7 +505,7 @@ MemoryUsageTracker::MemoryUsageTracker(
PointsToSet::BufferSet live_out_set =
points_to_analysis.GetPointsToSet(computation_->root_instruction())
.CreateFlattenedSet();
- tensorflow::gtl::FlatMap<const LogicalBuffer*, BufferId>
+ absl::flat_hash_map<const LogicalBuffer*, BufferId>
logical_buffer_to_buffer_id;
for (auto* item = instruction_list_.first(); item != nullptr;
@@ -854,7 +856,7 @@ int64 RematerializationCost(const HloInstruction* instruction,
Item* PickRematerializationCandidate(
const MemoryUsageTracker& memory_tracker,
const InstructionList& instruction_list, int64 memory_limit_bytes,
- tensorflow::gtl::FlatMap<const HloInstruction*, bool>* remat_able) {
+ absl::flat_hash_map<const HloInstruction*, bool>* remat_able) {
Item* best_item = nullptr;
int64 best_cost = 0;
@@ -980,10 +982,10 @@ StatusOr<bool> HloRematerialization::RematerializeComputation(
// rematerialization is essentially a move). If the next rematerialization of
// the instruction is also a move then the rematerialization is added to the
// blacklist.
- tensorflow::gtl::FlatSet<const HloInstruction*> remat_move_instructions;
+ absl::flat_hash_set<const HloInstruction*> remat_move_instructions;
// The map from instructions to their rematerializable status.
- tensorflow::gtl::FlatMap<const HloInstruction*, bool> remat_able;
+ absl::flat_hash_map<const HloInstruction*, bool> remat_able;
// The peak memory of the computation at any point in the instruction
// sequence.
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.h b/tensorflow/compiler/xla/service/hlo_rematerialization.h
index 7330d73c09..70d83c04f0 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.h
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.h
@@ -15,6 +15,8 @@
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_REMATERIALIZATION_H_
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/buffer_liveness.h"
#include "tensorflow/compiler/xla/service/call_graph.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -115,14 +117,13 @@ class HloRematerialization : public HloModulePass {
// computations called from sequential context
// (CallContext::kSequential). These values are updated as rematerialization
// occurs.
- tensorflow::gtl::FlatMap<const HloComputation*, int64>
- computation_peak_memory_;
+ absl::flat_hash_map<const HloComputation*, int64> computation_peak_memory_;
std::unique_ptr<TuplePointsToAnalysis> points_to_analysis_;
// Set of computations which have had rematerialization
// applied. Rematerialization is only applied once per computation.
- tensorflow::gtl::FlatSet<const HloComputation*> rematerialized_computations_;
+ absl::flat_hash_set<const HloComputation*> rematerialized_computations_;
// Count of the total instructions rematerialized.
int64 instructions_rematerialized_ = 0;
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.cc b/tensorflow/compiler/xla/service/hlo_schedule.cc
index 3fc5dbeb02..9972eb2077 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule.cc
+++ b/tensorflow/compiler/xla/service/hlo_schedule.cc
@@ -18,6 +18,8 @@ limitations under the License.
#include <queue>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/map_util.h"
@@ -30,7 +32,7 @@ namespace xla {
/* static */ StatusOr<HloSchedule> HloSchedule::CreateFromProto(
const HloModule* module, const HloScheduleProto& proto) {
- tensorflow::gtl::FlatMap<int64, const HloComputation*> id_to_computation;
+ absl::flat_hash_map<int64, const HloComputation*> id_to_computation;
for (const HloComputation* computation : module->computations()) {
id_to_computation[computation->unique_id()] = computation;
}
@@ -44,7 +46,7 @@ namespace xla {
<< "No computation exists in HLO module with id " << computation_id;
const HloComputation* computation = comp_it->second;
- tensorflow::gtl::FlatMap<int64, const HloInstruction*> id_to_instruction;
+ absl::flat_hash_map<int64, const HloInstruction*> id_to_instruction;
for (const HloInstruction* instruction : computation->instructions()) {
id_to_instruction[instruction->unique_id()] = instruction;
}
@@ -112,13 +114,13 @@ Status HloSchedule::UpdateComputationSchedule(
const HloComputation* computation) {
// Map from unique ID to HloInstruction pointer for instructions in the
// computation.
- tensorflow::gtl::FlatMap<int, const HloInstruction*> id_to_instruction;
+ absl::flat_hash_map<int, const HloInstruction*> id_to_instruction;
for (const HloInstruction* instruction : computation->instructions()) {
InsertOrDie(&id_to_instruction, instruction->unique_id(), instruction);
}
// Set of all HloInstructions in the schedule.
- tensorflow::gtl::FlatSet<int> ids_in_schedule;
+ absl::flat_hash_set<int> ids_in_schedule;
for (int id : sequences_.at(computation->unique_id()).ids()) {
InsertOrDie(&ids_in_schedule, id);
}
@@ -126,15 +128,13 @@ Status HloSchedule::UpdateComputationSchedule(
// Map from HloInstruction X to newly added instructions (instruction is in
// computation, but not in schedule) which use X. If an instruction is not in
// the map, then it has no users which are newly added instructions.
- tensorflow::gtl::FlatMap<const HloInstruction*,
- std::vector<const HloInstruction*>>
+ absl::flat_hash_map<const HloInstruction*, std::vector<const HloInstruction*>>
new_instruction_uses;
// For each newly added instruction, this is the count of the instruction's
// operands that have not yet been scheduled. When this value reaches zero,
// then the instruction may be placed in the schedule.
- tensorflow::gtl::FlatMap<const HloInstruction*, int>
- unscheduled_operand_count;
+ absl::flat_hash_map<const HloInstruction*, int> unscheduled_operand_count;
// Create a worklist of newly added instructions which are ready to be added
// to the schedule. Initialize worklist with those that have zero operands.
@@ -211,15 +211,15 @@ Status HloSchedule::Update() {
if (sequences_.size() > nonfusion_computations.size()) {
// Schedule contains some computations which have been removed from the
// HloModule. Remove them from the schedule as well.
- tensorflow::gtl::FlatSet<int64> nonfusion_computations_ids;
+ absl::flat_hash_set<int64> nonfusion_computations_ids;
for (const HloComputation* computation : nonfusion_computations) {
nonfusion_computations_ids.insert(computation->unique_id());
}
for (auto it = sequences_.begin(); it != sequences_.end();) {
if (nonfusion_computations_ids.count(it->first) == 0) {
- it = sequences_.erase(it);
+ sequences_.erase(it++);
} else {
- it++;
+ ++it;
}
}
}
@@ -254,7 +254,7 @@ Status HloSchedule::Verify() const {
// For each computation verify the set of instructions is the same and that
// each dependency and control edge is honored.
for (const HloComputation* computation : nonfusion_computations) {
- tensorflow::gtl::FlatMap<const HloInstruction*, int> instruction_position;
+ absl::flat_hash_map<const HloInstruction*, int> instruction_position;
int pos = 0;
for (const HloInstruction* instruction :
sequence(computation).instructions()) {
diff --git a/tensorflow/compiler/xla/service/hlo_schedule.h b/tensorflow/compiler/xla/service/hlo_schedule.h
index 270fe6039f..0a714101ee 100644
--- a/tensorflow/compiler/xla/service/hlo_schedule.h
+++ b/tensorflow/compiler/xla/service/hlo_schedule.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/types/span.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -103,8 +104,7 @@ class HloSchedule {
// Returns a map from HloComputation unique ID to instruction sequence. The
// map contains all sequences in the schedule.
- const tensorflow::gtl::FlatMap<int64, HloInstructionSequence>& sequences()
- const {
+ const absl::flat_hash_map<int64, HloInstructionSequence>& sequences() const {
return sequences_;
}
@@ -148,7 +148,7 @@ class HloSchedule {
// A map from computation unique ID to instruction sequence. Unique IDs are
// used rather than HloComputation pointers because HLO pointers are not
// unique across HLO transformations because pointers may be recycled.
- tensorflow::gtl::FlatMap<int64, HloInstructionSequence> sequences_;
+ absl::flat_hash_map<int64, HloInstructionSequence> sequences_;
};
std::ostream& operator<<(std::ostream& out, const HloSchedule& schedule);
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index de7e6b53d4..94c7bafd3b 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -369,10 +369,14 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
return HloSharding(tuple_shardings);
} else if (proto.type() == OpSharding::Type::OpSharding_Type_REPLICATED) {
return Replicate();
- } else if (proto.type() == OpSharding::Type::OpSharding_Type_MAXIMAL ||
- proto.tile_assignment_devices().size() == 1) {
+ } else if (proto.tile_assignment_devices().size() == 1) {
return HloSharding(proto.tile_assignment_devices(0));
}
+
+ TF_RET_CHECK(proto.type() != OpSharding::Type::OpSharding_Type_MAXIMAL)
+ << "Maximal sharding is expected to have single device assignment, but "
+ << proto.tile_assignment_devices().size() << " has provided.";
+
// Some versions of gcc cannot infer the TileAssignment constructor from a
// braced initializer-list, so create one manually.
std::vector<int64> devices(proto.tile_assignment_devices().begin(),
diff --git a/tensorflow/compiler/xla/service/hlo_value.cc b/tensorflow/compiler/xla/service/hlo_value.cc
index 8549487702..59594ab2f0 100644
--- a/tensorflow/compiler/xla/service/hlo_value.cc
+++ b/tensorflow/compiler/xla/service/hlo_value.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include <utility>
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@@ -31,7 +32,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -167,7 +167,7 @@ void HloValue::SetPositionsAndComputeUses(
positions_.insert(positions_.end(), positions.begin(), positions.end());
// Gather the computation roots at which this value appears.
- tensorflow::gtl::FlatSet<HloInstruction*> root_positions;
+ absl::flat_hash_set<HloInstruction*> root_positions;
for (const HloPosition& position : positions_) {
if (position.instruction ==
position.instruction->parent()->root_instruction()) {
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 6eb6658904..a7727824fe 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include <set>
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/service/hlo_casting_utils.h"
#include "tensorflow/compiler/xla/service/hlo_instructions.h"
@@ -23,7 +24,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -993,7 +993,7 @@ Status CheckSameIsHostTransfer(const HloInstruction* instr1,
// Checks various invariants of send and recv instructions.
Status VerifySendsAndRecvs(const HloModule& module) {
- tensorflow::gtl::FlatMap<int64, const HloInstruction*> host_channels;
+ absl::flat_hash_map<int64, const HloInstruction*> host_channels;
// Host send/recv instructions must have their own unique channel.
auto check_unique_host_channel = [&](const HloInstruction* instruction) {
const HloSendRecvInstruction* sendrecv =
@@ -1061,7 +1061,7 @@ StatusOr<bool> HloVerifier::Run(HloModule* module) {
TF_RETURN_IF_ERROR(VerifyHloStructure(module));
TF_RETURN_IF_ERROR(VerifySendsAndRecvs(*module));
- tensorflow::gtl::FlatMap<string, const HloInstruction*> instructions;
+ absl::flat_hash_map<string, const HloInstruction*> instructions;
for (auto* computation : module->computations()) {
for (const auto& instruction : computation->instructions()) {
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.cc b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
index 06f0e1ed25..1ebb331977 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.cc
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.cc
@@ -16,6 +16,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/indexed_array_analysis.h"
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
@@ -23,7 +25,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
namespace gtl = ::tensorflow::gtl;
@@ -95,7 +96,7 @@ Status IndexedArrayAnalysis::TraverseAndPopulateCache(
absl::InlinedVector<const HloInstruction*, 4> stack;
enum DfsState { kDiscovered, kVisited };
- gtl::FlatMap<const HloInstruction*, DfsState> dfs_state_map;
+ absl::flat_hash_map<const HloInstruction*, DfsState> dfs_state_map;
stack.push_back(root);
InsertOrDie(&dfs_state_map, root, kDiscovered);
diff --git a/tensorflow/compiler/xla/service/indexed_array_analysis.h b/tensorflow/compiler/xla/service/indexed_array_analysis.h
index 3e238f97a0..e5aa67fd85 100644
--- a/tensorflow/compiler/xla/service/indexed_array_analysis.h
+++ b/tensorflow/compiler/xla/service/indexed_array_analysis.h
@@ -18,10 +18,10 @@ limitations under the License.
#include <type_traits>
+#include "absl/container/flat_hash_map.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/util/ptr_util.h"
namespace xla {
@@ -360,7 +360,7 @@ class IndexedArrayAnalysis {
std::vector<std::unique_ptr<Array>> owned_tensors_;
std::vector<Literal> owned_literals_;
- tensorflow::gtl::FlatMap<const HloInstruction*, Array*> cache_;
+ absl::flat_hash_map<const HloInstruction*, Array*> cache_;
};
// A pass that prints all non-trivial results returned by IndexedArrayAnalysis.
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.cc b/tensorflow/compiler/xla/service/instruction_fusion.cc
index e884122fcb..5a99c40df4 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.cc
+++ b/tensorflow/compiler/xla/service/instruction_fusion.cc
@@ -22,11 +22,11 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/memory/memory.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -189,7 +189,7 @@ bool InstructionFusion::EffectivelyAtMostUnary(HloInstruction* hlo) {
bool InstructionFusion::CanFuseOnAllPaths(
HloInstruction* producer, HloInstruction* consumer,
const HloInstructionSet& do_not_fuse,
- tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>, bool>*
+ absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
result_cache) {
if (consumer == producer) {
return true;
@@ -241,7 +241,7 @@ InstructionFusion::ComputeGloballyUnfusible(
// fusing operations that require duplication later depending on
// is_expensive_().
HloInstructionSet do_not_duplicate;
- tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>, bool>
+ absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>
can_fuse_on_all_paths_result_cache;
for (HloInstruction* consumer : post_order) {
for (HloInstruction* producer : consumer->operands()) {
@@ -430,7 +430,7 @@ class ReversePostOrderFusionQueue : public FusionQueue {
private:
std::vector<HloInstruction*> post_order_;
- tensorflow::gtl::FlatMap<HloInstruction*, int> post_order_index_;
+ absl::flat_hash_map<HloInstruction*, int> post_order_index_;
};
} // namespace
diff --git a/tensorflow/compiler/xla/service/instruction_fusion.h b/tensorflow/compiler/xla/service/instruction_fusion.h
index c1ec3b18a1..da2032f6c7 100644
--- a/tensorflow/compiler/xla/service/instruction_fusion.h
+++ b/tensorflow/compiler/xla/service/instruction_fusion.h
@@ -1,3 +1,4 @@
+#include "absl/container/flat_hash_map.h"
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
@@ -158,8 +159,8 @@ class InstructionFusion : public HloModulePass {
bool CanFuseOnAllPaths(
HloInstruction* producer, HloInstruction* consumer,
const HloInstructionSet& do_not_fuse,
- tensorflow::gtl::FlatMap<std::pair<HloInstruction*, HloInstruction*>,
- bool>* result_cache);
+ absl::flat_hash_map<std::pair<HloInstruction*, HloInstruction*>, bool>*
+ result_cache);
// Computes the set of nodes that we do not want to fuse into any of their
// consumers based on a global analysis of the HLO graph.
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 082bf8bffe..25d5327561 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -498,6 +498,22 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(
constraints->SetBufferLayout(new_shape.layout(), *buffer));
}
+ } else if (instruction->IsCrossModuleAllReduce()) {
+ CHECK(get_channel_constraints(instruction))
+ << "Multi-module layout assignment requires ChannelLayoutConstraints";
+ int64 all_reduce_id = instruction->all_reduce_id().value();
+ if (!get_channel_constraints(instruction)
+ ->IsChannelConstrained(all_reduce_id)) {
+ continue;
+ }
+ // TODO(b/68493863): Change to use SetOperandLayout().
+ const Shape& buffer_shape = instruction->operand(0)->shape();
+ TF_RET_CHECK(ShapeUtil::IsArray(buffer_shape));
+ Shape new_buffer_shape =
+ get_channel_constraints(instruction)
+ ->LayoutShapeForChannel(buffer_shape, all_reduce_id);
+ TF_RETURN_IF_ERROR(
+ constraints->SetInstructionLayout(new_buffer_shape, instruction));
}
}
@@ -1512,19 +1528,6 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
// Verify all layouts in the shape have been set.
TF_RET_CHECK(LayoutUtil::HasLayout(instruction->shape()));
}
-
- // Copy the root instruction's result if its layout does not match the result
- // layout constraint.
- if (constraints.ResultLayout() != nullptr &&
- !constraints.ResultLayout()->MatchesLayoutInShape(
- computation->root_instruction()->shape())) {
- TF_ASSIGN_OR_RETURN(
- HloInstruction * new_root,
- CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
- computation->root_instruction()));
- computation->set_root_instruction(new_root);
- }
-
return Status::OK();
}
@@ -1654,6 +1657,18 @@ Status LayoutAssignment::RunOnComputation(
TF_RETURN_IF_ERROR(
ConstrainChannelLayouts(computation, channel_constraints));
}
+
+ // Copy the root instruction's result if its layout does not match the result
+ // layout constraint.
+ if (constraints.ResultLayout() != nullptr &&
+ !constraints.ResultLayout()->MatchesLayoutInShape(
+ computation->root_instruction()->shape())) {
+ TF_ASSIGN_OR_RETURN(
+ HloInstruction * new_root,
+ CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
+ computation->root_instruction()));
+ computation->set_root_instruction(new_root);
+ }
return Status::OK();
}
@@ -1709,6 +1724,30 @@ Status LayoutAssignment::ConstrainChannelLayouts(
ShapeUtil::GetMutableSubshape(instruction->mutable_shape(), {0});
*send_shape = shape;
}
+ } else if (instruction->IsCrossModuleAllReduce()) {
+ const Layout* layout =
+ get_channel_constraints(instruction)
+ ->ConstrainChannel(instruction->all_reduce_id().value(),
+ instruction->shape().layout());
+ if (layout != nullptr) {
+ // We found an already constrained layout which does not match the one
+ // the channel wants to impose. Either add a new kCopy, or use the
+ // existing one to marshal the correct shape.
+ HloInstruction* operand = instruction->mutable_operand(0);
+ Shape shape = operand->shape();
+ *shape.mutable_layout() = *layout;
+ if (operand->opcode() != HloOpcode::kCopy) {
+ HloInstruction* copy = operand->parent()->AddInstruction(
+ HloInstruction::CreateUnary(shape, HloOpcode::kCopy, operand));
+ RegisterAddedCopy(copy);
+ SetupCopiedInstruction(*operand, copy, {});
+ TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(0, copy));
+ operand = copy;
+ } else {
+ *operand->mutable_shape() = shape;
+ }
+ *instruction->mutable_shape() = shape;
+ }
}
}
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index e29c199c42..15f0adcaaf 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -25,6 +25,8 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -38,8 +40,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -228,8 +228,8 @@ class LayoutConstraints {
// Array-shaped buffers which have not yet been constrained.
std::set<LogicalBuffer::Id> unconstrained_buffer_ids_;
- mutable tensorflow::gtl::FlatMap<const HloInstruction*,
- std::unique_ptr<PointsToSet::BufferSet>>
+ mutable absl::flat_hash_map<const HloInstruction*,
+ std::unique_ptr<PointsToSet::BufferSet>>
buffer_sets_cache_;
HloComputation* computation_;
@@ -504,7 +504,7 @@ class LayoutAssignment : public HloModulePass {
// Every copy added to the module by the layout assignment pass is registered
// here.
- tensorflow::gtl::FlatSet<HloInstruction*> added_copies_;
+ absl::flat_hash_set<HloInstruction*> added_copies_;
// The pointer to the channel layout constraints passed in with the
// constructor. If not nullptr, this is an input/output argument.
@@ -521,8 +521,7 @@ class LayoutAssignment : public HloModulePass {
// The set of HLO instructions which lacked any layout constraint, thus
// receiving propagated default layouts.
- tensorflow::gtl::FlatSet<const HloInstruction*>
- unconstrained_layout_instructions_;
+ absl::flat_hash_set<const HloInstruction*> unconstrained_layout_instructions_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/layout_assignment_test.cc b/tensorflow/compiler/xla/service/layout_assignment_test.cc
index 752a61476d..10f9a95121 100644
--- a/tensorflow/compiler/xla/service/layout_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment_test.cc
@@ -860,6 +860,50 @@ TEST_F(LayoutAssignmentTest, ChannelLayoutMismatch) {
ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {1, 0})));
}
+TEST_F(LayoutAssignmentTest, AllReduceLayoutMissmatch) {
+ // Pin non matching layouts to parameter and root.
+ const char* module_str = R"(
+ HloModule test_module
+
+ add {
+ lhs = f32[] parameter(0)
+ rhs = f32[] parameter(1)
+ ROOT add = f32[] add(lhs, rhs)
+ }
+
+ ENTRY entry_computation {
+ param = (f32[2,2]) parameter(0)
+ gte = f32[2,2] get-tuple-element(param), index=0
+ ar.0 = f32[2,2] cross-replica-sum(gte),
+ all_reduce_id=0, replica_groups={{0}}, to_apply=add,
+ sharding={maximal device=0}
+ const = f32[2,2] constant(f32[2,2]{{0,1},{2,3}})
+ ROOT ar.1 = f32[2,2] cross-replica-sum(const),
+ all_reduce_id=0, replica_groups={{0}}, to_apply=add,
+ sharding={maximal device=1}
+ })";
+ TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
+ ParseAndReturnVerifiedModule(module_str));
+ ComputationLayout computation_layout(
+ module->entry_computation()->ComputeProgramShape());
+ Shape param_shape = ShapeUtil::MakeTupleShape(
+ {ShapeUtil::MakeShapeWithLayout(F32, {2, 2}, {0, 1})});
+ TF_ASSERT_OK(
+ computation_layout.mutable_parameter_layout(0)->CopyLayoutFromShape(
+ param_shape));
+ computation_layout.mutable_result_layout()->ResetLayout(
+ LayoutUtil::MakeLayout({1, 0}));
+
+ ChannelLayoutConstraints channel_constraints;
+ AssignLayouts(module.get(), &computation_layout, &channel_constraints);
+
+ EXPECT_THAT(LayoutOf(module.get(), "gte"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(module.get(), "ar.0"), ElementsAre(0, 1));
+ EXPECT_THAT(LayoutOf(module.get(), "ar.1"), ElementsAre(0, 1));
+ const HloInstruction* root = module->entry_computation()->root_instruction();
+ EXPECT_THAT(root->shape().layout().minor_to_major(), ElementsAre(1, 0));
+}
+
TEST_F(LayoutAssignmentTest, CopySliceOperandToAvoidImplicitLayoutChange) {
const char* module_str = R"(
HloModule CopySliceOperandToAvoidImplicitLayoutChange
diff --git a/tensorflow/compiler/xla/service/llvm_ir/BUILD b/tensorflow/compiler/xla/service/llvm_ir/BUILD
index 540bbb7c7a..6223a34b12 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/BUILD
+++ b/tensorflow/compiler/xla/service/llvm_ir/BUILD
@@ -38,6 +38,8 @@ cc_library(
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:logical_buffer",
"//tensorflow/core:lib",
+ "@com_google_absl//absl/container:flat_hash_map",
+ "@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@llvm//:core",
],
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
index e5370eca56..643ecd0fba 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h"
-#include <unordered_set>
+#include <map>
#include "llvm/IR/MDBuilder.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
@@ -164,9 +164,7 @@ llvm::MDNode* AliasAnalysis::GetNoaliasMetadataForBuffer(
add_buffers_to_worklist(operand);
}
- tensorflow::gtl::FlatSet<BufferAllocation::Slice,
- BufferAllocation::Slice::Hasher>
- buffers;
+ std::set<BufferAllocation::Slice> buffers;
for (const LogicalBuffer* buffer : worklist) {
// Skip buffers which cannot be added to the noalias set.
if (!assignment.HasAllocation(*buffer) ||
diff --git a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
index 8d9fa99d82..2b46b3c396 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/alias_analysis.h
@@ -16,14 +16,13 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_LLVM_IR_ALIAS_ANALYSIS_H_
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/str_cat.h"
#include "llvm/IR/Module.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
namespace llvm_ir {
@@ -77,14 +76,14 @@ class AliasAnalysis {
// A map from a buffer slice to metadata corresponding to its alias.scope
// metadata. The index kParameterAliasSet is used to hold aliasing
// information for parameters.
- tensorflow::gtl::FlatMap<BufferAllocation::Slice, llvm::MDNode*,
- BufferAllocation::Slice::Hasher>
+ absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*,
+ BufferAllocation::Slice::Hasher>
alias_scope_metadata_;
// A map from a buffer slice to metadata corresponding to its noalias
// metadata.
- tensorflow::gtl::FlatMap<BufferAllocation::Slice, llvm::MDNode*,
- BufferAllocation::Slice::Hasher>
+ absl::flat_hash_map<BufferAllocation::Slice, llvm::MDNode*,
+ BufferAllocation::Slice::Hasher>
noalias_metadata_;
};
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.cc b/tensorflow/compiler/xla/service/multi_output_fusion.cc
index b9ec31c497..2ca527bc4c 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.cc
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.cc
@@ -15,10 +15,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/multi_output_fusion.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -50,7 +50,7 @@ StatusOr<bool> MultiOutputFusion::Run(HloModule* module) {
all_fusion_candidates_.push_back(instruction);
std::vector<HloInstruction*> candidates;
- tensorflow::gtl::FlatSet<HloInstruction*> candidates_set;
+ absl::flat_hash_set<HloInstruction*> candidates_set;
VLOG(10) << "Looking at instruction: " << instruction->name();
for (auto operand : instruction->operands()) {
// Filter out the non-interesting instructions -- they
@@ -172,7 +172,7 @@ void MultiOutputFusion::Update(HloInstruction* instr1, HloInstruction* instr2) {
// Update the fusible list for fusion. Variable new_fusibles keeps
// track of the new or changed entries.
std::vector<std::pair<HloInstruction*, int64>> new_fusibles;
- tensorflow::gtl::FlatSet<HloInstruction*> in_list;
+ absl::flat_hash_set<HloInstruction*> in_list;
auto it = fusion_node.fusibles.begin();
while (it != fusion_node.fusibles.end()) {
HloInstruction* instr = it->first;
diff --git a/tensorflow/compiler/xla/service/multi_output_fusion.h b/tensorflow/compiler/xla/service/multi_output_fusion.h
index 0344626b26..9508ab2ed1 100644
--- a/tensorflow/compiler/xla/service/multi_output_fusion.h
+++ b/tensorflow/compiler/xla/service/multi_output_fusion.h
@@ -19,6 +19,7 @@ limitations under the License.
#include <queue>
#include <vector>
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
@@ -126,7 +127,7 @@ class MultiOutputFusion : public HloModulePass {
std::vector<FusionCandidate> candidates_;
// A map that maps an instruction to the index_.
- tensorflow::gtl::FlatMap<HloInstruction*, int> candidates_index_;
+ absl::flat_hash_map<HloInstruction*, int> candidates_index_;
// The reachability map of current computation.
std::unique_ptr<HloReachabilityMap> reachability_;
diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h
index 6dd89c240f..8909d0f4fe 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.h
+++ b/tensorflow/compiler/xla/service/name_uniquer.h
@@ -18,10 +18,10 @@ limitations under the License.
#include <string>
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
namespace xla {
@@ -69,7 +69,7 @@ class NameUniquer {
int64 next_ = 0;
// Set of all the identifiers which has been used.
- tensorflow::gtl::FlatSet<int64> used_;
+ absl::flat_hash_set<int64> used_;
};
// The string to use to separate the prefix of the name from the uniquing
@@ -78,7 +78,7 @@ class NameUniquer {
// Map from name prefix to the generator data structure which tracks used
// identifiers and generates new ones.
- tensorflow::gtl::FlatMap<string, SequentialIdGenerator> generated_names_;
+ absl::flat_hash_map<string, SequentialIdGenerator> generated_names_;
TF_DISALLOW_COPY_AND_ASSIGN(NameUniquer);
};
diff --git a/tensorflow/compiler/xla/service/reduce_precision_insertion.h b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
index 4bb22428f3..0b4e82e8d6 100644
--- a/tensorflow/compiler/xla/service/reduce_precision_insertion.h
+++ b/tensorflow/compiler/xla/service/reduce_precision_insertion.h
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_module.h"
#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
#include "tensorflow/compiler/xla/service/hlo_pass_pipeline.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 7194b2cafd..e379911462 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include <string>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/str_join.h"
@@ -33,7 +34,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
@@ -577,7 +577,7 @@ Status ValidateDotDimensionNumbers(
// Check that dimension numbers are unique.
auto dims_unique = [](absl::Span<const int64> contracting_dims,
absl::Span<const int64> batch_dims) -> bool {
- tensorflow::gtl::FlatSet<int64> dim_set;
+ absl::flat_hash_set<int64> dim_set;
auto is_unique = [&dim_set](int64 i) -> bool {
return dim_set.insert(i).second;
};
@@ -2380,7 +2380,9 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(HloOpcode operation,
!std::is_permutation(dimensions.begin(), dimensions.end(),
indices.begin())) {
return InvalidArgument(
- "Transpose dimensions not a permutation of the operand dimensions.");
+ "Transpose dimensions [%s] are not a permutation of the operand "
+ "dimensions (operand shape is %s).",
+ StrJoin(dimensions, ","), ShapeUtil::HumanString(operand));
}
// Permute(dimensions,input) computes output[dimensions[i]]=input[i]. However,
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index 921a984589..56952e3ada 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <string>
#include <utility>
+#include "absl/container/flat_hash_set.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@@ -26,7 +27,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/logging.h"
namespace xla {
@@ -147,7 +147,7 @@ void ScopedShapedBuffer::Deallocate() {
// Deallocate all non-null buffers. A buffer may appear in more than one spot
// in the shape (eg, a tuple with a repeated element) so keep track of what
// has been deallocated.
- tensorflow::gtl::FlatSet<void*> deallocated_ptrs;
+ absl::flat_hash_set<void*> deallocated_ptrs;
for (auto& pair : buffers_) {
se::DeviceMemoryBase& memory_base = pair.second;
if (!memory_base.is_null() &&
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index f952e64af2..9199e32d0f 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -89,6 +89,16 @@ class TransferManager {
const LiteralSlice& literal,
const ShapedBuffer& device_buffer);
+ // Hint type given to TransferLiteralToDeviceAsync.
+ enum TransferToDeviceHint {
+ // No hint available.
+ kNoHint,
+
+ // The destination buffer is undefined on the device, meaning it can be
+ // transferred to eagerly rather than waiting for Stream ordering.
+ kBufferUndefined,
+ };
+
// Transfers the given literal into the previously allocated device memory
// represented by the given ShapedBuffer using the given executor. The shape
// of the ShapedBuffer and DeviceShape(literal.shape()) must be compatible,
@@ -96,9 +106,13 @@ class TransferManager {
//
// This operation is performed asynchronously on the given stream. It returns
// once the transfer is enqueued.
+ //
+ // The optional hint can allow implementations to optimize transfers. It is
+ // not mandatory for an implementation to obey the hint.
virtual Status TransferLiteralToDeviceAsync(
se::Stream* stream, const LiteralSlice& literal,
- const ShapedBuffer& device_buffer) = 0;
+ const ShapedBuffer& device_buffer,
+ TransferToDeviceHint hint = kNoHint) = 0;
// Convenience methods for transferring an array to or from the device at a
// known address. This avoids having to construct a ShapedBuffer just to
diff --git a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
index a9e8a51e09..64ad1dc80e 100644
--- a/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
+++ b/tensorflow/compiler/xla/service/tuple_points_to_analysis.h
@@ -36,8 +36,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/compactptrset.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
index 56145822be..067cfcc17d 100644
--- a/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
+++ b/tensorflow/compiler/xla/service/while_loop_constant_sinking.cc
@@ -18,7 +18,6 @@ limitations under the License.
#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
diff --git a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
index e8fe33e626..9795b2830b 100644
--- a/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
+++ b/tensorflow/compiler/xla/service/while_loop_invariant_code_motion.cc
@@ -15,18 +15,18 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/while_loop_invariant_code_motion.h"
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/container/inlined_vector.h"
#include "tensorflow/compiler/xla/service/tuple_util.h"
#include "tensorflow/compiler/xla/service/while_util.h"
#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
namespace xla {
+using absl::flat_hash_map;
+using absl::flat_hash_set;
using absl::InlinedVector;
-using tensorflow::gtl::FlatMap;
-using tensorflow::gtl::FlatSet;
// Copies `to_hoist` to the computation containing `while_instr`, hoisting its
// operands as needed. All of its transitive operands are expected to be either
@@ -34,8 +34,8 @@ using tensorflow::gtl::FlatSet;
// function hoists the operands in `unhoisted_invariant_instructions` and moves
// them into `hoisted_instructions`.
static void CreateLoopInvariantCopy(
- FlatMap<HloInstruction*, HloInstruction*>* hoisted_instructions,
- FlatSet<HloInstruction*>* unhoisted_invariant_instructions,
+ flat_hash_map<HloInstruction*, HloInstruction*>* hoisted_instructions,
+ flat_hash_set<HloInstruction*>* unhoisted_invariant_instructions,
HloInstruction* while_instr, HloInstruction* to_hoist) {
HloComputation* parent_of_while = while_instr->parent();
HloComputation* while_body = while_instr->while_body();
@@ -147,13 +147,13 @@ WhileLoopInvariantCodeMotion::TryHoistingInvariantInstructionsFromWhileBody(
// Maps instructions in the while body to instructions hoisted outside the
// while that compute the same value.
- FlatMap<HloInstruction*, HloInstruction*> hoisted_instructions;
+ flat_hash_map<HloInstruction*, HloInstruction*> hoisted_instructions;
// Contains instructions that can be legally hoisted, but were deemed to be
// unprofitable to be hoisted alone by NotWorthHoistingIndividually. When we
// hoist an instruction in this set, we move it from
// unhoisted_invariant_instructions to hoisted_instructions.
- FlatSet<HloInstruction*> unhoisted_invariant_instructions;
+ flat_hash_set<HloInstruction*> unhoisted_invariant_instructions;
// Invariant GTE's axiomatically satisfy the constraints for
// unhoisted_invariant_instructions -- they can be legally hoisted, but there
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index 9a74f22395..630d71e5ca 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -14,12 +14,13 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/while_loop_simplifier.h"
+#include "absl/container/flat_hash_map.h"
+#include "absl/container/flat_hash_set.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_join.h"
#include "absl/types/optional.h"
#include "tensorflow/compiler/xla/service/call_inliner.h"
#include "tensorflow/compiler/xla/service/while_loop_analysis.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
namespace xla {
@@ -114,7 +115,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
return false;
}
- tensorflow::gtl::FlatSet<int64> used_tuple_indices;
+ absl::flat_hash_set<int64> used_tuple_indices;
for (HloComputation* comp : {while_body, while_cond}) {
// The HLO verifier ensures that while_input's shape matches while_init's
// shape, which we verified above is a tuple.
@@ -181,7 +182,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
used_tuple_indices.end());
std::sort(new_to_old_tuple_idx.begin(), new_to_old_tuple_idx.end());
- tensorflow::gtl::FlatMap<int64, int64> old_to_new_tuple_idx;
+ absl::flat_hash_map<int64, int64> old_to_new_tuple_idx;
for (int64 new_idx = 0; new_idx < new_to_old_tuple_idx.size(); ++new_idx) {
int64 old_idx = new_to_old_tuple_idx[new_idx];
old_to_new_tuple_idx[old_idx] = new_idx;
@@ -405,7 +406,7 @@ static StatusOr<bool> TryPropagateConstant(HloInstruction* while_op) {
// build a map from the tuple element index to the constant value. Limit this
// to scalar constant values because propagating array constants can regress
// performance by forcing us to copy constants.
- tensorflow::gtl::FlatMap<int, const HloInstruction*> index_to_constant;
+ absl::flat_hash_map<int, const HloInstruction*> index_to_constant;
for (int i = 0; i < root_operands.size(); i++) {
HloInstruction* instr = root_operands[i];
if (instr->opcode() == HloOpcode::kGetTupleElement &&
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 020c167ee9..476a9fe868 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -831,7 +831,8 @@ StatusOr<Shape> ParseShapeStringInternal(absl::string_view* s) {
/* static */ Status ShapeUtil::ValidateShapeWithOptionalLayoutInternal(
const Shape& shape) {
- if (shape.element_type() == PRIMITIVE_TYPE_INVALID) {
+ if (shape.element_type() == PRIMITIVE_TYPE_INVALID ||
+ !PrimitiveType_IsValid(shape.element_type())) {
return InvalidArgument("shape has invalid element type: %s",
shape.ShortDebugString());
}
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index f474ecb18c..8a0ae33042 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -422,6 +422,7 @@ xla_test(
"//tensorflow/core:regexp_internal",
"//tensorflow/core:test",
"@com_google_absl//absl/algorithm:container",
+ "@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/strings",
],
)
@@ -2145,11 +2146,11 @@ xla_test(
":test_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:xla_builder",
- "//tensorflow/compiler/xla/client:xla_computation",
"//tensorflow/compiler/xla/service:hlo_parser",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
"//tensorflow/core:test",
+ "@com_google_absl//absl/container:flat_hash_set",
],
)
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index 0171f51583..6c0847a875 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -394,6 +394,10 @@ class ParametricDotTestWithoutLayoutAssignment : public ParametricDotTest {
ParametricDotTestWithoutLayoutAssignment() {
execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
"layout-assignment");
+ // Disable algebraic simplification because the pass may replace a dot
+ // instruction with a layout-changing multiplication instruction.
+ execution_options_.mutable_debug_options()->add_xla_disable_hlo_passes(
+ "algsimp");
}
};
@@ -404,31 +408,18 @@ std::vector<DotTestParam> CreateNoLayoutAssignmentDotTestParameters() {
for (bool lhs_row_major : {true, false}) {
for (bool rhs_row_major : {true, false}) {
for (bool has_addend : {true, false}) {
+ // The addend needs to be row major to match the result of the dot.
params.push_back({/*m=*/1, /*k=*/k, /*n=*/n,
/*dot_lhs_row_major=*/lhs_row_major,
/*dot_rhs_row_major=*/rhs_row_major,
/*has_addend=*/has_addend,
/*addend_row_major=*/true});
- if (has_addend) {
- params.push_back({/*m=*/1, /*k=*/k, /*n=*/n,
- /*dot_lhs_row_major=*/lhs_row_major,
- /*dot_rhs_row_major=*/rhs_row_major,
- /*has_addend=*/has_addend,
- /*addend_row_major=*/false});
- }
if (n != 1) {
params.push_back({/*m=*/n, /*k=*/k, /*n=*/1,
/*dot_lhs_row_major=*/lhs_row_major,
/*dot_rhs_row_major=*/rhs_row_major,
/*has_addend=*/has_addend,
/*addend_row_major=*/true});
- if (has_addend) {
- params.push_back({/*m=*/n, /*k=*/k, /*n=*/1,
- /*dot_lhs_row_major=*/lhs_row_major,
- /*dot_rhs_row_major=*/rhs_row_major,
- /*has_addend=*/has_addend,
- /*addend_row_major=*/false});
- }
}
}
}
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index 181e5cbe29..bc433eac8f 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
+#include "absl/container/flat_hash_set.h"
#include "tensorflow/compiler/xla/client/xla_builder.h"
#include "tensorflow/compiler/xla/service/hlo_parser.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -145,7 +146,7 @@ ENTRY %sort.148.1589 (parameter.0: f32[1048576], parameter.1: s32[1048576]) -> (
ASSERT_EQ(args.size(), 2);
const Literal& key_arg = args[0];
- tensorflow::gtl::FlatSet<uint32> key_set;
+ absl::flat_hash_set<uint32> key_set;
for (const float& value : key_arg.data<float>()) {
EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
}
@@ -168,7 +169,7 @@ ENTRY %sort.148.1589 (parameter.0: s32[1048576], parameter.1: s32[1048576]) -> (
ASSERT_EQ(args.size(), 2);
const Literal& key_arg = args[0];
- tensorflow::gtl::FlatSet<int32> key_set;
+ absl::flat_hash_set<int32> key_set;
for (const int32& value : key_arg.data<int32>()) {
EXPECT_TRUE(key_set.insert(tensorflow::bit_cast<uint32>(value)).second);
}
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index db5a824de0..a6e70eb6ca 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include <vector>
#include "absl/algorithm/container.h"
+#include "absl/container/flat_hash_map.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_split.h"
@@ -32,7 +33,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/compiler/xla/tests/test_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
-#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -83,7 +83,7 @@ struct ParsedProfileOutputLine {
Status ParseOneProfileOutputLine(
const string& line, bool expect_hlo,
- gtl::FlatMap<string, ParsedProfileOutputLine>* parsed_results,
+ absl::flat_hash_map<string, ParsedProfileOutputLine>* parsed_results,
absl::Span<const absl::string_view> opcodes_to_ignore = {}) {
string separator = "[^:]*:: +";
string match_percentage = R"(\d+\.\d*% +\d+Σ)";
@@ -208,7 +208,7 @@ XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
std::vector<string> profile_output_lines =
absl::StrSplit(profile_output, '\n');
- gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
+ absl::flat_hash_map<string, ParsedProfileOutputLine> parsed_profile_lines;
TF_ASSERT_OK(ParseOneProfileOutputLine(
profile_output_lines[1], /*expect_hlo=*/false, &parsed_profile_lines));
@@ -314,7 +314,7 @@ XLA_TEST_F(HloProfileTest, ProfileWhileComputation) {
ASSERT_NE(while_body_profile_end, profile_output_lines.end());
- gtl::FlatMap<string, ParsedProfileOutputLine> parsed_profile_lines;
+ absl::flat_hash_map<string, ParsedProfileOutputLine> parsed_profile_lines;
for (auto while_body_profile_i = while_body_profile_start + 1;
while_body_profile_i != while_body_profile_end; while_body_profile_i++) {
diff --git a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
index fda4c31298..40ec1b0ba9 100644
--- a/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
+++ b/tensorflow/compiler/xrt/ops/xrt_execute_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
namespace tensorflow {
REGISTER_OP("XRTExecute")
- .Attr("Ninputs: int")
+ .Attr("Ninputs: int >= 0")
.Input("computation_handle: int64")
.Input("execution_config: string")
.Input("input_handles: Ninputs * int64")
diff --git a/tensorflow/compiler/xrt/tests/raw_api_test.cc b/tensorflow/compiler/xrt/tests/raw_api_test.cc
index 2952feb16a..f590fbf0d9 100644
--- a/tensorflow/compiler/xrt/tests/raw_api_test.cc
+++ b/tensorflow/compiler/xrt/tests/raw_api_test.cc
@@ -108,6 +108,14 @@ bool CompareLiteralToLiteralProto(const xla::Literal& a,
return equal;
}
+xla::XlaComputation OnePlusTwo() {
+ xla::XlaBuilder builder("OnePlusTwo");
+ auto c0 = xla::ConstantR0(&builder, 1.0f);
+ auto c1 = xla::ConstantR0(&builder, 2.0f);
+ xla::Add(c0, c1);
+ return builder.Build().ValueOrDie();
+}
+
xla::XlaComputation AddAndScale() {
xla::XlaBuilder builder("AddAndScale");
auto p0 = xla::Parameter(&builder, 0,
@@ -346,6 +354,39 @@ TEST(RawApiTest, CompileAndExecute) {
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
+TEST(RawApiTest, CompileAndExecuteZeroArg) {
+ xrt::XLAComputation c;
+ auto config = c.mutable_config();
+ auto shapes = config->mutable_program_shape();
+ *shapes->mutable_result() = xla::ShapeUtil::MakeShape(xla::F32, {});
+
+ xrt::XRTExecutionConfig e;
+ e.set_release_input_handles(true);
+ e.set_release_compilation_handle(true);
+ StoreComputationSnapshot(OnePlusTwo(), c.mutable_hlo_snapshot());
+
+ Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
+ auto e_config =
+ ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
+ auto computation =
+ ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
+ auto c_handle = ops::XRTCompile(root, computation);
+ auto result = ops::XRTExecute(root, c_handle, e_config,
+ std::initializer_list<Input>({}));
+ auto read_back = ops::XRTReadLiteralAndRelease(root, result);
+ TF_ASSERT_OK(root.status());
+
+ ClientSession session(root);
+ std::vector<Tensor> outputs;
+ TF_EXPECT_OK(session.Run({read_back}, &outputs));
+
+ xla::LiteralProto response;
+ EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
+
+ auto expected = xla::LiteralUtil::CreateR0<float>(3.0f);
+ EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
+}
+
TEST(RawApiTest, CompileAndExecuteReturnTuple) {
xrt::XLAAllocation p0;
p0.set_device_ordinal(0);
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index 1a9ae8ac3a..fbe0573d5d 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -123,6 +123,11 @@ py_library(
"//tensorflow/contrib/tensorrt:init_py",
"//tensorflow/contrib/ffmpeg:ffmpeg_ops_py",
],
+ }) + select({
+ "//tensorflow:with_ignite_support": [
+ "//tensorflow/contrib/ignite",
+ ],
+ "//conditions:default": [],
}),
)
@@ -132,7 +137,6 @@ cc_library(
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_kernels",
"//tensorflow/contrib/coder:all_kernels",
- "//tensorflow/contrib/data/kernels:dataset_kernels",
"//tensorflow/contrib/factorization/kernels:all_kernels",
"//tensorflow/contrib/hadoop:dataset_kernels",
"//tensorflow/contrib/input_pipeline:input_pipeline_ops_kernels",
@@ -163,8 +167,6 @@ cc_library(
deps = [
"//tensorflow/contrib/boosted_trees:boosted_trees_ops_op_lib",
"//tensorflow/contrib/coder:all_ops",
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
"//tensorflow/contrib/factorization:all_ops",
"//tensorflow/contrib/framework:all_ops",
"//tensorflow/contrib/hadoop:dataset_ops_op_lib",
@@ -187,5 +189,10 @@ cc_library(
"//tensorflow/contrib/kinesis:dataset_ops_op_lib",
"//tensorflow/contrib/tensorrt:trt_engine_op_op_lib",
],
+ }) + select({
+ "//tensorflow:with_ignite_support": [
+ "//tensorflow/contrib/ignite:dataset_ops_op_lib",
+ ],
+ "//conditions:default": [],
}),
)
diff --git a/tensorflow/contrib/batching/BUILD b/tensorflow/contrib/batching/BUILD
index b27a19b16c..648f3ebb05 100644
--- a/tensorflow/contrib/batching/BUILD
+++ b/tensorflow/contrib/batching/BUILD
@@ -7,64 +7,6 @@ package(
licenses(["notice"]) # Apache 2.0
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-
-cc_library(
- name = "batch_scheduler_hdrs",
- hdrs = ["batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:batch_scheduler_hdrs",
- ],
-)
-
-cc_library(
- name = "batch_scheduler",
- hdrs = ["batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:batch_scheduler",
- ],
-)
-
-cc_library(
- name = "shared_batch_scheduler_hdrs",
- hdrs = ["shared_batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:shared_batch_scheduler_hdrs",
- ],
-)
-
-cc_library(
- name = "shared_batch_scheduler",
- hdrs = ["shared_batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:shared_batch_scheduler",
- ],
- alwayslink = 1,
-)
-
-cc_library(
- name = "adaptive_shared_batch_scheduler",
- hdrs = ["adaptive_shared_batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:adaptive_shared_batch_scheduler",
- ],
-)
-
-cc_library(
- name = "serial_device_batch_scheduler",
- hdrs = ["serial_device_batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:serial_device_batch_scheduler",
- ],
-)
-
-cc_library(
- name = "basic_batch_scheduler",
- hdrs = ["basic_batch_scheduler.h"],
- deps = [
- "//tensorflow/core/kernels/batching_util:basic_batch_scheduler",
- ],
-)
-
load(
"//tensorflow:tensorflow.bzl",
"py_test",
diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
deleted file mode 100644
index 86250e6692..0000000000
--- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
-#define TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
-
-#include "tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_ADAPTIVE_SHARED_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h
deleted file mode 100644
index d9b37da693..0000000000
--- a/tensorflow/contrib/batching/basic_batch_scheduler.h
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
-#define TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
-
-#include "tensorflow/core/kernels/batching_util/basic_batch_scheduler.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_BASIC_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h
deleted file mode 100644
index 8e94e1fd8b..0000000000
--- a/tensorflow/contrib/batching/batch_scheduler.h
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
-#define TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
-
-#include "tensorflow/core/kernels/batching_util/batch_scheduler.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/serial_device_batch_scheduler.h b/tensorflow/contrib/batching/serial_device_batch_scheduler.h
deleted file mode 100644
index bf6b708361..0000000000
--- a/tensorflow/contrib/batching/serial_device_batch_scheduler.h
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_
-#define TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_
-
-#include "tensorflow/core/kernels/batching_util/serial_device_batch_scheduler.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_SERIAL_DEVICE_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h
deleted file mode 100644
index 83a59695d7..0000000000
--- a/tensorflow/contrib/batching/shared_batch_scheduler.h
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
-#define TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
-
-#include "tensorflow/core/kernels/batching_util/shared_batch_scheduler.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_SHARED_BATCH_SCHEDULER_H_
diff --git a/tensorflow/contrib/batching/test_util/BUILD b/tensorflow/contrib/batching/test_util/BUILD
deleted file mode 100644
index 7cb2d8079b..0000000000
--- a/tensorflow/contrib/batching/test_util/BUILD
+++ /dev/null
@@ -1,19 +0,0 @@
-# Description: Utilities to aid testing.
-
-package(
- default_visibility = ["//tensorflow:internal"],
-)
-
-licenses(["notice"]) # Apache 2.0
-
-exports_files(["LICENSE"])
-
-cc_library(
- name = "fake_clock_env",
- testonly = 1,
- hdrs = ["fake_clock_env.h"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core/kernels/batching_util:fake_clock_env",
- ],
-)
diff --git a/tensorflow/contrib/batching/test_util/fake_clock_env.h b/tensorflow/contrib/batching/test_util/fake_clock_env.h
deleted file mode 100644
index 40a39a5569..0000000000
--- a/tensorflow/contrib/batching/test_util/fake_clock_env.h
+++ /dev/null
@@ -1,21 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
-#define TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
-
-#include "tensorflow/core/kernels/batching_util/fake_clock_env.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_TEST_UTIL_FAKE_CLOCK_ENV_H_
diff --git a/tensorflow/contrib/batching/util/BUILD b/tensorflow/contrib/batching/util/BUILD
deleted file mode 100644
index 8f81b6702f..0000000000
--- a/tensorflow/contrib/batching/util/BUILD
+++ /dev/null
@@ -1,28 +0,0 @@
-# Description: Utilities.
-
-package(
- default_visibility = ["//tensorflow:internal"],
-)
-
-licenses(["notice"]) # Apache 2.0
-
-load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-
-cc_library(
- name = "periodic_function_dynamic",
- hdrs = ["periodic_function.h"],
- visibility = ["//visibility:public"],
- deps = [
- "//tensorflow/core/kernels/batching_util:periodic_function_dynamic",
- "//third_party/eigen3",
- ],
-)
-
-cc_library(
- name = "periodic_function",
- visibility = ["//visibility:public"],
- deps = [
- ":periodic_function_dynamic",
- "//tensorflow/core/kernels/batching_util:periodic_function",
- ],
-)
diff --git a/tensorflow/contrib/batching/util/periodic_function.h b/tensorflow/contrib/batching/util/periodic_function.h
deleted file mode 100644
index aa2ed0a385..0000000000
--- a/tensorflow/contrib/batching/util/periodic_function.h
+++ /dev/null
@@ -1,20 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
-#define TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
-
-#include "tensorflow/core/kernels/batching_util/periodic_function.h"
-
-#endif // TENSORFLOW_CONTRIB_BATCHING_UTIL_PERIODIC_FUNCTION_H_
diff --git a/tensorflow/contrib/bigtable/README.md b/tensorflow/contrib/bigtable/README.md
index f33eaf7e3d..2c44abed5e 100644
--- a/tensorflow/contrib/bigtable/README.md
+++ b/tensorflow/contrib/bigtable/README.md
@@ -203,7 +203,7 @@ def interleave_fn(index):
start = tf.string_join(['training_data_', start_idx_str])
end = tf.string_join(['training_data_', end_idx_str])
return table.scan_range(start_idx, end_idx, columns=columns)
-ds = ds.apply(tf.contrib.data.parallel_interleave(
+ds = ds.apply(tf.data.experimental.parallel_interleave(
interleave_fn, cycle_length=NUM_PARALLEL_READS, prefetch_input_elements=1))
```
@@ -249,7 +249,7 @@ def make_row_key_dataset():
- ...
- fake-data-23498103
"""
- counter_dataset = tf.contrib.data.Counter()
+ counter_dataset = tf.data.experimental.Counter()
width = 8
row_key_prefix = 'fake-data-'
ds = counter_dataset.map(lambda index: tf.as_string(index,
diff --git a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
index cf56822ff4..7c87b0daeb 100644
--- a/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
+++ b/tensorflow/contrib/bigtable/python/ops/bigtable_api.py
@@ -31,8 +31,8 @@ from six import iteritems
from six import string_types
from tensorflow.contrib.bigtable.ops import gen_bigtable_ops
-from tensorflow.contrib.data.python.ops import interleave_ops
from tensorflow.contrib.util import loader
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -228,7 +228,7 @@ class BigtableTable(object):
"""Retrieves a sampling of row keys from the Bigtable table.
This dataset is most often used in conjunction with
- `tf.contrib.data.parallel_interleave` to construct a set of ranges for
+ `tf.data.experimental.parallel_interleave` to construct a set of ranges for
scanning in parallel.
Returns:
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
index 6b6fe9663a..839eedd3a8 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/dnn_tree_combined_estimator_test.py
@@ -188,9 +188,8 @@ class CoreDNNBoostedTreeCombinedTest(test_util.TensorFlowTestCase):
# Train for a few steps.
est.train(input_fn=_train_input_fn, steps=1000)
- # 10 steps for dnn + 3 for 1 tree of depth 3 + 1 after the tree finished
- # + 1 for resource variables.
- self._assert_checkpoint(est.model_dir, global_step=15)
+ # 10 steps for dnn, 3 for 1 tree of depth 3 + 1 after the tree finished
+ self._assert_checkpoint(est.model_dir, global_step=14)
res = est.evaluate(input_fn=_eval_input_fn, steps=1)
self.assertLess(0.5, res["auc"])
est.predict(input_fn=_eval_input_fn)
diff --git a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
index d7b14e00ba..c155128c0e 100644
--- a/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
+++ b/tensorflow/contrib/boosted_trees/estimator_batch/estimator_test.py
@@ -238,8 +238,8 @@ class BoostedTreeEstimatorTest(test_util.TensorFlowTestCase):
output_leaf_index=False)
classifier.fit(input_fn=_train_input_fn, steps=15)
- # When no override of global steps, 6 steps were used.
- self._assert_checkpoint(classifier.model_dir, global_step=6)
+ # When no override of global steps, 5 steps were used.
+ self._assert_checkpoint(classifier.model_dir, global_step=5)
def testOverridesGlobalSteps(self):
learner_config = learner_pb2.LearnerConfig()
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index c7eb2493a8..8531e97f90 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -402,13 +402,13 @@ class GradientBoostedDecisionTreeModel(object):
self._feature_columns = feature_columns
self._learner_config_serialized = learner_config.SerializeToString()
self._num_quantiles = num_quantiles
- self._max_tree_depth = variables.Variable(
+ self._max_tree_depth = variables.VariableV1(
initial_value=self._learner_config.constraints.max_tree_depth)
- self._attempted_trees = variables.Variable(
+ self._attempted_trees = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
trainable=False,
name="attempted_trees")
- self._finalized_trees = variables.Variable(
+ self._finalized_trees = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
trainable=False,
name="finalized_trees")
@@ -770,28 +770,28 @@ class GradientBoostedDecisionTreeModel(object):
fc_name_idx += 1
# Create ensemble stats variables.
- num_layer_examples = variables.Variable(
+ num_layer_examples = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layer_examples",
trainable=False)
- num_layer_steps = variables.Variable(
+ num_layer_steps = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layer_steps",
trainable=False)
- num_layers = variables.Variable(
+ num_layers = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="num_layers",
trainable=False)
- active_tree = variables.Variable(
+ active_tree = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="active_tree",
trainable=False)
- active_layer = variables.Variable(
+ active_layer = variables.VariableV1(
initial_value=array_ops.zeros([], dtypes.int64),
name="active_layer",
trainable=False)
# Variable that becomes false once bias centering is done.
- continue_centering = variables.Variable(
+ continue_centering = variables.VariableV1(
initial_value=self._center_bias,
name="continue_centering",
trainable=False)
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
index 9d9941f696..6d20a2e7f4 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch_test.py
@@ -239,7 +239,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -503,7 +503,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -607,7 +607,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -711,7 +711,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -783,7 +783,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -847,7 +847,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1090,7 +1090,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1194,7 +1194,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1299,7 +1299,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
weights = array_ops.ones([batch_size, 1], dtypes.float32)
partition_ids = array_ops.zeros([batch_size], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1405,7 +1405,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1524,7 +1524,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
@@ -1656,7 +1656,7 @@ class GbdtTest(test_util.TensorFlowTestCase):
predictions = array_ops.constant(
[[0.0], [1.0], [0.0], [2.0]], dtype=dtypes.float32)
partition_ids = array_ops.zeros([4], dtypes.int32)
- ensemble_stamp = variables.Variable(
+ ensemble_stamp = variables.VariableV1(
initial_value=0,
name="ensemble_stamp",
trainable=False,
diff --git a/tensorflow/contrib/cmake/CMakeLists.txt b/tensorflow/contrib/cmake/CMakeLists.txt
index c6d6f04168..f675c135f4 100644
--- a/tensorflow/contrib/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/CMakeLists.txt
@@ -30,7 +30,6 @@ endif()
option(tensorflow_ENABLE_GRPC_SUPPORT "Enable gRPC support" ON)
option(tensorflow_ENABLE_HDFS_SUPPORT "Enable HDFS support" OFF)
-option(tensorflow_ENABLE_JEMALLOC_SUPPORT "Enable jemalloc support" OFF)
option(tensorflow_BUILD_CC_EXAMPLE "Build the C++ tutorial example" ON)
option(tensorflow_BUILD_PYTHON_BINDINGS "Build the Python bindings" ON)
option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON)
@@ -218,10 +217,6 @@ if (tensorflow_WIN_CPU_SIMD_OPTIONS)
endif()
endif()
-if (tensorflow_ENABLE_JEMALLOC_SUPPORT)
- add_definitions(-DTENSORFLOW_USE_JEMALLOC -DJEMALLOC_EXPORT=)
-endif()
-
# External dependencies
include(zlib)
include(gif)
@@ -329,12 +324,6 @@ if(tensorflow_ENABLE_GRPC_SUPPORT)
list(APPEND tensorflow_EXTERNAL_DEPENDENCIES boringssl)
endif()
endif()
-if(tensorflow_ENABLE_JEMALLOC_SUPPORT)
- include(jemalloc)
- list(APPEND tensorflow_EXTERNAL_LIBRARIES ${jemalloc_STATIC_LIBRARIES})
- list(APPEND tensorflow_EXTERNAL_DEPENDENCIES jemalloc)
- include_directories(${jemalloc_INCLUDE_DIRS})
-endif()
if(tensorflow_ENABLE_SNAPPY_SUPPORT)
include(snappy)
list(APPEND tensorflow_EXTERNAL_LIBRARIES ${snappy_STATIC_LIBRARIES})
diff --git a/tensorflow/contrib/cmake/external/jemalloc.cmake b/tensorflow/contrib/cmake/external/jemalloc.cmake
deleted file mode 100644
index afadcc007d..0000000000
--- a/tensorflow/contrib/cmake/external/jemalloc.cmake
+++ /dev/null
@@ -1,50 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-include (ExternalProject)
-
-set(jemalloc_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc/include)
-set(jemalloc_URL https://mirror.bazel.build/github.com/jemalloc/jemalloc-cmake/archive/jemalloc-cmake.4.3.1.tar.gz)
-set(jemalloc_HASH SHA256=f9be9a05fe906deb5c1c8ca818071a7d2e27d66fd87f5ba9a7bf3750bcedeaf0)
-set(jemalloc_BUILD ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc)
-
-if (WIN32)
- set(jemalloc_INCLUDE_DIRS
- ${jemalloc_INCLUDE_DIRS}
- ${CMAKE_CURRENT_BINARY_DIR}/jemalloc/src/jemalloc/include/msvc_compat
- )
- if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
- set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.lib)
- else()
- set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/jemalloc.lib)
- endif()
-else()
- set(jemalloc_STATIC_LIBRARIES ${jemalloc_BUILD}/Release/jemalloc.a)
-endif()
-
-ExternalProject_Add(jemalloc
- PREFIX jemalloc
- URL ${jemalloc_URL}
- URL_HASH ${jemalloc_HASH}
- DOWNLOAD_DIR "${DOWNLOAD_LOCATION}"
- BUILD_IN_SOURCE 1
- BUILD_BYPRODUCTS ${jemalloc_STATIC_LIBRARIES}
- BUILD_COMMAND ${CMAKE_COMMAND} --build . --config Release --target jemalloc
- INSTALL_COMMAND ${CMAKE_COMMAND} -E echo "Skipping install step."
- CMAKE_CACHE_ARGS
- -DCMAKE_BUILD_TYPE:STRING=Release
- -DCMAKE_VERBOSE_MAKEFILE:BOOL=OFF
- -Dwith-jemalloc-prefix:STRING=jemalloc_
- -Dwithout-export:BOOL=ON
-)
diff --git a/tensorflow/contrib/cmake/external/protobuf.cmake b/tensorflow/contrib/cmake/external/protobuf.cmake
index f56fb35a0f..56a57a2340 100644
--- a/tensorflow/contrib/cmake/external/protobuf.cmake
+++ b/tensorflow/contrib/cmake/external/protobuf.cmake
@@ -16,7 +16,7 @@ include (ExternalProject)
set(PROTOBUF_INCLUDE_DIRS ${CMAKE_CURRENT_BINARY_DIR}/protobuf/src/protobuf/src)
set(PROTOBUF_URL https://github.com/google/protobuf.git)
-set(PROTOBUF_TAG v3.6.0)
+set(PROTOBUF_TAG v3.6.1)
if(WIN32)
if(${CMAKE_GENERATOR} MATCHES "Visual Studio.*")
diff --git a/tensorflow/contrib/cmake/make.bat b/tensorflow/contrib/cmake/make.bat
new file mode 100644
index 0000000000..d52b24e01d
--- /dev/null
+++ b/tensorflow/contrib/cmake/make.bat
@@ -0,0 +1,38 @@
+%echo off
+
+cd /d %~dp0
+
+if exist _build rd /s /q _build
+
+mkdir _build
+chdir _build
+
+
+rem cmake ../ -G "Visual Studio 15 Win64" -DCMAKE_GENERATOR_TOOLSET=v141,host=x64 -DCMAKE_INSTALL_PREFIX:PATH=.\install
+
+CALL :NORMALIZEPATH "..\..\..\.."
+SET SOURCE_DIR=%RETVAL%
+
+echo %SOURCE_DIR%
+
+SET SOURCE_DIR=F:\frameworks\tensorflow\
+
+CALL :NORMALIZEPATH "../../../tools/git/gen_git_source.py"
+SET SOURCE_PYTHON_SCRIPT=%RETVAL%
+
+CALL :NORMALIZEPATH "../../../core/util/version_info.cc"
+SET SOURCE_VERSION_CC=%RETVAL%
+
+python %SOURCE_PYTHON_SCRIPT% --raw_generate %SOURCE_VERSION_CC% --source_dir %SOURCE_DIR% --git_tag_override=
+
+cmake ../ -G "Visual Studio 15 Win64" -DCMAKE_GENERATOR_TOOLSET=v141,host=x64 -DCMAKE_INSTALL_PREFIX:PATH=.\install
+
+EXIT /B
+
+:NORMALIZEPATH
+ SET RETVAL=%~dpfn1
+ EXIT /B
+
+
+
+ \ No newline at end of file
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index c0763f4c0e..6e72670142 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -132,10 +132,8 @@ tensorflow/contrib/cudnn_rnn/python
tensorflow/contrib/cudnn_rnn/python/layers
tensorflow/contrib/cudnn_rnn/python/ops
tensorflow/contrib/data
-tensorflow/contrib/data/kernels
tensorflow/contrib/data/python
tensorflow/contrib/data/python/kernel_tests
-tensorflow/contrib/data/python/kernel_tests/serialization
tensorflow/contrib/data/python/ops
tensorflow/contrib/decision_trees
tensorflow/contrib/decision_trees/proto
@@ -207,6 +205,8 @@ tensorflow/contrib/integrate/python
tensorflow/contrib/integrate/python/ops
tensorflow/contrib/kafka/python
tensorflow/contrib/kafka/python/ops
+tensorflow/contrib/ignite/python
+tensorflow/contrib/ignite/python/ops
tensorflow/contrib/keras
tensorflow/contrib/keras/api
tensorflow/contrib/keras/api/keras
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index 067c299a71..7e806685b8 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -258,14 +258,21 @@ add_dependencies(tf_core_lib ${tensorflow_EXTERNAL_DEPENDENCIES} tf_protos_cc)
# force_rebuild always runs forcing ${VERSION_INFO_CC} target to run
# ${VERSION_INFO_CC} would cache, but it depends on a phony never produced
# target.
-set(VERSION_INFO_CC ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc)
-add_custom_target(force_rebuild_target ALL DEPENDS ${VERSION_INFO_CC})
-add_custom_command(OUTPUT __force_rebuild COMMAND ${CMAKE_COMMAND} -E echo)
-add_custom_command(OUTPUT
- ${VERSION_INFO_CC}
- COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/git/gen_git_source.py
- ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} --git_tag_override=${GIT_TAG_OVERRIDE}
- DEPENDS __force_rebuild)
+# This code forces rebuild every time, not needed as version from git is fetched only once
+# move to make.bat which mimicks make.sh
+
+if (NOT WIN32)
+
+ set(VERSION_INFO_CC ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc)
+ add_custom_target(force_rebuild_target ALL DEPENDS ${VERSION_INFO_CC})
+ add_custom_command(OUTPUT __force_rebuild COMMAND ${CMAKE_COMMAND} -E echo)
+ add_custom_command(OUTPUT
+ ${VERSION_INFO_CC}
+ COMMAND ${PYTHON_EXECUTABLE} ${tensorflow_source_dir}/tensorflow/tools/git/gen_git_source.py
+ ARGS --raw_generate ${VERSION_INFO_CC} --source_dir ${tensorflow_source_dir} --git_tag_override=${GIT_TAG_OVERRIDE}
+ DEPENDS __force_rebuild)
+endif()
+
set(tf_version_srcs ${tensorflow_source_dir}/tensorflow/core/util/version_info.cc)
########################################################
diff --git a/tensorflow/contrib/crf/python/ops/crf.py b/tensorflow/contrib/crf/python/ops/crf.py
index 2a91dcb63a..43bb43129b 100644
--- a/tensorflow/contrib/crf/python/ops/crf.py
+++ b/tensorflow/contrib/crf/python/ops/crf.py
@@ -56,7 +56,6 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.layers import utils
from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import rnn
@@ -214,10 +213,11 @@ def crf_log_norm(inputs, sequence_lengths, transition_params):
log_norm)
return log_norm
- max_seq_len = array_ops.shape(inputs)[1]
- return control_flow_ops.cond(pred=math_ops.equal(max_seq_len, 1),
- true_fn=_single_seq_fn,
- false_fn=_multi_seq_fn)
+ return utils.smart_cond(
+ pred=math_ops.equal(inputs.shape[1].value or
+ array_ops.shape(inputs)[1], 1),
+ true_fn=_single_seq_fn,
+ false_fn=_multi_seq_fn)
def crf_log_likelihood(inputs,
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index 9f710613dd..38f1c65a4d 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -4,17 +4,6 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load(
- "//tensorflow:tensorflow.bzl",
- "tf_custom_op_library",
- "tf_gen_op_libs",
- "if_not_windows",
-)
-load(
- "//tensorflow/core:platform/default/build_config_root.bzl",
- "if_static",
-)
-
py_library(
name = "data",
srcs = ["__init__.py"],
@@ -25,30 +14,3 @@ py_library(
"//tensorflow/python:util",
],
)
-
-cc_library(
- name = "lib_proto_parsing_for_dataset_ops",
- deps = if_not_windows(["//tensorflow/core:lib_proto_parsing"]),
-)
-
-tf_custom_op_library(
- name = "_dataset_ops.so",
- srcs = [
- "ops/dataset_ops.cc",
- "ops/indexed_dataset_ops.cc",
- ],
- deps = [
- "//tensorflow/contrib/data/kernels:dataset_kernels",
- "//tensorflow/contrib/data/kernels:indexed_dataset",
- ] + if_static(
- extra_deps = [":lib_proto_parsing_for_dataset_ops"],
- otherwise = [],
- ),
-)
-
-tf_gen_op_libs(
- op_lib_names = [
- "dataset_ops",
- "indexed_dataset_ops",
- ],
-)
diff --git a/tensorflow/contrib/data/README.md b/tensorflow/contrib/data/README.md
index 848782e8d8..90be7a66ca 100644
--- a/tensorflow/contrib/data/README.md
+++ b/tensorflow/contrib/data/README.md
@@ -1,10 +1,12 @@
`tf.contrib.data` API
=====================
-NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead.
-We are continuing to support existing code using the `tf.contrib.data` APIs in
-the current version of TensorFlow, but will eventually remove support. The
-`tf.data` APIs are subject to backwards compatibility guarantees.
+NOTE: The `tf.contrib.data` module has been deprecated. Use `tf.data` instead,
+or `tf.data.experimental` for the experimental transformations previously hosted
+in this module. We are continuing to support existing code using the
+`tf.contrib.data` APIs in the current version of TensorFlow, but will eventually
+remove support. The non-experimental `tf.data` APIs are subject to backwards
+compatibility guarantees.
Porting your code to `tf.data`
------------------------------
@@ -25,13 +27,13 @@ instead apply them using `Dataset.apply()` transformation. The full list of
changes is as follows:
* `dataset.dense_to_sparse_batch(...)` is now
- `dataset.apply(tf.contrib.data.dense_to_sparse_batch(...)`.
+ `dataset.apply(tf.data.experimental.dense_to_sparse_batch(...)`.
* `dataset.enumerate(...)` is now
- `dataset.apply(tf.contrib.data.enumerate_dataset(...))`.
+ `dataset.apply(tf.data.experimental.enumerate_dataset(...))`.
* `dataset.group_by_window(...)` is now
- `dataset.apply(tf.contrib.data.group_by_window(...))`.
+ `dataset.apply(tf.data.experimental.group_by_window(...))`.
* `dataset.ignore_errors()` is now
- `dataset.apply(tf.contrib.data.ignore_errors())`.
+ `dataset.apply(tf.data.experimental.ignore_errors())`.
* `dataset.unbatch()` is now `dataset.apply(tf.contrib.data.unbatch())`.
The `Dataset.make_dataset_resource()` and `Iterator.dispose_op()` methods have
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 3cb51279c3..c3d3e981fa 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -96,10 +96,6 @@ from tensorflow.contrib.data.python.ops.interleave_ops import sample_from_datase
from tensorflow.contrib.data.python.ops.interleave_ops import sloppy_interleave
from tensorflow.contrib.data.python.ops.iterator_ops import CheckpointInputPipelineHook
from tensorflow.contrib.data.python.ops.iterator_ops import make_saveable_from_iterator
-
-# Optimization constant that can be used to enable auto-tuning.
-from tensorflow.contrib.data.python.ops.optimization import AUTOTUNE
-
from tensorflow.contrib.data.python.ops.parsing_ops import parse_example_dataset
from tensorflow.contrib.data.python.ops.prefetching_ops import copy_to_device
from tensorflow.contrib.data.python.ops.prefetching_ops import prefetch_to_device
@@ -114,11 +110,12 @@ from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.scan_ops import scan
from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.contrib.data.python.ops.sliding import sliding_window_batch
-from tensorflow.contrib.data.python.ops.stats_ops import latency_stats
-from tensorflow.contrib.data.python.ops.stats_ops import set_stats_aggregator
-from tensorflow.contrib.data.python.ops.stats_ops import StatsAggregator
from tensorflow.contrib.data.python.ops.unique import unique
from tensorflow.contrib.data.python.ops.writers import TFRecordWriter
+
+# Optimization constant that can be used to enable auto-tuning.
+from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE
+
from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
from tensorflow.python.data.ops.optional_ops import Optional
# pylint: enable=unused-import
diff --git a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc b/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
deleted file mode 100644
index cd9b7c68a0..0000000000
--- a/tensorflow/contrib/data/ops/indexed_dataset_ops.cc
+++ /dev/null
@@ -1,80 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-#include "tensorflow/core/framework/common_shape_fns.h"
-#include "tensorflow/core/framework/op.h"
-
-namespace tensorflow {
-
-REGISTER_OP("IdentityIndexedDataset")
- .Input("size: uint64")
- .Output("handle: variant")
- .SetIsStateful()
- .SetShapeFn(
- shape_inference::ScalarShape); // TODO(saeta): check input shapes.
-
-///////////////////////////////////////////////////////////////////////////////
-// IndexedDataset Internals
-///////////////////////////////////////////////////////////////////////////////
-
-// Creates the handle.
-REGISTER_OP("MaterializedIndexDatasetHandle")
- .Output("handle: resource")
- .Attr("container: string")
- .Attr("shared_name: string")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape);
-
-// Actually materialize the materialize handle.
-REGISTER_OP("IndexedDatasetMaterialize")
- .Input("dataset: variant")
- .Input("materialized: resource")
- .SetShapeFn(shape_inference::NoOutputs);
-
-namespace {
-
-Status GetShapeFn(shape_inference::InferenceContext* c) {
- shape_inference::ShapeHandle unused;
- TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
- std::vector<PartialTensorShape> output_shapes;
- TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
- if (output_shapes.size() != c->num_outputs()) {
- return errors::InvalidArgument(
- "`output_shapes` must be the same length as `output_types` (",
- output_shapes.size(), " vs. ", c->num_outputs());
- }
- for (size_t i = 0; i < output_shapes.size(); ++i) {
- shape_inference::ShapeHandle output_shape_handle;
- TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
- output_shapes[i], &output_shape_handle));
- c->set_output(static_cast<int>(i), output_shape_handle);
- }
- return Status::OK();
-}
-
-} // namespace
-
-REGISTER_OP("IndexedDatasetGet")
- .Input("materialized: resource")
- .Input("index: uint64")
- .Output("components: output_types")
- .Attr("output_types: list(type) >= 1")
- .Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(GetShapeFn)
- .Doc(R"doc(
-Gets the element at `index` from `materialized` IndexedDataset.
-)doc");
-
-} // namespace tensorflow
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index ce52c990ce..42f538b4ba 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -8,193 +8,26 @@ load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "py_test")
py_test(
- name = "batch_dataset_op_test",
- size = "medium",
- srcs = ["batch_dataset_op_test.py"],
+ name = "assert_element_shape_test",
+ srcs = ["assert_element_shape_test.py"],
srcs_version = "PY2AND3",
- tags = [
- "no_oss", # (b/79552534)
- "no_pip",
- ],
deps = [
"//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
"//tensorflow/python:script_ops",
- "//tensorflow/python:session",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "bucketing_test",
- size = "medium",
- srcs = ["bucketing_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
"//tensorflow/python:string_ops",
"//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
)
py_test(
- name = "csv_dataset_op_test",
- size = "medium",
- srcs = ["csv_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:error_ops",
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:platform_test",
- "//tensorflow/python:session",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/eager:context",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "dataset_constructor_op_test",
- size = "medium",
- srcs = ["dataset_constructor_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "manual",
- "nomac", # b/62040583
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
-py_test(
- name = "directed_interleave_dataset_test",
- size = "medium",
- srcs = ["directed_interleave_dataset_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:random_seed",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "get_single_element_test",
- size = "small",
- srcs = ["get_single_element_test.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:get_single_element",
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "indexed_dataset_ops_test",
- srcs = ["indexed_dataset_ops_test.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
- "//tensorflow/contrib/data/python/ops:gen_dataset_ops",
- "//tensorflow/contrib/data/python/ops:indexed_dataset_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "interleave_dataset_op_test",
- size = "medium",
- srcs = ["interleave_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_oss",
- "no_pip",
- "notap",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:script_ops",
- "//tensorflow/python:sparse_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "@six_archive//:six",
- ],
-)
-
-py_test(
- name = "iterator_ops_test",
- size = "small",
- srcs = ["iterator_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python:variables",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/estimator:estimator_py",
- ],
-)
-
-py_test(
name = "lmdb_dataset_op_test",
size = "medium",
srcs = ["lmdb_dataset_op_test.py"],
@@ -215,247 +48,24 @@ py_test(
"//tensorflow/python:platform",
"//tensorflow/python:platform_test",
"//tensorflow/python:session",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//third_party/py/numpy",
],
)
py_test(
- name = "map_dataset_op_test",
- size = "medium",
- srcs = ["map_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "noasan", # times out
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:error_ops",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "filter_dataset_op_test",
- size = "medium",
- srcs = ["filter_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:io_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "map_defun_op_test",
- size = "small",
- srcs = ["map_defun_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:map_defun",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:data_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python:functional_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:session",
- ],
-)
-
-py_test(
- name = "parsing_ops_test",
- size = "small",
- srcs = ["parsing_ops_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:parsing_ops",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:platform",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//third_party/py/numpy",
- ],
-)
-
-cuda_py_test(
- name = "prefetching_ops_test",
- size = "small",
- srcs = ["prefetching_ops_test.py"],
- additional_deps = [
- "//tensorflow/contrib/data/python/ops:prefetching_ops",
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:function",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python/compat:compat",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- ],
- tags = ["no_windows_gpu"],
-)
-
-py_test(
- name = "range_dataset_op_test",
+ name = "reduce_dataset_test",
size = "small",
- srcs = ["range_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/contrib/data/python/ops:counter",
- "//tensorflow/contrib/data/python/ops:enumerate_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_library(
- name = "reader_dataset_ops_test_base",
- testonly = 1,
- srcs = [
- "reader_dataset_ops_test_base.py",
- ],
- srcs_version = "PY2AND3",
- visibility = [
- "//tensorflow/contrib/data/python/kernel_tests:__pkg__",
- "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__",
- ],
+ srcs = ["reduce_dataset_test.py"],
deps = [
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/core:protos_all_py",
+ "//tensorflow/contrib/data/python/ops:get_single_element",
+ "//tensorflow/contrib/data/python/ops:grouping",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:lib",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/ops:readers",
- ],
-)
-
-py_test(
- name = "reader_dataset_ops_test",
- size = "medium",
- srcs = ["reader_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":reader_dataset_ops_test_base",
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:string_ops",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/data/util:nest",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "resample_test",
- size = "medium",
- srcs = ["resample_test.py"],
- shard_count = 2,
- srcs_version = "PY2AND3",
- tags = [
- "noasan",
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:resampling",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python:string_ops",
- "//tensorflow/python:util",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
- "@six_archive//:six",
- ],
-)
-
-py_test(
- name = "scan_dataset_op_test",
- size = "small",
- srcs = ["scan_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:scan_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/eager:context",
- "//third_party/py/numpy",
- ],
-)
-
-py_test(
- name = "shuffle_dataset_op_test",
- size = "medium",
- srcs = ["shuffle_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- "optonly",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:shuffle_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
],
)
@@ -471,151 +81,9 @@ py_test(
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
],
)
-
-py_library(
- name = "sql_dataset_op_test_base",
- srcs = ["sql_dataset_op_test_base.py"],
- srcs_version = "PY2AND3",
- visibility = [
- "//tensorflow/contrib/data/python/kernel_tests:__pkg__",
- "//tensorflow/contrib/data/python/kernel_tests/serialization:__pkg__",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:readers",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "@org_sqlite//:python",
- ],
-)
-
-py_test(
- name = "sql_dataset_op_test",
- size = "small",
- srcs = ["sql_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":sql_dataset_op_test_base",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- ],
-)
-
-py_test(
- name = "stats_dataset_ops_test",
- size = "medium",
- srcs = ["stats_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- ":reader_dataset_ops_test_base",
- ":stats_dataset_test_base",
- "//tensorflow/contrib/data/python/ops:stats_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- ],
-)
-
-py_library(
- name = "stats_dataset_test_base",
- srcs = ["stats_dataset_test_base.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/core:protos_all_py",
- "//tensorflow/python:client_testlib",
- ],
-)
-
-py_test(
- name = "threadpool_dataset_ops_test",
- size = "small",
- srcs = ["threadpool_dataset_ops_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:threadpool",
- "//tensorflow/contrib/data/python/ops:unique",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:script_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "unique_dataset_op_test",
- size = "small",
- srcs = ["unique_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = ["no_pip"],
- deps = [
- "//tensorflow/contrib/data/python/ops:unique",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:errors",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-py_test(
- name = "window_dataset_op_test",
- size = "medium",
- srcs = ["window_dataset_op_test.py"],
- srcs_version = "PY2AND3",
- tags = [
- "no_pip",
- ],
- deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:grouping",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
- "@absl_py//absl/testing:parameterized",
- ],
-)
-
-py_test(
- name = "writer_ops_test",
- size = "small",
- srcs = ["writer_ops_test.py"],
- deps = [
- "//tensorflow/contrib/data/python/ops:writers",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:lib",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:readers",
- ],
-)
-
-py_library(
- name = "test_utils",
- srcs = ["test_utils.py"],
- deps = [
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:errors",
- "//tensorflow/python/data/util:nest",
- ],
-)
diff --git a/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py
new file mode 100644
index 0000000000..0456463a19
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/assert_element_shape_test.py
@@ -0,0 +1,226 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import script_ops
+from tensorflow.python.platform import test
+
+
+class AssertElementShapeTest(test_base.DatasetTestBase):
+
+ def test_assert_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(expected_shapes, dataset.output_shapes)
+
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(3).map(create_dataset)
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ with self.assertRaises(ValueError):
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+
+ def test_assert_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 10)))
+ iterator = (
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+ def test_assert_partial_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(5).map(create_dataset)
+ partial_expected_shape = (
+ tensor_shape.TensorShape(None), # Unknown shape
+ tensor_shape.TensorShape((None, 4))) # Partial shape
+ result = dataset.apply(
+ batching.assert_element_shape(partial_expected_shape))
+ # Partial shapes are merged with actual shapes:
+ actual_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((3, 4)))
+ self.assertEqual(actual_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_partial_element_shape(self):
+
+ def create_dataset(_):
+ return (array_ops.ones(2, dtype=dtypes.float32),
+ array_ops.zeros((3, 4), dtype=dtypes.int32))
+
+ dataset = dataset_ops.Dataset.range(3).map(create_dataset)
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 10)))
+ with self.assertRaises(ValueError):
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+
+ def test_assert_partial_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ expected_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 4)))
+ result = dataset.apply(batching.assert_element_shape(expected_shapes))
+ self.assertEqual(expected_shapes, result.output_shapes)
+
+ iterator = result.make_initializable_iterator()
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ for _ in range(5):
+ sess.run(get_next)
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next)
+
+ def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self):
+
+ def create_unknown_shape_dataset(x):
+ return script_ops.py_func(
+ lambda _: ( # pylint: disable=g-long-lambda
+ np.ones(2, dtype=np.float32),
+ np.zeros((3, 4), dtype=np.int32)),
+ [x],
+ [dtypes.float32, dtypes.int32])
+
+ dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
+ unknown_shapes = (tensor_shape.TensorShape(None),
+ tensor_shape.TensorShape(None))
+ self.assertEqual(unknown_shapes, dataset.output_shapes)
+
+ wrong_shapes = (tensor_shape.TensorShape(2),
+ tensor_shape.TensorShape((None, 10)))
+ iterator = (
+ dataset.apply(batching.assert_element_shape(wrong_shapes))
+ .make_initializable_iterator())
+ init_op = iterator.initializer
+ get_next = iterator.get_next()
+ with self.cached_session() as sess:
+ sess.run(init_op)
+ with self.assertRaises(errors.InvalidArgumentError):
+ sess.run(get_next)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
index 1cc5ddc9a2..d2a72272db 100644
--- a/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/lmdb_dataset_op_test.py
@@ -22,6 +22,7 @@ import os
import shutil
from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,7 @@ from tensorflow.python.util import compat
prefix_path = "tensorflow/core/lib"
-class LMDBDatasetTest(test.TestCase):
+class LMDBDatasetTest(test_base.DatasetTestBase):
def setUp(self):
super(LMDBDatasetTest, self).setUp()
diff --git a/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py
new file mode 100644
index 0000000000..e7281d5318
--- /dev/null
+++ b/tensorflow/contrib/data/python/kernel_tests/reduce_dataset_test.py
@@ -0,0 +1,62 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for the experimental input pipeline ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from absl.testing import parameterized
+import numpy as np
+
+from tensorflow.contrib.data.python.ops import get_single_element
+from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.platform import test
+
+
+class ReduceDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
+
+ @parameterized.named_parameters(
+ ("SumZero", 0),
+ ("SumOne", 1),
+ ("SumFive", 5),
+ ("SumTen", 10),
+ )
+ def testReduceDataset(self, stop):
+ def init_fn(_):
+ return np.int64(0)
+
+ def reduce_fn(state, value):
+ return state + value
+
+ def finalize_fn(state):
+ return state
+
+ sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
+
+ stop_t = array_ops.placeholder(dtypes.int64, shape=[])
+ dataset = dataset_ops.Dataset.range(stop_t)
+ element = get_single_element.reduce_dataset(dataset, sum_reducer)
+
+ with self.cached_session() as sess:
+ value = sess.run(element, feed_dict={stop_t: stop})
+ self.assertEqual(stop * (stop - 1) / 2, value)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
index 90d18dca2a..c5a7862322 100644
--- a/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/slide_dataset_op_test.py
@@ -21,6 +21,7 @@ from absl.testing import parameterized
import numpy as np
from tensorflow.contrib.data.python.ops import sliding
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class SlideDatasetTest(test.TestCase, parameterized.TestCase):
+class SlideDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("1", 20, 14, 7, 1),
@@ -197,11 +198,6 @@ class SlideDatasetTest(test.TestCase, parameterized.TestCase):
sliding.sliding_window_batch(
window_size=1, stride=1, window_shift=1, window_stride=1))
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
-
def testSlideSparse(self):
def _sparse(i):
diff --git a/tensorflow/contrib/data/python/kernel_tests/test_utils.py b/tensorflow/contrib/data/python/kernel_tests/test_utils.py
deleted file mode 100644
index 4c3353fe40..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/test_utils.py
+++ /dev/null
@@ -1,73 +0,0 @@
-# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Test utilities for tf.data functionality."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import re
-
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import errors
-from tensorflow.python.platform import test
-
-
-class DatasetTestBase(test.TestCase):
- """Base class for dataset tests."""
-
- def _assert_datasets_equal(self, dataset1, dataset2):
- # TODO(rachelim): support sparse tensor outputs
- next1 = dataset1.make_one_shot_iterator().get_next()
- next2 = dataset2.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- while True:
- try:
- op1 = sess.run(next1)
- except errors.OutOfRangeError:
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next2)
- break
- op2 = sess.run(next2)
-
- op1 = nest.flatten(op1)
- op2 = nest.flatten(op2)
- assert len(op1) == len(op2)
- for i in range(len(op1)):
- self.assertAllEqual(op1[i], op2[i])
-
- def _assert_datasets_raise_same_error(self,
- dataset1,
- dataset2,
- exception_class,
- replacements=None):
- # We are defining next1 and next2 in the same line so that we get identical
- # file:line_number in the error messages
- # pylint: disable=line-too-long
- next1, next2 = dataset1.make_one_shot_iterator().get_next(), dataset2.make_one_shot_iterator().get_next()
- # pylint: enable=line-too-long
- with self.cached_session() as sess:
- try:
- sess.run(next1)
- raise ValueError(
- "Expected dataset to raise an error of type %s, but it did not." %
- repr(exception_class))
- except exception_class as e:
- expected_message = e.message
- for old, new, count in replacements:
- expected_message = expected_message.replace(old, new, count)
- # Check that the first segment of the error messages are the same.
- with self.assertRaisesRegexp(exception_class,
- re.escape(expected_message)):
- sess.run(next2)
diff --git a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
deleted file mode 100644
index 8b7b3ac0f7..0000000000
--- a/tensorflow/contrib/data/python/kernel_tests/window_dataset_op_test.py
+++ /dev/null
@@ -1,526 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ==============================================================================
-"""Tests for the experimental input pipeline ops."""
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-from absl.testing import parameterized
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import grouping
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.platform import test
-
-
-class WindowDatasetTest(test.TestCase, parameterized.TestCase):
-
- def _structuredDataset(self, structure, shape, dtype):
- if structure is None:
- return dataset_ops.Dataset.from_tensors(
- array_ops.zeros(shape, dtype=dtype))
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredDataset(substructure, shape, dtype)
- for substructure in structure
- ]))
-
- def _structuredElement(self, structure, shape, dtype):
- if structure is None:
- return array_ops.zeros(shape, dtype=dtype)
- else:
- return tuple([
- self._structuredElement(substructure, shape, dtype)
- for substructure in structure
- ])
-
- def _assertEqual(self, xs, ys):
- self.assertEqual(type(xs), type(ys))
- if isinstance(xs, tuple) and isinstance(ys, tuple):
- self.assertEqual(len(xs), len(ys))
- for x, y in zip(xs, ys):
- self._assertEqual(x, y)
- elif isinstance(xs, np.ndarray) and isinstance(ys, np.ndarray):
- self.assertAllEqual(xs, ys)
- else:
- self.assertEqual(xs, ys)
-
- @parameterized.named_parameters(
- ("1", None, np.int32([]), dtypes.bool),
- ("2", None, np.int32([]), dtypes.int32),
- ("3", None, np.int32([]), dtypes.float32),
- ("4", None, np.int32([]), dtypes.string),
- ("5", None, np.int32([2]), dtypes.int32),
- ("6", None, np.int32([2, 2]), dtypes.int32),
- ("7", (None, None, None), np.int32([]), dtypes.int32),
- ("8", (None, (None, None)), np.int32([]), dtypes.int32),
- )
- def testWindowDatasetFlatMap(self, structure, shape, dtype):
- """Tests windowing by chaining it with flat map.
-
- Args:
- structure: the input structure
- shape: the input shape
- dtype: the input data type
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return args[0]
- return dataset_ops.Dataset.zip(
- tuple([fn(*arg) if isinstance(arg, tuple) else arg for arg in args]))
-
- dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
- grouping.window_dataset(5)).flat_map(fn)
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(self._structuredElement(structure, shape, dtype))
- for _ in range(5):
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", None, np.int32([]), dtypes.bool),
- ("2", None, np.int32([]), dtypes.int32),
- ("3", None, np.int32([]), dtypes.float32),
- ("4", None, np.int32([]), dtypes.string),
- ("5", None, np.int32([2]), dtypes.int32),
- ("6", None, np.int32([2, 2]), dtypes.int32),
- ("7", (None, None, None), np.int32([]), dtypes.int32),
- ("8", (None, (None, None)), np.int32([]), dtypes.int32),
- )
- def testWindowDatasetBatchDense(self, structure, shape, dtype):
- """Tests batching of dense tensor windows.
-
- Args:
- structure: the input structure
- shape: the input shape
- dtype: the input data type
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.batch_window(args[0])
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg)
- for arg in args
- ])
-
- dataset = self._structuredDataset(structure, shape, dtype).repeat(5).apply(
- grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(
- self._structuredElement(structure, np.concatenate(
- ([5], shape), axis=0), dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([])),
- ("2", np.int32([1])),
- ("3", np.int32([1, 2, 3])),
- )
- def testWindowDatasetBatchDenseDynamicShape(self, shape):
- """Tests batching of dynamically shaped dense tensor windows.
-
- Args:
- shape: the input shape
- """
-
- shape_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensors(
- array_ops.zeros(shape_t)).repeat(5).apply(
- grouping.window_dataset(5)).apply(
- grouping._map_x_dataset(batching.batch_window))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shape_t: shape})
- expected = sess.run(
- self._structuredElement(None, np.concatenate(([5], shape), axis=0),
- dtypes.int32))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- def _make_dense_to_sparse_fn(self, is_scalar):
-
- def dense_to_sparse_scalar(tensor):
- indices = [[]]
- values = array_ops.expand_dims(tensor, 0)
- shape = []
- return sparse_tensor.SparseTensorValue(indices, values, shape)
-
- def dense_to_sparse_non_scalar(tensor):
- indices = array_ops.where(array_ops.ones_like(tensor, dtype=dtypes.bool))
- values = array_ops.gather_nd(tensor, indices)
- shape = array_ops.shape(tensor, out_type=dtypes.int64)
- return sparse_tensor.SparseTensorValue(indices, values, shape)
-
- if is_scalar:
- return dense_to_sparse_scalar
- return dense_to_sparse_non_scalar
-
- def _structuredSparseDataset(self, structure, shape, dtype):
- dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
- if structure is None:
- return dataset_ops.Dataset.from_tensors(
- dense_to_sparse(array_ops.zeros(shape, dtype=dtype)))
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredSparseDataset(substructure, shape, dtype)
- for substructure in structure
- ]))
-
- def _structuredSparseElement(self, structure, shape, dtype):
- dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
- if structure is None:
- return dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
- else:
- return tuple([
- self._structuredSparseElement(substructure, shape, dtype)
- for substructure in structure
- ])
-
- @parameterized.named_parameters(
- ("1", None, np.int32([]), dtypes.bool),
- ("2", None, np.int32([]), dtypes.int32),
- ("3", None, np.int32([]), dtypes.float32),
- ("4", None, np.int32([]), dtypes.string),
- ("5", None, np.int32([2]), dtypes.int32),
- ("6", None, np.int32([2, 2]), dtypes.int32),
- ("7", (None, None, None), np.int32([]), dtypes.int32),
- ("8", (None, (None, None)), np.int32([]), dtypes.int32),
- )
- def testWindowDatasetBatchSparse(self, structure, shape, dtype):
- """Tests batching of sparse tensor windows.
-
- Args:
- structure: the input structure
- shape: the input shape
- dtype: the input data type
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.batch_window(args[0])
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.batch_window(arg)
- for arg in args
- ])
-
- dataset = self._structuredSparseDataset(
- structure, shape, dtype).repeat(5).apply(
- grouping.window_dataset(5)).apply(grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(
- self._structuredSparseElement(structure,
- np.concatenate(([5], shape), axis=0),
- dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([])),
- ("2", np.int32([1])),
- ("3", np.int32([1, 2, 3])),
- )
- def testWindowDatasetBatchSparseDynamicShape(self, shape):
- """Tests batching of dynamically shaped sparse tensor windows.
-
- Args:
- shape: the input shape
- """
-
- shape_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensors(array_ops.zeros(shape_t)).map(
- self._make_dense_to_sparse_fn(len(shape) == 0)).repeat(5).apply( # pylint: disable=g-explicit-length-test
- grouping.window_dataset(5)).apply(
- grouping._map_x_dataset(batching.batch_window))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shape_t: shape})
- expected = sess.run(
- self._structuredSparseElement(None,
- np.concatenate(([5], shape), axis=0),
- dtypes.int32))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- def _structuredRaggedDataset(self, structure, shapes, dtype):
-
- if structure is None:
- return dataset_ops.Dataset.from_tensor_slices(shapes).map(
- lambda shape: array_ops.zeros(shape, dtype=dtype))
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredRaggedDataset(substructure, shapes, dtype)
- for substructure in structure
- ]))
-
- @parameterized.named_parameters(
- ("1", None, np.int32([[1], [2], [3]]), dtypes.bool, [-1]),
- ("2", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("3", None, np.int32([[1], [2], [3]]), dtypes.float32, [-1]),
- ("4", None, np.int32([[1], [2], [3]]), dtypes.string, [-1]),
- ("5", None, np.int32([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- ("6", None, np.int32([[3, 1, 3], [1, 3, 1]]), dtypes.int32, [-1, -1, -1]),
- ("7", (None, None, None), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("8", (None,
- (None, None)), np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("9", None, np.int32([[1], [2], [3]]), dtypes.int32, [-1]),
- ("10", None, np.int32([[1], [2], [3]]), dtypes.int32, np.int32([10])),
- )
- def testWindowDatasetPaddedBatchDense(self, structure, shapes, dtype,
- padded_shape):
- """Tests padded batching of dense tensor windows.
-
- Args:
- structure: the input structure
- shapes: the input shapes
- dtype: the input data type
- padded_shape: the shape to pad the output to
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.padded_batch_window(args[0], padded_shape)
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window(
- arg, padded_shape) for arg in args
- ])
-
- dataset = self._structuredRaggedDataset(structure, shapes, dtype).apply(
- grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
- expected = sess.run(
- self._structuredElement(
- structure,
- np.concatenate((np.int32([len(shapes)]), expected_shape)), dtype))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([[1], [2], [3]]), [-1]),
- ("2", np.int32([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- ("3", np.int32([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
- )
- def testWindowDatasetPaddedBatchDenseDynamicShape(self, shapes, padded_shape):
- """Tests padded batching of dynamically shaped dense tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- shapes_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply(
- grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shapes_t: shapes})
- expected_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
- expected = sess.run(
- self._structuredElement(
- None, np.concatenate((np.int32([len(shapes)]), expected_shape)),
- dtypes.int32))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int32([[1]]), np.int32([0])),
- ("2", np.int32([[10], [20]]), np.int32([15])),
- )
- def testWindowDatasetPaddedBatchDenseInvalid(self, shapes, padded_shape):
- """Tests invalid padded batching of dense tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).apply(
- grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- def _structuredRaggedSparseDataset(self, structure, shapes, dtype):
-
- def map_fn(shape):
- dense_to_sparse = self._make_dense_to_sparse_fn(False)
- return dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
-
- if structure is None:
- return dataset_ops.Dataset.from_tensor_slices(shapes).map(map_fn)
- else:
- return dataset_ops.Dataset.zip(
- tuple([
- self._structuredRaggedSparseDataset(substructure, shapes, dtype)
- for substructure in structure
- ]))
-
- def _structuredRaggedSparseElement(self, structure, shapes, dtype,
- padded_shape):
- if structure is None:
- dense_shape = np.maximum(np.amax(shapes, axis=0), padded_shape)
- values = []
- for shape in shapes:
- dense_to_sparse = self._make_dense_to_sparse_fn(len(shape) == 0) # pylint: disable=g-explicit-length-test
- sparse = dense_to_sparse(array_ops.zeros(shape, dtype=dtype))
- padded_sparse = sparse_tensor.SparseTensor(sparse.indices,
- sparse.values, dense_shape)
- reshaped_sparse = sparse_ops.sparse_reshape(
- padded_sparse,
- array_ops.concat([np.array([1], dtype=np.int64), dense_shape], 0))
- values.append(reshaped_sparse)
- return sparse_ops.sparse_concat(0, values)
- else:
- return tuple([
- self._structuredRaggedSparseElement(substructure, shapes, dtype,
- padded_shape)
- for substructure in structure
- ])
-
- @parameterized.named_parameters(
- ("1", None, np.int64([[1], [2], [3]]), dtypes.bool, [-1]),
- ("2", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("3", None, np.int64([[1], [2], [3]]), dtypes.float32, [-1]),
- ("4", None, np.int64([[1], [2], [3]]), dtypes.string, [-1]),
- ("5", None, np.int64([[1, 3], [2, 2], [3, 1]]), dtypes.int32, [-1, -1]),
- ("6", None, np.int64([[1, 3, 1], [3, 1, 3]]), dtypes.int32, [-1, -1, -1]),
- ("7", (None, None, None), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("8", (None,
- (None, None)), np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("9", None, np.int64([[1], [2], [3]]), dtypes.int32, [-1]),
- ("10", None, np.int64([[1], [2], [3]]), dtypes.int32, np.int64([10])),
- )
- def testWindowDatasetPaddedBatchSparse(self, structure, shapes, dtype,
- padded_shape):
- """Tests padded batching of sparse tensor windows.
-
- Args:
- structure: the input structure
- shapes: the input shapes
- dtype: the input data type
- padded_shape: the shape to pad the output to
- """
-
- def fn(*args):
- if len(args) == 1 and not isinstance(args[0], tuple):
- return batching.padded_batch_window(args[0], padded_shape)
-
- return tuple([
- fn(*arg) if isinstance(arg, tuple) else batching.padded_batch_window(
- arg, padded_shape) for arg in args
- ])
-
- dataset = self._structuredRaggedSparseDataset(
- structure, shapes, dtype).apply(grouping.window_dataset(
- len(shapes))).apply(grouping._map_x_dataset(fn))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- expected = sess.run(
- self._structuredRaggedSparseElement(structure, shapes, dtype,
- padded_shape))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int64([[1], [2], [3]]), [-1]),
- ("2", np.int64([[1, 3], [2, 2], [3, 1]]), [-1, -1]),
- ("3", np.int64([[3, 1, 3], [1, 3, 1]]), [-1, -1, -1]),
- )
- def testWindowDatasetPaddedBatchSparseDynamicShape(self, shapes,
- padded_shape):
- """Tests padded batching of dynamically shaped sparse tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- shapes_t = array_ops.placeholder(dtypes.int32)
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes_t).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map(
- self._make_dense_to_sparse_fn(False)
- ).apply(grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- iterator = dataset.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op, {shapes_t: shapes})
- expected = sess.run(
- self._structuredRaggedSparseElement(None, shapes, dtypes.int32,
- padded_shape))
- actual = sess.run(get_next)
- self._assertEqual(expected, actual)
-
- @parameterized.named_parameters(
- ("1", np.int64([[1]]), [0]),
- ("2", np.int64([[10], [20]]), [15]),
- )
- def testWindowDatasetPaddedBatchSparseInvalid(self, shapes, padded_shape):
- """Tests invalid padded batching of sparse tensor windows.
-
- Args:
- shapes: the input shapes
- padded_shape: the shape to pad the output to
- """
-
- dataset = dataset_ops.Dataset.from_tensor_slices(shapes).map(
- lambda shape: array_ops.zeros(shape, dtype=dtypes.int32)).map(
- self._make_dense_to_sparse_fn(False)
- ).apply(grouping.window_dataset(len(shapes))).apply(
- grouping._map_x_dataset(
- lambda x: batching.padded_batch_window(x, padded_shape)))
- get_next = dataset.make_one_shot_iterator().get_next()
- with self.cached_session() as sess:
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
-
-if __name__ == "__main__":
- test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index a14781cd93..34dc2379d0 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -16,10 +16,7 @@ py_library(
srcs = ["counter.py"],
srcs_version = "PY2AND3",
deps = [
- ":scan_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:counter",
],
)
@@ -28,12 +25,7 @@ py_library(
srcs = ["get_single_element.py"],
srcs_version = "PY2AND3",
deps = [
- ":grouping",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- "//third_party/py/numpy",
+ "//tensorflow/python/data/experimental/ops:get_single_element",
],
)
@@ -44,10 +36,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:training",
- "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
],
)
@@ -58,15 +47,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:constant_op",
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:random_seed",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:random_ops",
],
)
@@ -78,19 +59,19 @@ py_library(
srcs_version = "PY2AND3",
deps = [
":batching",
- ":gen_dataset_ops",
":interleave_ops",
- ":optimization",
":parsing_ops",
":shuffle_ops",
"//tensorflow/python:constant_op",
"//tensorflow/python:dataset_ops_gen",
"//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
"//tensorflow/python:framework_ops",
"//tensorflow/python:lib",
"//tensorflow/python:platform",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:readers",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:readers",
"//tensorflow/python/data/util:convert",
@@ -106,7 +87,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:shuffle_ops",
],
)
@@ -125,6 +106,7 @@ py_library(
"//tensorflow/python:math_ops",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/util:convert",
"//tensorflow/python/data/util:nest",
@@ -138,8 +120,7 @@ py_library(
srcs = ["enumerate_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:enumerate_ops",
],
)
@@ -148,11 +129,7 @@ py_library(
srcs = ["error_ops.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:error_ops",
],
)
@@ -161,16 +138,7 @@ py_library(
srcs = ["grouping.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:array_ops",
- "//tensorflow/python:check_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:grouping",
],
)
@@ -179,32 +147,7 @@ py_library(
srcs = ["interleave_ops.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
- ":random_ops",
- "//tensorflow/contrib/stateless",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:util",
- "//tensorflow/python/data/ops:readers",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- ],
-)
-
-py_library(
- name = "optimization",
- srcs = ["optimization.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
],
)
@@ -213,25 +156,7 @@ py_library(
srcs = ["parsing_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:parsing_ops",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python:tensor_shape",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- ],
-)
-
-py_library(
- name = "map_defun",
- srcs = ["map_defun.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:parsing_ops",
],
)
@@ -240,18 +165,7 @@ py_library(
srcs = ["resampling.py"],
srcs_version = "PY2AND3",
deps = [
- ":batching",
- ":interleave_ops",
- ":scan_ops",
- "//tensorflow/python:array_ops",
- "//tensorflow/python:control_flow_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:logging_ops",
- "//tensorflow/python:math_ops",
- "//tensorflow/python:random_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//third_party/py/numpy",
+ "//tensorflow/python/data/experimental/ops:resampling",
],
)
@@ -260,12 +174,7 @@ py_library(
srcs = ["scan_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:function",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:scan_ops",
],
)
@@ -285,32 +194,11 @@ py_library(
)
py_library(
- name = "stats_ops",
- srcs = ["stats_ops.py"],
- srcs_version = "PY2AND3",
- deps = [
- "//tensorflow/python:dataset_ops_gen",
- "//tensorflow/python:dtypes",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/ops:iterator_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- ],
-)
-
-py_library(
name = "threadpool",
srcs = ["threadpool.py"],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
- "//tensorflow/python:resource_variable_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
- "//tensorflow/python/eager:context",
+ "//tensorflow/python/data/experimental/ops:threadpool",
],
)
@@ -321,12 +209,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:unique",
],
)
@@ -337,56 +220,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/python:dtypes",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-tf_gen_op_wrapper_py(
- name = "gen_dataset_ops",
- out = "gen_dataset_ops.py",
- deps = [
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
- ],
-)
-
-tf_kernel_library(
- name = "dataset_ops_kernels",
- deps = [
- "//tensorflow/contrib/data/kernels:dataset_kernels",
- "//tensorflow/core:framework",
- ],
- alwayslink = 1,
-)
-
-tf_custom_op_py_library(
- name = "contrib_op_loader",
- srcs = ["contrib_op_loader.py"],
- dso = ["//tensorflow/contrib/data:_dataset_ops.so"],
- kernels = [
- ":dataset_ops_kernels",
- "//tensorflow/contrib/data:indexed_dataset_ops_op_lib",
- "//tensorflow/contrib/data:dataset_ops_op_lib",
- ],
- srcs_version = "PY2AND3",
- deps = [
- ":gen_dataset_ops",
- "//tensorflow/contrib/util:util_py",
- "//tensorflow/python:platform",
- ],
-)
-
-py_library(
- name = "indexed_dataset_ops",
- srcs = ["indexed_dataset_ops.py"],
- deps = [
- ":contrib_op_loader",
- ":gen_dataset_ops",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:writers",
],
)
@@ -394,11 +228,7 @@ py_library(
name = "prefetching_ops",
srcs = ["prefetching_ops.py"],
deps = [
- ":contrib_op_loader",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python/data/ops:dataset_ops",
- "//tensorflow/python/data/util:nest",
- "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
],
)
@@ -411,17 +241,14 @@ py_library(
":error_ops",
":get_single_element",
":grouping",
- ":indexed_dataset_ops",
":interleave_ops",
- ":map_defun",
- ":optimization",
":prefetching_ops",
+ ":random_ops",
":readers",
":resampling",
":scan_ops",
":shuffle_ops",
":sliding",
- ":stats_ops",
":threadpool",
":unique",
":writers",
diff --git a/tensorflow/contrib/data/python/ops/batching.py b/tensorflow/contrib/data/python/ops/batching.py
index 7a0f221284..8c60459ca8 100644
--- a/tensorflow/contrib/data/python/ops/batching.py
+++ b/tensorflow/contrib/data/python/ops/batching.py
@@ -17,134 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import get_single_element
-from tensorflow.contrib.data.python.ops import grouping
from tensorflow.contrib.framework import with_shape
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import convert
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import sparse_ops
from tensorflow.python.util import deprecation
-def batch_window(dataset):
- """Batches a window of tensors.
-
- Args:
- dataset: the input dataset.
-
- Returns:
- A `Tensor` representing the batch of the entire input dataset.
- """
- if isinstance(dataset.output_classes, tuple):
- raise TypeError("Input dataset expected to have a single component")
- if dataset.output_classes is ops.Tensor:
- return _batch_dense_window(dataset)
- elif dataset.output_classes is sparse_tensor.SparseTensor:
- return _batch_sparse_window(dataset)
- else:
- raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
-
-
-def _batch_dense_window(dataset):
- """Batches a window of dense tensors."""
-
- def key_fn(_):
- return np.int64(0)
-
- def shape_init_fn(_):
- return array_ops.shape(first_element)
-
- def shape_reduce_fn(state, value):
- check_ops.assert_equal(state, array_ops.shape(value))
- return state
-
- def finalize_fn(state):
- return state
-
- if dataset.output_shapes.is_fully_defined():
- shape = dataset.output_shapes
- else:
- first_element = get_single_element.get_single_element(dataset.take(1))
- shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
- finalize_fn)
- shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
-
- def batch_init_fn(_):
- batch_shape = array_ops.concat([[0], shape], 0)
- return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
-
- def batch_reduce_fn(state, value):
- return array_ops.concat([state, [value]], 0)
-
- batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer)))
-
-
-def _batch_sparse_window(dataset):
- """Batches a window of sparse tensors."""
-
- def key_fn(_):
- return np.int64(0)
-
- def shape_init_fn(_):
- return first_element.dense_shape
-
- def shape_reduce_fn(state, value):
- check_ops.assert_equal(state, value.dense_shape)
- return state
-
- def finalize_fn(state):
- return state
-
- if dataset.output_shapes.is_fully_defined():
- shape = dataset.output_shapes
- else:
- first_element = get_single_element.get_single_element(dataset.take(1))
- shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
- finalize_fn)
- shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
-
- def batch_init_fn(_):
- indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0)
- return sparse_tensor.SparseTensor(
- indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
- values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
- dense_shape=array_ops.concat(
- [np.array([0], dtype=np.int64),
- math_ops.cast(shape, dtypes.int64)], 0))
-
- def batch_reduce_fn(state, value):
- return sparse_ops.sparse_concat(0, [state, value])
-
- def reshape_fn(value):
- return sparse_ops.sparse_reshape(
- value,
- array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0))
-
- batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.map(reshape_fn).apply(
- grouping.group_by_reducer(key_fn, batch_reducer)))
-
-
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.dense_to_sparse_batch(...)`.")
def dense_to_sparse_batch(batch_size, row_shape):
"""A transformation that batches ragged elements into `tf.SparseTensor`s.
@@ -187,201 +67,10 @@ def dense_to_sparse_batch(batch_size, row_shape):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
-
- return _apply_fn
-
-
-def padded_batch_window(dataset, padded_shape, padding_value=None):
- """Batches a window of tensors with padding.
-
- Args:
- dataset: the input dataset.
- padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like
- object representing the shape to which the input elements should be padded
- prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a
- `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the
- maximum size of that dimension in each batch.
- padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the
- padding value to use. Defaults are `0` for numeric types and the empty
- string for string types. If `dataset` contains `tf.SparseTensor`, this
- value is ignored.
-
- Returns:
- A `Tensor` representing the batch of the entire input dataset.
-
- Raises:
- ValueError: if invalid arguments are provided.
- """
- if not issubclass(dataset.output_classes,
- (ops.Tensor, sparse_tensor.SparseTensor)):
- raise TypeError("Input dataset expected to have a single tensor component")
- if issubclass(dataset.output_classes, (ops.Tensor)):
- return _padded_batch_dense_window(dataset, padded_shape, padding_value)
- elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)):
- if padding_value is not None:
- raise ValueError("Padding value not allowed for sparse tensors")
- return _padded_batch_sparse_window(dataset, padded_shape)
- else:
- raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
-
-
-def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
- """Batches a window of dense tensors with padding."""
-
- padded_shape = math_ops.cast(
- convert.partial_shape_to_tensor(padded_shape), dtypes.int32)
-
- def key_fn(_):
- return np.int64(0)
-
- def max_init_fn(_):
- return padded_shape
-
- def max_reduce_fn(state, value):
- """Computes the maximum shape to pad to."""
- condition = math_ops.reduce_all(
- math_ops.logical_or(
- math_ops.less_equal(array_ops.shape(value), padded_shape),
- math_ops.equal(padded_shape, -1)))
- assert_op = control_flow_ops.Assert(condition, [
- "Actual shape greater than padded shape: ",
- array_ops.shape(value), padded_shape
- ])
- with ops.control_dependencies([assert_op]):
- return math_ops.maximum(state, array_ops.shape(value))
-
- def finalize_fn(state):
- return state
-
- # Compute the padded shape.
- max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
- padded_shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
-
- if padding_value is None:
- if dataset.output_types == dtypes.string:
- padding_value = ""
- elif dataset.output_types == dtypes.bool:
- padding_value = False
- elif dataset.output_types == dtypes.variant:
- raise TypeError("Unable to create padding for field of type 'variant'")
- else:
- padding_value = 0
-
- def batch_init_fn(_):
- batch_shape = array_ops.concat(
- [np.array([0], dtype=np.int32), padded_shape], 0)
- return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
-
- def batch_reduce_fn(state, value):
- return array_ops.concat([state, [value]], 0)
-
- def pad_fn(value):
- shape = array_ops.shape(value)
- left = array_ops.zeros_like(shape)
- right = padded_shape - shape
- return array_ops.pad(
- value, array_ops.stack([left, right], 1), constant_values=padding_value)
-
- batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.map(pad_fn).apply(
- grouping.group_by_reducer(key_fn, batch_reducer)))
-
-
-def _padded_batch_sparse_window(dataset, padded_shape):
- """Batches a window of sparse tensors with padding."""
-
- def key_fn(_):
- return np.int64(0)
-
- def max_init_fn(_):
- return convert.partial_shape_to_tensor(padded_shape)
-
- def max_reduce_fn(state, value):
- """Computes the maximum shape to pad to."""
- condition = math_ops.reduce_all(
- math_ops.logical_or(
- math_ops.less_equal(value.dense_shape, padded_shape),
- math_ops.equal(padded_shape, -1)))
- assert_op = control_flow_ops.Assert(condition, [
- "Actual shape greater than padded shape: ", value.dense_shape,
- padded_shape
- ])
- with ops.control_dependencies([assert_op]):
- return math_ops.maximum(state, value.dense_shape)
-
- def finalize_fn(state):
- return state
-
- # Compute the padded shape.
- max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
- padded_shape = get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
-
- def batch_init_fn(_):
- indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]],
- 0)
- return sparse_tensor.SparseTensor(
- indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
- values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
- dense_shape=array_ops.concat(
- [np.array([0], dtype=np.int64), padded_shape], 0))
-
- def batch_reduce_fn(state, value):
- padded_value = sparse_tensor.SparseTensor(
- indices=value.indices, values=value.values, dense_shape=padded_shape)
- reshaped_value = sparse_ops.sparse_reshape(
- padded_value,
- array_ops.concat(
- [np.array([1], dtype=np.int64), padded_value.dense_shape], 0))
- return sparse_ops.sparse_concat(0, [state, reshaped_value])
-
- reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
- return get_single_element.get_single_element(
- dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
-
-
-class _UnbatchDataset(dataset_ops.UnaryDataset):
- """A dataset that splits the elements of its input into multiple elements."""
-
- def __init__(self, input_dataset):
- """See `unbatch()` for more details."""
- super(_UnbatchDataset, self).__init__(input_dataset)
- flat_shapes = nest.flatten(input_dataset.output_shapes)
- if any(s.ndims == 0 for s in flat_shapes):
- raise ValueError("Cannot unbatch an input with scalar components.")
- known_batch_dim = tensor_shape.Dimension(None)
- for s in flat_shapes:
- try:
- known_batch_dim = known_batch_dim.merge_with(s[0])
- except ValueError:
- raise ValueError("Cannot unbatch an input whose components have "
- "different batch sizes.")
- self._input_dataset = input_dataset
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.unbatch_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return nest.map_structure(lambda s: s[1:],
- self._input_dataset.output_shapes)
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+ return batching.dense_to_sparse_batch(batch_size, row_shape)
+@deprecation.deprecated(None, "Use `tf.data.experimental.unbatch()`.")
def unbatch():
"""Splits elements of a dataset into multiple elements on the batch dimension.
@@ -403,39 +92,7 @@ def unbatch():
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- if not sparse.any_sparse(dataset.output_classes):
- return _UnbatchDataset(dataset)
-
- # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
- # are normalized to the rank-1 dense representation, so that the
- # sparse-oblivious unbatching logic will slice them
- # appropriately. This leads to a somewhat inefficient re-encoding step
- # for all SparseTensor components.
- # TODO(mrry): Consider optimizing this in future
- # if it turns out to be a bottleneck.
- def normalize(arg, *rest):
- if rest:
- return sparse.serialize_many_sparse_tensors((arg,) + rest)
- else:
- return sparse.serialize_many_sparse_tensors(arg)
-
- normalized_dataset = dataset.map(normalize)
-
- # NOTE(mrry): Our `map()` has lost information about the sparseness
- # of any SparseTensor components, so re-apply the structure of the
- # original dataset.
- restructured_dataset = _RestructuredDataset(
- normalized_dataset,
- dataset.output_types,
- dataset.output_shapes,
- dataset.output_classes,
- allow_unsafe_cast=True)
- return _UnbatchDataset(restructured_dataset)
-
- return _apply_fn
+ return batching.unbatch()
@deprecation.deprecated(
@@ -514,135 +171,8 @@ def padded_batch_and_drop_remainder(batch_size,
return _apply_fn
-class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
-
- def __init__(self, input_dataset, batch_size, row_shape):
- """See `Dataset.dense_to_sparse_batch()` for more details."""
- super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
- if not isinstance(input_dataset.output_types, dtypes.DType):
- raise TypeError("DenseToSparseDataset requires an input whose elements "
- "have a single component, whereas the input has %r." %
- input_dataset.output_types)
- self._input_dataset = input_dataset
- self._batch_size = batch_size
- self._row_shape = row_shape
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.dense_to_sparse_batch_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._batch_size,
- row_shape=convert.partial_shape_to_tensor(self._row_shape),
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return sparse_tensor.SparseTensor
-
- @property
- def output_shapes(self):
- return tensor_shape.vector(None).concatenate(self._row_shape)
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _RestructuredDataset(dataset_ops.UnaryDataset):
- """An internal helper for changing the structure and shape of a dataset."""
-
- def __init__(self,
- dataset,
- output_types,
- output_shapes=None,
- output_classes=None,
- allow_unsafe_cast=False):
- """Creates a new dataset with the given output types and shapes.
-
- The given `dataset` must have a structure that is convertible:
- * `dataset.output_types` must be the same as `output_types` module nesting.
- * Each shape in `dataset.output_shapes` must be compatible with each shape
- in `output_shapes` (if given).
-
- Note: This helper permits "unsafe casts" for shapes, equivalent to using
- `tf.Tensor.set_shape()` where domain-specific knowledge is available.
-
- Args:
- dataset: A `Dataset` object.
- output_types: A nested structure of `tf.DType` objects.
- output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
- If omitted, the shapes will be inherited from `dataset`.
- output_classes: (Optional.) A nested structure of class types.
- If omitted, the class types will be inherited from `dataset`.
- allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
- reported output types and shapes of the restructured dataset, e.g. to
- switch a sparse tensor represented as `tf.variant` to its user-visible
- type and shape.
-
- Raises:
- ValueError: If either `output_types` or `output_shapes` is not compatible
- with the structure of `dataset`.
- """
- super(_RestructuredDataset, self).__init__(dataset)
- self._input_dataset = dataset
-
- if not allow_unsafe_cast:
- # Validate that the types are compatible.
- output_types = nest.map_structure(dtypes.as_dtype, output_types)
- flat_original_types = nest.flatten(dataset.output_types)
- flat_new_types = nest.flatten(output_types)
- if flat_original_types != flat_new_types:
- raise ValueError(
- "Dataset with output types %r cannot be restructured to have "
- "output types %r" % (dataset.output_types, output_types))
-
- self._output_types = output_types
-
- if output_shapes is None:
- # Inherit shapes from the original `dataset`.
- self._output_shapes = nest.pack_sequence_as(output_types,
- nest.flatten(
- dataset.output_shapes))
- else:
- if not allow_unsafe_cast:
- # Validate that the shapes are compatible.
- nest.assert_same_structure(output_types, output_shapes)
- flat_original_shapes = nest.flatten(dataset.output_shapes)
- flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
-
- for original_shape, new_shape in zip(flat_original_shapes,
- flat_new_shapes):
- if not original_shape.is_compatible_with(new_shape):
- raise ValueError(
- "Dataset with output shapes %r cannot be restructured to have "
- "incompatible output shapes %r" % (dataset.output_shapes,
- output_shapes))
- self._output_shapes = nest.map_structure_up_to(
- output_types, tensor_shape.as_shape, output_shapes)
- if output_classes is None:
- # Inherit class types from the original `dataset`.
- self._output_classes = nest.pack_sequence_as(output_types,
- nest.flatten(
- dataset.output_classes))
- else:
- self._output_classes = output_classes
-
- def _as_variant_tensor(self):
- return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
-
+# TODO(b/116817045): Move this to `tf.data.experimental` when the `with_shape()`
+# function is available in the core.
def assert_element_shape(expected_shapes):
"""Assert the shape of this `Dataset`.
@@ -687,7 +217,8 @@ def assert_element_shape(expected_shapes):
def _apply_fn(dataset):
output_shapes = _merge_output_shapes(dataset.output_shapes,
expected_shapes)
- return _RestructuredDataset(
+ # pylint: disable=protected-access
+ return batching._RestructuredDataset(
dataset.map(_check_shape),
dataset.output_types,
output_shapes=output_shapes,
@@ -696,49 +227,7 @@ def assert_element_shape(expected_shapes):
return _apply_fn
-class _MapAndBatchDataset(dataset_ops.MapDataset):
- """A `Dataset` that maps a function over a batch of elements."""
-
- def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
- drop_remainder):
- """See `Dataset.map()` for details."""
- super(_MapAndBatchDataset, self).__init__(input_dataset, map_func)
- self._batch_size_t = ops.convert_to_tensor(
- batch_size, dtype=dtypes.int64, name="batch_size")
- self._num_parallel_calls_t = ops.convert_to_tensor(
- num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
- self._drop_remainder_t = ops.convert_to_tensor(
- drop_remainder, dtype=dtypes.bool, name="drop_remainder")
-
- self._batch_size = batch_size
- self._drop_remainder = drop_remainder
-
- def _as_variant_tensor(self):
- # pylint: disable=protected-access
- input_resource = self._input_dataset._as_variant_tensor()
- return gen_dataset_ops.map_and_batch_dataset_v2(
- input_resource,
- self._map_func.captured_inputs,
- f=self._map_func,
- batch_size=self._batch_size_t,
- num_parallel_calls=self._num_parallel_calls_t,
- drop_remainder=self._drop_remainder_t,
- **dataset_ops.flat_structure(self))
- # pylint: enable=protected-access
-
- @property
- def output_shapes(self):
- dim = self._batch_size if self._drop_remainder else None
- return nest.pack_sequence_as(self._output_shapes, [
- tensor_shape.vector(dim).concatenate(s)
- for s in nest.flatten(self._output_shapes)
- ])
-
- @property
- def output_types(self):
- return self._output_types
-
-
+@deprecation.deprecated(None, "Use `tf.data.experimental.map_and_batch(...)`.")
def map_and_batch(map_func,
batch_size,
num_parallel_batches=None,
@@ -779,17 +268,5 @@ def map_and_batch(map_func,
ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
specified.
"""
-
- if num_parallel_batches is None and num_parallel_calls is None:
- num_parallel_calls = batch_size
- elif num_parallel_batches is not None and num_parallel_calls is None:
- num_parallel_calls = batch_size * num_parallel_batches
- elif num_parallel_batches is not None and num_parallel_calls is not None:
- raise ValueError("The `num_parallel_batches` and `num_parallel_calls` "
- "arguments are mutually exclusive.")
-
- def _apply_fn(dataset):
- return _MapAndBatchDataset(dataset, map_func, batch_size,
- num_parallel_calls, drop_remainder)
-
- return _apply_fn
+ return batching.map_and_batch(map_func, batch_size, num_parallel_batches,
+ drop_remainder, num_parallel_calls)
diff --git a/tensorflow/contrib/data/python/ops/counter.py b/tensorflow/contrib/data/python/ops/counter.py
index 6ef65f9624..4ff5bf3e39 100644
--- a/tensorflow/contrib/data/python/ops/counter.py
+++ b/tensorflow/contrib/data/python/ops/counter.py
@@ -17,13 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import scan_ops
-
-from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.Counter(...)`.")
def Counter(start=0, step=1, dtype=dtypes.int64):
"""Creates a `Dataset` that counts from `start` in steps of size `step`.
@@ -46,8 +45,4 @@ def Counter(start=0, step=1, dtype=dtypes.int64):
Returns:
A `Dataset` of scalar `dtype` elements.
"""
- with ops.name_scope("counter"):
- start = ops.convert_to_tensor(start, dtype=dtype, name="start")
- step = ops.convert_to_tensor(step, dtype=dtype, name="step")
- return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
- scan_ops.scan(start, lambda state, _: (state + step, state)))
+ return counter.Counter(start, step, dtype)
diff --git a/tensorflow/contrib/data/python/ops/enumerate_ops.py b/tensorflow/contrib/data/python/ops/enumerate_ops.py
index 490281e0d2..a21da4d3ec 100644
--- a/tensorflow/contrib/data/python/ops/enumerate_ops.py
+++ b/tensorflow/contrib/data/python/ops/enumerate_ops.py
@@ -17,12 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
+from tensorflow.python.data.experimental.ops import enumerate_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.enumerate_dataset(...)`.")
def enumerate_dataset(start=0):
"""A transformation that enumerate the elements of a dataset.
@@ -49,10 +50,4 @@ def enumerate_dataset(start=0):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
- return dataset_ops.Dataset.zip((dataset_ops.Dataset.range(start, max_value),
- dataset))
-
- return _apply_fn
+ return enumerate_ops.enumerate_dataset(start)
diff --git a/tensorflow/contrib/data/python/ops/error_ops.py b/tensorflow/contrib/data/python/ops/error_ops.py
index 615dbcabd4..0559a2e09c 100644
--- a/tensorflow/contrib/data/python/ops/error_ops.py
+++ b/tensorflow/contrib/data/python/ops/error_ops.py
@@ -17,11 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
-from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.ignore_errors()`.")
def ignore_errors():
"""Creates a `Dataset` from another `Dataset` and silently ignores any errors.
@@ -44,34 +44,4 @@ def ignore_errors():
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _IgnoreErrorsDataset(dataset)
-
- return _apply_fn
-
-
-class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that silently ignores errors when computing its input."""
-
- def __init__(self, input_dataset):
- """See `Dataset.ignore_errors()` for details."""
- super(_IgnoreErrorsDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.ignore_errors_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+ return error_ops.ignore_errors()
diff --git a/tensorflow/contrib/data/python/ops/get_single_element.py b/tensorflow/contrib/data/python/ops/get_single_element.py
index a6713b017a..58ad9eea90 100644
--- a/tensorflow/contrib/data/python/ops/get_single_element.py
+++ b/tensorflow/contrib/data/python/ops/get_single_element.py
@@ -19,13 +19,13 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.experimental.ops import get_single_element as experimental_get_single_element
from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.get_single_element(...)`.")
def get_single_element(dataset):
"""Returns the single element in `dataset` as a nested structure of tensors.
@@ -61,18 +61,10 @@ def get_single_element(dataset):
InvalidArgumentError (at runtime): if `dataset` does not contain exactly
one element.
"""
- if not isinstance(dataset, dataset_ops.Dataset):
- raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
-
- nested_ret = nest.pack_sequence_as(
- dataset.output_types, gen_dataset_ops.dataset_to_single_element(
- dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(dataset)))
- return sparse.deserialize_sparse_tensors(
- nested_ret, dataset.output_types, dataset.output_shapes,
- dataset.output_classes)
+ return experimental_get_single_element.get_single_element(dataset)
+@deprecation.deprecated(None, "Use `tf.data.Dataset.reduce(...)`.")
def reduce_dataset(dataset, reducer):
"""Returns the result of reducing the `dataset` using `reducer`.
@@ -90,11 +82,4 @@ def reduce_dataset(dataset, reducer):
if not isinstance(dataset, dataset_ops.Dataset):
raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
- # The sentinel dataset is used in case the reduced dataset is empty.
- sentinel_dataset = dataset_ops.Dataset.from_tensors(
- reducer.finalize_func(reducer.init_func(np.int64(0))))
- reduced_dataset = dataset.apply(
- grouping.group_by_reducer(lambda x: np.int64(0), reducer))
-
- return get_single_element(
- reduced_dataset.concatenate(sentinel_dataset).take(1))
+ return dataset.reduce(reducer.init_func(np.int64(0)), reducer.reduce_func)
diff --git a/tensorflow/contrib/data/python/ops/grouping.py b/tensorflow/contrib/data/python/ops/grouping.py
index 7cae33beb3..a99dc2f29a 100644
--- a/tensorflow/contrib/data/python/ops/grouping.py
+++ b/tensorflow/contrib/data/python/ops/grouping.py
@@ -17,20 +17,13 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import math_ops
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.group_by_reducer(...)`.")
def group_by_reducer(key_func, reducer):
"""A transformation that groups elements and performs a reduction.
@@ -52,14 +45,11 @@ def group_by_reducer(key_func, reducer):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _GroupByReducerDataset(dataset, key_func, reducer)
-
- return _apply_fn
+ return grouping.group_by_reducer(key_func, reducer)
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.group_by_window(...)`.")
def group_by_window(key_func,
reduce_func,
window_size=None,
@@ -98,27 +88,12 @@ def group_by_window(key_func,
ValueError: if neither or both of {`window_size`, `window_size_func`} are
passed.
"""
- if (window_size is not None and window_size_func or
- not (window_size is not None or window_size_func)):
- raise ValueError("Must pass either window_size or window_size_func.")
-
- if window_size is not None:
-
- def constant_window_func(unused_key):
- return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
-
- window_size_func = constant_window_func
-
- assert window_size_func is not None
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _GroupByWindowDataset(dataset, key_func, reduce_func,
- window_size_func)
-
- return _apply_fn
+ return grouping.group_by_window(key_func, reduce_func, window_size,
+ window_size_func)
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.bucket_by_sequence_length(...)`.")
def bucket_by_sequence_length(element_length_func,
bucket_boundaries,
bucket_batch_sizes,
@@ -163,342 +138,12 @@ def bucket_by_sequence_length(element_length_func,
Raises:
ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
"""
- with ops.name_scope("bucket_by_seq_length"):
- if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1):
- raise ValueError(
- "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1")
-
- batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
-
- def element_to_bucket_id(*args):
- """Return int64 id of the length bucket for this element."""
- seq_length = element_length_func(*args)
-
- boundaries = list(bucket_boundaries)
- buckets_min = [np.iinfo(np.int32).min] + boundaries
- buckets_max = boundaries + [np.iinfo(np.int32).max]
- conditions_c = math_ops.logical_and(
- math_ops.less_equal(buckets_min, seq_length),
- math_ops.less(seq_length, buckets_max))
- bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
-
- return bucket_id
-
- def window_size_fn(bucket_id):
- # The window size is set to the batch size for this bucket
- window_size = batch_sizes[bucket_id]
- return window_size
-
- def make_padded_shapes(shapes, none_filler=None):
- padded = []
- for shape in nest.flatten(shapes):
- shape = tensor_shape.TensorShape(shape)
- shape = [
- none_filler if d.value is None else d
- for d in shape
- ]
- padded.append(shape)
- return nest.pack_sequence_as(shapes, padded)
-
- def batching_fn(bucket_id, grouped_dataset):
- """Batch elements in dataset."""
- batch_size = window_size_fn(bucket_id)
- if no_padding:
- return grouped_dataset.batch(batch_size)
- none_filler = None
- if pad_to_bucket_boundary:
- err_msg = ("When pad_to_bucket_boundary=True, elements must have "
- "length < max(bucket_boundaries).")
- check = check_ops.assert_less(
- bucket_id,
- constant_op.constant(len(bucket_batch_sizes) - 1,
- dtype=dtypes.int64),
- message=err_msg)
- with ops.control_dependencies([check]):
- boundaries = constant_op.constant(bucket_boundaries,
- dtype=dtypes.int64)
- bucket_boundary = boundaries[bucket_id]
- none_filler = bucket_boundary - 1
- shapes = make_padded_shapes(
- padded_shapes or grouped_dataset.output_shapes,
- none_filler=none_filler)
- return grouped_dataset.padded_batch(batch_size, shapes, padding_values)
-
- def _apply_fn(dataset):
- return dataset.apply(
- group_by_window(element_to_bucket_id, batching_fn,
- window_size_func=window_size_fn))
-
- return _apply_fn
-
-
-def _map_x_dataset(map_func):
- """A transformation that maps `map_func` across its input.
-
- This transformation is similar to `tf.data.Dataset.map`, but in addition to
- supporting dense and sparse tensor inputs, it also supports dataset inputs.
-
- Args:
- map_func: A function mapping a nested structure of tensors and/or datasets
- (having shapes and types defined by `self.output_shapes` and
- `self.output_types`) to another nested structure of tensors and/or
- datasets.
-
- Returns:
- Dataset: A `Dataset`.
- """
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- return _MapXDataset(dataset, map_func)
-
- return _apply_fn
-
-
-# TODO(b/115382007) Remove this once canned reducers move to core.
-def window_dataset(window_size):
- """A transformation that creates window datasets from the input dataset.
-
- The resulting datasets will contain `window_size` elements (or
- `N % window_size` for the last dataset if `window_size` does not divide the
- number of input elements `N` evenly).
-
- Args:
- window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
- consecutive elements of the input dataset to combine into a window.
-
- Returns:
- Dataset: A `Dataset`.
- """
-
- def _apply_fn(dataset):
- return dataset_ops.WindowDataset(
- dataset,
- size=window_size,
- shift=window_size,
- stride=1,
- drop_remainder=False)
-
- return _apply_fn
-
-
-class _GroupByReducerDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that groups its input and performs a reduction."""
-
- def __init__(self, input_dataset, key_func, reducer):
- """See `group_by_reducer()` for details."""
- super(_GroupByReducerDataset, self).__init__(input_dataset)
+ return grouping.bucket_by_sequence_length(
+ element_length_func, bucket_boundaries, bucket_batch_sizes, padded_shapes,
+ padding_values, pad_to_bucket_boundary, no_padding)
- self._input_dataset = input_dataset
- self._make_key_func(key_func, input_dataset)
- self._make_init_func(reducer.init_func)
- self._make_reduce_func(reducer.reduce_func, input_dataset)
- self._make_finalize_func(reducer.finalize_func)
-
- def _make_key_func(self, key_func, input_dataset):
- """Make wrapping Defun for key_func."""
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- key_func, "tf.contrib.data.group_by_reducer()", input_dataset)
- if not (
- wrapped_func.output_types == dtypes.int64 and
- wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
- raise ValueError(
- "`key_func` must return a single tf.int64 tensor. "
- "Got type=%s and shape=%s"
- % (wrapped_func.output_types, wrapped_func.output_shapes))
- self._key_func = wrapped_func.function
-
- def _make_init_func(self, init_func):
- """Make wrapping Defun for init_func."""
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- init_func, "tf.contrib.data.group_by_reducer()",
- input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
- input_types=dtypes.int64)
- self._init_func = wrapped_func.function
- self._state_classes = wrapped_func.output_classes
- self._state_shapes = wrapped_func.output_shapes
- self._state_types = wrapped_func.output_types
-
- def _make_reduce_func(self, reduce_func, input_dataset):
- """Make wrapping Defun for reduce_func."""
-
- # Iteratively rerun the reduce function until reaching a fixed point on
- # `self._state_shapes`.
- need_to_rerun = True
- while need_to_rerun:
-
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- reduce_func, "tf.contrib.data.group_by_reducer()",
- input_classes=(self._state_classes, input_dataset.output_classes),
- input_shapes=(self._state_shapes, input_dataset.output_shapes),
- input_types=(self._state_types, input_dataset.output_types),
- add_to_graph=False)
-
- # Extract and validate class information from the returned values.
- for new_state_class, state_class in zip(
- nest.flatten(wrapped_func.output_classes),
- nest.flatten(self._state_classes)):
- if not issubclass(new_state_class, state_class):
- raise TypeError(
- "The element classes for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_classes, wrapped_func.output_classes))
-
- # Extract and validate type information from the returned values.
- for new_state_type, state_type in zip(
- nest.flatten(wrapped_func.output_types),
- nest.flatten(self._state_types)):
- if new_state_type != state_type:
- raise TypeError(
- "The element types for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_types, wrapped_func.output_types))
-
- # Extract shape information from the returned values.
- flat_state_shapes = nest.flatten(self._state_shapes)
- flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
- weakened_state_shapes = [
- original.most_specific_compatible_shape(new)
- for original, new in zip(flat_state_shapes, flat_new_state_shapes)
- ]
-
- need_to_rerun = False
- for original_shape, weakened_shape in zip(flat_state_shapes,
- weakened_state_shapes):
- if original_shape.ndims is not None and (
- weakened_shape.ndims is None or
- original_shape.as_list() != weakened_shape.as_list()):
- need_to_rerun = True
- break
-
- if need_to_rerun:
- self._state_shapes = nest.pack_sequence_as(self._state_shapes,
- weakened_state_shapes)
-
- self._reduce_func = wrapped_func.function
- self._reduce_func.add_to_graph(ops.get_default_graph())
-
- def _make_finalize_func(self, finalize_func):
- """Make wrapping Defun for finalize_func."""
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- finalize_func, "tf.contrib.data.group_by_reducer()",
- input_classes=self._state_classes, input_shapes=self._state_shapes,
- input_types=self._state_types)
- self._finalize_func = wrapped_func.function
- self._output_classes = wrapped_func.output_classes
- self._output_shapes = wrapped_func.output_shapes
- self._output_types = wrapped_func.output_types
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.group_by_reducer_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._key_func.captured_inputs,
- self._init_func.captured_inputs,
- self._reduce_func.captured_inputs,
- self._finalize_func.captured_inputs,
- key_func=self._key_func,
- init_func=self._init_func,
- reduce_func=self._reduce_func,
- finalize_func=self._finalize_func,
- **dataset_ops.flat_structure(self))
-
-
-class _GroupByWindowDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that groups its input and performs a windowed reduction."""
-
- def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
- """See `group_by_window()` for details."""
- super(_GroupByWindowDataset, self).__init__(input_dataset)
-
- self._input_dataset = input_dataset
-
- self._make_key_func(key_func, input_dataset)
- self._make_reduce_func(reduce_func, input_dataset)
- self._make_window_size_func(window_size_func)
-
- def _make_window_size_func(self, window_size_func):
- """Make wrapping Defun for window_size_func."""
- def window_size_func_wrapper(key):
- return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- window_size_func_wrapper, "tf.contrib.data.group_by_window()",
- input_classes=ops.Tensor, input_shapes=tensor_shape.scalar(),
- input_types=dtypes.int64)
- if not (
- wrapped_func.output_types == dtypes.int64 and
- wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
- raise ValueError(
- "`window_size_func` must return a single tf.int64 scalar tensor.")
- self._window_size_func = wrapped_func.function
-
- def _make_key_func(self, key_func, input_dataset):
- """Make wrapping Defun for key_func."""
- def key_func_wrapper(*args):
- return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- key_func_wrapper, "tf.contrib.data.group_by_window()", input_dataset)
- if not (
- wrapped_func.output_types == dtypes.int64 and
- wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
- raise ValueError(
- "`key_func` must return a single tf.int64 scalar tensor.")
- self._key_func = wrapped_func.function
-
- def _make_reduce_func(self, reduce_func, input_dataset):
- """Make wrapping Defun for reduce_func."""
- nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- reduce_func, "tf.contrib.data.reduce_by_window()",
- input_classes=(ops.Tensor, nested_dataset),
- input_shapes=(tensor_shape.scalar(), nested_dataset),
- input_types=(dtypes.int64, nested_dataset),
- experimental_nested_dataset_support=True)
- if not isinstance(
- wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access
- raise TypeError("`reduce_func` must return a `Dataset` object.")
- self._output_classes = wrapped_func.output_classes.output_classes
- self._output_types = wrapped_func.output_types.output_types
- self._output_shapes = wrapped_func.output_shapes.output_shapes
- self._reduce_func = wrapped_func.function
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.group_by_window_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._key_func.captured_inputs,
- self._reduce_func.captured_inputs,
- self._window_size_func.captured_inputs,
- key_func=self._key_func,
- reduce_func=self._reduce_func,
- window_size_func=self._window_size_func,
- **dataset_ops.flat_structure(self))
-
-
-class Reducer(object):
+class Reducer(grouping.Reducer):
"""A reducer is used for reducing a set of elements.
A reducer is represented as a tuple of the three functions:
@@ -507,58 +152,6 @@ class Reducer(object):
3) finalization function: state => result
"""
+ @deprecation.deprecated(None, "Use `tf.data.experimental.Reducer(...)`.")
def __init__(self, init_func, reduce_func, finalize_func):
- self._init_func = init_func
- self._reduce_func = reduce_func
- self._finalize_func = finalize_func
-
- @property
- def init_func(self):
- return self._init_func
-
- @property
- def reduce_func(self):
- return self._reduce_func
-
- @property
- def finalize_func(self):
- return self._finalize_func
-
-
-class _MapXDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that maps a function over elements in its input."""
-
- def __init__(self, input_dataset, map_func):
- """See `map_x_dataset()` for details."""
- super(_MapXDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- map_func,
- "tf.contrib.data.map_x_dataset()",
- input_dataset,
- experimental_nested_dataset_support=True)
- self._output_classes = wrapped_func.output_classes
- self._output_shapes = wrapped_func.output_shapes
- self._output_types = wrapped_func.output_types
- self._map_func = wrapped_func.function
-
- def _as_variant_tensor(self):
- input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
- return gen_dataset_ops.map_dataset(
- input_t,
- self._map_func.captured_inputs,
- f=self._map_func,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
+ super(Reducer, self).__init__(init_func, reduce_func, finalize_func)
diff --git a/tensorflow/contrib/data/python/ops/interleave_ops.py b/tensorflow/contrib/data/python/ops/interleave_ops.py
index bfa3fdf543..f50da4d429 100644
--- a/tensorflow/contrib/data/python/ops/interleave_ops.py
+++ b/tensorflow/contrib/data/python/ops/interleave_ops.py
@@ -17,21 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib import stateless
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
-from tensorflow.contrib.data.python.ops import random_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import readers
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import math_ops
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.parallel_interleave(...)`.")
def parallel_interleave(map_func,
cycle_length,
block_length=1,
@@ -81,12 +72,9 @@ def parallel_interleave(map_func,
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return readers.ParallelInterleaveDataset(
- dataset, map_func, cycle_length, block_length, sloppy,
- buffer_output_elements, prefetch_input_elements)
-
- return _apply_fn
+ return interleave_ops.parallel_interleave(
+ map_func, cycle_length, block_length, sloppy, buffer_output_elements,
+ prefetch_input_elements)
@deprecation.deprecated(
@@ -140,61 +128,12 @@ def sloppy_interleave(map_func, cycle_length, block_length=1):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return readers.ParallelInterleaveDataset(
- dataset,
- map_func,
- cycle_length,
- block_length,
- sloppy=True,
- buffer_output_elements=None,
- prefetch_input_elements=None)
-
- return _apply_fn
-
-
-class _DirectedInterleaveDataset(dataset_ops.Dataset):
- """A substitute for `Dataset.interleave()` on a fixed list of datasets."""
-
- def __init__(self, selector_input, data_inputs):
- self._selector_input = selector_input
- self._data_inputs = list(data_inputs)
-
- for data_input in data_inputs[1:]:
- if (data_input.output_types != data_inputs[0].output_types or
- data_input.output_classes != data_inputs[0].output_classes):
- raise TypeError("All datasets must have the same type and class.")
-
- def _as_variant_tensor(self):
- # pylint: disable=protected-access
- return gen_dataset_ops.directed_interleave_dataset(
- self._selector_input._as_variant_tensor(),
- [data_input._as_variant_tensor() for data_input in self._data_inputs],
- **dataset_ops.flat_structure(self))
- # pylint: enable=protected-access
-
- def _inputs(self):
- return [self._selector_input] + self._data_inputs
-
- @property
- def output_classes(self):
- return self._data_inputs[0].output_classes
-
- @property
- def output_shapes(self):
- ret = self._data_inputs[0].output_shapes
- for data_input in self._data_inputs[1:]:
- ret = nest.pack_sequence_as(ret, [
- ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
- nest.flatten(ret), nest.flatten(data_input.output_shapes))
- ])
- return ret
-
- @property
- def output_types(self):
- return self._data_inputs[0].output_types
+ return interleave_ops.parallel_interleave(
+ map_func, cycle_length, block_length, sloppy=True)
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.sample_from_datasets(...)`.")
def sample_from_datasets(datasets, weights=None, seed=None):
"""Samples elements at random from the datasets in `datasets`.
@@ -218,64 +157,11 @@ def sample_from_datasets(datasets, weights=None, seed=None):
ValueError: If the `weights` argument is specified and does not match the
length of the `datasets` element.
"""
- num_datasets = len(datasets)
- if not isinstance(weights, dataset_ops.Dataset):
- if weights is None:
- # Select inputs with uniform probability.
- logits = [[1.0] * num_datasets]
-
- else:
- # Use the given `weights` as the probability of choosing the respective
- # input.
- weights = ops.convert_to_tensor(weights, name="weights")
- if weights.dtype not in (dtypes.float32, dtypes.float64):
- raise TypeError("`weights` must be convertible to a tensor of "
- "`tf.float32` or `tf.float64` elements.")
- if not weights.shape.is_compatible_with([num_datasets]):
- raise ValueError(
- "`weights` must be a vector of length `len(datasets)`.")
-
- # The `stateless_multinomial()` op expects log-probabilities, as opposed
- # to weights.
- logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
-
- # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
- # is a `Dataset`, it is possible that evaluating it has a side effect the
- # user depends on.
- if len(datasets) == 1:
- return datasets[0]
-
- def select_dataset_constant_logits(seed):
- return array_ops.squeeze(
- stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
-
- selector_input = dataset_ops.MapDataset(
- random_ops.RandomDataset(seed).batch(2),
- select_dataset_constant_logits,
- use_inter_op_parallelism=False)
-
- else:
- # Use each element of the given `weights` dataset as the probability of
- # choosing the respective input.
-
- # The `stateless_multinomial()` op expects log-probabilities, as opposed to
- # weights.
- logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
-
- def select_dataset_varying_logits(logits, seed):
- return array_ops.squeeze(
- stateless.stateless_multinomial(logits, 1, seed=seed), axis=[0, 1])
-
- logits_and_seeds = dataset_ops.Dataset.zip(
- (logits_ds, random_ops.RandomDataset(seed).batch(2)))
- selector_input = dataset_ops.MapDataset(
- logits_and_seeds,
- select_dataset_varying_logits,
- use_inter_op_parallelism=False)
-
- return _DirectedInterleaveDataset(selector_input, datasets)
+ return interleave_ops.sample_from_datasets(datasets, weights, seed)
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.choose_from_datasets(...)`.")
def choose_from_datasets(datasets, choice_dataset):
"""Creates a dataset that deterministically chooses elements from `datasets`.
@@ -311,10 +197,4 @@ def choose_from_datasets(datasets, choice_dataset):
TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
type.
"""
- if not (choice_dataset.output_types == dtypes.int64
- and choice_dataset.output_shapes.is_compatible_with(
- tensor_shape.scalar())
- and choice_dataset.output_classes == ops.Tensor):
- raise TypeError("`choice_dataset` must be a dataset of scalar "
- "`tf.int64` tensors.")
- return _DirectedInterleaveDataset(choice_dataset, datasets)
+ return interleave_ops.choose_from_datasets(datasets, choice_dataset)
diff --git a/tensorflow/contrib/data/python/ops/iterator_ops.py b/tensorflow/contrib/data/python/ops/iterator_ops.py
index 18515e21ed..48c325c86f 100644
--- a/tensorflow/contrib/data/python/ops/iterator_ops.py
+++ b/tensorflow/contrib/data/python/ops/iterator_ops.py
@@ -16,15 +16,13 @@
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import checkpoint_management
-from tensorflow.python.training import saver as saver_lib
-from tensorflow.python.training import session_run_hook
+from tensorflow.python.data.experimental.ops import iterator_ops
+from tensorflow.python.util import deprecation
+
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.make_saveable_from_iterator(...)`.")
def make_saveable_from_iterator(iterator):
"""Returns a SaveableObject for saving/restore iterator state using Saver.
@@ -60,27 +58,10 @@ def make_saveable_from_iterator(iterator):
Note: Not all iterators support checkpointing yet. Attempting to save the
state of an unsupported iterator will throw an error.
"""
- return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access
-
-
-class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
- """SaveableObject for saving/restoring iterator state."""
+ return iterator_ops.make_saveable_from_iterator(iterator)
- def __init__(self, iterator_resource):
- serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
- specs = [
- saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
- iterator_resource.name + "-state")
- ]
- super(_Saveable, self).__init__(iterator_resource, specs,
- iterator_resource.name)
- def restore(self, restored_tensors, unused_restored_shapes):
- with ops.colocate_with(self.op):
- return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
-
-
-class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
+class CheckpointInputPipelineHook(iterator_ops.CheckpointInputPipelineHook):
"""Checkpoints input pipeline state every N steps or seconds.
This hook saves the state of the iterators in the `Graph` so that when
@@ -125,135 +106,7 @@ class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
collector when building the eval graph.
"""
+ @deprecation.deprecated(
+ None, "Use `tf.data.experimental.CheckpointInputPipelineHook(...)`.")
def __init__(self, estimator):
- """Initializes a `CheckpointInputPipelineHook`.
-
- Args:
- estimator: Estimator.
-
- Raises:
- ValueError: One of `save_steps` or `save_secs` should be set.
- ValueError: At most one of saver or scaffold should be set.
- """
- # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
- # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
- # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
- # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
- # to be different to avoid conflicts with the model checkpoint.
-
- # pylint: disable=protected-access
- checkpoint_prefix = "input"
- if estimator._config.num_worker_replicas > 1:
- # Distributed setting.
- suffix = "_{}_{}".format(estimator._config.task_type,
- estimator._config.task_id)
- checkpoint_prefix += suffix
- # pylint: enable=protected-access
-
- # We use a composition paradigm instead of inheriting from
- # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
- # to check whether a `CheckpointSaverHook` is already present in the list
- # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
- # would thwart this behavior. This hook checkpoints *only the iterators*
- # and not the graph variables.
- self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
- estimator.model_dir,
- save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access
- save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access
- checkpoint_basename=checkpoint_prefix + ".ckpt")
-
- # Name for the protocol buffer file that will contain the list of most
- # recent checkpoints stored as a `CheckpointState` protocol buffer.
- # This file, kept in the same directory as the checkpoint files, is
- # automatically managed by the `Saver` to keep track of recent checkpoints.
- # The default name used by the `Saver` for this file is "checkpoint". Here
- # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
- # `checkpoint_dir` is the same as the model checkpoint directory, there are
- # no conflicts during restore.
- self._latest_filename = "checkpoint_" + checkpoint_prefix
- self._first_run = True
-
- def begin(self):
- # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
- # collection if no `Saver` or `Scaffold` is provided.
- # pylint: disable=protected-access
- if (self._checkpoint_saver_hook._saver is None and
- self._checkpoint_saver_hook._scaffold is None):
- iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
- saveables = [_Saveable(i) for i in iterators]
- self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
- self._latest_filename)
- # pylint: enable=protected-access
- self._checkpoint_saver_hook.begin()
-
- def _restore_or_save_initial_ckpt(self, session):
- # Ideally this should be run in after_create_session but is not for the
- # following reason:
- # Currently there is no way of enforcing an order of running the
- # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
- # is run *after* this hook. That is troublesome because
- # 1. If a checkpoint exists and this hook restores it, the initializer hook
- # will override it.
- # 2. If no checkpoint exists, this hook will try to save an initialized
- # iterator which will result in an exception.
- #
- # As a temporary fix we enter the following implicit contract between this
- # hook and the _DatasetInitializerHook.
- # 1. The _DatasetInitializerHook initializes the iterator in the call to
- # after_create_session.
- # 2. This hook saves the iterator on the first call to `before_run()`, which
- # is guaranteed to happen after `after_create_session()` of all hooks
- # have been run.
-
- # Check if there is an existing checkpoint. If so, restore from it.
- # pylint: disable=protected-access
- latest_checkpoint_path = checkpoint_management.latest_checkpoint(
- self._checkpoint_saver_hook._checkpoint_dir,
- latest_filename=self._latest_filename)
- if latest_checkpoint_path:
- self._checkpoint_saver_hook._get_saver().restore(session,
- latest_checkpoint_path)
- else:
- # The checkpoint saved here is the state at step "global_step".
- # Note: We do not save the GraphDef or MetaGraphDef here.
- global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
- self._checkpoint_saver_hook._save(session, global_step)
- self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
- # pylint: enable=protected-access
-
- def before_run(self, run_context):
- if self._first_run:
- self._restore_or_save_initial_ckpt(run_context.session)
- self._first_run = False
- return self._checkpoint_saver_hook.before_run(run_context)
-
- def after_run(self, run_context, run_values):
- self._checkpoint_saver_hook.after_run(run_context, run_values)
-
- def end(self, session):
- self._checkpoint_saver_hook.end(session)
-
-
-class _CustomSaver(saver_lib.Saver):
- """`Saver` with a different default `latest_filename`.
-
- This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
- the model ckpt saved by the `CheckpointSaverHook`.
- """
-
- def __init__(self, var_list, latest_filename):
- super(_CustomSaver, self).__init__(var_list)
- self._latest_filename = latest_filename
-
- def save(self,
- sess,
- save_path,
- global_step=None,
- latest_filename=None,
- meta_graph_suffix="meta",
- write_meta_graph=True,
- write_state=True,
- strip_default_attrs=False):
- return super(_CustomSaver, self).save(
- sess, save_path, global_step, latest_filename or self._latest_filename,
- meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
+ super(CheckpointInputPipelineHook, self).__init__(estimator)
diff --git a/tensorflow/contrib/data/python/ops/parsing_ops.py b/tensorflow/contrib/data/python/ops/parsing_ops.py
index cfbba701b0..3aeee9d8e4 100644
--- a/tensorflow/contrib/data/python/ops/parsing_ops.py
+++ b/tensorflow/contrib/data/python/ops/parsing_ops.py
@@ -17,92 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.ops import parsing_ops
+from tensorflow.python.data.experimental.ops import parsing_ops
+from tensorflow.python.util import deprecation
-class _ParseExampleDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that parses `example` dataset into a `dict` dataset."""
-
- def __init__(self, input_dataset, features, num_parallel_calls):
- super(_ParseExampleDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if not all(types == dtypes.string
- for types in nest.flatten(input_dataset.output_types)):
- raise TypeError("Input dataset should be a dataset of vectors of strings")
- self._num_parallel_calls = num_parallel_calls
- # pylint: disable=protected-access
- self._features = parsing_ops._prepend_none_dimension(features)
- # sparse_keys and dense_keys come back sorted here.
- (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
- dense_shapes) = parsing_ops._features_to_raw_params(
- self._features, [
- parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
- parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
- ])
- # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
- (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
- dense_shape_as_shape) = parsing_ops._process_raw_parameters(
- None, dense_defaults, sparse_keys, sparse_types, dense_keys,
- dense_types, dense_shapes)
- # pylint: enable=protected-access
- self._sparse_keys = sparse_keys
- self._sparse_types = sparse_types
- self._dense_keys = dense_keys
- self._dense_defaults = dense_defaults_vec
- self._dense_shapes = dense_shapes
- self._dense_types = dense_types
- dense_output_shapes = [
- self._input_dataset.output_shapes.concatenate(shape)
- for shape in dense_shape_as_shape
- ]
- sparse_output_shapes = [
- self._input_dataset.output_shapes.concatenate([None])
- for _ in range(len(sparse_keys))
- ]
-
- self._output_shapes = dict(
- zip(self._dense_keys + self._sparse_keys,
- dense_output_shapes + sparse_output_shapes))
- self._output_types = dict(
- zip(self._dense_keys + self._sparse_keys,
- self._dense_types + self._sparse_types))
- self._output_classes = dict(
- zip(self._dense_keys + self._sparse_keys,
- [ops.Tensor for _ in range(len(self._dense_defaults))] +
- [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
- ]))
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.parse_example_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._num_parallel_calls,
- self._dense_defaults,
- self._sparse_keys,
- self._dense_keys,
- self._sparse_types,
- self._dense_shapes,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_classes(self):
- return self._output_classes
-
-
-# TODO(b/111553342): add arguments names and example names as well.
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.parse_example_dataset(...)`.")
def parse_example_dataset(features, num_parallel_calls=1):
"""A transformation that parses `Example` protos into a `dict` of tensors.
@@ -130,21 +50,4 @@ def parse_example_dataset(features, num_parallel_calls=1):
Raises:
ValueError: if features argument is None.
"""
- if features is None:
- raise ValueError("Missing: features was %s." % features)
-
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
- if any([
- isinstance(feature, parsing_ops.SparseFeature)
- for _, feature in features.items()
- ]):
- # pylint: disable=protected-access
- # pylint: disable=g-long-lambda
- out_dataset = out_dataset.map(
- lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features(
- features, x), num_parallel_calls=num_parallel_calls)
- return out_dataset
-
- return _apply_fn
+ return parsing_ops.parse_example_dataset(features, num_parallel_calls)
diff --git a/tensorflow/contrib/data/python/ops/prefetching_ops.py b/tensorflow/contrib/data/python/ops/prefetching_ops.py
index 58395879e6..adfb390cd9 100644
--- a/tensorflow/contrib/data/python/ops/prefetching_ops.py
+++ b/tensorflow/contrib/data/python/ops/prefetching_ops.py
@@ -17,320 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import warnings
-
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.eager import context
-from tensorflow.python.framework import device as framework_device
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import function
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import gen_dataset_ops as core_gen_dataset_ops
-from tensorflow.python.ops import resource_variable_ops
-
-
-def function_buffering_resource(string_arg,
- target_device,
- f,
- buffer_size,
- output_types,
- container="",
- shared_name=None,
- name=None):
- """Creates a FunctionBufferingResource.
-
- A FunctionBufferingResource fills up a buffer by calling a function `f` on
- `target_device`. `f` should take in only a single string argument as input.
-
- Args:
- string_arg: The single string argument to the function.
- target_device: The device to run `f` on.
- f: The function to be executed.
- buffer_size: Size of the buffer to be populated.
- output_types: The output types generated by the function.
- container: (Optional) string. Defaults to "".
- shared_name: (Optional) string.
- name: (Optional) string to name the op.
-
- Returns:
- Handle to a FunctionBufferingResource.
- """
- if shared_name is None:
- shared_name = ""
- return gen_dataset_ops.function_buffering_resource(
- string_arg=string_arg,
- target_device=target_device,
- shared_name=shared_name,
- f=f,
- buffer_size=buffer_size,
- container=container,
- name=name,
- output_types=output_types)
-
-
-def function_buffering_resource_get_next(function_buffer_resource,
- output_types,
- name=None):
- return gen_dataset_ops.function_buffering_resource_get_next(
- function_buffer_resource=function_buffer_resource,
- output_types=output_types,
- name=name)
-
-
-def function_buffering_resource_reset(function_buffer_resource, name=None):
- return gen_dataset_ops.function_buffering_resource_reset(
- function_buffer_resource=function_buffer_resource, name=name)
-
-
-# pylint: disable=protected-access
-class _PrefetchToDeviceIterator(object):
- """A replacement for `tf.data.Iterator` that prefetches to another device.
-
- Args:
- input_dataset: The input dataset
- one_shot: If true, we make a one shot iterator that's already initialized.
- device: A fully specified device string where we want to prefetch to
- buffer_size: Size of the prefetching buffer.
- shared_name: (Optional.) If non-empty, the returned iterator will be
- shared under the given name across multiple sessions that share the
- same devices (e.g. when using a remote server).
-
- Returns:
- An Iterator type object.
- """
-
- def __init__(self,
- input_dataset,
- one_shot,
- device,
- buffer_size,
- shared_name=None):
- self._input_dataset = input_dataset
- self._get_next_call_count = 0
- self._one_shot = one_shot
- if shared_name is None:
- shared_name = ""
-
- if self._one_shot:
- self._input_iterator = input_dataset.make_one_shot_iterator()
- else:
- self._input_iterator = iterator_ops.Iterator.from_structure(
- self._input_dataset.output_types, self._input_dataset.output_shapes,
- shared_name, self._input_dataset.output_classes)
- input_iterator_handle = self._input_iterator.string_handle()
-
- @function.Defun(dtypes.string)
- def _prefetch_fn(handle):
- """Prefetches one element from `input_iterator`."""
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- handle, self._input_iterator.output_types,
- self._input_iterator.output_shapes,
- self._input_iterator.output_classes)
- ret = remote_iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- iterator_device = gen_dataset_ops.iterator_get_device(
- self._input_iterator._iterator_resource)
-
- with ops.device(device):
- self._buffering_resource = function_buffering_resource(
- f=_prefetch_fn,
- target_device=iterator_device,
- string_arg=input_iterator_handle,
- buffer_size=buffer_size,
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(self._input_dataset.output_types,
- self._input_dataset.output_classes)))
-
- if not self._one_shot:
- reset_op = function_buffering_resource_reset(self._buffering_resource)
- with ops.control_dependencies([reset_op]):
- self._initializer = self._input_iterator.make_initializer(
- self._input_dataset)
-
- def get_next(self, name=None):
- """See `tf.data.Iterator.get_next`."""
- self._get_next_call_count += 1
- if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
- warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
-
- flat_ret = gen_dataset_ops.function_buffering_resource_get_next(
- self._buffering_resource,
- output_types=nest.flatten(sparse.as_dense_types(
- self.output_types, self.output_classes)), name=name)
-
- ret = sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(self.output_types, flat_ret),
- self.output_types, self.output_shapes, self.output_classes)
-
- for tensor, shape in zip(
- nest.flatten(ret), nest.flatten(self.output_shapes)):
- if isinstance(tensor, ops.Tensor):
- tensor.set_shape(shape)
-
- return ret
-
- @property
- def initializer(self):
- if self._one_shot:
- raise NotImplementedError("Can't initialize a one_shot_iterator")
- return self._initializer
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
- """A replacement for `tf.data.Iterator` that prefetches to another device.
-
- Args:
- input_dataset: The input dataset
- one_shot: If true, we make a one shot iterator that's already initialized.
- device: A fully specified device string where we want to prefetch to
- buffer_size: Size of the prefetching buffer.
- shared_name: (Optional.) If non-empty, the returned iterator will be
- shared under the given name across multiple sessions that share the
- same devices (e.g. when using a remote server).
-
- Returns:
- An Iterator type object.
- """
-
- def __init__(self,
- input_dataset,
- device,
- buffer_size):
- with ops.device("/device:CPU:0"):
- super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
- input_iterator_handle = core_gen_dataset_ops.iterator_to_string_handle(
- self._resource)
-
- self._device = device
-
- @function.Defun(dtypes.string)
- def _prefetch_fn(handle):
- """Prefetches one element from `input_iterator`."""
- remote_iterator = iterator_ops.Iterator.from_string_handle(
- handle, self.output_types, self.output_shapes, self.output_classes)
- ret = remote_iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- _prefetch_fn.add_to_graph(None)
-
- with ops.device(device):
- self._buffering_resource = function_buffering_resource(
- f=_prefetch_fn,
- output_types=self._flat_output_types,
- target_device=gen_dataset_ops.iterator_get_device(self._resource),
- string_arg=input_iterator_handle,
- buffer_size=buffer_size,
- shared_name=iterator_ops._generate_shared_name(
- "function_buffer_resource"))
-
- def _next_internal(self):
- """Returns a nested structure of `tf.Tensor`s containing the next element.
- """
- # This runs in sync mode as iterators use an error status to communicate
- # that there is no more data to iterate over.
- # TODO(b/77291417): Fix
- with context.execution_mode(context.SYNC):
- with ops.device(self._device):
- ret = gen_dataset_ops.function_buffering_resource_get_next(
- function_buffer_resource=self._buffering_resource,
- output_types=self._flat_output_types)
- return sparse.deserialize_sparse_tensors(
- nest.pack_sequence_as(self._output_types, ret), self._output_types,
- self._output_shapes, self._output_classes)
-# pylint: enable=protected-access
-
-
-class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
- """A `Dataset` whose iterator prefetches elements to another device."""
-
- def __init__(self, input_dataset, device, buffer_size):
- super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._device = device
- self._buffer_size = buffer_size if buffer_size is not None else 1
-
- # The static analysis cannot tell that the eager iterator's superclass has
- # a `next()` method.
- # pylint: disable=non-iterator-returned
- def __iter__(self):
- """Creates an `Iterator` for enumerating the elements of this dataset.
-
- The returned iterator implements the Python iterator protocol and therefore
- can only be used in eager mode.
-
- Returns:
- An `Iterator` over the elements of this dataset.
-
- Raises:
- RuntimeError: If eager execution is enabled.
- """
- if context.executing_eagerly():
- return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
- self._buffer_size)
- else:
- raise RuntimeError("dataset.__iter__() is only supported when eager "
- "execution is enabled.")
- # pylint: enable=non-iterator-returned
-
- def make_one_shot_iterator(self):
- if context.executing_eagerly():
- return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
- self._buffer_size)
- else:
- return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True,
- device=self._device,
- buffer_size=self._buffer_size)
-
- def make_initializable_iterator(self, shared_name=None):
- return _PrefetchToDeviceIterator(
- self._input_dataset,
- one_shot=False,
- device=self._device,
- buffer_size=self._buffer_size,
- shared_name=shared_name)
-
- def _as_variant_tensor(self):
- # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
- # transformation methods is called.
- # TODO(mrry): Investigate support for chaining further transformations after
- # the prefetch, including GPU support.
- raise NotImplementedError("`prefetch_to_device()` must be the last "
- "transformation in a dataset pipeline.")
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.prefetch_to_device(...)`.")
def prefetch_to_device(device, buffer_size=None):
"""A transformation that prefetches dataset values to the given `device`.
@@ -346,12 +38,10 @@ def prefetch_to_device(device, buffer_size=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return _PrefetchToDeviceDataset(dataset, device, buffer_size)
-
- return _apply_fn
+ return prefetching_ops.prefetch_to_device(device, buffer_size)
+@deprecation.deprecated(None, "Use `tf.data.experimental.copy_to_device(...)`.")
def copy_to_device(target_device, source_device="/cpu:0"):
"""A transformation that copies dataset elements to the given `target_device`.
@@ -363,165 +53,4 @@ def copy_to_device(target_device, source_device="/cpu:0"):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _CopyToDeviceDataset(
- dataset, target_device=target_device, source_device=source_device)
-
- return _apply_fn
-
-
-# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
-# all inputs to the Op are in host memory, thereby avoiding some unnecessary
-# Sends and Recvs.
-class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that copies elements to another device."""
-
- def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
- """Constructs a _CopyToDeviceDataset.
-
- Args:
- input_dataset: `Dataset` to be copied
- target_device: The name of the device to which elements would be copied.
- source_device: Device where input_dataset would be placed.
- """
- super(_CopyToDeviceDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._target_device = target_device
- spec = framework_device.DeviceSpec().from_string(self._target_device)
- self._is_gpu_target = (spec.device_type == "GPU")
- self._source_device_string = source_device
- self._source_device = ops.convert_to_tensor(source_device)
-
- self._flat_output_shapes = nest.flatten(
- sparse.as_dense_shapes(self._input_dataset.output_shapes,
- self._input_dataset.output_classes))
- self._flat_output_types = nest.flatten(
- sparse.as_dense_types(self._input_dataset.output_types,
- self._input_dataset.output_classes))
-
- @function.Defun()
- def _init_func():
- """Creates an iterator for the input dataset.
-
- Returns:
- A `string` tensor that encapsulates the iterator created.
- """
- # pylint: disable=protected-access
- ds_variant = self._input_dataset._as_variant_tensor()
- resource = core_gen_dataset_ops.anonymous_iterator(
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
- with ops.control_dependencies(
- [core_gen_dataset_ops.make_iterator(ds_variant, resource)]):
- return core_gen_dataset_ops.iterator_to_string_handle(resource)
-
- @function.Defun()
- def _remote_init_func():
- return functional_ops.remote_call(
- target=self._source_device,
- args=_init_func.captured_inputs,
- Tout=[dtypes.string],
- f=_init_func)
-
- self._init_func = _remote_init_func
- self._init_captured_args = _remote_init_func.captured_inputs
-
- @function.Defun(dtypes.string)
- def _next_func(string_handle):
- """Calls get_next for created iterator.
-
- Args:
- string_handle: An iterator string handle created by _init_func
- Returns:
- The elements generated from `input_dataset`
- """
- with ops.device(self._source_device_string):
- iterator = iterator_ops.Iterator.from_string_handle(
- string_handle, self.output_types, self.output_shapes,
- self.output_classes)
- ret = iterator.get_next()
- return nest.flatten(sparse.serialize_sparse_tensors(ret))
-
- @function.Defun(dtypes.string)
- def _remote_next_func(string_handle):
- return functional_ops.remote_call(
- target=self._source_device,
- args=[string_handle] + _next_func.captured_inputs,
- Tout=self._flat_output_types,
- f=_next_func)
-
- self._next_func = _remote_next_func
- self._next_captured_args = _remote_next_func.captured_inputs
-
- @function.Defun(dtypes.string)
- def _finalize_func(string_handle):
- """Destroys the iterator resource created.
-
- Args:
- string_handle: An iterator string handle created by _init_func
- Returns:
- Tensor constant 0
- """
- iterator_resource = core_gen_dataset_ops.iterator_from_string_handle_v2(
- string_handle,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
- with ops.control_dependencies([
- resource_variable_ops.destroy_resource_op(
- iterator_resource, ignore_lookup_error=True)]):
- return array_ops.constant(0, dtypes.int64)
-
- @function.Defun(dtypes.string)
- def _remote_finalize_func(string_handle):
- return functional_ops.remote_call(
- target=self._source_device,
- args=[string_handle] + _finalize_func.captured_inputs,
- Tout=[dtypes.int64],
- f=_finalize_func)
-
- self._finalize_func = _remote_finalize_func
- self._finalize_captured_args = _remote_finalize_func.captured_inputs
-
- g = ops.get_default_graph()
- _remote_init_func.add_to_graph(g)
- _remote_next_func.add_to_graph(g)
- _remote_finalize_func.add_to_graph(g)
- # pylint: enable=protected-scope
-
- # The one_shot_iterator implementation needs a 0 arg _make_dataset function
- # that thereby captures all the inputs required to create the dataset. Since
- # there are strings that are inputs to the GeneratorDataset which can't be
- # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
- # GPU
- def make_one_shot_iterator(self):
- if self._is_gpu_target:
- raise ValueError("Cannot create a one shot iterator when using "
- "`tf.contrib.data.copy_to_device()` on GPU. Please use "
- "`Dataset.make_initializable_iterator()` instead.")
- else:
- return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
-
- def _as_variant_tensor(self):
- with ops.device(self._target_device):
- return core_gen_dataset_ops.generator_dataset(
- self._init_captured_args,
- self._next_captured_args,
- self._finalize_captured_args,
- init_func=self._init_func,
- next_func=self._next_func,
- finalize_func=self._finalize_func,
- output_types=self._flat_output_types,
- output_shapes=self._flat_output_shapes)
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
+ return prefetching_ops.copy_to_device(target_device, source_device)
diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py
index 344a0763c8..2c95125636 100644
--- a/tensorflow/contrib/data/python/ops/random_ops.py
+++ b/tensorflow/contrib/data/python/ops/random_ops.py
@@ -17,36 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import random_seed
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.data.experimental.ops import random_ops
+from tensorflow.python.util import deprecation
-class RandomDataset(dataset_ops.DatasetSource):
+class RandomDataset(random_ops.RandomDataset):
"""A `Dataset` of pseudorandom values."""
+ @deprecation.deprecated(
+ None, "Use `tf.data.experimental.RandomDataset(...)`.")
def __init__(self, seed=None):
- """A `Dataset` of pseudorandom values."""
- super(RandomDataset, self).__init__()
- self._seed, self._seed2 = random_seed.get_seed(seed)
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.random_dataset(
- seed=self._seed,
- seed2=self._seed2,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return ops.Tensor
-
- @property
- def output_shapes(self):
- return tensor_shape.scalar()
-
- @property
- def output_types(self):
- return dtypes.int64
+ super(RandomDataset, self).__init__(seed)
diff --git a/tensorflow/contrib/data/python/ops/readers.py b/tensorflow/contrib/data/python/ops/readers.py
index d9d06e2703..4601376dff 100644
--- a/tensorflow/contrib/data/python/ops/readers.py
+++ b/tensorflow/contrib/data/python/ops/readers.py
@@ -17,295 +17,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
-import csv
-
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.contrib.data.python.ops import parsing_ops
-from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers as core_readers
-from tensorflow.python.data.util import convert
from tensorflow.python.data.util import nest
-from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.lib.io import file_io
-from tensorflow.python.ops import gen_dataset_ops
-from tensorflow.python.platform import gfile
+from tensorflow.python.ops import gen_experimental_dataset_ops
from tensorflow.python.util import deprecation
-_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32,
- dtypes.int64, dtypes.string)
-
-
-def _is_valid_int32(str_val):
- try:
- # Checks equality to prevent int32 overflow
- return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype(
- str_val)
- except (ValueError, OverflowError):
- return False
-
-
-def _is_valid_int64(str_val):
- try:
- dtypes.int64.as_numpy_dtype(str_val)
- return True
- except (ValueError, OverflowError):
- return False
-
-
-def _is_valid_float(str_val, float_dtype):
- try:
- return float_dtype.as_numpy_dtype(str_val) < np.inf
- except ValueError:
- return False
-
-
-def _infer_type(str_val, na_value, prev_type):
- """Given a string, infers its tensor type.
-
- Infers the type of a value by picking the least 'permissive' type possible,
- while still allowing the previous type inference for this column to be valid.
-
- Args:
- str_val: String value to infer the type of.
- na_value: Additional string to recognize as a NA/NaN CSV value.
- prev_type: Type previously inferred based on values of this column that
- we've seen up till now.
- Returns:
- Inferred dtype.
- """
- if str_val in ("", na_value):
- # If the field is null, it gives no extra information about its type
- return prev_type
-
- type_list = [
- dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string
- ] # list of types to try, ordered from least permissive to most
-
- type_functions = [
- _is_valid_int32,
- _is_valid_int64,
- lambda str_val: _is_valid_float(str_val, dtypes.float32),
- lambda str_val: _is_valid_float(str_val, dtypes.float64),
- lambda str_val: True,
- ] # Corresponding list of validation functions
-
- for i in range(len(type_list)):
- validation_fn = type_functions[i]
- if validation_fn(str_val) and (prev_type is None or
- prev_type in type_list[:i + 1]):
- return type_list[i]
-
-
-def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header):
- """Generator that yields rows of CSV file(s) in order."""
- for fn in filenames:
- with file_io.FileIO(fn, "r") as f:
- rdr = csv.reader(
- f,
- delimiter=field_delim,
- quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE)
- if header:
- next(rdr) # Skip header lines
-
- for csv_row in rdr:
- if len(csv_row) != num_cols:
- raise ValueError(
- "Problem inferring types: CSV row has different number of fields "
- "than expected.")
- yield csv_row
-
-
-def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim,
- na_value, header, num_rows_for_inference,
- select_columns):
- """Infers column types from the first N valid CSV records of files."""
- if select_columns is None:
- select_columns = range(num_cols)
- inferred_types = [None] * len(select_columns)
-
- for i, csv_row in enumerate(
- _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)):
- if num_rows_for_inference is not None and i >= num_rows_for_inference:
- break
-
- for j, col_index in enumerate(select_columns):
- inferred_types[j] = _infer_type(csv_row[col_index], na_value,
- inferred_types[j])
-
- # Replace None's with a default type
- inferred_types = [t or dtypes.string for t in inferred_types]
- # Default to 0 or '' for null values
- return [
- constant_op.constant([0 if t is not dtypes.string else ""], dtype=t)
- for t in inferred_types
- ]
-
-
-def _infer_column_names(filenames, field_delim, use_quote_delim):
- """Infers column names from first rows of files."""
- csv_kwargs = {
- "delimiter": field_delim,
- "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE
- }
- with file_io.FileIO(filenames[0], "r") as f:
- try:
- column_names = next(csv.reader(f, **csv_kwargs))
- except StopIteration:
- raise ValueError(("Received StopIteration when reading the header line "
- "of %s. Empty file?") % filenames[0])
-
- for name in filenames[1:]:
- with file_io.FileIO(name, "r") as f:
- try:
- if next(csv.reader(f, **csv_kwargs)) != column_names:
- raise ValueError(
- "Files have different column names in the header row.")
- except StopIteration:
- raise ValueError(("Received StopIteration when reading the header line "
- "of %s. Empty file?") % filenames[0])
- return column_names
-
-
-def _get_sorted_col_indices(select_columns, column_names):
- """Transforms select_columns argument into sorted column indices."""
- names_to_indices = {n: i for i, n in enumerate(column_names)}
- num_cols = len(column_names)
- for i, v in enumerate(select_columns):
- if isinstance(v, int):
- if v < 0 or v >= num_cols:
- raise ValueError(
- "Column index %d specified in select_columns out of valid range." %
- v)
- continue
- if v not in names_to_indices:
- raise ValueError(
- "Value '%s' specified in select_columns not a valid column index or "
- "name." % v)
- select_columns[i] = names_to_indices[v]
-
- # Sort and ensure there are no duplicates
- result = sorted(set(select_columns))
- if len(result) != len(select_columns):
- raise ValueError("select_columns contains duplicate columns")
- return result
-
-
-def _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed):
- """Optionally shuffle and repeat dataset, as requested."""
- if num_epochs != 1 and shuffle:
- # Use shuffle_and_repeat for perf
- return dataset.apply(
- shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
- shuffle_seed))
- elif shuffle:
- return dataset.shuffle(shuffle_buffer_size, shuffle_seed)
- elif num_epochs != 1:
- return dataset.repeat(num_epochs)
- return dataset
-
-
-def make_tf_record_dataset(file_pattern,
- batch_size,
- parser_fn=None,
- num_epochs=None,
- shuffle=True,
- shuffle_buffer_size=None,
- shuffle_seed=None,
- prefetch_buffer_size=optimization.AUTOTUNE,
- num_parallel_reads=None,
- num_parallel_parser_calls=None,
- drop_final_batch=False):
- """Reads and optionally parses TFRecord files into a dataset.
-
- Provides common functionality such as batching, optional parsing, shuffling,
- and performant defaults.
-
- Args:
- file_pattern: List of files or patterns of TFRecord file paths.
- See `tf.gfile.Glob` for pattern rules.
- batch_size: An int representing the number of records to combine
- in a single batch.
- parser_fn: (Optional.) A function accepting string input to parse
- and process the record contents. This function must map records
- to components of a fixed shape, so they may be batched. By
- default, uses the record contents unmodified.
- num_epochs: (Optional.) An int specifying the number of times this
- dataset is repeated. If None (the default), cycles through the
- dataset forever.
- shuffle: (Optional.) A bool that indicates whether the input
- should be shuffled. Defaults to `True`.
- shuffle_buffer_size: (Optional.) Buffer size to use for
- shuffling. A large buffer size ensures better shuffling, but
- increases memory usage and startup time.
- shuffle_seed: (Optional.) Randomization seed to use for shuffling.
- prefetch_buffer_size: (Optional.) An int specifying the number of
- feature batches to prefetch for performance improvement.
- Defaults to auto-tune. Set to 0 to disable prefetching.
- num_parallel_reads: (Optional.) Number of threads used to read
- records from files. By default or if set to a value >1, the
- results will be interleaved.
- num_parallel_parser_calls: (Optional.) Number of parallel
- records to parse in parallel. Defaults to an automatic selection.
- drop_final_batch: (Optional.) Whether the last batch should be
- dropped in case its size is smaller than `batch_size`; the
- default behavior is not to drop the smaller batch.
-
- Returns:
- A dataset, where each element matches the output of `parser_fn`
- except it will have an additional leading `batch-size` dimension,
- or a `batch_size`-length 1-D tensor of strings if `parser_fn` is
- unspecified.
- """
- files = dataset_ops.Dataset.list_files(
- file_pattern, shuffle=shuffle, seed=shuffle_seed)
-
- if num_parallel_reads is None:
- # Note: We considered auto-tuning this value, but there is a concern
- # that this affects the mixing of records from different files, which
- # could affect training convergence/accuracy, so we are defaulting to
- # a constant for now.
- num_parallel_reads = 24
- dataset = core_readers.TFRecordDataset(
- files, num_parallel_reads=num_parallel_reads)
-
- if shuffle_buffer_size is None:
- # TODO(josh11b): Auto-tune this value when not specified
- shuffle_buffer_size = 10000
- dataset = _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
-
- # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to
- # improve the shape inference, because it makes the batch dimension static.
- # It is safe to do this because in that case we are repeating the input
- # indefinitely, and all batches will be full-sized.
- drop_final_batch = drop_final_batch or num_epochs is None
-
- if parser_fn is None:
- dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
- else:
- # TODO(josh11b): if num_parallel_parser_calls is None, use some function
- # of num cores instead of map_and_batch's default behavior of one batch.
- dataset = dataset.apply(batching.map_and_batch(
- parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
- drop_remainder=drop_final_batch))
-
- if prefetch_buffer_size == 0:
- return dataset
- else:
- return dataset.prefetch(buffer_size=prefetch_buffer_size)
-
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.make_csv_dataset(...)`.")
def make_csv_dataset(
file_pattern,
batch_size,
@@ -387,7 +112,6 @@ def make_csv_dataset(
prefetch_buffer_size: An int specifying the number of feature
batches to prefetch for performance improvement. Recommended value is the
number of batches consumed per training step. Defaults to auto-tune.
-
num_parallel_reads: Number of threads used to read CSV records from files.
If >1, the results will be interleaved.
sloppy: If `True`, reading performance will be improved at
@@ -411,106 +135,18 @@ def make_csv_dataset(
Raises:
ValueError: If any of the arguments is malformed.
"""
- # Create dataset of all matching filenames
- filenames = _get_file_names(file_pattern, False)
- dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
- if shuffle:
- dataset = dataset.shuffle(len(filenames), shuffle_seed)
-
- # Clean arguments; figure out column names and defaults
+ return readers.make_csv_dataset(
+ file_pattern, batch_size, column_names, column_defaults, label_name,
+ select_columns, field_delim, use_quote_delim, na_value, header,
+ num_epochs, shuffle, shuffle_buffer_size, shuffle_seed,
+ prefetch_buffer_size, num_parallel_reads, sloppy, num_rows_for_inference,
+ compression_type)
- if column_names is None:
- if not header:
- raise ValueError("Cannot infer column names without a header line.")
- # If column names are not provided, infer from the header lines
- column_names = _infer_column_names(filenames, field_delim, use_quote_delim)
- if len(column_names) != len(set(column_names)):
- raise ValueError("Cannot have duplicate column names.")
- if select_columns is not None:
- select_columns = _get_sorted_col_indices(select_columns, column_names)
-
- if column_defaults is not None:
- column_defaults = [
- constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
- for x in column_defaults
- ]
- else:
- # If column defaults are not provided, infer from records at graph
- # construction time
- column_defaults = _infer_column_defaults(
- filenames, len(column_names), field_delim, use_quote_delim, na_value,
- header, num_rows_for_inference, select_columns)
-
- if select_columns is not None and len(column_defaults) != len(select_columns):
- raise ValueError(
- "If specified, column_defaults and select_columns must have same "
- "length."
- )
- if select_columns is not None and len(column_names) > len(select_columns):
- # Pick the relevant subset of column names
- column_names = [column_names[i] for i in select_columns]
-
- if label_name is not None and label_name not in column_names:
- raise ValueError("`label_name` provided must be one of the columns.")
-
- def filename_to_dataset(filename):
- return CsvDataset(
- filename,
- record_defaults=column_defaults,
- field_delim=field_delim,
- use_quote_delim=use_quote_delim,
- na_value=na_value,
- select_cols=select_columns,
- header=header,
- compression_type=compression_type,
- )
-
- def map_fn(*columns):
- """Organizes columns into a features dictionary.
-
- Args:
- *columns: list of `Tensor`s corresponding to one csv record.
- Returns:
- An OrderedDict of feature names to values for that particular record. If
- label_name is provided, extracts the label feature to be returned as the
- second element of the tuple.
- """
- features = collections.OrderedDict(zip(column_names, columns))
- if label_name is not None:
- label = features.pop(label_name)
- return features, label
- return features
-
- # Read files sequentially (if num_parallel_reads=1) or in parallel
- dataset = dataset.apply(
- interleave_ops.parallel_interleave(
- filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
-
- dataset = _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
-
- # Apply batch before map for perf, because map has high overhead relative
- # to the size of the computation in each map.
- # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
- # improve the shape inference, because it makes the batch dimension static.
- # It is safe to do this because in that case we are repeating the input
- # indefinitely, and all batches will be full-sized.
- dataset = dataset.batch(batch_size=batch_size,
- drop_remainder=num_epochs is None)
- dataset = dataset_ops.MapDataset(
- dataset, map_fn, use_inter_op_parallelism=False)
- dataset = dataset.prefetch(prefetch_buffer_size)
-
- return dataset
-
-
-_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB
-
-
-class CsvDataset(dataset_ops.DatasetSource):
+class CsvDataset(readers.CsvDataset):
"""A Dataset comprising lines from one or more CSV files."""
+ @deprecation.deprecated(None, "Use `tf.data.experimental.CsvDataset(...)`.")
def __init__(self,
filenames,
record_defaults,
@@ -521,140 +157,13 @@ class CsvDataset(dataset_ops.DatasetSource):
use_quote_delim=True,
na_value="",
select_cols=None):
- """Creates a `CsvDataset` by reading and decoding CSV files.
-
- The elements of this dataset correspond to records from the file(s).
- RFC 4180 format is expected for CSV files
- (https://tools.ietf.org/html/rfc4180)
- Note that we allow leading and trailing spaces with int or float field.
-
-
- For example, suppose we have a file 'my_file0.csv' with four CSV columns of
- different data types:
- ```
- abcdefg,4.28E10,5.55E6,12
- hijklmn,-5.3E14,,2
- ```
-
- We can construct a CsvDataset from it as follows:
- ```python
- dataset = tf.contrib.data.CsvDataset(
- "my_file*.csv",
- [tf.float32, # Required field, use dtype or empty tensor
- tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0
- tf.int32, # Required field, use dtype or empty tensor
- ],
- select_cols=[1,2,3] # Only parse last three columns
- )
- ```
-
- The expected output of its iterations is:
- ```python
- next_element = dataset.make_one_shot_iterator().get_next()
- with tf.Session() as sess:
- while True:
- try:
- print(sess.run(next_element))
- except tf.errors.OutOfRangeError:
- break
-
- >> (4.28e10, 5.55e6, 12)
- >> (-5.3e14, 0.0, 2)
- ```
-
- Args:
- filenames: A `tf.string` tensor containing one or more filenames.
- record_defaults: A list of default values for the CSV fields. Each item in
- the list is either a valid CSV `DType` (float32, float64, int32, int64,
- string), or a `Tensor` object with one of the above types. One per
- column of CSV data, with either a scalar `Tensor` default value for the
- column if it is optional, or `DType` or empty `Tensor` if required. If
- both this and `select_columns` are specified, these must have the same
- lengths, and `column_defaults` is assumed to be sorted in order of
- increasing column index.
- compression_type: (Optional.) A `tf.string` scalar evaluating to one of
- `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
- compression.
- buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
- to buffer while reading files. Defaults to 4MB.
- header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
- have header line(s) that should be skipped when parsing. Defaults to
- `False`.
- field_delim: (Optional.) A `tf.string` scalar containing the delimiter
- character that separates fields in a record. Defaults to `","`.
- use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
- double quotation marks as regular characters inside of string fields
- (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
- na_value: (Optional.) A `tf.string` scalar indicating a value that will
- be treated as NA/NaN.
- select_cols: (Optional.) A sorted list of column indices to select from
- the input data. If specified, only this subset of columns will be
- parsed. Defaults to parsing all columns.
- """
- super(CsvDataset, self).__init__()
- self._filenames = ops.convert_to_tensor(
- filenames, dtype=dtypes.string, name="filenames")
- self._compression_type = convert.optional_param_to_tensor(
- "compression_type",
- compression_type,
- argument_default="",
- argument_dtype=dtypes.string)
- record_defaults = [
- constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
- for x in record_defaults
- ]
- self._record_defaults = ops.convert_n_to_tensor(
- record_defaults, name="record_defaults")
- self._buffer_size = convert.optional_param_to_tensor(
- "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
- self._header = ops.convert_to_tensor(
- header, dtype=dtypes.bool, name="header")
- self._field_delim = ops.convert_to_tensor(
- field_delim, dtype=dtypes.string, name="field_delim")
- self._use_quote_delim = ops.convert_to_tensor(
- use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
- self._na_value = ops.convert_to_tensor(
- na_value, dtype=dtypes.string, name="na_value")
- self._select_cols = convert.optional_param_to_tensor(
- "select_cols",
- select_cols,
- argument_default=[],
- argument_dtype=dtypes.int64,
- )
- self._output_shapes = tuple(
- tensor_shape.scalar() for _ in range(len(record_defaults)))
- self._output_types = tuple(d.dtype for d in self._record_defaults)
- self._output_classes = tuple(
- ops.Tensor for _ in range(len(record_defaults)))
-
- def _as_variant_tensor(self):
- # Constructs graph node for the dataset op.
- return contrib_gen_dataset_ops.csv_dataset(
- filenames=self._filenames,
- record_defaults=self._record_defaults,
- buffer_size=self._buffer_size,
- header=self._header,
- output_shapes=self._output_shapes,
- field_delim=self._field_delim,
- use_quote_delim=self._use_quote_delim,
- na_value=self._na_value,
- select_cols=self._select_cols,
- compression_type=self._compression_type,
- )
-
- @property
- def output_types(self):
- return self._output_types
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_classes(self):
- return self._output_classes
+ super(CsvDataset, self).__init__(
+ filenames, record_defaults, compression_type, buffer_size, header,
+ field_delim, use_quote_delim, na_value, select_cols)
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.make_batched_features_dataset(...)`.")
def make_batched_features_dataset(file_pattern,
batch_size,
features,
@@ -759,57 +268,15 @@ def make_batched_features_dataset(file_pattern,
Raises:
ValueError: If `label_key` is not one of the `features` keys.
"""
- # Create dataset of all matching filenames
- filenames = _get_file_names(file_pattern, False)
- dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
- if shuffle:
- dataset = dataset.shuffle(len(filenames), shuffle_seed)
-
- # Read `Example` records from files as tensor objects.
- if reader_args is None:
- reader_args = []
+ return readers.make_batched_features_dataset(
+ file_pattern, batch_size, features, reader, label_key, reader_args,
+ num_epochs, shuffle, shuffle_buffer_size, shuffle_seed,
+ prefetch_buffer_size, reader_num_threads, parser_num_threads,
+ sloppy_ordering, drop_final_batch)
- # Read files sequentially (if reader_num_threads=1) or in parallel
- dataset = dataset.apply(
- interleave_ops.parallel_interleave(
- lambda filename: reader(filename, *reader_args),
- cycle_length=reader_num_threads,
- sloppy=sloppy_ordering))
- # Extract values if the `Example` tensors are stored as key-value tuples.
- if dataset.output_types == (dtypes.string, dtypes.string):
- dataset = dataset_ops.MapDataset(
- dataset, lambda _, v: v, use_inter_op_parallelism=False)
-
- # Apply dataset repeat and shuffle transformations.
- dataset = _maybe_shuffle_and_repeat(
- dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
-
- # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
- # improve the shape inference, because it makes the batch dimension static.
- # It is safe to do this because in that case we are repeating the input
- # indefinitely, and all batches will be full-sized.
- dataset = dataset.batch(
- batch_size, drop_remainder=drop_final_batch or num_epochs is None)
-
- # Parse `Example` tensors to a dictionary of `Feature` tensors.
- dataset = dataset.apply(
- parsing_ops.parse_example_dataset(
- features, num_parallel_calls=parser_num_threads))
-
- if label_key:
- if label_key not in features:
- raise ValueError(
- "The `label_key` provided (%r) must be one of the `features` keys." %
- label_key)
- dataset = dataset.map(lambda x: (x, x.pop(label_key)))
-
- dataset = dataset.prefetch(prefetch_buffer_size)
- return dataset
-
-
-@deprecation.deprecated(None,
- "Use `tf.contrib.data.make_batched_features_dataset`")
+@deprecation.deprecated(
+ None, "Use `tf.data.experimental.make_batched_features_dataset(...)`")
def read_batch_features(file_pattern,
batch_size,
features,
@@ -879,7 +346,7 @@ def read_batch_features(file_pattern,
Returns:
A dict from keys in features to `Tensor` or `SparseTensor` objects.
"""
- dataset = make_batched_features_dataset(
+ dataset = readers.make_batched_features_dataset(
file_pattern,
batch_size,
features,
@@ -893,96 +360,13 @@ def read_batch_features(file_pattern,
return outputs
-def _get_file_names(file_pattern, shuffle):
- """Parse list of file names from pattern, optionally shuffled.
-
- Args:
- file_pattern: File glob pattern, or list of glob patterns.
- shuffle: Whether to shuffle the order of file names.
-
- Returns:
- List of file names matching `file_pattern`.
-
- Raises:
- ValueError: If `file_pattern` is empty, or pattern matches no files.
- """
- if isinstance(file_pattern, list):
- if not file_pattern:
- raise ValueError("File pattern is empty.")
- file_names = []
- for entry in file_pattern:
- file_names.extend(gfile.Glob(entry))
- else:
- file_names = list(gfile.Glob(file_pattern))
-
- if not file_names:
- raise ValueError("No files match %s." % file_pattern)
-
- # Sort files so it will be deterministic for unit tests.
- if not shuffle:
- file_names = sorted(file_names)
- return file_names
-
-
-class SqlDataset(dataset_ops.DatasetSource):
+class SqlDataset(readers.SqlDataset):
"""A `Dataset` consisting of the results from a SQL query."""
+ @deprecation.deprecated(None, "Use `tf.data.experimental.SqlDataset(...)`.")
def __init__(self, driver_name, data_source_name, query, output_types):
- """Creates a `SqlDataset`.
-
- `SqlDataset` allows a user to read data from the result set of a SQL query.
- For example:
-
- ```python
- dataset = tf.contrib.data.SqlDataset("sqlite", "/foo/bar.sqlite3",
- "SELECT name, age FROM people",
- (tf.string, tf.int32))
- iterator = dataset.make_one_shot_iterator()
- next_element = iterator.get_next()
- # Prints the rows of the result set of the above query.
- while True:
- try:
- print(sess.run(next_element))
- except tf.errors.OutOfRangeError:
- break
- ```
-
- Args:
- driver_name: A 0-D `tf.string` tensor containing the database type.
- Currently, the only supported value is 'sqlite'.
- data_source_name: A 0-D `tf.string` tensor containing a connection string
- to connect to the database.
- query: A 0-D `tf.string` tensor containing the SQL query to execute.
- output_types: A tuple of `tf.DType` objects representing the types of the
- columns returned by `query`.
- """
- super(SqlDataset, self).__init__()
- self._driver_name = ops.convert_to_tensor(
- driver_name, dtype=dtypes.string, name="driver_name")
- self._data_source_name = ops.convert_to_tensor(
- data_source_name, dtype=dtypes.string, name="data_source_name")
- self._query = ops.convert_to_tensor(
- query, dtype=dtypes.string, name="query")
- self._output_types = output_types
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.sql_dataset(self._driver_name,
- self._data_source_name, self._query,
- nest.flatten(self.output_types),
- nest.flatten(self.output_shapes))
-
- @property
- def output_classes(self):
- return nest.map_structure(lambda _: ops.Tensor, self._output_types)
-
- @property
- def output_shapes(self):
- return nest.map_structure(lambda _: tensor_shape.TensorShape([]),
- self._output_types)
-
- @property
- def output_types(self):
- return self._output_types
+ super(SqlDataset, self).__init__(
+ driver_name, data_source_name, query, output_types)
class LMDBDataset(dataset_ops.DatasetSource):
@@ -1013,7 +397,7 @@ class LMDBDataset(dataset_ops.DatasetSource):
filenames, dtype=dtypes.string, name="filenames")
def _as_variant_tensor(self):
- return contrib_gen_dataset_ops.lmdb_dataset(
+ return gen_experimental_dataset_ops.experimental_lmdb_dataset(
self._filenames,
output_types=nest.flatten(self.output_types),
output_shapes=nest.flatten(self.output_shapes))
diff --git a/tensorflow/contrib/data/python/ops/resampling.py b/tensorflow/contrib/data/python/ops/resampling.py
index 75642f143e..29d77528d9 100644
--- a/tensorflow/contrib/data/python/ops/resampling.py
+++ b/tensorflow/contrib/data/python/ops/resampling.py
@@ -17,22 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import numpy as np
-
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
-from tensorflow.contrib.data.python.ops import scan_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_util
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import logging_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import random_ops
+from tensorflow.python.data.experimental.ops import resampling
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.rejection_resample(...)`.")
def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
"""A transformation that resamples a dataset to achieve a target distribution.
@@ -52,243 +42,5 @@ def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- """Function from `Dataset` to `Dataset` that applies the transformation."""
- target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
- class_values_ds = dataset.map(class_func)
-
- # Get initial distribution.
- if initial_dist is not None:
- initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
- acceptance_dist, prob_of_original = (
- _calculate_acceptance_probs_with_mixing(initial_dist_t,
- target_dist_t))
- initial_dist_ds = dataset_ops.Dataset.from_tensors(
- initial_dist_t).repeat()
- acceptance_dist_ds = dataset_ops.Dataset.from_tensors(
- acceptance_dist).repeat()
- prob_of_original_ds = dataset_ops.Dataset.from_tensors(
- prob_of_original).repeat()
- else:
- initial_dist_ds = _estimate_initial_dist_ds(
- target_dist_t, class_values_ds)
- acceptance_and_original_prob_ds = initial_dist_ds.map(
- lambda initial: _calculate_acceptance_probs_with_mixing(
- initial, target_dist_t))
- acceptance_dist_ds = acceptance_and_original_prob_ds.map(
- lambda accept_prob, _: accept_prob)
- prob_of_original_ds = acceptance_and_original_prob_ds.map(
- lambda _, prob_original: prob_original)
- filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
- class_values_ds, seed)
- # Prefetch filtered dataset for speed.
- filtered_ds = filtered_ds.prefetch(3)
-
- prob_original_static = _get_prob_original_static(
- initial_dist_t, target_dist_t) if initial_dist is not None else None
- if prob_original_static == 1:
- return dataset_ops.Dataset.zip((class_values_ds, dataset))
- elif prob_original_static == 0:
- return filtered_ds
- else:
- return interleave_ops.sample_from_datasets(
- [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds],
- weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
- seed=seed)
-
- return _apply_fn
-
-
-def _get_prob_original_static(initial_dist_t, target_dist_t):
- """Returns the static probability of sampling from the original.
-
- `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters
- an Op that it isn't defined for. We have some custom logic to avoid this.
-
- Args:
- initial_dist_t: A tensor of the initial distribution.
- target_dist_t: A tensor of the target distribution.
-
- Returns:
- The probability of sampling from the original distribution as a constant,
- if it is a constant, or `None`.
- """
- init_static = tensor_util.constant_value(initial_dist_t)
- target_static = tensor_util.constant_value(target_dist_t)
-
- if init_static is None or target_static is None:
- return None
- else:
- return np.min(target_static / init_static)
-
-
-def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
- seed):
- """Filters a dataset based on per-class acceptance probabilities.
-
- Args:
- dataset: The dataset to be filtered.
- acceptance_dist_ds: A dataset of acceptance probabilities.
- initial_dist_ds: A dataset of the initial probability distribution, given or
- estimated.
- class_values_ds: A dataset of the corresponding classes.
- seed: (Optional.) Python integer seed for the resampler.
-
- Returns:
- A dataset of (class value, data) after filtering.
- """
- def maybe_warn_on_large_rejection(accept_dist, initial_dist):
- proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
- return control_flow_ops.cond(
- math_ops.less(proportion_rejected, .5),
- lambda: accept_dist,
- lambda: logging_ops.Print( # pylint: disable=g-long-lambda
- accept_dist, [proportion_rejected, initial_dist, accept_dist],
- message="Proportion of examples rejected by sampler is high: ",
- summarize=100,
- first_n=10))
-
- acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds,
- initial_dist_ds))
- .map(maybe_warn_on_large_rejection))
-
- def _gather_and_copy(class_val, acceptance_prob, data):
- return class_val, array_ops.gather(acceptance_prob, class_val), data
-
- current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
- (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy)
- filtered_ds = (
- current_probabilities_and_class_and_data_ds
- .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
- return filtered_ds.map(lambda class_value, _, data: (class_value, data))
-
-
-def _estimate_initial_dist_ds(
- target_dist_t, class_values_ds, dist_estimation_batch_size=32,
- smoothing_constant=10):
- num_classes = (target_dist_t.shape[0].value or
- array_ops.shape(target_dist_t)[0])
- initial_examples_per_class_seen = array_ops.fill(
- [num_classes], np.int64(smoothing_constant))
-
- def update_estimate_and_tile(num_examples_per_class_seen, c):
- updated_examples_per_class_seen, dist = _estimate_data_distribution(
- c, num_examples_per_class_seen)
- tiled_dist = array_ops.tile(
- array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
- return updated_examples_per_class_seen, tiled_dist
-
- initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
- .apply(scan_ops.scan(initial_examples_per_class_seen,
- update_estimate_and_tile))
- .apply(batching.unbatch()))
-
- return initial_dist_ds
-
-
-def _get_target_to_initial_ratio(initial_probs, target_probs):
- # Add tiny to initial_probs to avoid divide by zero.
- denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
- return target_probs / denom
-
-
-def _estimate_data_distribution(c, num_examples_per_class_seen):
- """Estimate data distribution as labels are seen.
-
- Args:
- c: The class labels. Type `int32`, shape `[batch_size]`.
- num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
- containing counts.
-
- Returns:
- num_examples_per_lass_seen: Updated counts. Type `int64`, shape
- `[num_classes]`.
- dist: The updated distribution. Type `float32`, shape `[num_classes]`.
- """
- num_classes = num_examples_per_class_seen.get_shape()[0].value
- # Update the class-count based on what labels are seen in batch.
- num_examples_per_class_seen = math_ops.add(
- num_examples_per_class_seen, math_ops.reduce_sum(
- array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
- init_prob_estimate = math_ops.truediv(
- num_examples_per_class_seen,
- math_ops.reduce_sum(num_examples_per_class_seen))
- dist = math_ops.cast(init_prob_estimate, dtypes.float32)
- return num_examples_per_class_seen, dist
-
-
-def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
- """Calculates the acceptance probabilities and mixing ratio.
-
- In this case, we assume that we can *either* sample from the original data
- distribution with probability `m`, or sample from a reshaped distribution
- that comes from rejection sampling on the original distribution. This
- rejection sampling is done on a per-class basis, with `a_i` representing the
- probability of accepting data from class `i`.
-
- This method is based on solving the following analysis for the reshaped
- distribution:
-
- Let F be the probability of a rejection (on any example).
- Let p_i be the proportion of examples in the data in class i (init_probs)
- Let a_i is the rate the rejection sampler should *accept* class i
- Let t_i is the target proportion in the minibatches for class i (target_probs)
-
- ```
- F = sum_i(p_i * (1-a_i))
- = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1
- ```
-
- An example with class `i` will be accepted if `k` rejections occur, then an
- example with class `i` is seen by the rejector, and it is accepted. This can
- be written as follows:
-
- ```
- t_i = sum_k=0^inf(F^k * p_i * a_i)
- = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1
- = p_i * a_i / sum_j(p_j * a_j) using F from above
- ```
-
- Note that the following constraints hold:
- ```
- 0 <= p_i <= 1, sum_i(p_i) = 1
- 0 <= a_i <= 1
- 0 <= t_i <= 1, sum_i(t_i) = 1
- ```
-
- A solution for a_i in terms of the other variables is the following:
- ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
-
- If we try to minimize the amount of data rejected, we get the following:
-
- M_max = max_i [ t_i / p_i ]
- M_min = min_i [ t_i / p_i ]
-
- The desired probability of accepting data if it comes from class `i`:
-
- a_i = (t_i/p_i - m) / (M_max - m)
-
- The desired probability of pulling a data element from the original dataset,
- rather than the filtered one:
-
- m = M_min
-
- Args:
- initial_probs: A Tensor of the initial probability distribution, given or
- estimated.
- target_probs: A Tensor of the corresponding classes.
-
- Returns:
- (A 1D Tensor with the per-class acceptance probabilities, the desired
- probability of pull from the original distribution.)
- """
- ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
- max_ratio = math_ops.reduce_max(ratio_l)
- min_ratio = math_ops.reduce_min(ratio_l)
-
- # Target prob to sample from original distribution.
- m = min_ratio
-
- # TODO(joelshor): Simplify fraction, if possible.
- a_i = (ratio_l - m) / (max_ratio - m)
- return a_i, m
+ return resampling.rejection_resample(class_func, target_dist, initial_dist,
+ seed)
diff --git a/tensorflow/contrib/data/python/ops/scan_ops.py b/tensorflow/contrib/data/python/ops/scan_ops.py
index c52582cd35..0ca9fddb23 100644
--- a/tensorflow/contrib/data/python/ops/scan_ops.py
+++ b/tensorflow/contrib/data/python/ops/scan_ops.py
@@ -17,137 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import collections
-
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import nest
-from tensorflow.python.data.util import sparse
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.ops import gen_dataset_ops
-
-
-class _ScanDataset(dataset_ops.UnaryDataset):
- """A dataset that scans a function across its input."""
-
- def __init__(self, input_dataset, initial_state, scan_func):
- """See `scan()` for details."""
- super(_ScanDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- with ops.name_scope("initial_state"):
- # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
- # values to tensors.
- self._initial_state = nest.pack_sequence_as(initial_state, [
- sparse_tensor.SparseTensor.from_value(t)
- if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
- t, name="component_%d" % i)
- for i, t in enumerate(nest.flatten(initial_state))
- ])
-
- # Compute initial values for the state classes, shapes and types based on
- # the initial state. The shapes may be refined by running `tf_scan_func` one
- # or more times below.
- self._state_classes = sparse.get_classes(self._initial_state)
- self._state_shapes = nest.pack_sequence_as(
- self._initial_state,
- [t.get_shape() for t in nest.flatten(self._initial_state)])
- self._state_types = nest.pack_sequence_as(
- self._initial_state,
- [t.dtype for t in nest.flatten(self._initial_state)])
-
- # Will be populated by calling `tf_scan_func`.
- self._output_classes = None
- self._output_shapes = None
- self._output_types = None
-
- # Iteratively rerun the scan function until reaching a fixed point on
- # `self._state_shapes`.
- need_to_rerun = True
- while need_to_rerun:
-
- wrapped_func = dataset_ops.StructuredFunctionWrapper(
- scan_func, "tf.contrib.data.scan()",
- input_classes=(self._state_classes, input_dataset.output_classes),
- input_shapes=(self._state_shapes, input_dataset.output_shapes),
- input_types=(self._state_types, input_dataset.output_types),
- add_to_graph=False)
- if not (
- isinstance(wrapped_func.output_types, collections.Sequence) and
- len(wrapped_func.output_types) == 2):
- raise TypeError("The scan function must return a pair comprising the "
- "new state and the output value.")
-
- new_state_classes, self._output_classes = wrapped_func.output_classes
-
- # Extract and validate class information from the returned values.
- for new_state_class, state_class in zip(
- nest.flatten(new_state_classes),
- nest.flatten(self._state_classes)):
- if not issubclass(new_state_class, state_class):
- raise TypeError(
- "The element classes for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_classes, new_state_classes))
-
- # Extract and validate type information from the returned values.
- new_state_types, self._output_types = wrapped_func.output_types
- for new_state_type, state_type in zip(
- nest.flatten(new_state_types), nest.flatten(self._state_types)):
- if new_state_type != state_type:
- raise TypeError(
- "The element types for the new state must match the initial "
- "state. Expected %s; got %s." %
- (self._state_types, new_state_types))
-
- # Extract shape information from the returned values.
- new_state_shapes, self._output_shapes = wrapped_func.output_shapes
-
- flat_state_shapes = nest.flatten(self._state_shapes)
- flat_new_state_shapes = nest.flatten(new_state_shapes)
- weakened_state_shapes = [
- original.most_specific_compatible_shape(new)
- for original, new in zip(flat_state_shapes, flat_new_state_shapes)
- ]
-
- need_to_rerun = False
- for original_shape, weakened_shape in zip(flat_state_shapes,
- weakened_state_shapes):
- if original_shape.ndims is not None and (
- weakened_shape.ndims is None or
- original_shape.as_list() != weakened_shape.as_list()):
- need_to_rerun = True
- break
-
- if need_to_rerun:
- self._state_shapes = nest.pack_sequence_as(self._state_shapes,
- weakened_state_shapes)
-
- self._scan_func = wrapped_func.function
- self._scan_func.add_to_graph(ops.get_default_graph())
-
- def _as_variant_tensor(self):
- input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
- return gen_dataset_ops.scan_dataset(
- input_t,
- nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
- self._scan_func.captured_inputs,
- f=self._scan_func,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._output_classes
-
- @property
- def output_shapes(self):
- return self._output_shapes
-
- @property
- def output_types(self):
- return self._output_types
+from tensorflow.python.data.experimental.ops import scan_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.scan(...)`.")
def scan(initial_state, scan_func):
"""A transformation that scans a function across an input dataset.
@@ -168,7 +42,4 @@ def scan(initial_state, scan_func):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
- def _apply_fn(dataset):
- return _ScanDataset(dataset, initial_state, scan_func)
-
- return _apply_fn
+ return scan_ops.scan(initial_state, scan_func)
diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py
index 985d1d87d0..329b34fdfe 100644
--- a/tensorflow/contrib/data/python/ops/shuffle_ops.py
+++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py
@@ -17,54 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import random_seed
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
-
-
-class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that fuses `shuffle` and `repeat`."""
-
- def __init__(self, input_dataset, buffer_size, count=None, seed=None):
- super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._buffer_size = ops.convert_to_tensor(
- buffer_size, dtype=dtypes.int64, name="buffer_size")
- if count is None:
- self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
- else:
- self._count = ops.convert_to_tensor(
- count, dtype=dtypes.int64, name="count")
- self._seed, self._seed2 = random_seed.get_seed(seed)
-
- def _as_variant_tensor(self):
- # pylint: disable=protected-access
- input_resource = self._input_dataset._as_variant_tensor()
- return gen_dataset_ops.shuffle_and_repeat_dataset(
- input_resource,
- buffer_size=self._buffer_size,
- count=self._count,
- seed=self._seed,
- seed2=self._seed2,
- **dataset_ops.flat_structure(self))
- # pylint: enable=protected-access
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None,
+ "Use `tf.data.experimental.shuffle_and_repeat(...)`.")
def shuffle_and_repeat(buffer_size, count=None, seed=None):
"""Shuffles and repeats a Dataset returning a new permutation for each epoch.
@@ -93,8 +51,4 @@ def shuffle_and_repeat(buffer_size, count=None, seed=None):
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset): # pylint: disable=missing-docstring
- return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed)
-
- return _apply_fn
+ return shuffle_ops.shuffle_and_repeat(buffer_size, count, seed)
diff --git a/tensorflow/contrib/data/python/ops/threadpool.py b/tensorflow/contrib/data/python/ops/threadpool.py
index 9d165ad52a..20cceb4647 100644
--- a/tensorflow/contrib/data/python/ops/threadpool.py
+++ b/tensorflow/contrib/data/python/ops/threadpool.py
@@ -17,89 +17,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import threading
-
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.eager import context
-from tensorflow.python.ops import resource_variable_ops
-
-_uid_counter = 0
-_uid_lock = threading.Lock()
-
-
-def _generate_shared_name(prefix):
- with _uid_lock:
- global _uid_counter
- uid = _uid_counter
- _uid_counter += 1
- return "{}{}".format(prefix, uid)
-
-
-# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
-class PrivateThreadPool(object):
- """A stateful resource that represents a private thread pool."""
-
- def __init__(self, num_threads, display_name=None,
- max_intra_op_parallelism=1):
- """Creates a `PrivateThreadPool` with the given number of threads."""
- if context.executing_eagerly():
- shared_name = _generate_shared_name("privatethreadpool")
- self._resource = gen_dataset_ops.thread_pool_handle(
- num_threads=num_threads,
- max_intra_op_parallelism=max_intra_op_parallelism,
- display_name=display_name,
- shared_name=shared_name)
- self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
- handle=self._resource, handle_device=context.context().device_name)
- else:
- self._resource = gen_dataset_ops.thread_pool_handle(
- num_threads=num_threads,
- max_intra_op_parallelism=max_intra_op_parallelism,
- display_name=display_name)
-
-
-class _ThreadPoolDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and sets a custom threadpool."""
-
- def __init__(self, input_dataset, thread_pool):
- super(_ThreadPoolDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- self._thread_pool = thread_pool
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.thread_pool_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._thread_pool._resource, # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
-
-# TODO(b/73383364): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
-def override_threadpool(dataset, thread_pool):
- """Returns a new dataset that uses the given thread pool for its operations.
-
- Args:
- dataset: A `tf.data.Dataset` object.
- thread_pool: A `PrivateThreadPool` object.
-
- Returns:
- A dataset containing the same values as `dataset`, but which uses
- `thread_pool` to compute any of its parallel operations (such as
- `tf.data.Dataset.map`).
- """
- return _ThreadPoolDataset(dataset, thread_pool)
+# pylint: disable=unused-import
+from tensorflow.python.data.experimental.ops.threadpool import override_threadpool
+from tensorflow.python.data.experimental.ops.threadpool import PrivateThreadPool
diff --git a/tensorflow/contrib/data/python/ops/unique.py b/tensorflow/contrib/data/python/ops/unique.py
index bad67a580d..909d06c677 100644
--- a/tensorflow/contrib/data/python/ops/unique.py
+++ b/tensorflow/contrib/data/python/ops/unique.py
@@ -17,12 +17,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.framework import dtypes
+from tensorflow.python.data.experimental.ops import unique as experimental_unique
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(None, "Use `tf.data.experimental.unique()`.")
def unique():
"""Creates a `Dataset` from another `Dataset`, discarding duplicates.
@@ -40,39 +39,4 @@ def unique():
A `Dataset` transformation function, which can be passed to
`tf.data.Dataset.apply`.
"""
-
- def _apply_fn(dataset):
- return _UniqueDataset(dataset)
-
- return _apply_fn
-
-
-class _UniqueDataset(dataset_ops.UnaryDataset):
- """A `Dataset` contains the unique elements from its input."""
-
- def __init__(self, input_dataset):
- """See `unique()` for details."""
- super(_UniqueDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
- dtypes.string):
- raise TypeError(
- "`tf.contrib.data.unique()` only supports inputs with a single "
- "`tf.int32`, `tf.int64`, or `tf.string` component.")
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.unique_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
+ return experimental_unique.unique()
diff --git a/tensorflow/contrib/data/python/ops/writers.py b/tensorflow/contrib/data/python/ops/writers.py
index c455fdcba6..42fb69bf07 100644
--- a/tensorflow/contrib/data/python/ops/writers.py
+++ b/tensorflow/contrib/data/python/ops/writers.py
@@ -17,42 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.data.ops import dataset_ops
-from tensorflow.python.data.util import convert
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import tensor_shape
-from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.data.experimental.ops import writers
+from tensorflow.python.util import deprecation
-class TFRecordWriter(object):
+class TFRecordWriter(writers.TFRecordWriter):
"""Writes data to a TFRecord file."""
+ @deprecation.deprecated(
+ None, "Use `tf.data.experimental.TFRecordWriter(...)`.")
def __init__(self, filename, compression_type=None):
- self._filename = ops.convert_to_tensor(
- filename, dtypes.string, name="filename")
- self._compression_type = convert.optional_param_to_tensor(
- "compression_type",
- compression_type,
- argument_default="",
- argument_dtype=dtypes.string)
-
- def write(self, dataset):
- """Returns a `tf.Operation` to write a dataset to a file.
-
- Args:
- dataset: a `tf.data.Dataset` whose elements are to be written to a file
-
- Returns:
- A `tf.Operation` that, when run, writes contents of `dataset` to a file.
- """
- if not isinstance(dataset, dataset_ops.Dataset):
- raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
- if (dataset.output_types != dtypes.string or
- dataset.output_shapes != tensor_shape.scalar()):
- raise TypeError(
- "`dataset` must produce scalar `DT_STRING` tensors whereas it "
- "produces shape {0} and types {1}".format(dataset.output_shapes,
- dataset.output_types))
- return gen_dataset_ops.dataset_to_tf_record(
- dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access
+ super(TFRecordWriter, self).__init__(filename, compression_type)
diff --git a/tensorflow/contrib/decision_trees/proto/BUILD b/tensorflow/contrib/decision_trees/proto/BUILD
index 3b50a48336..06940a90d5 100644
--- a/tensorflow/contrib/decision_trees/proto/BUILD
+++ b/tensorflow/contrib/decision_trees/proto/BUILD
@@ -17,7 +17,6 @@ tf_proto_library(
name = "generic_tree_model",
srcs = ["generic_tree_model.proto"],
cc_api_version = 2,
- java_api_version = 2,
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/contrib/distribute/README.md b/tensorflow/contrib/distribute/README.md
index 91a27f97b7..2e025765e4 100644
--- a/tensorflow/contrib/distribute/README.md
+++ b/tensorflow/contrib/distribute/README.md
@@ -231,7 +231,8 @@ The same `input_fn` will be used for all workers if you use
important to shuffle your dataset in your `input_fn`.
`MirroredStrategy` will insert a `tf.dataset.Dataset.shard` call in you
-`input_fn`. As a result, each worker gets a fraction of your input data.
+`input_fn` if `auto_shard_dataset` is set to `True`. As a result, each worker
+gets a fraction of your input data.
### Performance Tips
diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD
index e329b964c4..defa82f98a 100644
--- a/tensorflow/contrib/distribute/python/BUILD
+++ b/tensorflow/contrib/distribute/python/BUILD
@@ -22,14 +22,15 @@ py_library(
visibility = ["//tensorflow:internal"],
deps = [
":input_ops",
+ ":prefetching_ops_v2",
"//tensorflow/python:array_ops",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:device_util",
"//tensorflow/python:distribute",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:resource_variable_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
- "//tensorflow/python/data/ops:multi_device_iterator_ops",
"//tensorflow/python/eager:context",
"//tensorflow/python/training/checkpointable:base",
"@six_archive//:six",
@@ -648,6 +649,32 @@ cuda_py_test(
)
py_library(
+ name = "prefetching_ops_v2",
+ srcs = ["prefetching_ops_v2.py"],
+ deps = [
+ "//tensorflow/contrib/data/python/ops:prefetching_ops",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+cuda_py_test(
+ name = "prefetching_ops_v2_test",
+ srcs = ["prefetching_ops_v2_test.py"],
+ additional_deps = [
+ ":prefetching_ops_v2",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+)
+
+py_library(
name = "input_ops",
srcs = ["input_ops.py"],
visibility = ["//tensorflow:internal"],
@@ -701,6 +728,7 @@ cuda_py_test(
additional_deps = [
":keras_test_lib",
],
+ shard_count = 16,
tags = [
"multi_and_single_gpu",
"no_pip",
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
index c900b41e14..9809204f8f 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy.py
@@ -216,7 +216,7 @@ class CollectiveAllReduceStrategy(mirrored_strategy.MirroredStrategy):
"""Configures the object.
Args:
- session_config: a @{tf.ConfigProto}
+ session_config: a `tf.ConfigProto`
cluster_spec: a dict, ClusterDef or ClusterSpec object specifying the
cluster configurations.
task_type: the current task type, such as "worker".
diff --git a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
index 33ffbf6abe..6796a23d46 100644
--- a/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/collective_all_reduce_strategy_test.py
@@ -128,7 +128,8 @@ class CollectiveAllReduceStrategyTestBase(
# TODO(yuefengz): support non-Mirrored variable as destinations.
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(
+ d.update(v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/keras_test.py b/tensorflow/contrib/distribute/python/keras_test.py
index 3aab2c521f..993cb2bac3 100644
--- a/tensorflow/contrib/distribute/python/keras_test.py
+++ b/tensorflow/contrib/distribute/python/keras_test.py
@@ -189,6 +189,14 @@ def get_dataset(distribution):
return dataset
+def get_predict_dataset(distribution):
+ inputs = np.zeros((10, 3), dtype=np.float32)
+ dataset = dataset_ops.Dataset.from_tensor_slices(inputs)
+ dataset = dataset.repeat(100)
+ dataset = batch_wrapper(dataset, 10, distribution)
+ return dataset
+
+
strategies = [combinations.default_strategy,
combinations.one_device_strategy,
combinations.mirrored_strategy_with_gpu_and_cpu,
@@ -387,16 +395,26 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
distributed_training_utils.validate_distributed_dataset_inputs(
strategy, x, y)
- def test_calling_model_with_numpy_arrays(self):
+ # TODO(anjalisridhar): Move this test along with other numpy related tests to
+ # its own class.
+ @combinations.generate(strategy_combinations())
+ def test_creating_var_with_numpy_arrays(self, distribution):
+ with self.cached_session():
+ x = np.asarray(np.random.random((64, 3)), dtype=np.float32)
+ var_x = distributed_training_utils.get_var_for_numpy(distribution, x)
+ val = self.evaluate(var_x.value())
+ # Verify that the numpy value is copied to the variable.
+ self.assertAllEqual(x, val)
+
+ @combinations.generate(strategy_combinations())
+ def test_calling_model_with_numpy_arrays(self, distribution):
with self.cached_session():
model = get_model()
optimizer = gradient_descent.GradientDescentOptimizer(0.001)
loss = 'mse'
- metrics = ['mae', keras.metrics.CategoricalAccuracy()]
- strategy = mirrored_strategy.MirroredStrategy(['/device:GPU:1',
- '/device:GPU:0'])
- model.compile(optimizer, loss, metrics=metrics, distribute=strategy)
+ metrics = ['mae']
+ model.compile(optimizer, loss, metrics=metrics, distribute=distribution)
inputs = np.zeros((64, 3), dtype=np.float32)
targets = np.zeros((64, 4), dtype=np.float32)
@@ -420,6 +438,48 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.predict(inputs, batch_size=8)
@combinations.generate(strategy_combinations())
+ def test_calling_model_with_nested_numpy_arrays(self, distribution):
+ with self.cached_session():
+ a = keras.layers.Input(shape=(3,), name='input_a')
+ b = keras.layers.Input(shape=(3,), name='input_b')
+
+ dense = keras.layers.Dense(4, name='dense')
+ c = dense(a)
+ d = dense(b)
+ e = keras.layers.Dropout(0.5, name='dropout')(c)
+
+ model = keras.models.Model([a, b], [d, e])
+
+ optimizer = gradient_descent.GradientDescentOptimizer(learning_rate=0.001)
+ loss = 'mse'
+ model.compile(optimizer, loss, distribute=distribution)
+
+ input_a_np = np.asarray(np.random.random((64, 3)), dtype=np.float32)
+ input_b_np = np.asarray(np.random.random((64, 3)), dtype=np.float32)
+ inputs = [input_a_np, input_b_np]
+
+ output_d_np = np.asarray(np.random.random((64, 4)), dtype=np.float32)
+ output_e_np = np.asarray(np.random.random((64, 4)), dtype=np.float32)
+ targets = [output_d_np, output_e_np]
+
+ # Call fit with validation data
+ model.fit(inputs, targets, epochs=1, batch_size=8, verbose=0)
+
+ # TODO(anjalisridhar): We need tests for when the batch size and steps are
+ # smaller and results in a 0 batch_size and steps value.
+ model.evaluate(inputs, targets)
+ # with steps
+ model.evaluate(inputs, targets, steps=2)
+ # with batch_size
+ model.evaluate(inputs, targets, batch_size=8)
+
+ model.predict(inputs)
+ # with steps
+ model.predict(inputs, steps=2)
+ # with batch_size
+ model.predict(inputs, batch_size=8)
+
+ @combinations.generate(strategy_combinations())
def test_calling_model_on_same_dataset(self, distribution):
with self.cached_session():
model = get_model()
@@ -436,7 +496,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
validation_data=dataset, validation_steps=2)
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
validation_data=dataset, validation_steps=2)
- model.predict(dataset, steps=2)
+ model.predict(get_predict_dataset(distribution), steps=2)
# TODO(priyag): Enable this test for TPU. Currently tuples/dict don't work
# as clone_model's input_tensors argument only seems to accept list and not
@@ -496,10 +556,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
- model.predict(dataset, steps=2)
- # Test with validation data
- model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=0,
- validation_data=dataset, validation_steps=2)
+ model.predict(get_predict_dataset(distribution), steps=2)
@combinations.generate(strategy_and_optimizer_combinations())
def test_fit_eval_and_predict_with_optimizer(self, distribution, optimizer):
@@ -513,7 +570,7 @@ class TestWithDistributionStrategy(test.TestCase, parameterized.TestCase):
model.fit(dataset, epochs=1, steps_per_epoch=2, verbose=1)
model.evaluate(dataset, steps=2, verbose=1)
- model.predict(dataset, steps=2)
+ model.predict(get_predict_dataset(distribution), steps=2)
def test_unsupported_features(self):
with self.cached_session():
@@ -726,8 +783,12 @@ class NormalizationLayerWithDistributionStrategyTest(
dataset = dataset.repeat(100)
dataset = batch_wrapper(dataset, 32, distribution)
+ predict_dataset = dataset_ops.Dataset.from_tensor_slices(x)
+ predict_dataset = predict_dataset.repeat(100)
+ predict_dataset = batch_wrapper(predict_dataset, 32, distribution)
+
model.fit(dataset, epochs=4, verbose=0, steps_per_epoch=10)
- out = model.predict(dataset, steps=2)
+ out = model.predict(predict_dataset, steps=2)
out -= keras.backend.eval(norm.beta)
out /= keras.backend.eval(norm.gamma)
np.testing.assert_allclose(out.mean(), 0.0, atol=1e-1)
@@ -811,8 +872,7 @@ class CorrectnessWithDistributionStrategyTest(test.TestCase,
predict_batch_size = 4
if with_distribution:
predict_batch_size //= with_distribution.num_towers
- predict_dataset = dataset_ops.Dataset.from_tensor_slices((x_predict,
- x_predict))
+ predict_dataset = dataset_ops.Dataset.from_tensor_slices(x_predict)
predict_dataset = batch_wrapper(predict_dataset,
predict_batch_size, distribution)
predict_result = model.predict(predict_dataset, steps=1)
diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py
index f7773aff4f..8163494c8e 100644
--- a/tensorflow/contrib/distribute/python/metrics_v1_test.py
+++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py
@@ -86,11 +86,10 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase):
def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn):
with ops.Graph().as_default(), distribution.scope():
iterator = distribution.distribute_dataset(
- dataset_fn).make_initializable_iterator()
+ dataset_fn).make_one_shot_iterator()
value, update = distribution.call_for_each_tower(
metric_fn, iterator.get_next())
update = distribution.group(update)
- self.evaluate(iterator.initializer)
self.evaluate(variables.local_variables_initializer())
# TODO(josh11b): Once we switch to using a global batch size for input,
# replace "distribution.num_towers" with "1".
diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py
index d082d5c419..ba147e7824 100644
--- a/tensorflow/contrib/distribute/python/minimize_loss_test.py
+++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py
@@ -41,14 +41,6 @@ from tensorflow.python.ops.losses import losses_impl
class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
- def _get_iterator(self, ds):
- if context.executing_eagerly():
- iterator = ds.make_one_shot_iterator()
- else:
- iterator = ds.make_initializable_iterator()
- self.evaluate(iterator.initializer)
- return iterator
-
@combinations.generate(
combinations.times(
combinations.distributions_and_v1_optimizers(),
@@ -70,7 +62,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -106,7 +99,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.group(
@@ -165,7 +159,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, *inputs, run_concurrently=layer.built))
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -249,7 +244,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS)
return control_flow_ops.group(fetches)
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -342,7 +338,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
distribution.call_for_each_tower(
model_fn, x, y, run_concurrently=False))
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return distribution.run_steps_on_dataset(
@@ -435,7 +432,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase):
output=loss)
return distribution.group(train_op)
- iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn))
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
initial_loss = lambda: constant_op.constant(1e7)
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy.py b/tensorflow/contrib/distribute/python/mirrored_strategy.py
index 504f45a695..6bd380a22d 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy.py
@@ -347,6 +347,8 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
set, the `configure` method will try to find the best one.
prefetch_on_device: optional boolean to specify whether to prefetch input
data to devices.
+ auto_shard_dataset: whether to auto-shard the dataset when there are
+ multiple workers.
"""
def __init__(self,
@@ -354,11 +356,13 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
num_gpus=None,
num_gpus_per_worker=None,
cross_tower_ops=None,
- prefetch_on_device=None):
+ prefetch_on_device=None,
+ auto_shard_dataset=False):
super(MirroredStrategy, self).__init__()
self._cross_tower_ops = cross_tower_ops
self._prefetch_on_device = prefetch_on_device
+ self._auto_shard_dataset = auto_shard_dataset
# Rememeber num GPUs which might be needed by `configure` method.
if num_gpus is not None and num_gpus_per_worker is not None:
raise ValueError(
@@ -477,11 +481,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
if self._cluster_spec:
return values.MultiWorkerDataset(
partial(self._call_dataset_fn, dataset_fn), self._worker_device_map,
- self._prefetch_on_device)
+ self._prefetch_on_device, self._auto_shard_dataset)
else:
return values.PerDeviceDataset(
- self._call_dataset_fn(dataset_fn),
- self._devices,
+ self._call_dataset_fn(dataset_fn), self._devices,
self._prefetch_on_device)
# TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed.
@@ -624,9 +627,11 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
return self._get_cross_tower_ops().batch_reduce(aggregation,
value_destination_pairs)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
# TODO(josh11b): In eager mode, use one thread per device.
assert isinstance(var, values.DistributedVariable)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
updates = {}
for d, v in var._index.items(): # pylint: disable=protected-access
name = "update_%d" % self._device_index.get(d)
@@ -635,10 +640,12 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
updates[d] = fn(v,
*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
- return values.regroup(updates, values.Mirrored)
+ return values.update_regroup(self, updates, should_group)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
assert isinstance(colocate_with, list)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
# TODO(josh11b): In eager mode, use one thread per device.
updates = {}
for d in colocate_with:
@@ -646,7 +653,7 @@ class MirroredStrategy(distribute_lib.DistributionStrategy):
with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
updates[d] = fn(*values.select_device_mirrored(d, args),
**values.select_device_mirrored(d, kwargs))
- return values.regroup(updates, values.Mirrored)
+ return values.update_regroup(self, updates, should_group)
def read_var(self, tower_local_var):
"""Read the aggregate value of a tower-local variable."""
diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
index 04c712ce1d..eeac528329 100644
--- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
+++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py
@@ -300,15 +300,9 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
dist = mirrored_strategy.MirroredStrategy(
["/device:GPU:0", "/device:CPU:0"])
- ds = dist.distribute_dataset(
- lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10))
- if context.executing_eagerly():
- iterator = ds.make_one_shot_iterator()
- else:
- iterator = ds.make_initializable_iterator()
- self.evaluate([iterator.initializer])
-
- features = iterator.get_next()
+ features = dist.distribute_dataset(
+ lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)
+ ).make_one_shot_iterator().get_next()
with dist.scope():
result = dist.call_for_each_tower(
@@ -832,7 +826,7 @@ class MirroredStrategyVariableCreationTest(test.TestCase):
with dist.scope():
ret_v_sum = dist.call_for_each_tower(model_fn, run_concurrently=False)
- update_ops = dist.unwrap(dist.update(ret_v_sum, update, 5.0))
+ update_ops = dist.update(ret_v_sum, update, 5.0, grouped=False)
# Initialize variables.
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py
index 17b7ab74f6..7644acedc9 100644
--- a/tensorflow/contrib/distribute/python/monitor.py
+++ b/tensorflow/contrib/distribute/python/monitor.py
@@ -51,7 +51,6 @@ class Monitor(object):
else:
if session is None:
raise ValueError("Should provide a `session` in Graph mode.")
- session.run(step_callable._iterator.initializer) # pylint: disable=protected-access
self._run_step = session.make_callable(step_callable())
session.run(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/one_device_strategy.py b/tensorflow/contrib/distribute/python/one_device_strategy.py
index 23b220f64b..f525919048 100644
--- a/tensorflow/contrib/distribute/python/one_device_strategy.py
+++ b/tensorflow/contrib/distribute/python/one_device_strategy.py
@@ -141,14 +141,21 @@ class OneDeviceStrategy(distribute_lib.DistributionStrategy):
else:
assert False
- def _update(self, var, fn, *args, **kwargs):
- with ops.device(self._device), distribute_lib.UpdateContext(self._device):
- return fn(var, *args, **kwargs)
+ def _update(self, var, options, fn, *args, **kwargs):
+ # The implementations of _update() and _update_non_slot() are identical
+ # except _update() passes `var` as the first argument to `fn()`.
+ return self._update_non_slot(var, options, fn, var, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
del colocate_with
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.device(self._device), distribute_lib.UpdateContext(self._device):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def read_var(self, tower_local_var):
"""Read the aggregate value of a tower-local variable."""
diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
index 3064433129..6e9ba37a19 100644
--- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py
+++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py
@@ -42,11 +42,8 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
model_fn, dataset_fn, layer = minimize_loss_example(
optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss)
- ds = distribution.distribute_dataset(dataset_fn)
- if context.executing_eagerly():
- iterator = ds.make_one_shot_iterator()
- else:
- iterator = ds.make_initializable_iterator()
+ iterator = distribution.distribute_dataset(
+ dataset_fn).make_one_shot_iterator()
def run_step():
return control_flow_ops.group(distribution.unwrap(
@@ -55,7 +52,6 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase):
if not context.executing_eagerly():
with self.cached_session() as sess:
- sess.run(iterator.initializer)
run_step = sess.make_callable(run_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy.py b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
index 1125d027f6..6ddd91507b 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy.py
@@ -343,21 +343,33 @@ class ParameterServerStrategy(distribute_lib.DistributionStrategy):
return nest.map_structure(_select_fn, structured)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
if isinstance(var, values.AggregatingVariable):
var = var.get()
if not isinstance(var, resource_variable_ops.ResourceVariable):
raise ValueError(
"You can not update `var` %r. It must be a Variable." % var)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.colocate_with(var), distribute_lib.UpdateContext(var.device):
- return fn(var, *self._select_single_value(args),
- **self._select_single_value(kwargs))
+ result = fn(var, *self._select_single_value(args),
+ **self._select_single_value(kwargs))
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
# TODO(yuefengz): does it need to call _select_single_value?
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
with ops.device(
colocate_with.device), distribute_lib.UpdateContext(colocate_with):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def _unwrap(self, val):
if isinstance(val, values.DistributedValues):
diff --git a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
index 12789e0bc9..353d11a583 100644
--- a/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
+++ b/tensorflow/contrib/distribute/python/parameter_server_strategy_test.py
@@ -395,7 +395,8 @@ class ParameterServerStrategyTestBase(
# TODO(yuefengz): support non-Mirrored variable as destinations.
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(
+ d.update(v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
new file mode 100644
index 0000000000..d48aa9c89b
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py
@@ -0,0 +1,232 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Extension of prefetching_ops to support more than one device."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.util import nest as data_nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.util import nest
+
+
+# pylint: disable=protected-access
+class _PrefetchToDeviceIterator(object):
+ """A replacement for `tf.data.Iterator` that prefetches to another device.
+
+ Args:
+ input_dataset: The input dataset.
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ devices: Devices on which to prefetch.
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be shared
+ under the given name across multiple sessions that share the same devices
+ (e.g. when using a remote server). Only used if one_shot is False.
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ one_shot,
+ devices,
+ buffer_size,
+ shared_name=None):
+ self._input_dataset = input_dataset
+ self._get_next_call_count = 0
+ self._one_shot = one_shot
+ if shared_name is None:
+ shared_name = ""
+ self._devices = devices
+
+ if self._one_shot:
+ self._input_iterator = input_dataset.make_one_shot_iterator()
+ else:
+ self._input_iterator = iterator_ops.Iterator.from_structure(
+ self._input_dataset.output_types, self._input_dataset.output_shapes,
+ shared_name, self._input_dataset.output_classes)
+ input_iterator_handle = self._input_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, self._input_iterator.output_types,
+ self._input_iterator.output_shapes,
+ self._input_iterator.output_classes)
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ target_device = ged_ops.experimental_iterator_get_device(
+ self._input_iterator._iterator_resource)
+ self._buffering_resources = []
+ for device in nest.flatten(self._devices):
+ with ops.device(device):
+ buffer_resource_handle = prefetching_ops.function_buffering_resource(
+ f=_prefetch_fn,
+ output_types=data_nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes)),
+ target_device=target_device,
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ shared_name=shared_name)
+ self._buffering_resources.append(buffer_resource_handle)
+
+ if not self._one_shot:
+ reset_ops = []
+ for buffer_resource in self._buffering_resources:
+ reset_ops.append(
+ ged_ops.experimental_function_buffering_resource_reset(
+ buffer_resource))
+ with ops.control_dependencies(reset_ops):
+ self._initializer = self._input_iterator.make_initializer(
+ self._input_dataset)
+
+ def get_next(self, name=None):
+ """See `tf.data.Iterator.get_next`."""
+ self._get_next_call_count += 1
+ if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
+ warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
+
+ flat_result = []
+ # TODO(priyag): This will fail if the input size (typically number of
+ # batches) is not divisible by number of devices.
+ # How do we handle that more gracefully / let the user know?
+ for buffer_resource in self._buffering_resources:
+ flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
+ buffer_resource,
+ output_types=data_nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ name=name)
+
+ ret = sparse.deserialize_sparse_tensors(
+ data_nest.pack_sequence_as(self.output_types, flat_ret),
+ self.output_types, self.output_shapes, self.output_classes)
+
+ for tensor, shape in zip(
+ data_nest.flatten(ret), data_nest.flatten(self.output_shapes)):
+ if isinstance(tensor, ops.Tensor):
+ tensor.set_shape(shape)
+ flat_result.append(ret)
+
+ return nest.pack_sequence_as(self._devices, flat_result)
+
+ @property
+ def initializer(self):
+ if self._one_shot:
+ raise NotImplementedError("Can't initialize a one_shot_iterator")
+ return self._initializer
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+# pylint: enable=protected-access
+
+
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` whose iterator prefetches elements to other device(s)."""
+
+ def __init__(self, input_dataset, devices, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._devices = devices
+ self._buffer_size = buffer_size if buffer_size is not None else 1
+
+ def make_one_shot_iterator(self):
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=True,
+ devices=self._devices,
+ buffer_size=self._buffer_size)
+
+ def make_initializable_iterator(self, shared_name=None):
+ if context.executing_eagerly():
+ raise RuntimeError(
+ "make_initializable_iterator is not supported when eager "
+ "execution is enabled.")
+
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=False,
+ devices=self._devices,
+ buffer_size=self._buffer_size,
+ shared_name=shared_name)
+
+ def _as_variant_tensor(self):
+ # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
+ # transformation methods is called.
+ # TODO(mrry): Investigate support for chaining further transformations after
+ # the prefetch, including GPU support.
+ raise NotImplementedError("`prefetch_to_devices()` must be the last "
+ "transformation in a dataset pipeline.")
+
+ # TODO(priyag): Fix the output types, shapes and classes to match the result
+ # of get_next (which has the additional nesting layer of devices now).
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+def prefetch_to_devices(devices, buffer_size=None):
+ """A transformation that prefetches dataset values to the given `devices`.
+
+ NOTE: Although the transformation creates a `tf.data.Dataset`, the
+ transformation must be the final `Dataset` in the input pipeline.
+
+ Args:
+ devices: A nested structure of devices on which to prefetch the data. It can
+ be a single device name, or a tuple or list of device names.
+ buffer_size: (Optional.) The number of elements to buffer on each device.
+ Defaults to an automatically chosen value.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _PrefetchToDeviceDataset(dataset, devices, buffer_size)
+
+ return _apply_fn
diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
new file mode 100644
index 0000000000..16799104e8
--- /dev/null
+++ b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py
@@ -0,0 +1,90 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for prefetching_ops_v2."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.distribute.python import prefetching_ops_v2
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import test_util
+from tensorflow.python.platform import test
+
+
+class PrefetchingOpsV2Test(test.TestCase):
+
+ def testPrefetchToOneDevice(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices("/gpu:0"))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ with self.cached_session() as sess:
+ for i in range(10):
+ self.assertEqual(i, sess.run(next_element))
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(next_element)
+
+ def testPrefetchToTwoDevicesInAList(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
+
+ iterator = device_dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+
+ output = []
+ # TODO(rohanj): Modify test to go till the end of the dataset when we
+ # switch to MultiDeviceIterator.
+ with self.cached_session() as sess:
+ for _ in range(4):
+ result = sess.run(next_element)
+ self.assertEqual(2, len(result))
+ output.extend(result)
+ self.assertEquals(set(range(8)), set(output))
+
+ def testPrefetchToTwoDevicesWithReinit(self):
+ if not test_util.is_gpu_available():
+ self.skipTest("No GPU available")
+
+ host_dataset = dataset_ops.Dataset.range(10)
+ device_dataset = host_dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"]))
+
+ iterator = device_dataset.make_initializable_iterator()
+ next_element = iterator.get_next()
+
+ # TODO(rohanj): Modify test to go till the end of the dataset when we
+ # switch to MultiDeviceIterator.
+ with self.cached_session() as sess:
+ sess.run(iterator.initializer)
+ for _ in range(4):
+ sess.run(next_element)
+ sess.run(iterator.initializer)
+ for _ in range(4):
+ sess.run(next_element)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py
index 23bf36184f..1b5a4f64e5 100644
--- a/tensorflow/contrib/distribute/python/step_fn.py
+++ b/tensorflow/contrib/distribute/python/step_fn.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.eager import backprop
-from tensorflow.python.eager import context
from tensorflow.python.training import optimizer as optimizer_lib
@@ -51,11 +50,7 @@ class StandardInputStep(Step):
def __init__(self, dataset_fn, distribution):
super(StandardInputStep, self).__init__(distribution)
self._distributed_input = distribution.distribute_dataset(dataset_fn)
- if context.executing_eagerly():
- self._iterator = self._distributed_input.make_one_shot_iterator()
- else:
- # TODO(priyag): Expose initializer via some initializer property.
- self._iterator = self._distributed_input.make_initializable_iterator()
+ self._iterator = self._distributed_input.make_one_shot_iterator()
class StandardSingleLossStep(StandardInputStep):
diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py
index 1ff9b9ceec..f1ada49fa3 100644
--- a/tensorflow/contrib/distribute/python/step_fn_test.py
+++ b/tensorflow/contrib/distribute/python/step_fn_test.py
@@ -50,7 +50,6 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase):
run_step = single_loss_step
else:
with self.cached_session() as sess:
- sess.run(single_loss_step._iterator.initializer)
run_step = sess.make_callable(single_loss_step())
self.evaluate(variables.global_variables_initializer())
diff --git a/tensorflow/contrib/distribute/python/strategy_test_lib.py b/tensorflow/contrib/distribute/python/strategy_test_lib.py
index 5d498fb629..fd280f5754 100644
--- a/tensorflow/contrib/distribute/python/strategy_test_lib.py
+++ b/tensorflow/contrib/distribute/python/strategy_test_lib.py
@@ -115,7 +115,8 @@ class DistributionTestBase(test.TestCase):
with ops.control_dependencies([fetched]):
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(d.update(
+ v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
@@ -169,7 +170,8 @@ class DistributionTestBase(test.TestCase):
with ops.control_dependencies([fetched]):
g = d.reduce(
variable_scope.VariableAggregation.SUM, g, destinations=v)
- with ops.control_dependencies(d.unwrap(d.update(v, update, g))):
+ with ops.control_dependencies(d.update(
+ v, update, g, grouped=False)):
after_list.append(d.read_var(v))
return before_list, after_list
diff --git a/tensorflow/contrib/distribute/python/tpu_strategy.py b/tensorflow/contrib/distribute/python/tpu_strategy.py
index a6762e5e87..c3c7df3cd8 100644
--- a/tensorflow/contrib/distribute/python/tpu_strategy.py
+++ b/tensorflow/contrib/distribute/python/tpu_strategy.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.tpu.python.tpu import tpu
from tensorflow.contrib.tpu.python.tpu import tpu_system_metadata as tpu_system_metadata_lib
from tensorflow.contrib.tpu.python.tpu import training_loop
from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
@@ -37,9 +38,13 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.training import device_util
+from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest
+_TPU_INITIALIZE_SYSTEM_COLLECTION = "TPU_STRATEGY_INITIALIZE"
+
+
def get_tpu_system_metadata(tpu_cluster_resolver):
"""Retrieves TPU system metadata given a TPUClusterResolver."""
master = tpu_cluster_resolver.master()
@@ -56,6 +61,58 @@ def get_tpu_system_metadata(tpu_cluster_resolver):
return tpu_system_metadata
+# TODO(jhseu): Deduplicate with MirroredStrategy?
+def _create_tpu_mirrored_variable(devices, real_mirrored_creator, *args,
+ **kwargs): # pylint: disable=g-missing-docstring
+ # Figure out what collections this variable should be added to.
+ # We'll add the TPUMirroredVariable to those collections instead.
+ collections = kwargs.pop("collections", None)
+ if collections is None:
+ collections = [ops.GraphKeys.GLOBAL_VARIABLES]
+ kwargs["collections"] = []
+
+ # TODO(jhseu): Should we have different behavior for different
+ # synchronization settings?
+
+ # Get aggregation value
+ # TODO(jhseu): Support aggregation in a tower context.
+ aggregation = kwargs.pop("aggregation", vs.VariableAggregation.NONE)
+ if aggregation not in [
+ vs.VariableAggregation.NONE,
+ vs.VariableAggregation.SUM,
+ vs.VariableAggregation.MEAN,
+ vs.VariableAggregation.ONLY_FIRST_TOWER,
+ ]:
+ raise ValueError("Invalid variable aggregation mode: {} for variable: {}"
+ .format(aggregation, kwargs["name"]))
+
+ # Ignore user-specified caching device, not needed for mirrored variables.
+ kwargs.pop("caching_device", None)
+
+ # TODO(josh11b,apassos): It would be better if variable initialization
+ # was never recorded on the tape instead of having to do this manually
+ # here.
+ with tape.stop_recording():
+ index = real_mirrored_creator(devices, *args, **kwargs)
+ result = values.TPUMirroredVariable(index, index[devices[0]], aggregation)
+
+ if not context.executing_eagerly():
+ g = ops.get_default_graph()
+ # If "trainable" is True, next_creator() will add the member variables
+ # to the TRAINABLE_VARIABLES collection, so we manually remove
+ # them and replace with the MirroredVariable. We can't set
+ # "trainable" to False for next_creator() since that causes functions
+ # like implicit_gradients to skip those variables.
+ if kwargs.get("trainable", True):
+ collections.append(ops.GraphKeys.TRAINABLE_VARIABLES)
+ l = g.get_collection_ref(ops.GraphKeys.TRAINABLE_VARIABLES)
+ for v in index.values():
+ l.remove(v)
+ g.add_to_collections(collections, result)
+ return result
+
+
+# TODO(jhseu): Stop inheriting from OneDeviceStrategy.
class TPUStrategy(one_device_strategy.OneDeviceStrategy):
"""Experimental TPU distribution strategy implementation."""
@@ -82,6 +139,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# TODO(sourabhbajaj): Change this from num_cores to metadata_override
self._num_cores_override = num_cores
+ # TODO(jhseu): Switch to DeviceAssignment to support pods and model
+ # parallelism.
+ device_map = {d.name: i for i, d in enumerate(self._tpu_metadata.devices)
+ if "device:TPU:" in d.name}
+ self._device_index = values.PerDevice(device_map)
+ self._tpu_devices = sorted(device_map.keys())
+ # Only create variables for the number of towers we're running.
+ self._tpu_devices = self._tpu_devices[:self.num_towers]
+
# TODO(sourabhbajaj): Remove this once performance of running one step
# at a time is comparable to multiple steps.
self.steps_per_run = steps_per_run
@@ -231,6 +297,7 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# For outputs that have already been aggregated, take the first value
# from the list as each value should be the same. Else return the full
# list of values.
+ # TODO(josh11b): If aggregation is NONE, we should return a PerDevice value.
if aggregation is not variables_lib.VariableAggregation.NONE:
# TODO(priyag): Should this return the element or a list with 1 element
last_step_tensor_outputs_dict[name] = output[0]
@@ -239,6 +306,8 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
return ctx
def _call_for_each_tower(self, fn, *args, **kwargs):
+ # TODO(jhseu): Consider making it so call_for_each_tower implies that we're
+ # in a tpu.rewrite(), and update TPUMirroredVariable accordingly.
kwargs.pop('run_concurrently', None)
with one_device_strategy._OneDeviceTowerContext(self): # pylint: disable=protected-access
return fn(*args, **kwargs)
@@ -248,7 +317,15 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
# TODO(priyag): Add appopriate call here when eager is supported for TPUs.
raise NotImplementedError('Eager mode not supported in TPUStrategy.')
else:
- return [tpu.initialize_system()]
+ # TODO(jhseu): We need this hack because DistributionStrategies must be
+ # pickleable for copy.deepcopy(). Remove when initialize_system goes away.
+ graph = ops.get_default_graph()
+ tpu_init = graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION)
+ if tpu_init:
+ return tpu_init
+ graph.add_to_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION,
+ tpu.initialize_system())
+ return graph.get_collection(_TPU_INITIALIZE_SYSTEM_COLLECTION)
def finalize(self):
if context.executing_eagerly():
@@ -257,21 +334,53 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
else:
return [tpu.shutdown_system()]
+ def _get_devices_from(self, colocate_with=None):
+ # TODO(jhseu): Change this when we support model parallelism.
+ return self._tpu_devices
+
+ def _create_variable(self, next_creator, *args, **kwargs):
+ """Create a TPUMirroredVariable. See `DistributionStrategy.scope`."""
+ colocate_with = kwargs.pop("colocate_with", None)
+ devices = self._get_devices_from(colocate_with)
+
+ def _real_mirrored_creator(devices, *args, **kwargs): # pylint: disable=g-missing-docstring
+ index = {}
+ for i, d in enumerate(devices):
+ with ops.device(d):
+ if i > 0:
+ # Give replicas meaningful distinct names:
+ var0name = index[devices[0]].name.split(":")[0]
+ # We append a / to variable names created on towers with id > 0 to
+ # ensure that we ignore the name scope and instead use the given
+ # name as the absolute name of the variable.
+ kwargs["name"] = "%s/replica_%d/" % (var0name, i)
+ # Initialize replicas with the same value:
+ if context.executing_eagerly():
+ kwargs["initial_value"] = array_ops.identity(
+ index[devices[0]].value())
+ else:
+ def initial_value_fn(device=d):
+ with ops.device(device):
+ return array_ops.identity(index[devices[0]].initial_value)
+ kwargs["initial_value"] = initial_value_fn
+ with context.context().device_policy(context.DEVICE_PLACEMENT_SILENT):
+ v = next_creator(*args, **kwargs)
+ assert not isinstance(v, values.TPUMirroredVariable)
+ index[d] = v
+ return index
+
+ return _create_tpu_mirrored_variable(devices, _real_mirrored_creator, *args,
+ **kwargs)
+
def _reduce(self, aggregation, value, destinations):
- graph = ops.get_default_graph()
- cf_context = graph._get_control_flow_context() # pylint: disable=protected-access
- # If we're inside the ReplicateContext, reduction should be done using
- # CrossReplicaSum while outside we can directly use an add_n op.
- while cf_context:
- if isinstance(cf_context, tpu.TPUReplicateContext):
- if aggregation == vs.VariableAggregation.MEAN:
- # TODO(jhseu): Revisit once we support model-parallelism.
- value *= (1. / self.num_towers)
- elif aggregation != vs.VariableAggregation.SUM:
- raise NotImplementedError(
- 'Currently only support sum & mean in TPUStrategy.')
- return tpu_ops.cross_replica_sum(value)
- cf_context = cf_context.outer_context
+ if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
+ if aggregation == vs.VariableAggregation.MEAN:
+ # TODO(jhseu): Revisit once we support model-parallelism.
+ value *= (1. / self.num_towers)
+ elif aggregation != vs.VariableAggregation.SUM:
+ raise NotImplementedError(
+ "Currently only support sum & mean in TPUStrategy.")
+ return tpu_ops.cross_replica_sum(value)
# Validate that the destination is same as the host device
# Note we don't do this when in replicate context as the reduction is
@@ -290,10 +399,46 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
return output * (1. / len(value))
return output
- def _unwrap(self, value):
- if isinstance(value, list):
- return value
- return [value]
+ def _update(self, var, options, fn, *args, **kwargs):
+ assert isinstance(var, values.TPUMirroredVariable)
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
+
+ if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access
+ if should_group:
+ return fn(var, *args, **kwargs)
+ else:
+ return [fn(var, *args, **kwargs)]
+
+ # Otherwise, we revert to MirroredStrategy behavior and update each variable
+ # directly.
+ updates = {}
+ for d, v in var._index.items(): # pylint: disable=protected-access
+ name = "update_%d" % self._device_index.get(d)
+ with ops.device(d), distribute_lib.UpdateContext(d), ops.name_scope(name):
+ # If args and kwargs are not mirrored, the value is returned as is.
+ updates[d] = fn(v,
+ *values.select_device_mirrored(d, args),
+ **values.select_device_mirrored(d, kwargs))
+ return values.update_regroup(self, updates, should_group)
+
+ # TODO(josh11b): Need to implement _update_non_slot()!
+
+ def read_var(self, var):
+ assert isinstance(var, values.TPUMirroredVariable)
+ return var.read_value()
+
+ def _unwrap(self, val):
+ if isinstance(val, values.DistributedValues):
+ # Return in a deterministic order.
+ return [val.get(device=d) for d in sorted(val.devices)]
+ elif isinstance(val, list):
+ # TODO(josh11b): We need to remove this case; per device values should
+ # be represented using a PerDevice wrapper instead of a list with
+ # one entry per device.
+ return val
+ return [val]
+
@property
def num_towers(self):
@@ -323,6 +468,14 @@ class TPUStrategy(one_device_strategy.OneDeviceStrategy):
def should_save_summary(self):
return True
+ @property
+ def worker_devices(self):
+ return self._tpu_devices
+
+ @property
+ def parameter_devices(self):
+ return self._tpu_devices
+
def get_host_cpu_device(self, host_id):
if self._tpu_cluster_resolver.get_master() in ('', 'local'):
return '/replica:0/task:0/device:CPU:0'
diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py
index cce41e7717..18ceba42c2 100644
--- a/tensorflow/contrib/distribute/python/values.py
+++ b/tensorflow/contrib/distribute/python/values.py
@@ -22,17 +22,20 @@ from __future__ import division
from __future__ import print_function
import collections
+import contextlib
import weakref
import six
from tensorflow.contrib.distribute.python import input_ops
-from tensorflow.python.data.ops import multi_device_iterator_ops
+from tensorflow.contrib.distribute.python import prefetching_ops_v2
from tensorflow.python.eager import context
+from tensorflow.python.eager import tape
from tensorflow.python.framework import device as tf_device
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_resource_variable_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as variables_lib
@@ -363,18 +366,7 @@ class MirroredVariable(DistributedVariable, Mirrored,
# We are calling assign on the mirrored variable in cross tower context,
# use update to update the variable.
strategy = distribution_strategy_context.get_distribution_strategy()
- updates = strategy.update(self, f, *args, **kwargs)
- grouped = strategy.group(updates)
- if isinstance(updates, DistributedValues) and updates.is_tensor_like:
- # Make sure we run all updates. Without this, something like
- # session.run(mirrored_var.assign*(...)) may only update one tower.
- index = {}
- for d in updates.devices:
- with ops.device(d), ops.control_dependencies([grouped]):
- index[d] = array_ops.identity(updates.get(d))
- return Mirrored(index)
- else:
- return grouped
+ return strategy.update(self, f, *args, **kwargs)
else:
_assert_tower_context()
# We are calling an assign function on the mirrored variable in tower
@@ -453,6 +445,384 @@ ops.register_tensor_conversion_function(MirroredVariable,
_tensor_conversion_mirrored)
+def _enclosing_tpu_context():
+ # pylint: disable=protected-access
+ tpu_context = ops.get_default_graph()._get_control_flow_context()
+ # pylint: enable=protected-access
+ while tpu_context is not None and not isinstance(
+ tpu_context, control_flow_ops.XLAControlFlowContext):
+ tpu_context = tpu_context.outer_context
+ return tpu_context
+
+
+# TODO(jhseu): Deduplicate code. We copy code because we don't want to
+# inherit from DistributedDelegate. DistributedDelegate will not work in a
+# tpu.replicate() because it assumes that you're in a device context where you
+# can operate on a single version of the variable, but a tpu.replicate()
+# operates on all variables and is replicated during a rewrite pass.
+class TPUMirroredVariable(checkpointable.CheckpointableBase):
+ """Holds a map from device to TPU variables whose values are kept in sync."""
+
+ def __init__(self, index, primary_var, aggregation):
+ # Use a weakref to make it easy to map from the contained values
+ # to the container without introducing a reference cycle.
+ for v in six.itervalues(index):
+ v._mirrored_container = weakref.ref(self) # pylint: disable=protected-access
+ self._index = {device_util.canonicalize(key): value
+ for key, value in six.iteritems(index)}
+ self._primary_var = primary_var
+ self._common_name = self._primary_var.name.split(":")[0]
+ self._aggregation = aggregation
+ # Needed for GradientTape
+ self._trainable = self._primary_var.trainable
+
+ def _get(self, device=None):
+ """Returns the value for the current device or raises a ValueError."""
+ if device is None:
+ tower_context = distribution_strategy_context.get_tower_context()
+ if tower_context:
+ device = tower_context.device
+ else:
+ device = distribute_lib.get_update_device()
+ if device is None:
+ return self._get_cross_tower()
+ device = device_util.canonicalize(device)
+ try:
+ return self._index[device]
+ except KeyError as e:
+ six.raise_from(
+ ValueError("Device %s not found in %s (current device %s)" %
+ (device, self._index.keys(), device_util.current())), e)
+
+ # pylint: disable=multiple-statements
+ def __add__(self, o): return self.read_value() + o
+ def __radd__(self, o): return o + self.read_value()
+ def __sub__(self, o): return self.read_value() - o
+ def __rsub__(self, o): return o - self.read_value()
+ def __mul__(self, o): return self.read_value() * o
+ def __rmul__(self, o): return o * self.read_value()
+ def __truediv__(self, o): return self.read_value() / o
+ def __rtruediv__(self, o): return o / self.read_value()
+ def __floordiv__(self, o): return self.read_value() // o
+ def __rfloordiv__(self, o): return o // self.read_value()
+ def __mod__(self, o): return self.read_value() % o
+ def __rmod__(self, o): return o % self.read_value()
+ def __lt__(self, o): return self.read_value() < o
+ def __le__(self, o): return self.read_value() <= o
+ def __gt__(self, o): return self.read_value() > o
+ def __ge__(self, o): return self.read_value() >= o
+ def __and__(self, o): return self.read_value() & o
+ def __rand__(self, o): return o & self.read_value()
+ def __or__(self, o): return self.read_value() | o
+ def __ror__(self, o): return o | self.read_value()
+ def __xor__(self, o): return self.read_value() ^ o
+ def __rxor__(self, o): return o ^ self.read_value()
+ def __getitem__(self, o): return self.read_value()[o]
+ def __pow__(self, o, modulo=None): return pow(self.read_value(), o, modulo)
+ def __rpow__(self, o): return pow(o, self.read_value())
+ def __invert__(self): return ~self.read_value()
+ def __neg__(self): return -self.read_value()
+ def __abs__(self): return abs(self.read_value())
+
+ def __div__(self, o):
+ try:
+ return self.read_value().__div__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rdiv__(self, o):
+ try:
+ return self.read_value().__rdiv__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __matmul__(self, o):
+ try:
+ return self.read_value().__matmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ def __rmatmul__(self, o):
+ try:
+ return self.read_value().__rmatmul__(o)
+ except AttributeError:
+ # See https://docs.python.org/3/library/constants.html#NotImplemented
+ return NotImplemented
+
+ @property
+ def handle(self):
+ # If we're in a tpu.rewrite(), return the replicated handle.
+ tpu_context = _enclosing_tpu_context()
+ if tpu_context is not None:
+ return tpu_context.get_replicated_var_handle(
+ self._common_name, nest.flatten(self._index))
+
+ device = distribute_lib.get_update_device()
+ if device is None:
+ return self._primary_var.handle
+ device = device_util.canonicalize(device)
+ try:
+ return self._index[device].handle
+ except KeyError as e:
+ six.raise_from(
+ ValueError("Device %s not found in %s (current device %s)" %
+ (device, self._index.keys(), device_util.current())), e)
+
+ # The arguments to update() are automatically unwrapped so the update()
+ # function would normally see regular variables, not MirroredVariables.
+ # However, the update function can still operate on wrapped MirroredVariables
+ # through object members, captured arguments, etc. This is more likely in an
+ # update_non_slot() function (like OptimizerV2._finish), which can
+ # update several non-slot variables in one call.
+ def _assign_func(self, *args, **kwargs):
+ if distribution_strategy_context.get_distribution_strategy().__class__.__name__ != "TPUStrategy":
+ raise ValueError("You may only assign to a TPUMirroredVariable within a "
+ "TPUStrategy.")
+ f = kwargs.pop("f")
+ if distribution_strategy_context.get_cross_tower_context():
+ if _enclosing_tpu_context() is not None:
+ return distribution_strategy_context.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+
+ update_device = distribute_lib.get_update_device()
+ # We are calling update on the mirrored variable in cross tower context.
+ if update_device is not None:
+ # We are calling an assign function on the mirrored variable in cross
+ # tower context.
+ v = self._get(device=update_device)
+ return f(v, *args, **kwargs)
+
+ return distribution_strategy_context.get_distribution_strategy().update(
+ self, f, *args, **kwargs)
+ else:
+ _assert_tower_context()
+ # We are calling an assign function on the mirrored variable in tower
+ # context.
+ # We reduce the value we want to assign/add/sub. More details about how we
+ # handle the different use cases can be found in the _reduce method.
+ # We call the function on each of the mirrored variables with the reduced
+ # value.
+ if self._aggregation == vs.VariableAggregation.NONE:
+ raise ValueError("You must specify an aggregation method to update a "
+ "TPUMirroredVariable in Tower Context.")
+
+ def merge_fn(strategy, value, *other_args, **other_kwargs):
+ return strategy.update(
+ self, f,
+ strategy.reduce(
+ aggregation=self._aggregation, value=value, destinations=self),
+ *other_args, **other_kwargs)
+
+ return distribution_strategy_context.get_tower_context().merge_call(
+ merge_fn, *args, **kwargs)
+
+ @contextlib.contextmanager
+ def _handle_graph(self, handle):
+ # Note: might have an eager tensor but not be executing eagerly when
+ # building functions.
+ if (context.executing_eagerly() or isinstance(handle, ops.EagerTensor)
+ or ops.has_default_graph()):
+ yield
+ else:
+ with handle.graph.as_default():
+ yield
+
+ @property
+ def trainable(self):
+ return self._trainable
+
+ def _read_variable_op(self, parent_op=None):
+ if self.trainable:
+ tape.variable_accessed(self)
+ if parent_op is not None:
+ with ops.control_dependencies([parent_op]):
+ return gen_resource_variable_ops.read_variable_op(
+ self.handle, self.dtype)
+
+ return gen_resource_variable_ops.read_variable_op(
+ self.handle, self.dtype)
+
+ def read_value(self):
+ return self._read_variable_op()
+
+ def assign_sub(self, *args, **kwargs):
+ def assign_sub_fn(var, delta, **kw):
+ name = kw.pop("name", None)
+ read_value = kw.pop("read_value", True)
+ with self._handle_graph(var.handle):
+ op = gen_resource_variable_ops.assign_sub_variable_op(
+ var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op(parent_op=op)
+ return op
+
+ return self._assign_func(f=assign_sub_fn, *args, **kwargs)
+
+ def assign_add(self, *args, **kwargs):
+ def assign_add_fn(var, delta, **kw):
+ name = kw.pop("name", None)
+ read_value = kw.pop("read_value", True)
+ with self._handle_graph(var.handle):
+ op = gen_resource_variable_ops.assign_add_variable_op(
+ var.handle, ops.convert_to_tensor(delta, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op(parent_op=op)
+ return op
+
+ return self._assign_func(f=assign_add_fn, *args, **kwargs)
+
+ def assign(self, *args, **kwargs):
+ def assign_fn(var, value, **kw):
+ name = kw.pop("name", None)
+ read_value = kw.pop("read_value", True)
+ with self._handle_graph(var.handle):
+ op = gen_resource_variable_ops.assign_variable_op(
+ var.handle, ops.convert_to_tensor(value, dtype=self.dtype),
+ name=name)
+ if read_value:
+ return self._read_variable_op(parent_op=op)
+ return op
+
+ return self._assign_func(f=assign_fn, *args, **kwargs)
+
+ @property
+ def aggregation(self):
+ return self._aggregation
+
+ @property
+ def constraint(self):
+ return None
+
+ @property
+ def initializer(self):
+ return control_flow_ops.group(
+ [v.initializer for v in nest.flatten(self._index)])
+
+ @property
+ def graph(self):
+ return self._primary_var.graph
+
+ @property
+ def _shared_name(self):
+ return self._common_name
+
+ @property
+ def _unique_id(self):
+ return self._primary_var._unique_id # pylint: disable=protected-access
+
+ @property
+ def name(self):
+ return self._primary_var.name
+
+ @property
+ def dtype(self):
+ return self._primary_var.dtype
+
+ @property
+ def shape(self):
+ return self._primary_var.shape
+
+ def get_shape(self):
+ return self._primary_var.get_shape()
+
+ def to_proto(self, export_scope=None):
+ return self._primary_var.to_proto(export_scope=export_scope)
+
+ def _get_cross_tower(self):
+ device = device_util.canonicalize(device_util.current())
+ if device in self._index:
+ return self._index[device]
+ return self._primary_var
+
+ def _as_graph_element(self):
+ # pylint: disable=protected-access
+ if distribution_strategy_context.get_cross_tower_context():
+ return self._primary_var._as_graph_element()
+ return self._read_variable_op()
+
+ def _gather_saveables_for_checkpoint(self):
+ """Overrides CheckpointableBase method.
+
+ This allows both name-based and object-based save and restore of
+ MirroredVariables.
+
+ Returns:
+ A dictionary mapping attribute names to `SaveableObject` factories.
+ """
+ def _saveable_factory(name=self._common_name):
+ return _MirroredSaveable(self, self._primary_var, name)
+ return {checkpointable.VARIABLE_VALUE_KEY: _saveable_factory}
+
+ def _should_act_as_resource_variable(self):
+ """Pass resource_variable_ops.is_resource_variable check."""
+ pass
+
+ # Needed to pass ResourceVariable checks.
+ @property
+ def op(self):
+ return self._primary_var.op
+
+ @property
+ def _in_graph_mode(self):
+ return self._primary_var._in_graph_mode # pylint: disable=protected-access
+
+ def _dense_var_to_tensor(self, dtype=None, name=None, as_ref=False):
+ """Converts a variable to a tensor."""
+ # pylint: disable=protected-access
+ if _enclosing_tpu_context() is None:
+ return self._get()._dense_var_to_tensor(dtype, name, as_ref)
+ # pylint: enable=protected-access
+ if dtype is not None and dtype != self.dtype:
+ raise NotImplementedError
+ if as_ref:
+ return self.handle
+ else:
+ return self.read_value()
+
+ def is_initialized(self, name=None):
+ """Identifies if all the component variables are initialized.
+
+ Args:
+ name: Name of the final `logical_and` op.
+
+ Returns:
+ The op that evaluates to True or False depending on if all the
+ component variables are initialized.
+ """
+ # TODO(jhseu): Do we need TPU context implementation?
+
+ # We have to cast the self._index.values() to a `list` because when we
+ # use `model_to_estimator` to run tf.keras models, self._index.values() is
+ # of type `dict_values` and not `list`.
+ values_list = nest.flatten(self._index)
+ result = values_list[0].is_initialized()
+ # We iterate through the list of values except the last one to allow us to
+ # name the final `logical_and` op the same name that is passed by the user
+ # to the `is_initialized` op. For distributed variables, the
+ # `is_initialized` op is a `logical_and` op.
+ for v in values_list[1:-1]:
+ result = math_ops.logical_and(result, v.is_initialized())
+ result = math_ops.logical_and(result, values_list[-1].is_initialized(),
+ name=name)
+ return result
+
+
+# Register a conversion function which reads the value of the variable,
+# allowing instances of the class to be used as tensors.
+def _tensor_conversion_tpu_mirrored(var, dtype=None, name=None, as_ref=False):
+ return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access
+
+
+ops.register_tensor_conversion_function(TPUMirroredVariable,
+ _tensor_conversion_tpu_mirrored)
+ops.register_dense_tensor_like_type(TPUMirroredVariable)
+
+
class _TowerLocalSaveable(saver.BaseSaverBuilder.SaveableObject):
"""Class for defining how to restore a TowerLocalVariable."""
@@ -668,6 +1038,29 @@ def select_device_mirrored(device, structured):
return nest.map_structure(_get_mirrored, structured)
+def update_regroup(strategy, updates, should_group):
+ """Regroup for an update, with dependencies to ensure all updates execute."""
+ regrouped = regroup(updates, Mirrored)
+ if not should_group:
+ return nest.map_structure(strategy.unwrap, regrouped)
+ grouped_flat = []
+ for u in nest.flatten(regrouped):
+ if isinstance(u, DistributedValues):
+ g = strategy.group(u)
+ if u.is_tensor_like:
+ # Make sure we run all updates. Without this, something like
+ # session.run(strategy.update(...)) may only update one tower.
+ index = {}
+ for d in u.devices:
+ with ops.device(d), ops.control_dependencies([g]):
+ index[d] = array_ops.identity(u.get(d))
+ g = Mirrored(index)
+ else:
+ g = u
+ grouped_flat.append(g)
+ return nest.pack_sequence_as(regrouped, grouped_flat)
+
+
class PerDeviceDataIterator(object):
"""An iterator (like `tf.data.Iterator`) into a `PerDeviceDataset`."""
@@ -683,7 +1076,7 @@ class PerDeviceDataIterator(object):
def get_next(self, name=None):
"""Scatter the input across devices."""
if self._prefetch_on_device:
- data_list = self._iterator.get_next()
+ data_list = self._iterator.get_next(name=name)
index = dict(zip(self._devices, data_list))
else:
batch = self._iterator.get_next(name=name)
@@ -703,24 +1096,21 @@ class PerDeviceDataIterator(object):
class PerDeviceDataset(object):
"""Like `tf.data.Dataset` split devices, producing `PerDevice` data."""
- def __init__(
- self,
- dataset,
- devices,
- prefetch_on_device=None,
- ):
+ def __init__(self, dataset, devices, prefetch_on_device=None):
self._devices = devices
# Default to using prefetching in graph mode, unless specified.
- # TODO(rohanj): Enable prefetching in eager mode.
+ # TODO(priyag): Enable prefetching in eager mode.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently")
- self._dataset = dataset
- if not self._prefetch_on_device:
+ if self._prefetch_on_device:
+ self._dataset = dataset.apply(
+ prefetching_ops_v2.prefetch_to_devices(self._devices))
+ else:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
@@ -728,33 +1118,15 @@ class PerDeviceDataset(object):
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerDeviceDataset."""
- # Graph mode prefetching with one shot iterator is disabled.
- if not context.executing_eagerly():
- raise ValueError("Cannot create a one shot iterator. Please use "
- "`make_initializable_iterator()` instead.")
- # Eager mode prefetching would error out in constructor. Only remaining
- # cases are non-prefetching eager / graph mode. We delegate to
- # PerDeviceDataIterator to handle them.
dataset_iterator = self._dataset.make_one_shot_iterator()
- return PerDeviceDataIterator(
- dataset_iterator, self._devices, prefetch_on_device=False)
+ return PerDeviceDataIterator(dataset_iterator, self._devices,
+ self._prefetch_on_device)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerDeviceDataset."""
- # Eager mode generates already initialized iterators. Hence we cannot create
- # an initializable iterator.
- if context.executing_eagerly():
- raise ValueError("Cannot create initializable iterator in Eager mode. "
- "Please use `make_one_shot_iterator` instead.")
- if self._prefetch_on_device:
- dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
- self._dataset, self._devices)
- else:
- dataset_iterator = self._dataset.make_initializable_iterator()
- return PerDeviceDataIterator(
- dataset_iterator,
- self._devices,
- prefetch_on_device=self._prefetch_on_device)
+ dataset_iterator = self._dataset.make_initializable_iterator()
+ return PerDeviceDataIterator(dataset_iterator, self._devices,
+ self._prefetch_on_device)
class MultiWorkerDataIterator(object):
@@ -814,7 +1186,8 @@ class MultiWorkerDataset(object):
eager mode.
"""
- def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None):
+ def __init__(self, dataset_fn, worker_device_map, prefetch_on_device=None,
+ auto_shard=False):
"""Initialize the MultiWorkerDataset object.
Args:
@@ -822,6 +1195,7 @@ class MultiWorkerDataset(object):
worker_device_map: a dict mapping from each worker to a list of devices
that belong to this worker.
prefetch_on_device: whether to prefetch to devices.
+ auto_shard: whether to auto-shard the dataset.
"""
self._worker_device_map = worker_device_map
self._datasets = {}
@@ -831,12 +1205,11 @@ class MultiWorkerDataset(object):
six.iteritems(worker_device_map)):
with ops.device(worker):
worker_input = dataset_fn()
- worker_input = input_ops.auto_shard_dataset(
- worker_input, len(worker_device_map), i)
+ if auto_shard:
+ worker_input = input_ops.auto_shard_dataset(
+ worker_input, len(worker_device_map), i)
self._datasets[worker] = PerDeviceDataset(
- worker_input,
- worker_devices,
- prefetch_on_device=prefetch_on_device)
+ worker_input, worker_devices, prefetch_on_device=prefetch_on_device)
def make_one_shot_iterator(self):
iterators = {}
diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py
index 002d61f46e..121d2fbb3f 100644
--- a/tensorflow/contrib/distribute/python/values_test.py
+++ b/tensorflow/contrib/distribute/python/values_test.py
@@ -349,11 +349,7 @@ class PerDeviceDatasetTest(test.TestCase):
def _test_iterator_no_prefetch(self, devices, dataset, expected_values):
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=False)
- if context.executing_eagerly():
- iterator = per_device_dataset.make_one_shot_iterator()
- else:
- iterator = per_device_dataset.make_initializable_iterator()
- self.evaluate([iterator.initializer])
+ iterator = per_device_dataset.make_one_shot_iterator()
for expected_value in expected_values:
next_element = iterator.get_next()
@@ -370,14 +366,21 @@ class PerDeviceDatasetTest(test.TestCase):
if not context.executing_eagerly():
per_device_dataset = values.PerDeviceDataset(
dataset, devices, prefetch_on_device=True)
- iterator = per_device_dataset.make_initializable_iterator()
- self.evaluate([iterator.initializer])
+ iterator = per_device_dataset.make_one_shot_iterator()
+ # With prefetching, we cannot guarantee which input ends up on which
+ # device, so we verify that the complete set seen on all devices is
+ # correct, and equal numbers are distributed to each device.
+ combined_actual = []
+ combined_expected = []
for expected_value in expected_values:
next_element = iterator.get_next()
- computed_value = self.evaluate(
- [values.select_device(d, next_element) for d in devices])
- self.assertEqual(expected_value, computed_value)
+ combined_actual.extend(
+ self.evaluate(
+ [values.select_device(d, next_element) for d in devices]))
+ combined_expected.extend(expected_value)
+
+ self.assertEqual(set(combined_expected), set(combined_actual))
with self.assertRaises(errors.OutOfRangeError):
next_element = iterator.get_next()
@@ -638,7 +641,7 @@ class MirroredVariableTest(test.TestCase):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
- with self.test_session() as sess:
+ with self.cached_session(config=self.config) as sess:
v, devices, mirrored = _make_mirrored()
# Overwrite the initial values.
@@ -741,7 +744,7 @@ class MirroredVariableTest(test.TestCase):
if context.num_gpus() < 1 or context.executing_eagerly():
self.skipTest("A GPU is not available for this test or it's eager mode.")
- with self.test_session(
+ with self.session(
graph=ops.Graph()) as sess, mirrored_strategy.MirroredStrategy(
["/device:GPU:0"]).scope():
with ops.device("/device:GPU:0"):
@@ -824,7 +827,7 @@ class TowerLocalVariableTest(test.TestCase):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
- with self.test_session() as sess:
+ with self.cached_session(config=self.config) as sess:
v, tower_local = _make_tower_local(variable_scope.VariableAggregation.SUM)
# Overwrite the initial values.
@@ -847,7 +850,7 @@ class TowerLocalVariableTest(test.TestCase):
if context.num_gpus() < 1 and context.executing_eagerly():
self.skipTest("A GPU is not available for this test in eager mode.")
- with self.test_session() as sess:
+ with self.cached_session(config=self.config) as sess:
v, tower_local = _make_tower_local(
variable_scope.VariableAggregation.MEAN)
diff --git a/tensorflow/contrib/eager/python/datasets.py b/tensorflow/contrib/eager/python/datasets.py
index 135095a979..3aed121233 100644
--- a/tensorflow/contrib/eager/python/datasets.py
+++ b/tensorflow/contrib/eager/python/datasets.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import prefetching_ops
+from tensorflow.python.data.experimental.ops import prefetching_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
@@ -54,7 +54,7 @@ class Iterator(iterator_ops.EagerIterator):
"""
if isinstance(dataset, prefetching_ops._PrefetchToDeviceDataset): # pylint: disable=protected-access
raise TypeError(
- "`tf.contrib.data.prefetch_to_device()` is not compatible with "
+ "`tf.data.experimental.prefetch_to_device()` is not compatible with "
"`tf.contrib.eager.Iterator`. Use `for ... in dataset:` to iterate "
"over the dataset instead.")
diff --git a/tensorflow/contrib/eager/python/datasets_test.py b/tensorflow/contrib/eager/python/datasets_test.py
index a753d77580..6a508fc6ba 100644
--- a/tensorflow/contrib/eager/python/datasets_test.py
+++ b/tensorflow/contrib/eager/python/datasets_test.py
@@ -24,11 +24,11 @@ import time
import numpy as np
from tensorflow.contrib import lookup
-from tensorflow.contrib.data.python.ops import prefetching_ops
-from tensorflow.contrib.data.python.ops import threadpool
-from tensorflow.contrib.data.python.ops import unique
from tensorflow.contrib.eager.python import datasets
from tensorflow.python.data import Dataset
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.data.experimental.ops import threadpool
+from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.eager import test
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py
index 34a9984b0e..d85188de03 100644
--- a/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py
+++ b/tensorflow/contrib/eager/python/examples/revnet/imagenet_input.py
@@ -169,11 +169,11 @@ class ImageNetInput(object):
# Read the data from disk in parallel
dataset = dataset.apply(
- tf.contrib.data.parallel_interleave(
+ tf.data.experimental.parallel_interleave(
fetch_dataset, cycle_length=self.num_parallel_calls, sloppy=True))
if self.cache:
dataset = dataset.cache().apply(
- tf.contrib.data.shuffle_and_repeat(1024 * 16))
+ tf.data.experimental.shuffle_and_repeat(1024 * 16))
else:
dataset = dataset.shuffle(1024)
@@ -188,9 +188,11 @@ class ImageNetInput(object):
# batch size. As long as this validation is done with consistent batch size,
# exactly the same images will be used.
dataset = dataset.apply(
- tf.contrib.data.map_and_batch(
- self.dataset_parser, batch_size=batch_size,
- num_parallel_batches=self.num_cores, drop_remainder=True))
+ tf.data.experimental.map_and_batch(
+ self.dataset_parser,
+ batch_size=batch_size,
+ num_parallel_batches=self.num_cores,
+ drop_remainder=True))
# Transpose for performance on TPU
if self.transpose_input:
diff --git a/tensorflow/contrib/eager/python/remote_test.py b/tensorflow/contrib/eager/python/remote_test.py
index ba6fe9701d..7aa4b598b8 100644
--- a/tensorflow/contrib/eager/python/remote_test.py
+++ b/tensorflow/contrib/eager/python/remote_test.py
@@ -47,8 +47,9 @@ def run_sync_and_async(f):
@functools.wraps(f)
def decorator(self, *args, **kwargs):
- with context.execution_mode(context.ASYNC):
- f(self, *args, **kwargs)
+ # TODO(b/117110239): Re-enable.
+ # with context.execution_mode(context.ASYNC):
+ # f(self, *args, **kwargs)
with context.execution_mode(context.SYNC):
f(self, *args, **kwargs)
diff --git a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
index 5faf0aacfe..6ca7aaf989 100644
--- a/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
+++ b/tensorflow/contrib/estimator/python/estimator/dnn_with_layer_annotations.py
@@ -151,7 +151,7 @@ def make_input_layer_with_layer_annotations(original_input_layer):
# spec and looking at the keys.
spec = feature_column_lib.make_parse_example_spec(feature_columns)
for key in spec.keys():
- tensor = ops.convert_to_tensor(features[key])
+ tensor = ops.convert_to_tensor_or_indexed_slices(features[key])
ops.add_to_collection(
LayerAnnotationsCollectionNames.keys(
LayerAnnotationsCollectionNames.UNPROCESSED_FEATURES), key)
diff --git a/tensorflow/contrib/estimator/python/estimator/rnn_test.py b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
index 1aebed348d..89506ee661 100644
--- a/tensorflow/contrib/estimator/python/estimator/rnn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/rnn_test.py
@@ -25,12 +25,12 @@ import tempfile
import numpy as np
import six
-from tensorflow.contrib.data.python.ops import readers
from tensorflow.contrib.estimator.python.estimator import head as head_lib
from tensorflow.contrib.estimator.python.estimator import rnn
from tensorflow.contrib.feature_column.python.feature_column import sequence_feature_column as seq_fc
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.estimator import model_fn
from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.canned import parsing_utils
diff --git a/tensorflow/contrib/factorization/python/ops/gmm_ops.py b/tensorflow/contrib/factorization/python/ops/gmm_ops.py
index e076631bc1..d365ad1117 100644
--- a/tensorflow/contrib/factorization/python/ops/gmm_ops.py
+++ b/tensorflow/contrib/factorization/python/ops/gmm_ops.py
@@ -154,10 +154,10 @@ class GmmAlgorithm(object):
def _create_variables(self):
"""Initializes GMM algorithm."""
init_value = array_ops.constant([], dtype=dtypes.float32)
- self._means = variables.Variable(init_value,
- name=self.CLUSTERS_VARIABLE,
- validate_shape=False)
- self._covs = variables.Variable(
+ self._means = variables.VariableV1(init_value,
+ name=self.CLUSTERS_VARIABLE,
+ validate_shape=False)
+ self._covs = variables.VariableV1(
init_value, name=self.CLUSTERS_COVS_VARIABLE, validate_shape=False)
# Mixture weights, representing the probability that a randomly
# selected unobservable data (in EM terms) was generated by component k.
@@ -165,9 +165,9 @@ class GmmAlgorithm(object):
array_ops.tile([1.0 / self._num_classes], [self._num_classes]),
name=self.CLUSTERS_WEIGHT,
validate_shape=False)
- self._cluster_centers_initialized = variables.Variable(False,
- dtype=dtypes.bool,
- name='initialized')
+ self._cluster_centers_initialized = variables.VariableV1(False,
+ dtype=dtypes.bool,
+ name='initialized')
def _initialize_variables(self, data, initial_means=None):
"""Initializes variables.
diff --git a/tensorflow/contrib/factorization/python/ops/wals_test.py b/tensorflow/contrib/factorization/python/ops/wals_test.py
index 9bdbd05015..75d577f429 100644
--- a/tensorflow/contrib/factorization/python/ops/wals_test.py
+++ b/tensorflow/contrib/factorization/python/ops/wals_test.py
@@ -420,13 +420,13 @@ class WALSMatrixFactorizationUnsupportedTest(test.TestCase):
class SweepHookTest(test.TestCase):
def test_sweeps(self):
- is_row_sweep_var = variables.Variable(True)
- is_sweep_done_var = variables.Variable(False)
- init_done = variables.Variable(False)
- row_prep_done = variables.Variable(False)
- col_prep_done = variables.Variable(False)
- row_train_done = variables.Variable(False)
- col_train_done = variables.Variable(False)
+ is_row_sweep_var = variables.VariableV1(True)
+ is_sweep_done_var = variables.VariableV1(False)
+ init_done = variables.VariableV1(False)
+ row_prep_done = variables.VariableV1(False)
+ col_prep_done = variables.VariableV1(False)
+ row_train_done = variables.VariableV1(False)
+ col_train_done = variables.VariableV1(False)
init_op = state_ops.assign(init_done, True)
row_prep_op = state_ops.assign(row_prep_done, True)
@@ -486,7 +486,7 @@ class StopAtSweepHookTest(test.TestCase):
def test_stop(self):
hook = wals_lib._StopAtSweepHook(last_sweep=10)
- completed_sweeps = variables.Variable(
+ completed_sweeps = variables.VariableV1(
8, name=wals_lib.WALSMatrixFactorization.COMPLETED_SWEEPS)
train_op = state_ops.assign_add(completed_sweeps, 1)
hook.begin()
diff --git a/tensorflow/contrib/fused_conv/BUILD b/tensorflow/contrib/fused_conv/BUILD
index 490da9b33b..57a5bfbf43 100644
--- a/tensorflow/contrib/fused_conv/BUILD
+++ b/tensorflow/contrib/fused_conv/BUILD
@@ -145,6 +145,7 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
],
tags = [
+ "manual", # TODO(b/117128481): re-enable after fixing OSS build
"no_pip",
"requires-gpu-sm70",
],
@@ -169,6 +170,7 @@ cuda_py_test(
],
main = "python/ops/fused_conv2d_bias_activation_benchmark.py",
tags = [
+ "manual", # TODO(b/117128481): re-enable after fixing OSS build
"requires-gpu-sm70",
],
)
diff --git a/tensorflow/contrib/ignite/BUILD b/tensorflow/contrib/ignite/BUILD
new file mode 100644
index 0000000000..9393b702d1
--- /dev/null
+++ b/tensorflow/contrib/ignite/BUILD
@@ -0,0 +1,139 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "if_not_windows",
+ "if_windows",
+ "tf_custom_op_library",
+ "tf_custom_op_py_library",
+ "tf_gen_op_libs",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+ "tf_py_test",
+)
+
+py_library(
+ name = "ignite",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":dataset_ops",
+ ],
+)
+
+tf_custom_op_library(
+ name = "_dataset_ops.so",
+ srcs = ["ops/dataset_ops.cc"],
+ deps = [":dataset_kernels"],
+)
+
+tf_gen_op_libs(
+ op_lib_names = ["dataset_ops"],
+)
+
+cc_library(
+ name = "dataset_kernels",
+ srcs = [
+ "kernels/ignite_dataset_ops.cc",
+ "kernels/ignite_client.h",
+ "kernels/ignite_byte_swapper.h",
+ "kernels/ignite_plain_client.h",
+ "kernels/ignite_ssl_wrapper.h",
+ "kernels/ignite_ssl_wrapper.cc",
+ "kernels/ignite_binary_object_parser.h",
+ "kernels/ignite_binary_object_parser.cc",
+ "kernels/ignite_dataset.h",
+ "kernels/ignite_dataset.cc",
+ "kernels/ignite_dataset_iterator.h",
+ "kernels/ignite_dataset_iterator.cc",
+ ] + if_not_windows([
+ "kernels/ignite_plain_client_unix.cc",
+ ]) + if_windows([
+ "kernels/ignite_plain_client_windows.cc",
+ ]),
+ copts = if_windows([
+ "-DWIN32_LEAN_AND_MEAN",
+ ]),
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//third_party/eigen3",
+ "@boringssl//:ssl",
+ "@protobuf_archive//:protobuf_headers",
+ ],
+ alwayslink = 1,
+)
+
+py_library(
+ name = "dataset_ops",
+ srcs = [
+ "python/ops/ignite_dataset_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":ignite_op_loader",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+tf_gen_op_wrapper_py(
+ name = "gen_dataset_ops",
+ out = "python/ops/gen_dataset_ops.py",
+ deps = ["//tensorflow/contrib/ignite:dataset_ops_op_lib"],
+)
+
+tf_kernel_library(
+ name = "dataset_ops_kernels",
+ deps = [
+ ":dataset_kernels",
+ "//tensorflow/core:framework",
+ ],
+ alwayslink = 1,
+)
+
+tf_custom_op_py_library(
+ name = "ignite_op_loader",
+ srcs = ["python/ops/ignite_op_loader.py"],
+ dso = ["//tensorflow/contrib/ignite:_dataset_ops.so"],
+ kernels = [
+ ":dataset_ops_kernels",
+ "//tensorflow/contrib/ignite:dataset_ops_op_lib",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":gen_dataset_ops",
+ "//tensorflow/contrib/util:util_py",
+ "//tensorflow/python:platform",
+ ],
+)
+
+# The Apache Ignite servers have to setup before the test and tear down
+# after the test manually. The docker engine has to be installed.
+#
+# To setup Apache Ignite servers:
+# $ bash ./python/tests/start_ignite.sh
+#
+# To tear down Apache Ignite servers:
+# $ bash ./python/tests/stop_ignite.sh
+tf_py_test(
+ name = "ignite_dataset_test",
+ srcs = ["python/tests/ignite_dataset_test.py"],
+ additional_deps = [
+ ":ignite",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:platform_test",
+ ],
+ tags = [
+ "manual",
+ "no_windows",
+ "notap",
+ ],
+)
diff --git a/tensorflow/contrib/ignite/README.md b/tensorflow/contrib/ignite/README.md
new file mode 100644
index 0000000000..55c89d2799
--- /dev/null
+++ b/tensorflow/contrib/ignite/README.md
@@ -0,0 +1,167 @@
+# Ignite Dataset
+
+- [Overview](#overview)
+- [Features](#features)
+ * [Distributed In-Memory Datasource](#distributed-in-memory-datasource)
+ * [Structured Objects](#structured-objects)
+ * [Distributed Training](#distributed-training)
+ * [SSL Connection](#ssl-connection)
+ * [Windows Support](#windows-support)
+- [Try it out](#try-it-out)
+- [Limitations](#limitations)
+
+## Overview
+
+[Apache Ignite](https://ignite.apache.org/) is a memory-centric distributed database, caching, and processing platform for
+transactional, analytical, and streaming workloads, delivering in-memory speeds at petabyte scale. This contrib package contains an integration between Apache Ignite and TensorFlow. The integration is based on [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) from TensorFlow side and [Binary Client Protocol](https://apacheignite.readme.io/v2.6/docs/binary-client-protocol) from Apache Ignite side. It allows to use Apache Ignite as a data source for neural network training, inference and all other computations supported by TensorFlow.
+
+## Features
+
+Ignite Dataset provides features that that you can use in a wide range of cases. The most important and interesting features are described below.
+
+### Distributed In-Memory Datasource
+[Apache Ignite](https://ignite.apache.org/) is a distributed in-memory database, caching, and processing platform that provides fast data access. It allows you to avoid limitations of hard drive and store and operate with as much data as you need in distributed cluster. You can utilize
+these benefits of Apache Ignite by using Ignite Dataset. Moreover, Ignite Dataset can be used for the following use-cases:
+- If you have a **gigabyte** of data you can keep it on a single machine on a hard drive, but you will face with hard drive speed limitations. At the same time, you can store your data in Apache Ignite on the same machine and use it as a datasource for TensorFlow and thus avoid these limitations.
+- If you have a **terabyte** of data you probably still can keep it on a single machine on a hard drive, but you will face with hard drive speed limitations again. At the same time, you can store your data in Apache Ignite distributed in-memory cluster and use it as a datasource for TensorFlow and thus avoid these limitations.
+- If you have a **petabyte** of data you can't keep it on a single machine. At the same time, you can store your data in Apache Ignite distributed in-memory cluster and use it as a datasource for TensorFlow.
+
+Note that Apache Ignite is not just a step of ETL pipeline between a database or a data warehouse and TensorFlow. Apache Ignite is a high-grade database itself. By choosing Apache Ignite and TensorFlow you are getting everything you need to work with operational or historical data and, at the same time, an ability to use this data for neural network training and inference.
+
+```bash
+$ apache-ignite-fabric/bin/ignite.sh
+$ apache-ignite-fabric/bin/sqlline.sh -u "jdbc:ignite:thin://localhost:10800/"
+
+jdbc:ignite:thin://localhost/> CREATE TABLE KITTEN_CACHE (ID LONG PRIMARY KEY, NAME VARCHAR);
+jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (1, 'WARM KITTY');
+jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (2, 'SOFT KITTY');
+jdbc:ignite:thin://localhost/> INSERT INTO KITTEN_CACHE VALUES (3, 'LITTLE BALL OF FUR');
+```
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="SQL_PUBLIC_KITTEN_CACHE")
+>>> iterator = dataset.make_one_shot_iterator()
+>>> next_obj = iterator.get_next()
+>>>
+>>> with tf.Session() as sess:
+>>> for _ in range(3):
+>>> print(sess.run(next_obj))
+
+{'key': 1, 'val': {'NAME': b'WARM KITTY'}}
+{'key': 2, 'val': {'NAME': b'SOFT KITTY'}}
+{'key': 3, 'val': {'NAME': b'LITTLE BALL OF FUR'}}
+```
+
+### Structured Objects
+[Apache Ignite](https://ignite.apache.org/) allows to store any type of objects. These objects can have any hierarchy. Ignite Dataset provides an ability to work with such objects.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="IMAGES")
+>>> iterator = dataset.make_one_shot_iterator()
+>>> next_obj = iterator.get_next()
+>>>
+>>> with tf.Session() as sess:
+>>> print(sess.run(next_obj))
+
+{
+ 'key': 'kitten.png',
+ 'val': {
+ 'metadata': {
+ 'file_name': b'kitten.png',
+ 'label': b'little ball of fur',
+ width: 800,
+ height: 600
+ },
+ 'pixels': [0, 0, 0, 0, ..., 0]
+ }
+}
+```
+ Neural network training and other computations require transformations that can be done as part of [tf.data](https://www.tensorflow.org/api_docs/python/tf/data) pipeline if you use Ignite Dataset.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="IMAGES").map(lambda obj: obj['val']['pixels'])
+>>> iterator = dataset.make_one_shot_iterator()
+>>> next_obj = iterator.get_next()
+>>>
+>>> with tf.Session() as sess:
+>>> print(sess.run(next_obj))
+
+[0, 0, 0, 0, ..., 0]
+```
+
+### Distributed Training
+
+TensorFlow is a machine learning framework that [natively supports](https://www.tensorflow.org/deploy/distributed) distributed neural network training, inference and other computations. The main idea behind the distributed neural network training is the ability to calculate gradients of loss functions (squares of the errors) on every partition of data (in terms of horizontal partitioning) and then sum them to get loss function gradient of the whole dataset.
+
+<a href="https://www.codecogs.com/eqnedit.php?latex=\nabla[\sum_1^n(y&space;-&space;\hat{y})^2]&space;=&space;\nabla[\sum_1^{n_1}(y&space;-&space;\hat{y})^2]&space;&plus;&space;\nabla[\sum_{n_1}^{n_2}(y&space;-&space;\hat{y})^2]&space;&plus;&space;...&space;&plus;&space;\nabla[\sum_{n_{k-1}}^n(y&space;-&space;\hat{y})^2]" target="_blank"><img src="https://latex.codecogs.com/gif.latex?\nabla[\sum_1^n(y&space;-&space;\hat{y})^2]&space;=&space;\nabla[\sum_1^{n_1}(y&space;-&space;\hat{y})^2]&space;&plus;&space;\nabla[\sum_{n_1}^{n_2}(y&space;-&space;\hat{y})^2]&space;&plus;&space;...&space;&plus;&space;\nabla[\sum_{n_{k-1}}^n(y&space;-&space;\hat{y})^2]" title="\nabla[\sum_1^n(y - \hat{y})^2] = \nabla[\sum_1^{n_1}(y - \hat{y})^2] + \nabla[\sum_{n_1}^{n_2}(y - \hat{y})^2] + ... + \nabla[\sum_{n_{k-1}}^n(y - \hat{y})^2]" /></a>
+
+Using this ability we can calculate gradients on the nodes the data is stored on, reduce them and then finally update model parameters. It allows to avoid data transfers between nodes and thus to avoid network bottlenecks.
+
+Apache Ignite uses horizontal partitioning to store data in distributed cluster. When we create Apache Ignite cache (or table in terms of SQL), we can specify the number of partitions the data will be partitioned on. For example, if an Apache Ignite cluster consists of 10 machines and we create cache with 10 partitions, then every machine will maintain approximately one data partition.
+
+Ignite Dataset allows using these two aspects of distributed neural network training (using TensorFlow) and Apache Ignite partitioning. Ignite Dataset is a computation graph operation that can be performed on a remote worker. The remote worker can override Ignite Dataset parameters (such as `host`, `port` or `part`) by setting correstondent environment variables for worker process (such as `IGNITE_DATASET_HOST`, `IGNITE_DATASET_PORT` or `IGNITE_DATASET_PART`). Using this overriding approach, we can assign a specific partition to every worker so that one worker handles one partition and, at the same time, transparently work with single dataset.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset("IMAGES")
+>>>
+>>> # Compute gradients locally on every worker node.
+>>> gradients = []
+>>> for i in range(5):
+>>> with tf.device("/job:WORKER/task:%d" % i):
+>>> device_iterator = dataset.make_one_shot_iterator()
+>>> device_next_obj = device_iterator.get_next()
+>>> gradient = compute_gradient(device_next_obj)
+>>> gradients.append(gradient)
+>>>
+>>> # Aggregate them on master node.
+>>> result_gradient = tf.reduce_sum(gradients)
+>>>
+>>> with tf.Session("grpc://localhost:10000") as sess:
+>>> print(sess.run(result_gradient))
+```
+
+High-level TensorFlow API for [distributed training](https://www.tensorflow.org/api_docs/python/tf/contrib/distribute/DistributionStrategy) is supported as well.
+
+### SSL Connection
+
+Apache Ignite allows to protect data transfer channels by [SSL](https://en.wikipedia.org/wiki/Transport_Layer_Security) and authentification. Ignite Dataset supports both SSL connection with and without authntication. For more information, please refer to the [Apache Ignite SSL/TLS](https://apacheignite.readme.io/docs/ssltls) documentation.
+
+```python
+>>> import tensorflow as tf
+>>> from tensorflow.contrib.ignite import IgniteDataset
+>>>
+>>> dataset = IgniteDataset(cache_name="IMAGES", certfile="client.pem", cert_password="password", username="ignite", password="ignite")
+>>> ...
+```
+
+### Windows Support
+
+Ignite Dataset is fully compatible with Windows. You can use it as part of TensorFlow on your Windows workstation as well as on Linux/MacOS systems.
+
+## Try it out
+
+The simplest way to try Ignite Dataset is to run a [Docker](https://www.docker.com/) container with Apache Ignite and loaded [MNIST](http://yann.lecun.com/exdb/mnist/) data and after start interruct with it using Ignite Dataset. Such container is available on Docker Hub: [dmitrievanthony/ignite-with-mnist](https://hub.docker.com/r/dmitrievanthony/ignite-with-mnist/). You need to start this container on your machine:
+
+```
+docker run -it -p 10800:10800 dmitrievanthony/ignite-with-mnist
+```
+
+After that you will be able to work with it following way:
+
+![ignite-dataset-mnist](https://s3.amazonaws.com/helloworld23423423ew23/ignite-dataset-mnist.png "Ignite Dataset Mnist")
+
+## Limitations
+
+Presently, Ignite Dataset works with assumption that all objects in the cache have the same structure (homogeneous objects) and the cache contains at least one object. Another limitation concerns structured objects, Ignite Dataset does not support UUID, Maps and Object arrays that might be parts of an object structure.
diff --git a/tensorflow/contrib/ignite/__init__.py b/tensorflow/contrib/ignite/__init__.py
new file mode 100644
index 0000000000..f42947696f
--- /dev/null
+++ b/tensorflow/contrib/ignite/__init__.py
@@ -0,0 +1,42 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""IgniteDataset that allows to get data from Apache Ignite.
+
+Apache Ignite is a memory-centric distributed database, caching, and
+processing platform for transactional, analytical, and streaming workloads,
+delivering in-memory speeds at petabyte scale. This contrib package
+contains an integration between Apache Ignite and TensorFlow. The
+integration is based on tf.data from TensorFlow side and Binary Client
+Protocol from Apache Ignite side. It allows to use Apache Ignite as a
+datasource for neural network training, inference and all other
+computations supported by TensorFlow. Ignite Dataset is based on Apache
+Ignite Binary Client Protocol:
+https://apacheignite.readme.io/v2.6/docs/binary-client-protocol.
+
+@@IgniteDataset
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.ignite.python.ops.ignite_dataset_ops import IgniteDataset
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "IgniteDataset",
+]
+
+remove_undocumented(__name__)
diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc
new file mode 100644
index 0000000000..2c8a7d44b0
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.cc
@@ -0,0 +1,334 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/lib/core/errors.h"
+
+namespace tensorflow {
+
+BinaryObjectParser::BinaryObjectParser() : byte_swapper_(ByteSwapper(false)) {}
+
+Status BinaryObjectParser::Parse(uint8_t** ptr,
+ std::vector<Tensor>* out_tensors,
+ std::vector<int32_t>* types) const {
+ uint8_t object_type_id = ParseByte(ptr);
+
+ // Skip non-leaf nodes.
+ if (object_type_id != WRAPPED_OBJ && object_type_id != COMPLEX_OBJ)
+ types->push_back(object_type_id);
+
+ switch (object_type_id) {
+ case BYTE: {
+ out_tensors->emplace_back(cpu_allocator(), DT_UINT8, TensorShape({}));
+ out_tensors->back().scalar<uint8>()() = ParseByte(ptr);
+ break;
+ }
+ case SHORT: {
+ out_tensors->emplace_back(cpu_allocator(), DT_INT16, TensorShape({}));
+ out_tensors->back().scalar<int16>()() = ParseShort(ptr);
+ break;
+ }
+ case USHORT: {
+ out_tensors->emplace_back(cpu_allocator(), DT_UINT16, TensorShape({}));
+ out_tensors->back().scalar<uint16>()() = ParseUnsignedShort(ptr);
+ break;
+ }
+ case INT: {
+ out_tensors->emplace_back(cpu_allocator(), DT_INT32, TensorShape({}));
+ out_tensors->back().scalar<int32>()() = ParseInt(ptr);
+ break;
+ }
+ case LONG: {
+ out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({}));
+ out_tensors->back().scalar<int64>()() = ParseLong(ptr);
+ break;
+ }
+ case FLOAT: {
+ out_tensors->emplace_back(cpu_allocator(), DT_FLOAT, TensorShape({}));
+ out_tensors->back().scalar<float>()() = ParseFloat(ptr);
+ break;
+ }
+ case DOUBLE: {
+ out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE, TensorShape({}));
+ out_tensors->back().scalar<double>()() = ParseDouble(ptr);
+ break;
+ }
+ case BOOL: {
+ out_tensors->emplace_back(cpu_allocator(), DT_BOOL, TensorShape({}));
+ out_tensors->back().scalar<bool>()() = ParseBool(ptr);
+ break;
+ }
+ case STRING: {
+ out_tensors->emplace_back(cpu_allocator(), DT_STRING, TensorShape({}));
+ out_tensors->back().scalar<string>()() = ParseString(ptr);
+ break;
+ }
+ case DATE: {
+ out_tensors->emplace_back(cpu_allocator(), DT_INT64, TensorShape({}));
+ out_tensors->back().scalar<int64>()() = ParseLong(ptr);
+ break;
+ }
+ case BYTE_ARR: {
+ int32_t length = ParseInt(ptr);
+ uint8_t* arr = ParseByteArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_UINT8,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<uint8>().data());
+ break;
+ }
+ case SHORT_ARR: {
+ int32_t length = ParseInt(ptr);
+ int16_t* arr = ParseShortArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_INT16,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<int16>().data());
+ break;
+ }
+ case USHORT_ARR: {
+ int32_t length = ParseInt(ptr);
+ uint16_t* arr = ParseUnsignedShortArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_UINT16,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<uint16>().data());
+ break;
+ }
+ case INT_ARR: {
+ int32_t length = ParseInt(ptr);
+ int32_t* arr = ParseIntArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_INT32,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<int32>().data());
+ break;
+ }
+ case LONG_ARR: {
+ int32_t length = ParseInt(ptr);
+ int64_t* arr = ParseLongArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_INT64,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<int64>().data());
+ break;
+ }
+ case FLOAT_ARR: {
+ int32_t length = ParseInt(ptr);
+ float* arr = ParseFloatArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_FLOAT,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<float>().data());
+ break;
+ }
+ case DOUBLE_ARR: {
+ int32_t length = ParseInt(ptr);
+ double* arr = ParseDoubleArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_DOUBLE,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<double>().data());
+ break;
+ }
+ case BOOL_ARR: {
+ int32_t length = ParseInt(ptr);
+ bool* arr = ParseBoolArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_BOOL,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<bool>().data());
+ break;
+ }
+ case STRING_ARR: {
+ int32_t length = ParseInt(ptr);
+ out_tensors->emplace_back(cpu_allocator(), DT_STRING,
+ TensorShape({length}));
+ for (int32_t i = 0; i < length; i++)
+ out_tensors->back().vec<string>()(i) = ParseString(ptr);
+ break;
+ }
+ case DATE_ARR: {
+ int32_t length = ParseInt(ptr);
+ int64_t* arr = ParseLongArr(ptr, length);
+ out_tensors->emplace_back(cpu_allocator(), DT_INT64,
+ TensorShape({length}));
+ std::copy_n(arr, length, out_tensors->back().flat<int64>().data());
+ break;
+ }
+ case WRAPPED_OBJ: {
+ int32_t byte_arr_size = ParseInt(ptr);
+ TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types));
+ int32_t offset = ParseInt(ptr);
+
+ break;
+ }
+ case COMPLEX_OBJ: {
+ uint8_t version = ParseByte(ptr);
+ int16_t flags = ParseShort(ptr);
+ int32_t type_id = ParseInt(ptr);
+ int32_t hash_code = ParseInt(ptr);
+ int32_t length = ParseInt(ptr);
+ int32_t schema_id = ParseInt(ptr);
+ int32_t schema_offset = ParseInt(ptr);
+
+ // 24 is size of header just read.
+ uint8_t* end = *ptr + schema_offset - 24;
+ int32_t i = 0;
+ while (*ptr < end) {
+ i++;
+ TF_RETURN_IF_ERROR(Parse(ptr, out_tensors, types));
+ }
+
+ *ptr += (length - schema_offset);
+
+ break;
+ }
+ default: {
+ return errors::Unknown("Unknowd binary type (type id ",
+ (int)object_type_id, ")");
+ }
+ }
+
+ return Status::OK();
+}
+
+uint8_t BinaryObjectParser::ParseByte(uint8_t** ptr) const {
+ uint8_t res = **ptr;
+ *ptr += 1;
+
+ return res;
+}
+
+int16_t BinaryObjectParser::ParseShort(uint8_t** ptr) const {
+ int16_t* res = *reinterpret_cast<int16_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt16(res);
+ *ptr += 2;
+
+ return *res;
+}
+
+uint16_t BinaryObjectParser::ParseUnsignedShort(uint8_t** ptr) const {
+ uint16_t* res = *reinterpret_cast<uint16_t**>(ptr);
+ byte_swapper_.SwapIfRequiredUnsignedInt16(res);
+ *ptr += 2;
+
+ return *res;
+}
+
+int32_t BinaryObjectParser::ParseInt(uint8_t** ptr) const {
+ int32_t* res = *reinterpret_cast<int32_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt32(res);
+ *ptr += 4;
+
+ return *res;
+}
+
+int64_t BinaryObjectParser::ParseLong(uint8_t** ptr) const {
+ int64_t* res = *reinterpret_cast<int64_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt64(res);
+ *ptr += 8;
+
+ return *res;
+}
+
+float BinaryObjectParser::ParseFloat(uint8_t** ptr) const {
+ float* res = *reinterpret_cast<float**>(ptr);
+ byte_swapper_.SwapIfRequiredFloat(res);
+ *ptr += 4;
+
+ return *res;
+}
+
+double BinaryObjectParser::ParseDouble(uint8_t** ptr) const {
+ double* res = *reinterpret_cast<double**>(ptr);
+ byte_swapper_.SwapIfRequiredDouble(res);
+ *ptr += 8;
+
+ return *res;
+}
+
+bool BinaryObjectParser::ParseBool(uint8_t** ptr) const {
+ bool res = **reinterpret_cast<bool**>(ptr);
+ *ptr += 1;
+
+ return res;
+}
+
+string BinaryObjectParser::ParseString(uint8_t** ptr) const {
+ int32_t length = ParseInt(ptr);
+ string res(*reinterpret_cast<char**>(ptr), length);
+ *ptr += length;
+
+ return res;
+}
+
+uint8_t* BinaryObjectParser::ParseByteArr(uint8_t** ptr, int length) const {
+ uint8_t* res = *reinterpret_cast<uint8_t**>(ptr);
+ *ptr += length;
+
+ return res;
+}
+
+int16_t* BinaryObjectParser::ParseShortArr(uint8_t** ptr, int length) const {
+ int16_t* res = *reinterpret_cast<int16_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt16Arr(res, length);
+ *ptr += length * 2;
+
+ return res;
+}
+
+uint16_t* BinaryObjectParser::ParseUnsignedShortArr(uint8_t** ptr,
+ int length) const {
+ uint16_t* res = *reinterpret_cast<uint16_t**>(ptr);
+ byte_swapper_.SwapIfRequiredUnsignedInt16Arr(res, length);
+ *ptr += length * 2;
+
+ return res;
+}
+
+int32_t* BinaryObjectParser::ParseIntArr(uint8_t** ptr, int length) const {
+ int32_t* res = *reinterpret_cast<int32_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt32Arr(res, length);
+ *ptr += length * 4;
+
+ return res;
+}
+
+int64_t* BinaryObjectParser::ParseLongArr(uint8_t** ptr, int length) const {
+ int64_t* res = *reinterpret_cast<int64_t**>(ptr);
+ byte_swapper_.SwapIfRequiredInt64Arr(res, length);
+ *ptr += length * 8;
+
+ return res;
+}
+
+float* BinaryObjectParser::ParseFloatArr(uint8_t** ptr, int length) const {
+ float* res = *reinterpret_cast<float**>(ptr);
+ byte_swapper_.SwapIfRequiredFloatArr(res, length);
+ *ptr += length * 4;
+
+ return res;
+}
+
+double* BinaryObjectParser::ParseDoubleArr(uint8_t** ptr, int length) const {
+ double* res = *reinterpret_cast<double**>(ptr);
+ byte_swapper_.SwapIfRequiredDoubleArr(res, length);
+ *ptr += length * 8;
+
+ return res;
+}
+
+bool* BinaryObjectParser::ParseBoolArr(uint8_t** ptr, int length) const {
+ bool* res = *reinterpret_cast<bool**>(ptr);
+ *ptr += length;
+
+ return res;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h
new file mode 100644
index 0000000000..eb1f856643
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_
+
+#include <vector>
+#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class BinaryObjectParser {
+ public:
+ BinaryObjectParser();
+ Status Parse(uint8_t** ptr, std::vector<Tensor>* out_tensors,
+ std::vector<int32_t>* types) const;
+
+ private:
+ uint8_t ParseByte(uint8_t** ptr) const;
+ int16_t ParseShort(uint8_t** ptr) const;
+ uint16_t ParseUnsignedShort(uint8_t** ptr) const;
+ int32_t ParseInt(uint8_t** ptr) const;
+ int64_t ParseLong(uint8_t** ptr) const;
+ float ParseFloat(uint8_t** ptr) const;
+ double ParseDouble(uint8_t** ptr) const;
+ bool ParseBool(uint8_t** ptr) const;
+ string ParseString(uint8_t** ptr) const;
+ uint8_t* ParseByteArr(uint8_t** ptr, int length) const;
+ int16_t* ParseShortArr(uint8_t** ptr, int length) const;
+ uint16_t* ParseUnsignedShortArr(uint8_t** ptr, int length) const;
+ int32_t* ParseIntArr(uint8_t** ptr, int length) const;
+ int64_t* ParseLongArr(uint8_t** ptr, int length) const;
+ float* ParseFloatArr(uint8_t** ptr, int length) const;
+ double* ParseDoubleArr(uint8_t** ptr, int length) const;
+ bool* ParseBoolArr(uint8_t** ptr, int length) const;
+
+ const ByteSwapper byte_swapper_;
+};
+
+enum ObjectType {
+ BYTE = 1,
+ SHORT = 2,
+ INT = 3,
+ LONG = 4,
+ FLOAT = 5,
+ DOUBLE = 6,
+ USHORT = 7,
+ BOOL = 8,
+ STRING = 9,
+ DATE = 11,
+ BYTE_ARR = 12,
+ SHORT_ARR = 13,
+ INT_ARR = 14,
+ LONG_ARR = 15,
+ FLOAT_ARR = 16,
+ DOUBLE_ARR = 17,
+ USHORT_ARR = 18,
+ BOOL_ARR = 19,
+ STRING_ARR = 20,
+ DATE_ARR = 22,
+ WRAPPED_OBJ = 27,
+ COMPLEX_OBJ = 103
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BINARY_OBJECT_PARSER_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h b/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h
new file mode 100644
index 0000000000..46df3e39dc
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h
@@ -0,0 +1,126 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_
+
+#include <stdint.h>
+#include "tensorflow/core/platform/byte_order.h"
+
+namespace tensorflow {
+
+class ByteSwapper {
+ public:
+ ByteSwapper(bool big_endian) { swap_ = big_endian == port::kLittleEndian; }
+
+ inline void SwapIfRequiredInt16(int16_t *x) const {
+ if (swap_) {
+ Swap16(x);
+ }
+ }
+
+ inline void SwapIfRequiredUnsignedInt16(uint16_t *x) const {
+ if (swap_) {
+ Swap16(reinterpret_cast<int16_t *>(x));
+ }
+ }
+
+ inline void SwapIfRequiredInt32(int32_t *x) const {
+ if (swap_) {
+ Swap32(x);
+ }
+ }
+
+ inline void SwapIfRequiredFloat(float *x) const {
+ if (swap_) {
+ Swap32(reinterpret_cast<int32_t *>(x));
+ }
+ }
+
+ inline void SwapIfRequiredInt64(int64_t *x) const {
+ if (swap_) {
+ Swap64(x);
+ }
+ }
+
+ inline void SwapIfRequiredDouble(double *x) const {
+ if (swap_) {
+ Swap64(reinterpret_cast<int64_t *>(x));
+ }
+ }
+
+ inline void SwapIfRequiredInt16Arr(int16_t *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++) Swap16(&x[i]);
+ }
+ }
+
+ inline void SwapIfRequiredUnsignedInt16Arr(uint16_t *x,
+ int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++)
+ Swap16(reinterpret_cast<int16_t *>(&x[i]));
+ }
+ }
+
+ inline void SwapIfRequiredInt32Arr(int32_t *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++) Swap32(&x[i]);
+ }
+ }
+
+ inline void SwapIfRequiredFloatArr(float *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++)
+ Swap32(reinterpret_cast<int32_t *>(&x[i]));
+ }
+ }
+
+ inline void SwapIfRequiredInt64Arr(int64_t *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++) Swap64(&x[i]);
+ }
+ }
+
+ inline void SwapIfRequiredDoubleArr(double *x, int32_t length) const {
+ if (swap_) {
+ for (int32_t i = 0; i < length; i++)
+ Swap64(reinterpret_cast<int64_t *>(&x[i]));
+ }
+ }
+
+ private:
+ inline void Swap16(int16_t *x) const {
+ *x = ((*x & 0xFF) << 8) | ((*x >> 8) & 0xFF);
+ }
+
+ inline void Swap32(int32_t *x) const {
+ *x = ((*x & 0xFF) << 24) | (((*x >> 8) & 0xFF) << 16) |
+ (((*x >> 16) & 0xFF) << 8) | ((*x >> 24) & 0xFF);
+ }
+
+ inline void Swap64(int64_t *x) const {
+ *x = ((*x & 0xFF) << 56) | (((*x >> 8) & 0xFF) << 48) |
+ (((*x >> 16) & 0xFF) << 40) | (((*x >> 24) & 0xFF) << 32) |
+ (((*x >> 32) & 0xFF) << 24) | (((*x >> 40) & 0xFF) << 16) |
+ (((*x >> 48) & 0xFF) << 8) | ((*x >> 56) & 0xFF);
+ }
+
+ bool swap_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_BYTE_SWAPPER_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_client.h b/tensorflow/contrib/ignite/kernels/ignite_client.h
new file mode 100644
index 0000000000..459b50b48f
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_client.h
@@ -0,0 +1,84 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_
+
+#include "tensorflow/contrib/ignite/kernels/ignite_byte_swapper.h"
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/core/status.h"
+
+namespace tensorflow {
+
+class Client {
+ public:
+ Client(bool big_endian) : byte_swapper_(ByteSwapper(big_endian)) {}
+ virtual Status Connect() = 0;
+ virtual Status Disconnect() = 0;
+ virtual bool IsConnected() = 0;
+ virtual int GetSocketDescriptor() = 0;
+ virtual Status ReadData(uint8_t *buf, const int32_t length) = 0;
+ virtual Status WriteData(const uint8_t *buf, const int32_t length) = 0;
+
+ inline Status ReadByte(uint8_t *data) { return ReadData(data, 1); }
+
+ inline Status ReadShort(int16_t *data) {
+ TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 2));
+ byte_swapper_.SwapIfRequiredInt16(data);
+
+ return Status::OK();
+ }
+
+ inline Status ReadInt(int32_t *data) {
+ TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 4));
+ byte_swapper_.SwapIfRequiredInt32(data);
+
+ return Status::OK();
+ }
+
+ inline Status ReadLong(int64_t *data) {
+ TF_RETURN_IF_ERROR(ReadData((uint8_t *)data, 8));
+ byte_swapper_.SwapIfRequiredInt64(data);
+
+ return Status::OK();
+ }
+
+ inline Status WriteByte(const uint8_t data) { return WriteData(&data, 1); }
+
+ inline Status WriteShort(const int16_t data) {
+ int16_t tmp = data;
+ byte_swapper_.SwapIfRequiredInt16(&tmp);
+ return WriteData((uint8_t *)&tmp, 2);
+ }
+
+ inline Status WriteInt(const int32_t data) {
+ int32_t tmp = data;
+ byte_swapper_.SwapIfRequiredInt32(&tmp);
+ return WriteData((uint8_t *)&tmp, 4);
+ }
+
+ inline Status WriteLong(const int64_t data) {
+ int64_t tmp = data;
+ byte_swapper_.SwapIfRequiredInt64(&tmp);
+ return WriteData((uint8_t *)&tmp, 8);
+ }
+
+ private:
+ const ByteSwapper byte_swapper_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_CLIENT_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset.cc
new file mode 100644
index 0000000000..c4a7d3c513
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset.cc
@@ -0,0 +1,81 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+IgniteDataset::IgniteDataset(OpKernelContext* ctx, string cache_name,
+ string host, int32 port, bool local, int32 part,
+ int32 page_size, string username, string password,
+ string certfile, string keyfile,
+ string cert_password, std::vector<int32> schema,
+ std::vector<int32> permutation,
+ DataTypeVector dtypes,
+ std::vector<PartialTensorShape> shapes)
+ : DatasetBase(DatasetContext(ctx)),
+ cache_name_(std::move(cache_name)),
+ host_(std::move(host)),
+ port_(port),
+ local_(local),
+ part_(part),
+ page_size_(page_size),
+ username_(std::move(username)),
+ password_(std::move(password)),
+ certfile_(std::move(certfile)),
+ keyfile_(std::move(keyfile)),
+ cert_password_(std::move(cert_password)),
+ schema_(std::move(schema)),
+ permutation_(std::move(permutation)),
+ dtypes_(dtypes),
+ shapes_(shapes) {
+ LOG(INFO) << "Ignite Dataset created [cache_name='" << cache_name_
+ << "', host='" << host_ << "', port=" << port_
+ << ", local=" << local_ << ", part=" << part_
+ << ", page_size=" << page_size_ << ", username='" << username_
+ << "', certfile='" << certfile_ << "', keyfile='"
+ << keyfile_ + "']";
+}
+
+IgniteDataset::~IgniteDataset() { LOG(INFO) << "Ignite Dataset destroyed"; }
+
+std::unique_ptr<IteratorBase> IgniteDataset::MakeIteratorInternal(
+ const string& prefix) const {
+ return std::unique_ptr<IteratorBase>(new IgniteDatasetIterator(
+ {this, strings::StrCat(prefix, "::Ignite")}, std::move(this->host_),
+ this->port_, std::move(this->cache_name_), this->local_, this->part_,
+ this->page_size_, std::move(this->username_), std::move(this->password_),
+ std::move(this->certfile_), std::move(this->keyfile_),
+ std::move(this->cert_password_), std::move(this->schema_),
+ std::move(this->permutation_)));
+}
+
+const DataTypeVector& IgniteDataset::output_dtypes() const { return dtypes_; }
+
+const std::vector<PartialTensorShape>& IgniteDataset::output_shapes() const {
+ return shapes_;
+}
+
+string IgniteDataset::DebugString() const { return "IgniteDatasetOp::Dataset"; }
+
+Status IgniteDataset::AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const {
+ return errors::Unimplemented(
+ "IgniteDataset does not support 'AsGraphDefInternal'");
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset.h b/tensorflow/contrib/ignite/kernels/ignite_dataset.h
new file mode 100644
index 0000000000..66bfdf2e2a
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset.h
@@ -0,0 +1,63 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_
+
+#include "tensorflow/core/framework/dataset.h"
+
+namespace tensorflow {
+
+class IgniteDataset : public DatasetBase {
+ public:
+ IgniteDataset(OpKernelContext* ctx, string cache_name, string host,
+ int32 port, bool local, int32 part, int32 page_size,
+ string username, string password, string certfile,
+ string keyfile, string cert_password, std::vector<int32> schema,
+ std::vector<int32> permutation, DataTypeVector dtypes,
+ std::vector<PartialTensorShape> shapes);
+ ~IgniteDataset();
+ std::unique_ptr<IteratorBase> MakeIteratorInternal(
+ const string& prefix) const override;
+ const DataTypeVector& output_dtypes() const override;
+ const std::vector<PartialTensorShape>& output_shapes() const override;
+ string DebugString() const override;
+
+ protected:
+ Status AsGraphDefInternal(SerializationContext* ctx,
+ DatasetGraphDefBuilder* b,
+ Node** output) const override;
+
+ private:
+ const string cache_name_;
+ const string host_;
+ const int32 port_;
+ const bool local_;
+ const int32 part_;
+ const int32 page_size_;
+ const string username_;
+ const string password_;
+ const string certfile_;
+ const string keyfile_;
+ const string cert_password_;
+ const std::vector<int32> schema_;
+ const std::vector<int32> permutation_;
+ const DataTypeVector dtypes_;
+ const std::vector<PartialTensorShape> shapes_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc
new file mode 100644
index 0000000000..5da9127aa6
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.cc
@@ -0,0 +1,422 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h"
+
+#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h"
+#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+IgniteDatasetIterator::IgniteDatasetIterator(
+ const Params& params, string host, int32 port, string cache_name,
+ bool local, int32 part, int32 page_size, string username, string password,
+ string certfile, string keyfile, string cert_password,
+ std::vector<int32> schema, std::vector<int32> permutation)
+ : DatasetIterator<IgniteDataset>(params),
+ cache_name_(std::move(cache_name)),
+ local_(local),
+ part_(part),
+ page_size_(page_size),
+ username_(std::move(username)),
+ password_(std::move(password)),
+ schema_(std::move(schema)),
+ permutation_(std::move(permutation)),
+ remainder_(-1),
+ cursor_id_(-1),
+ last_page_(false),
+ valid_state_(true) {
+ Client* p_client = new PlainClient(std::move(host), port, false);
+
+ if (certfile.empty())
+ client_ = std::unique_ptr<Client>(p_client);
+ else
+ client_ = std::unique_ptr<Client>(
+ new SslWrapper(std::unique_ptr<Client>(p_client), std::move(certfile),
+ std::move(keyfile), std::move(cert_password), false));
+
+ LOG(INFO) << "Ignite Dataset Iterator created";
+}
+
+IgniteDatasetIterator::~IgniteDatasetIterator() {
+ Status status = CloseConnection();
+ if (!status.ok()) LOG(ERROR) << status.ToString();
+
+ LOG(INFO) << "Ignite Dataset Iterator destroyed";
+}
+
+Status IgniteDatasetIterator::GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) {
+ mutex_lock l(mutex_);
+
+ if (valid_state_) {
+ Status status =
+ GetNextInternalWithValidState(ctx, out_tensors, end_of_sequence);
+
+ if (!status.ok()) valid_state_ = false;
+
+ return status;
+ }
+
+ return errors::Unknown("Iterator is invalid");
+}
+
+Status IgniteDatasetIterator::SaveInternal(IteratorStateWriter* writer) {
+ return errors::Unimplemented(
+ "Iterator for IgniteDataset does not support 'SaveInternal'");
+}
+
+Status IgniteDatasetIterator::RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) {
+ return errors::Unimplemented(
+ "Iterator for IgniteDataset does not support 'RestoreInternal')");
+}
+
+Status IgniteDatasetIterator::GetNextInternalWithValidState(
+ IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) {
+ if (remainder_ == 0 && last_page_) {
+ cursor_id_ = -1;
+ *end_of_sequence = true;
+
+ return Status::OK();
+ } else {
+ TF_RETURN_IF_ERROR(EstablishConnection());
+
+ if (remainder_ == -1) {
+ TF_RETURN_IF_ERROR(ScanQuery());
+ } else if (remainder_ == 0) {
+ TF_RETURN_IF_ERROR(LoadNextPage());
+ }
+
+ uint8_t* initial_ptr = ptr_;
+ std::vector<Tensor> tensors;
+ std::vector<int32_t> types;
+
+ TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse key
+ TF_RETURN_IF_ERROR(parser_.Parse(&ptr_, &tensors, &types)); // Parse val
+
+ remainder_ -= (ptr_ - initial_ptr);
+
+ TF_RETURN_IF_ERROR(CheckTypes(types));
+
+ for (size_t i = 0; i < tensors.size(); i++)
+ out_tensors->push_back(tensors[permutation_[i]]);
+
+ *end_of_sequence = false;
+
+ return Status::OK();
+ }
+
+ *end_of_sequence = true;
+
+ return Status::OK();
+}
+
+Status IgniteDatasetIterator::EstablishConnection() {
+ if (!client_->IsConnected()) {
+ TF_RETURN_IF_ERROR(client_->Connect());
+
+ Status status = Handshake();
+ if (!status.ok()) {
+ Status disconnect_status = client_->Disconnect();
+ if (!disconnect_status.ok()) LOG(ERROR) << disconnect_status.ToString();
+
+ return status;
+ }
+ }
+
+ return Status::OK();
+}
+
+Status IgniteDatasetIterator::CloseConnection() {
+ if (cursor_id_ != -1 && !last_page_) {
+ TF_RETURN_IF_ERROR(EstablishConnection());
+
+ TF_RETURN_IF_ERROR(client_->WriteInt(kCloseConnectionReqLength));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kCloseConnectionOpcode));
+ TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
+ TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Resource ID
+
+ int32_t res_len;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
+ if (res_len < kMinResLength)
+ return errors::Unknown("Close Resource Response is corrupted");
+
+ int64_t req_id;
+ TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
+ int32_t status;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&status));
+ if (status != 0) {
+ uint8_t err_msg_header;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
+ if (err_msg_header == kStringVal) {
+ int32_t err_msg_length;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
+
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
+ TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
+ string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
+
+ return errors::Unknown("Close Resource Error [status=", status,
+ ", message=", err_msg, "]");
+ }
+ return errors::Unknown("Close Resource Error [status=", status, "]");
+ }
+
+ cursor_id_ = -1;
+
+ return client_->Disconnect();
+ } else {
+ LOG(INFO) << "Query Cursor " << cursor_id_ << " is already closed";
+ }
+
+ return client_->IsConnected() ? client_->Disconnect() : Status::OK();
+}
+
+Status IgniteDatasetIterator::Handshake() {
+ int32_t msg_len = kHandshakeReqDefaultLength;
+
+ if (username_.empty())
+ msg_len += 1;
+ else
+ msg_len += 5 + username_.length(); // 1 byte header, 4 bytes length.
+
+ if (password_.empty())
+ msg_len += 1;
+ else
+ msg_len += 5 + password_.length(); // 1 byte header, 4 bytes length.
+
+ TF_RETURN_IF_ERROR(client_->WriteInt(msg_len));
+ TF_RETURN_IF_ERROR(client_->WriteByte(1));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMajorVersion));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolMinorVersion));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kProtocolPatchVersion));
+ TF_RETURN_IF_ERROR(client_->WriteByte(2));
+ if (username_.empty()) {
+ TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal));
+ } else {
+ TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal));
+ TF_RETURN_IF_ERROR(client_->WriteInt(username_.length()));
+ TF_RETURN_IF_ERROR(
+ client_->WriteData(reinterpret_cast<const uint8_t*>(username_.c_str()),
+ username_.length()));
+ }
+
+ if (password_.empty()) {
+ TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal));
+ } else {
+ TF_RETURN_IF_ERROR(client_->WriteByte(kStringVal));
+ TF_RETURN_IF_ERROR(client_->WriteInt(password_.length()));
+ TF_RETURN_IF_ERROR(
+ client_->WriteData(reinterpret_cast<const uint8_t*>(password_.c_str()),
+ password_.length()));
+ }
+
+ int32_t handshake_res_len;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&handshake_res_len));
+ uint8_t handshake_res;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&handshake_res));
+
+ if (handshake_res != 1) {
+ int16_t serv_ver_major;
+ TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_major));
+ int16_t serv_ver_minor;
+ TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_minor));
+ int16_t serv_ver_patch;
+ TF_RETURN_IF_ERROR(client_->ReadShort(&serv_ver_patch));
+ uint8_t header;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&header));
+
+ if (header == kStringVal) {
+ int32_t length;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&length));
+
+ uint8_t* err_msg_c = new uint8_t[length];
+ auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
+ TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, length));
+ string err_msg(reinterpret_cast<char*>(err_msg_c), length);
+
+ return errors::Unknown("Handshake Error [result=", handshake_res,
+ ", version=", serv_ver_major, ".", serv_ver_minor,
+ ".", serv_ver_patch, ", message='", err_msg, "']");
+ } else if (header == kNullVal) {
+ return errors::Unknown("Handshake Error [result=", handshake_res,
+ ", version=", serv_ver_major, ".", serv_ver_minor,
+ ".", serv_ver_patch, "]");
+ } else {
+ return errors::Unknown("Handshake Error [result=", handshake_res,
+ ", version=", serv_ver_major, ".", serv_ver_minor,
+ ".", serv_ver_patch, "]");
+ }
+ }
+
+ return Status::OK();
+}
+
+Status IgniteDatasetIterator::ScanQuery() {
+ TF_RETURN_IF_ERROR(client_->WriteInt(kScanQueryReqLength));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kScanQueryOpcode));
+ TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
+ TF_RETURN_IF_ERROR(
+ client_->WriteInt(JavaHashCode(cache_name_))); // Cache name
+ TF_RETURN_IF_ERROR(client_->WriteByte(0)); // Flags
+ TF_RETURN_IF_ERROR(client_->WriteByte(kNullVal)); // Filter object
+ TF_RETURN_IF_ERROR(client_->WriteInt(page_size_)); // Cursor page size
+ TF_RETURN_IF_ERROR(client_->WriteInt(part_)); // part_ition to query
+ TF_RETURN_IF_ERROR(client_->WriteByte(local_)); // local_ flag
+
+ uint64 wait_start = Env::Default()->NowMicros();
+ int32_t res_len;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
+ int64_t wait_stop = Env::Default()->NowMicros();
+
+ LOG(INFO) << "Scan Query waited " << (wait_stop - wait_start) / 1000 << " ms";
+
+ if (res_len < kMinResLength)
+ return errors::Unknown("Scan Query Response is corrupted");
+
+ int64_t req_id;
+ TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
+
+ int32_t status;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&status));
+
+ if (status != 0) {
+ uint8_t err_msg_header;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
+
+ if (err_msg_header == kStringVal) {
+ int32_t err_msg_length;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
+
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
+ TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
+ string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
+
+ return errors::Unknown("Scan Query Error [status=", status,
+ ", message=", err_msg, "]");
+ }
+ return errors::Unknown("Scan Query Error [status=", status, "]");
+ }
+
+ TF_RETURN_IF_ERROR(client_->ReadLong(&cursor_id_));
+
+ int32_t row_cnt;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt));
+
+ int32_t page_size = res_len - kScanQueryResHeaderLength;
+
+ return ReceivePage(page_size);
+}
+
+Status IgniteDatasetIterator::LoadNextPage() {
+ TF_RETURN_IF_ERROR(client_->WriteInt(kLoadNextPageReqLength));
+ TF_RETURN_IF_ERROR(client_->WriteShort(kLoadNextPageOpcode));
+ TF_RETURN_IF_ERROR(client_->WriteLong(0)); // Request ID
+ TF_RETURN_IF_ERROR(client_->WriteLong(cursor_id_)); // Cursor ID
+
+ uint64 wait_start = Env::Default()->NowMicros();
+ int32_t res_len;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&res_len));
+ uint64 wait_stop = Env::Default()->NowMicros();
+
+ LOG(INFO) << "Load Next Page waited " << (wait_stop - wait_start) / 1000
+ << " ms";
+
+ if (res_len < kMinResLength)
+ return errors::Unknown("Load Next Page Response is corrupted");
+
+ int64_t req_id;
+ TF_RETURN_IF_ERROR(client_->ReadLong(&req_id));
+
+ int32_t status;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&status));
+
+ if (status != 0) {
+ uint8_t err_msg_header;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&err_msg_header));
+
+ if (err_msg_header == kStringVal) {
+ int32_t err_msg_length;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&err_msg_length));
+
+ uint8_t* err_msg_c = new uint8_t[err_msg_length];
+ auto clean = gtl::MakeCleanup([err_msg_c] { delete[] err_msg_c; });
+ TF_RETURN_IF_ERROR(client_->ReadData(err_msg_c, err_msg_length));
+ string err_msg(reinterpret_cast<char*>(err_msg_c), err_msg_length);
+
+ return errors::Unknown("Load Next Page Error [status=", status,
+ ", message=", err_msg, "]");
+ }
+ return errors::Unknown("Load Next Page Error [status=", status, "]");
+ }
+
+ int32_t row_cnt;
+ TF_RETURN_IF_ERROR(client_->ReadInt(&row_cnt));
+
+ int32_t page_size = res_len - kLoadNextPageResHeaderLength;
+
+ return ReceivePage(page_size);
+}
+
+Status IgniteDatasetIterator::ReceivePage(int32_t page_size) {
+ remainder_ = page_size;
+ page_ = std::unique_ptr<uint8_t>(new uint8_t[remainder_]);
+ ptr_ = page_.get();
+
+ uint64 start = Env::Default()->NowMicros();
+ TF_RETURN_IF_ERROR(client_->ReadData(ptr_, remainder_));
+ uint64 stop = Env::Default()->NowMicros();
+
+ double size_in_mb = 1.0 * remainder_ / 1024 / 1024;
+ double time_in_s = 1.0 * (stop - start) / 1000 / 1000;
+ LOG(INFO) << "Page size " << size_in_mb << " Mb, time " << time_in_s * 1000
+ << " ms download speed " << size_in_mb / time_in_s << " Mb/sec";
+
+ uint8_t last_page_b;
+ TF_RETURN_IF_ERROR(client_->ReadByte(&last_page_b));
+
+ last_page_ = !last_page_b;
+
+ return Status::OK();
+}
+
+Status IgniteDatasetIterator::CheckTypes(const std::vector<int32_t>& types) {
+ if (schema_.size() != types.size())
+ return errors::Unknown("Object has unexpected schema");
+
+ for (size_t i = 0; i < schema_.size(); i++) {
+ if (schema_[i] != types[permutation_[i]])
+ return errors::Unknown("Object has unexpected schema");
+ }
+
+ return Status::OK();
+}
+
+int32_t IgniteDatasetIterator::JavaHashCode(string str) const {
+ int32_t h = 0;
+ for (char& c : str) {
+ h = 31 * h + c;
+ }
+ return h;
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h
new file mode 100644
index 0000000000..c499e2c9cc
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_iterator.h
@@ -0,0 +1,99 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_
+
+#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h"
+#include "tensorflow/contrib/ignite/kernels/ignite_client.h"
+#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h"
+#include "tensorflow/core/platform/mutex.h"
+
+namespace tensorflow {
+
+class IgniteDatasetIterator : public DatasetIterator<IgniteDataset> {
+ public:
+ IgniteDatasetIterator(const Params& params, string host, int32 port,
+ string cache_name, bool local, int32 part,
+ int32 page_size, string username, string password,
+ string certfile, string keyfile, string cert_password,
+ std::vector<int32> schema,
+ std::vector<int32> permutation);
+ ~IgniteDatasetIterator();
+ Status GetNextInternal(IteratorContext* ctx, std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override;
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override;
+ Status RestoreInternal(IteratorContext* ctx,
+ IteratorStateReader* reader) override;
+
+ private:
+ Status GetNextInternalWithValidState(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence);
+
+ Status EstablishConnection();
+ Status CloseConnection();
+ Status Handshake();
+ Status ScanQuery();
+ Status LoadNextPage();
+ Status ReceivePage(int32_t page_size);
+ Status CheckTypes(const std::vector<int32_t>& types);
+ int32_t JavaHashCode(string str) const;
+
+ std::unique_ptr<Client> client_;
+ BinaryObjectParser parser_;
+
+ const string cache_name_;
+ const bool local_;
+ const int32 part_;
+ const int32 page_size_;
+ const string username_;
+ const string password_;
+ const std::vector<int32> schema_;
+ const std::vector<int32> permutation_;
+
+ int32_t remainder_;
+ int64_t cursor_id_;
+ bool last_page_;
+
+ bool valid_state_;
+
+ mutex mutex_;
+
+ std::unique_ptr<uint8_t> page_;
+ uint8_t* ptr_;
+};
+
+constexpr uint8_t kNullVal = 101;
+constexpr uint8_t kStringVal = 9;
+constexpr uint8_t kProtocolMajorVersion = 1;
+constexpr uint8_t kProtocolMinorVersion = 1;
+constexpr uint8_t kProtocolPatchVersion = 0;
+constexpr int16_t kScanQueryOpcode = 2000;
+constexpr int16_t kLoadNextPageOpcode = 2001;
+constexpr int16_t kCloseConnectionOpcode = 0;
+constexpr int32_t kScanQueryReqLength = 25;
+constexpr int32_t kScanQueryResHeaderLength = 25;
+constexpr int32_t kLoadNextPageReqLength = 18;
+constexpr int32_t kLoadNextPageResHeaderLength = 17;
+constexpr int32_t kCloseConnectionReqLength = 18;
+constexpr int32_t kHandshakeReqDefaultLength = 8;
+constexpr int32_t kMinResLength = 12;
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_DATASET_ITERATOR_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc
new file mode 100644
index 0000000000..f75b1c5ff5
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_dataset_ops.cc
@@ -0,0 +1,198 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <stdlib.h>
+
+#include "tensorflow/contrib/ignite/kernels/ignite_binary_object_parser.h"
+#include "tensorflow/contrib/ignite/kernels/ignite_dataset.h"
+#include "tensorflow/core/framework/dataset.h"
+#include "tensorflow/core/lib/strings/numbers.h"
+
+namespace tensorflow {
+namespace {
+
+Status SchemaToTypes(const std::vector<int32>& schema, DataTypeVector* dtypes) {
+ for (auto e : schema) {
+ if (e == BYTE || e == BYTE_ARR) {
+ dtypes->push_back(DT_UINT8);
+ } else if (e == SHORT || e == SHORT_ARR) {
+ dtypes->push_back(DT_INT16);
+ } else if (e == INT || e == INT_ARR) {
+ dtypes->push_back(DT_INT32);
+ } else if (e == LONG || e == LONG_ARR) {
+ dtypes->push_back(DT_INT64);
+ } else if (e == FLOAT || e == FLOAT_ARR) {
+ dtypes->push_back(DT_FLOAT);
+ } else if (e == DOUBLE || e == DOUBLE_ARR) {
+ dtypes->push_back(DT_DOUBLE);
+ } else if (e == USHORT || e == USHORT_ARR) {
+ dtypes->push_back(DT_UINT8);
+ } else if (e == BOOL || e == BOOL_ARR) {
+ dtypes->push_back(DT_BOOL);
+ } else if (e == STRING || e == STRING_ARR) {
+ dtypes->push_back(DT_STRING);
+ } else {
+ return errors::Unknown("Unexpected type in schema [type_id=", e, "]");
+ }
+ }
+
+ return Status::OK();
+}
+
+Status SchemaToShapes(const std::vector<int32>& schema,
+ std::vector<PartialTensorShape>* shapes) {
+ for (auto e : schema) {
+ if (e >= 1 && e < 10) {
+ shapes->push_back(PartialTensorShape({}));
+ } else if (e >= 12 && e < 21) {
+ shapes->push_back(PartialTensorShape({-1}));
+ } else {
+ return errors::Unknown("Unexpected type in schema [type_id=", e, "]");
+ }
+ }
+
+ return Status::OK();
+}
+
+class IgniteDatasetOp : public DatasetOpKernel {
+ public:
+ using DatasetOpKernel::DatasetOpKernel;
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ string cache_name = "";
+ string host = "";
+ int32 port = -1;
+ bool local = false;
+ int32 part = -1;
+ int32 page_size = -1;
+ string username = "";
+ string password = "";
+ string certfile = "";
+ string keyfile = "";
+ string cert_password = "";
+
+ const char* env_cache_name = std::getenv("IGNITE_DATASET_CACHE_NAME");
+ const char* env_host = std::getenv("IGNITE_DATASET_HOST");
+ const char* env_port = std::getenv("IGNITE_DATASET_PORT");
+ const char* env_local = std::getenv("IGNITE_DATASET_LOCAL");
+ const char* env_part = std::getenv("IGNITE_DATASET_PART");
+ const char* env_page_size = std::getenv("IGNITE_DATASET_PAGE_SIZE");
+ const char* env_username = std::getenv("IGNITE_DATASET_USERNAME");
+ const char* env_password = std::getenv("IGNITE_DATASET_PASSWORD");
+ const char* env_certfile = std::getenv("IGNITE_DATASET_CERTFILE");
+ const char* env_keyfile = std::getenv("IGNITE_DATASET_KEYFILE");
+ const char* env_cert_password = std::getenv("IGNITE_DATASET_CERT_PASSWORD");
+
+ if (env_cache_name) {
+ cache_name = string(env_cache_name);
+ } else {
+ OP_REQUIRES_OK(
+ ctx, ParseScalarArgument<string>(ctx, "cache_name", &cache_name));
+ }
+
+ if (env_host) {
+ host = string(env_host);
+ } else {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<string>(ctx, "host", &host));
+ }
+
+ if (env_port) {
+ OP_REQUIRES(ctx, strings::safe_strto32(env_port, &port),
+ errors::InvalidArgument("IGNITE_DATASET_PORT environment "
+ "variable is not a valid integer: ",
+ env_port));
+ } else {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int32>(ctx, "port", &port));
+ }
+
+ if (env_local) {
+ local = true;
+ } else {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<bool>(ctx, "local", &local));
+ }
+
+ if (env_part) {
+ OP_REQUIRES(ctx, strings::safe_strto32(env_part, &part),
+ errors::InvalidArgument("IGNITE_DATASET_PART environment "
+ "variable is not a valid integer: ",
+ env_part));
+ } else {
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int32>(ctx, "part", &part));
+ }
+
+ if (env_page_size) {
+ OP_REQUIRES(ctx, strings::safe_strto32(env_page_size, &page_size),
+ errors::InvalidArgument("IGNITE_DATASET_PAGE_SIZE "
+ "environment variable is not a valid "
+ "integer: ",
+ env_page_size));
+ } else {
+ OP_REQUIRES_OK(ctx,
+ ParseScalarArgument<int32>(ctx, "page_size", &page_size));
+ }
+
+ if (env_username) username = string(env_username);
+
+ if (env_password) password = string(env_password);
+
+ if (env_certfile) certfile = string(env_certfile);
+
+ if (env_keyfile) keyfile = string(env_keyfile);
+
+ if (env_cert_password) cert_password = string(env_cert_password);
+
+ const Tensor* schema_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("schema", &schema_tensor));
+ OP_REQUIRES(ctx, schema_tensor->dims() == 1,
+ errors::InvalidArgument("`schema` must be a vector."));
+
+ std::vector<int32> schema;
+ schema.reserve(schema_tensor->NumElements());
+ for (int i = 0; i < schema_tensor->NumElements(); i++) {
+ schema.push_back(schema_tensor->flat<int32>()(i));
+ }
+
+ const Tensor* permutation_tensor;
+ OP_REQUIRES_OK(ctx, ctx->input("permutation", &permutation_tensor));
+ OP_REQUIRES(ctx, permutation_tensor->dims() == 1,
+ errors::InvalidArgument("`permutation` must be a vector."));
+
+ std::vector<int32> permutation;
+ permutation.resize(permutation_tensor->NumElements());
+ for (int i = 0; i < permutation_tensor->NumElements(); i++) {
+ // Inversed permutation.
+ permutation[permutation_tensor->flat<int32>()(i)] = i;
+ }
+
+ DataTypeVector dtypes;
+ std::vector<PartialTensorShape> shapes;
+
+ OP_REQUIRES_OK(ctx, SchemaToTypes(schema, &dtypes));
+ OP_REQUIRES_OK(ctx, SchemaToShapes(schema, &shapes));
+
+ *output = new IgniteDataset(
+ ctx, std::move(cache_name), std::move(host), port, local, part,
+ page_size, std::move(username), std::move(password),
+ std::move(certfile), std::move(keyfile), std::move(cert_password),
+ std::move(schema), std::move(permutation), std::move(dtypes),
+ std::move(shapes));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("IgniteDataset").Device(DEVICE_CPU),
+ IgniteDatasetOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client.h b/tensorflow/contrib/ignite/kernels/ignite_plain_client.h
new file mode 100644
index 0000000000..75424c19ee
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client.h
@@ -0,0 +1,43 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_
+
+#include "tensorflow/contrib/ignite/kernels/ignite_client.h"
+
+namespace tensorflow {
+
+class PlainClient : public Client {
+ public:
+ PlainClient(string host, int port, bool big_endian);
+ ~PlainClient();
+
+ Status Connect() override;
+ Status Disconnect() override;
+ bool IsConnected() override;
+ int GetSocketDescriptor() override;
+ Status ReadData(uint8_t* buf, const int32_t length) override;
+ Status WriteData(const uint8_t* buf, const int32_t length) override;
+
+ private:
+ const string host_;
+ const int port_;
+ int sock_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_PLAIN_CLIENT_H_
diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc b/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc
new file mode 100644
index 0000000000..cf672942c6
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client_unix.cc
@@ -0,0 +1,123 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h"
+
+#include <arpa/inet.h>
+#include <netdb.h>
+#include <sys/socket.h>
+#include <unistd.h>
+
+#include <iostream>
+#include <map>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+PlainClient::PlainClient(string host, int port, bool big_endian)
+ : Client(big_endian), host_(std::move(host)), port_(port), sock_(-1) {}
+
+PlainClient::~PlainClient() {
+ if (IsConnected()) {
+ Status status = Disconnect();
+ if (!status.ok()) LOG(WARNING) << status.ToString();
+ }
+}
+
+Status PlainClient::Connect() {
+ if (sock_ == -1) {
+ sock_ = socket(AF_INET, SOCK_STREAM, 0);
+ if (sock_ == -1) return errors::Internal("Failed to create socket");
+ }
+
+ sockaddr_in server;
+
+ server.sin_addr.s_addr = inet_addr(host_.c_str());
+ if (server.sin_addr.s_addr == -1) {
+ hostent* he;
+ in_addr** addr_list;
+
+ if ((he = gethostbyname(host_.c_str())) == NULL)
+ return errors::Internal("Failed to resolve hostname \"", host_, "\"");
+
+ addr_list = (in_addr**)he->h_addr_list;
+ if (addr_list[0] != NULL) server.sin_addr = *addr_list[0];
+ }
+
+ server.sin_family = AF_INET;
+ server.sin_port = htons(port_);
+
+ if (connect(sock_, (sockaddr*)&server, sizeof(server)) < 0)
+ return errors::Internal("Failed to connect to \"", host_, ":", port_, "\"");
+
+ LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" established";
+
+ return Status::OK();
+}
+
+Status PlainClient::Disconnect() {
+ int close_res = close(sock_);
+ sock_ = -1;
+
+ LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" is closed";
+
+ return close_res == 0
+ ? Status::OK()
+ : errors::Internal("Failed to correctly close connection");
+}
+
+bool PlainClient::IsConnected() { return sock_ != -1; }
+
+int PlainClient::GetSocketDescriptor() { return sock_; }
+
+Status PlainClient::ReadData(uint8_t* buf, const int32_t length) {
+ int received = 0;
+
+ while (received < length) {
+ int res = recv(sock_, buf, length - received, 0);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while reading from socket: ", res,
+ ", ", string(strerror(errno)));
+
+ if (res == 0) return errors::Internal("Server closed connection");
+
+ received += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+Status PlainClient::WriteData(const uint8_t* buf, const int32_t length) {
+ int sent = 0;
+
+ while (sent < length) {
+ int res = send(sock_, buf, length - sent, 0);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while writing into socket: ", res,
+ ", ", string(strerror(errno)));
+
+ sent += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc b/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc
new file mode 100644
index 0000000000..dad5aace5f
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_plain_client_windows.cc
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_plain_client.h"
+
+#define WIN32_LEAN_AND_MEAN
+#include <windows.h>
+#include <winsock2.h>
+#include <ws2tcpip.h>
+
+#pragma comment(lib, "Ws2_32.lib")
+#pragma comment(lib, "Mswsock.lib")
+#pragma comment(lib, "AdvApi32.lib")
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/gtl/cleanup.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+PlainClient::PlainClient(string host, int port, bool big_endian)
+ : Client(big_endian),
+ host_(std::move(host)),
+ port_(port),
+ sock_(INVALID_SOCKET) {}
+
+PlainClient::~PlainClient() {
+ if (IsConnected()) {
+ Status status = Disconnect();
+ if (!status.ok()) LOG(WARNING) << status.ToString();
+ }
+}
+
+Status PlainClient::Connect() {
+ WSADATA wsaData;
+ addrinfo *result = NULL, *ptr = NULL, hints;
+
+ int res = WSAStartup(MAKEWORD(2, 2), &wsaData);
+ if (res != 0) return errors::Internal("WSAStartup failed with error: ", res);
+
+ ZeroMemory(&hints, sizeof(hints));
+ hints.ai_family = AF_UNSPEC;
+ hints.ai_socktype = SOCK_STREAM;
+ hints.ai_protocol = IPPROTO_TCP;
+
+ res = getaddrinfo(host_.c_str(), std::to_string(port_).c_str(), &hints,
+ &result);
+ if (res != 0) return errors::Internal("Getaddrinfo failed with error: ", res);
+
+ auto clean = gtl::MakeCleanup([result] { freeaddrinfo(result); });
+
+ for (ptr = result; ptr != NULL; ptr = ptr->ai_next) {
+ sock_ = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol);
+ if (sock_ == INVALID_SOCKET) {
+ WSACleanup();
+ return errors::Internal("Socket failed with error: ", WSAGetLastError());
+ }
+
+ res = connect(sock_, ptr->ai_addr, (int)ptr->ai_addrlen);
+ if (res == SOCKET_ERROR) {
+ closesocket(sock_);
+ sock_ = INVALID_SOCKET;
+ continue;
+ }
+
+ break;
+ }
+
+ if (sock_ == INVALID_SOCKET) {
+ WSACleanup();
+ return errors::Internal("Unable to connect to server");
+ }
+
+ LOG(INFO) << "Connection to \"" << host_ << ":" << port_ << "\" established";
+
+ return Status::OK();
+}
+
+Status PlainClient::Disconnect() {
+ int res = shutdown(sock_, SD_SEND);
+ closesocket(sock_);
+ WSACleanup();
+
+ if (res == SOCKET_ERROR)
+ return errors::Internal("Shutdown failed with error: ", WSAGetLastError());
+ else
+ return Status::OK();
+}
+
+bool PlainClient::IsConnected() { return sock_ != INVALID_SOCKET; }
+
+int PlainClient::GetSocketDescriptor() { return sock_; }
+
+Status PlainClient::ReadData(uint8_t *buf, const int32_t length) {
+ int received = 0;
+
+ while (received < length) {
+ int res = recv(sock_, (char *)buf, length - received, 0);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while reading from socket: ",
+ res);
+
+ if (res == 0) return errors::Internal("Server closed connection");
+
+ received += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+Status PlainClient::WriteData(const uint8_t *buf, const int32_t length) {
+ int sent = 0;
+
+ while (sent < length) {
+ int res = send(sock_, (char *)buf, length - sent, 0);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while writing into socket: ",
+ res);
+
+ sent += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc
new file mode 100644
index 0000000000..ceb479b084
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.cc
@@ -0,0 +1,151 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h"
+
+#include <openssl/err.h>
+#include <openssl/ssl.h>
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+static int PasswordCb(char *buf, int size, int rwflag, void *password) {
+ strncpy(buf, (char *)(password), size);
+ buf[size - 1] = '\0';
+ return (strlen(buf));
+}
+
+SslWrapper::SslWrapper(std::shared_ptr<Client> client, string certfile,
+ string keyfile, string cert_password, bool big_endian)
+ : Client(big_endian),
+ client_(client),
+ certfile_(std::move(certfile)),
+ keyfile_(std::move(keyfile)),
+ cert_password_(std::move(cert_password)),
+ ctx_(nullptr),
+ ssl_(nullptr) {}
+
+SslWrapper::~SslWrapper() {
+ if (IsConnected()) {
+ Status status = Disconnect();
+ if (!status.ok()) LOG(WARNING) << status.ToString();
+ }
+
+ if (ctx_ != nullptr) {
+ SSL_CTX_free(ctx_);
+ ctx_ = nullptr;
+ }
+
+ if (ssl_ != nullptr) {
+ SSL_free(ssl_);
+ ssl_ = nullptr;
+ }
+}
+
+Status SslWrapper::InitSslContext() {
+ OpenSSL_add_all_algorithms();
+ SSL_load_error_strings();
+
+ ctx_ = SSL_CTX_new(SSLv23_method());
+ if (ctx_ == NULL) return errors::Internal("Couldn't create SSL context");
+
+ SSL_CTX_set_default_passwd_cb(ctx_, PasswordCb);
+ SSL_CTX_set_default_passwd_cb_userdata(ctx_, (void *)cert_password_.c_str());
+
+ if (SSL_CTX_use_certificate_chain_file(ctx_, certfile_.c_str()) != 1)
+ return errors::Internal("Couldn't load cetificate chain (file '", certfile_,
+ "')");
+
+ string private_key_file = keyfile_.empty() ? certfile_ : keyfile_;
+ if (SSL_CTX_use_PrivateKey_file(ctx_, private_key_file.c_str(),
+ SSL_FILETYPE_PEM) != 1)
+ return errors::Internal("Couldn't load private key (file '",
+ private_key_file, "')");
+
+ return Status::OK();
+}
+
+Status SslWrapper::Connect() {
+ if (ctx_ == NULL) {
+ TF_RETURN_IF_ERROR(InitSslContext());
+ }
+
+ ssl_ = SSL_new(ctx_);
+ if (ssl_ == NULL)
+ return errors::Internal("Failed to establish SSL connection");
+
+ TF_RETURN_IF_ERROR(client_->Connect());
+
+ SSL_set_fd(ssl_, client_->GetSocketDescriptor());
+ if (SSL_connect(ssl_) != 1)
+ return errors::Internal("Failed to establish SSL connection");
+
+ LOG(INFO) << "SSL connection established";
+
+ return Status::OK();
+}
+
+Status SslWrapper::Disconnect() {
+ SSL_free(ssl_);
+ ssl_ = nullptr;
+
+ LOG(INFO) << "SSL connection closed";
+
+ return client_->Disconnect();
+}
+
+bool SslWrapper::IsConnected() { return client_->IsConnected(); }
+
+int SslWrapper::GetSocketDescriptor() { return client_->GetSocketDescriptor(); }
+
+Status SslWrapper::ReadData(uint8_t *buf, const int32_t length) {
+ int received = 0;
+
+ while (received < length) {
+ int res = SSL_read(ssl_, buf, length - received);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while reading from SSL socket: ",
+ res);
+
+ if (res == 0) return errors::Internal("Server closed SSL connection");
+
+ received += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+Status SslWrapper::WriteData(const uint8_t *buf, const int32_t length) {
+ int sent = 0;
+
+ while (sent < length) {
+ int res = SSL_write(ssl_, buf, length - sent);
+
+ if (res < 0)
+ return errors::Internal("Error occurred while writing into socket: ",
+ res);
+
+ sent += res;
+ buf += res;
+ }
+
+ return Status::OK();
+}
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h
new file mode 100644
index 0000000000..0406644bba
--- /dev/null
+++ b/tensorflow/contrib/ignite/kernels/ignite_ssl_wrapper.h
@@ -0,0 +1,51 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_
+#define TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_
+
+#include "tensorflow/contrib/ignite/kernels/ignite_client.h"
+
+#include <openssl/ssl.h>
+
+namespace tensorflow {
+
+class SslWrapper : public Client {
+ public:
+ SslWrapper(std::shared_ptr<Client> client, string certfile, string keyfile,
+ string cert_password, bool big_endian);
+ ~SslWrapper();
+
+ Status Connect() override;
+ Status Disconnect() override;
+ bool IsConnected() override;
+ int GetSocketDescriptor() override;
+ Status ReadData(uint8_t* buf, const int32_t length) override;
+ Status WriteData(const uint8_t* buf, const int32_t length) override;
+
+ private:
+ Status InitSslContext();
+
+ std::shared_ptr<Client> client_;
+ string certfile_;
+ string keyfile_;
+ string cert_password_;
+ SSL_CTX* ctx_;
+ SSL* ssl_;
+};
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CONTRIB_IGNITE_KERNELS_IGNITE_SSL_WRAPPER_H_
diff --git a/tensorflow/contrib/ignite/ops/dataset_ops.cc b/tensorflow/contrib/ignite/ops/dataset_ops.cc
new file mode 100644
index 0000000000..3d6fbe00e6
--- /dev/null
+++ b/tensorflow/contrib/ignite/ops/dataset_ops.cc
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/shape_inference.h"
+
+namespace tensorflow {
+
+REGISTER_OP("IgniteDataset")
+ .Input("cache_name: string")
+ .Input("host: string")
+ .Input("port: int32")
+ .Input("local: bool")
+ .Input("part: int32")
+ .Input("page_size: int32")
+ .Input("schema: int32")
+ .Input("permutation: int32")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+IgniteDataset that allows to get data from Apache Ignite.
+
+Apache Ignite is a memory-centric distributed database, caching, and processing
+platform for transactional, analytical, and streaming workloads, delivering
+in-memory speeds at petabyte scale. This contrib package contains an
+integration between Apache Ignite and TensorFlow. The integration is based on
+tf.data from TensorFlow side and Binary Client Protocol from Apache Ignite side.
+It allows to use Apache Ignite as a datasource for neural network training,
+inference and all other computations supported by TensorFlow. Ignite Dataset
+is based on Apache Ignite Binary Client Protocol.
+
+cache_name: Ignite Cache Name.
+host: Ignite Thin Client Host.
+port: Ignite Thin Client Port.
+local: Local flag that defines that data should be fetched from local host only.
+part: Partition data should be fetched from.
+page_size: Page size for Ignite Thin Client.
+schema: Internal structure that defines schema of cache objects.
+permutation: Internal structure that defines permutation of cache objects.
+)doc");
+
+} // namespace tensorflow
diff --git a/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py
new file mode 100644
index 0000000000..cfe59b6b23
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/ops/ignite_dataset_ops.py
@@ -0,0 +1,772 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ignite Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import socket
+import ssl
+import struct
+
+from tensorflow.contrib.ignite.python.ops import gen_dataset_ops
+from tensorflow.contrib.ignite.python.ops import ignite_op_loader # pylint: disable=unused-import
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+
+
+class Readable(object):
+ """Readable abstract class that exposes methods to do reading-related
+
+ operations.
+ """
+
+ @abc.abstractmethod
+ def __init__(self):
+ pass
+
+ def read_byte(self):
+ """Reads and returnes byte."""
+ return self._read("b", 1)
+
+ def read_short(self):
+ """Reads and returns short (2 bytes, little-endian)."""
+ return self._read("h", 2)
+
+ def read_int(self):
+ """Reads and returns int (4 bytes, little-endian)."""
+ return self._read("i", 4)
+
+ def read_long(self):
+ """Reads and returns long (8 bytes, little-endian)."""
+ return self._read("q", 8)
+
+ def skip(self, length):
+ """Skips the specified number of bytes."""
+ self.read_data(length)
+
+ @abc.abstractmethod
+ def read_data(self, length):
+ """Reads the specified number of bytes and returns them as a buffer."""
+ return None
+
+ def _read(self, data_type, length):
+ """Reads, unpacks and returns specified type (little-endian)."""
+ data_buffer = self.read_data(length)
+ return struct.unpack("<" + data_type, data_buffer)[0]
+
+
+class DataBuffer(Readable):
+ """DataBuffer class that exposes methods to read data from a byte buffer."""
+
+ def __init__(self, data_buffer):
+ """Constructs a new instance based on the specified byte buffer.
+
+ Args:
+ data_buffer: Buffer to be read.
+ """
+ Readable.__init__(self)
+ self.buffer = data_buffer
+ self.ptr = 0
+
+ def read_data(self, length):
+ """Reads the specified number of bytes and returns them as a buffer."""
+ data_buffer = self.buffer[self.ptr:][:length]
+ self.ptr += length
+ return data_buffer
+
+
+class TcpClient(Readable):
+ """TcpClient class that exposes methods to read data from a socket."""
+
+ def __init__(self, host, port, certfile=None, keyfile=None, password=None):
+ """Constructs a new instance based on the specified host and port.
+
+ Args:
+ host: Host to be connected.
+ port: Port to be connected.
+ certfile: File in PEM format containing the certificate as well as any
+ number of CA certificates needed to establish the certificate's
+ authenticity.
+ keyfile: File containing the private key (otherwise the private key will
+ be taken from certfile as well).
+ password: Password to be used if the private key is encrypted and a
+ password is necessary.
+
+ Raises:
+ ValueError: If the wrong combination of arguments is provided.
+ """
+ Readable.__init__(self)
+ self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
+
+ if certfile is not None:
+ context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
+ context.load_cert_chain(certfile, keyfile, password)
+ self.sock = context.wrap_socket(self.sock)
+ else:
+ if keyfile is not None:
+ raise ValueError("SSL is disabled, keyfile must not be specified "
+ "(to enable SSL specify certfile)")
+ if password is not None:
+ raise ValueError("SSL is disabled, password must not be specified "
+ "(to enable SSL specify certfile)")
+
+ self.host = host
+ self.port = port
+
+ def __enter__(self):
+ """Connects to host and port specified in the constructor."""
+ self.sock.connect((self.host, self.port))
+ return self
+
+ def __exit__(self, t, v, traceback):
+ """Disconnects the socket."""
+ self.sock.close()
+
+ def write_byte(self, v):
+ """Writes the specified byte."""
+ self._write(v, "b")
+
+ def write_short(self, v):
+ """Writes the specified short (2 bytes, little-endian)."""
+ self._write(v, "h")
+
+ def write_int(self, v):
+ """Writes the specified short (4 bytes, little-endian)."""
+ self._write(v, "i")
+
+ def write_long(self, v):
+ """Writes the specified int (8 bytes, little-endian)."""
+ self._write(v, "q")
+
+ def write_string(self, v):
+ """Writes the specified string."""
+ self.sock.sendall(v.encode("UTF-8"))
+
+ def read_data(self, length):
+ """Reads the specified number of bytes and returns them as a buffer."""
+ data_buffer = None
+ rem = length
+ while rem > 0:
+ buf = self.sock.recv(rem)
+ rem = rem - len(buf)
+ if data_buffer is None:
+ data_buffer = buf
+ else:
+ data_buffer += buf
+ return data_buffer
+
+ def _write(self, value, data_type):
+ """Packs and writes data using the specified type (little-endian)."""
+ data_buffer = struct.pack("<" + data_type, value)
+ self.sock.sendall(data_buffer)
+
+
+class BinaryType(object):
+ """BinaryType class that encapsulated type id, type name and fields."""
+
+ def __init__(self, type_id, type_name, fields):
+ """Constructs a new instance of BinaryType."""
+ self.type_id = type_id
+ self.type_name = type_name
+ self.fields = fields
+
+
+class BinaryField(object):
+ """BinaryField class that encapsulated field name, type id and field id."""
+
+ def __init__(self, field_name, type_id, field_id):
+ """Constructs a new instance of BinaryField."""
+ self.field_name = field_name
+ self.type_id = type_id
+ self.field_id = field_id
+
+
+# Binary types defined in Apache Ignite Thin client and supported by
+# TensorFlow on Apache Ignite, see
+# https://apacheignite.readme.io/v2.6/docs/binary-client-protocol.
+# True means that type is a vector, False means type is scalar.
+types = {
+ 1: (dtypes.uint8, False),
+ 2: (dtypes.int16, False),
+ 3: (dtypes.int32, False),
+ 4: (dtypes.int64, False),
+ 5: (dtypes.float32, False),
+ 6: (dtypes.float64, False),
+ 7: (dtypes.uint16, False),
+ 8: (dtypes.bool, False),
+ 9: (dtypes.string, False),
+ 12: (dtypes.uint8, True),
+ 13: (dtypes.int16, True),
+ 14: (dtypes.int32, True),
+ 15: (dtypes.int64, True),
+ 16: (dtypes.float32, True),
+ 17: (dtypes.float64, True),
+ 18: (dtypes.uint16, True),
+ 19: (dtypes.bool, True),
+ 20: (dtypes.string, True)
+}
+
+
+class TypeTreeNode(object):
+ """TypeTreeNode class exposes methods to format object tree structure
+
+ data.
+ """
+
+ def __init__(self, name, type_id, fields=None, permutation=None):
+ """Constructs a new instance of TypeTreeNode.
+
+ Args:
+ name: Name of the object tree node.
+ type_id: Type id of the object tree node.
+ fields: List of fields (children of the object tree node).
+ permutation: Permutation that should be applied to order object children.
+ """
+ self.name = name
+ self.type_id = type_id
+ self.fields = fields
+ self.permutation = permutation
+
+ def to_output_classes(self):
+ """Formats the tree object as required by `Dataset.output_classes`."""
+ if self.fields is None:
+ return ops.Tensor
+ output_classes = {}
+ for field in self.fields:
+ output_classes[field.name] = field.to_output_classes()
+ return output_classes
+
+ def to_output_shapes(self):
+ """Formats the tree object as required by `Dataset.output_shapes`."""
+ if self.fields is None:
+ if self.type_id in types:
+ object_type = types[self.type_id]
+ is_array = object_type[1]
+ if is_array:
+ return tensor_shape.TensorShape([None])
+ return tensor_shape.TensorShape([])
+ raise ValueError("Unsupported type [type_id=%d]" % self.type_id)
+ output_shapes = {}
+ for field in self.fields:
+ output_shapes[field.name] = field.to_output_shapes()
+ return output_shapes
+
+ def to_output_types(self):
+ """Formats the tree object as required by `Dataset.output_types`."""
+ if self.fields is None:
+ if self.type_id in types:
+ object_type = types[self.type_id]
+ return object_type[0]
+ raise ValueError("Unsupported type [type_id=%d]" % self.type_id)
+ else:
+ output_types = {}
+ for field in self.fields:
+ output_types[field.name] = field.to_output_types()
+ return output_types
+
+ def to_flat(self):
+ """Returns a list of node types."""
+ return self.to_flat_rec([])
+
+ def to_permutation(self):
+ """Returns a permutation that should be applied to order object leaves."""
+ correct_order_dict = {}
+ self.traversal_rec(correct_order_dict, 0)
+ object_order = []
+ self.traversal_permutation_rec(object_order)
+ return [correct_order_dict[o] for o in object_order]
+
+ def to_flat_rec(self, flat):
+ """Formats a list of leaf node types in pre-order."""
+ if self.fields is None:
+ flat.append(self.type_id)
+ else:
+ for field in self.fields:
+ field.to_flat_rec(flat)
+ return flat
+
+ def traversal_permutation_rec(self, permutation):
+ """Collects nodes in accordance with permutation."""
+ if self.fields is None:
+ permutation.append(self)
+ else:
+ for idx in self.permutation:
+ field = self.fields[idx]
+ field.traversal_permutation_rec(permutation)
+
+ def traversal_rec(self, d, i):
+ """Collects nodes in pre-order traversal."""
+ if self.fields is None:
+ d[self] = i
+ i += 1
+ else:
+ for field in self.fields:
+ i = field.traversal_rec(d, i)
+ return i
+
+
+class IgniteClient(TcpClient):
+ """IgniteClient enables working with Apache Ignite using a thin client.
+
+ This client works with assumption that all object in the cache
+ have the same structure (homogeneous objects) and the cache contains at
+ least one object.
+ """
+
+ def __init__(self,
+ host,
+ port,
+ username=None,
+ password=None,
+ certfile=None,
+ keyfile=None,
+ cert_password=None):
+ """Constructs a new instance of IgniteClient.
+
+ Args:
+ host: Apache Ignite Thin client host to be connected.
+ port: Apache Ignite Thin client port to be connected.
+ username: Apache Ignite Thin Client authentication username.
+ password: Apache Ignite Thin Client authentication password.
+ certfile: File in PEM format containing the certificate as well as any
+ number of CA certificates needed to establish the certificate's
+ authenticity.
+ keyfile: File containing the private key (otherwise the private key will
+ be taken from certfile as well).
+ cert_password: Password to be used if the private key is encrypted and a
+ password is necessary.
+ """
+ TcpClient.__init__(self, host, port, certfile, keyfile, cert_password)
+ self.username = username
+ self.password = password
+
+ def handshake(self):
+ """Makes a handshake after connect and before any other calls."""
+ msg_len = 8
+
+ if self.username is None:
+ msg_len += 1
+ else:
+ msg_len += 5 + len(self.username)
+
+ if self.password is None:
+ msg_len += 1
+ else:
+ msg_len += 5 + len(self.password)
+
+ self.write_int(msg_len) # Message length
+ self.write_byte(1) # Handshake operation
+ self.write_short(1) # Version (1.1.0)
+ self.write_short(1)
+ self.write_short(0)
+ self.write_byte(2) # Thin client
+
+ if self.username is None: # Username
+ self.write_byte(101)
+ else:
+ self.write_byte(9)
+ self.write_int(len(self.username))
+ self.write_string(self.username)
+
+ if self.password is None: # Password
+ self.write_byte(101)
+ else:
+ self.write_byte(9)
+ self.write_int(len(self.password))
+ self.write_string(self.password)
+
+ self.read_int() # Result length
+ res = self.read_byte()
+
+ if res != 1:
+ serv_ver_major = self.read_short()
+ serv_ver_minor = self.read_short()
+ serv_ver_patch = self.read_short()
+ err_msg = self._parse_string()
+ if err_msg is None:
+ raise RuntimeError(
+ "Handshake Error [result=%d, version=%d.%d.%d]" %
+ (res, serv_ver_major, serv_ver_minor, serv_ver_patch))
+ else:
+ raise RuntimeError(
+ "Handshake Error [result=%d, version=%d.%d.%d, message='%s']" %
+ (res, serv_ver_major, serv_ver_minor, serv_ver_patch, err_msg))
+
+ def get_cache_type(self, cache_name):
+ """Collects type information about objects stored in the specified cache."""
+ cache_name_hash = self._java_hash_code(cache_name)
+ self.write_int(25) # Message length
+ self.write_short(2000) # Operation code
+ self.write_long(0) # Request ID
+ self.write_int(cache_name_hash) # Cache name
+ self.write_byte(0) # Flags
+ self.write_byte(101) # Filter (NULL)
+ self.write_int(1) # Cursor page size
+ self.write_int(-1) # Partition to query
+ self.write_byte(0) # Local flag
+
+ result_length = self.read_int()
+ self.read_long() # Request id
+ status = self.read_int()
+
+ if status != 0:
+ err_msg = self._parse_string()
+ if err_msg is None:
+ raise RuntimeError("Scan Query Error [status=%s]" % status)
+ else:
+ raise RuntimeError(
+ "Scan Query Error [status=%s, message='%s']" % (status, err_msg))
+
+ self.read_long() # Cursor id
+ row_count = self.read_int()
+
+ if row_count == 0:
+ raise RuntimeError("Scan Query returned empty result, so it's "
+ "impossible to derive the cache type")
+
+ payload = DataBuffer(self.read_data(result_length - 25))
+
+ self.read_byte() # Next page
+
+ res = TypeTreeNode("root", 0, [
+ self._collect_types("key", payload),
+ self._collect_types("val", payload)
+ ], [0, 1])
+
+ return res
+
+ def _java_hash_code(self, s):
+ """Computes hash code of the specified string using Java code."""
+ h = 0
+ for c in s:
+ h = (31 * h + ord(c)) & 0xFFFFFFFF
+ return ((h + 0x80000000) & 0xFFFFFFFF) - 0x80000000
+
+ def _collect_types(self, field_name, data):
+ """Extracts type information from the specified object."""
+ type_id = data.read_byte()
+
+ # Byte scalar.
+ if type_id == 1:
+ data.skip(1)
+ return TypeTreeNode(field_name, type_id)
+
+ # Short scalar.
+ if type_id == 2:
+ data.skip(2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Integer scalar.
+ if type_id == 3:
+ data.skip(4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Long scalar.
+ if type_id == 4:
+ data.skip(8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Float scalar.
+ if type_id == 5:
+ data.skip(4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Double scalar.
+ if type_id == 6:
+ data.skip(8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Char scalar.
+ if type_id == 7:
+ data.skip(2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Bool scalar.
+ if type_id == 8:
+ data.skip(1)
+ return TypeTreeNode(field_name, type_id)
+
+ # String scalar.
+ if type_id == 9:
+ length = data.read_int()
+ data.skip(length)
+ return TypeTreeNode(field_name, type_id)
+
+ # UUID scalar.
+ if type_id == 10:
+ data.skip(16)
+ return TypeTreeNode(field_name, type_id)
+
+ # Date scalar.
+ if type_id == 11:
+ data.skip(8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Byte array.
+ if type_id == 12:
+ length = data.read_int()
+ data.skip(length)
+ return TypeTreeNode(field_name, type_id)
+
+ # Short array.
+ if type_id == 13:
+ length = data.read_int()
+ data.skip(length * 2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Integer array.
+ if type_id == 14:
+ length = data.read_int()
+ data.skip(length * 4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Long array.
+ if type_id == 15:
+ length = data.read_int()
+ data.skip(length * 8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Float array.
+ if type_id == 16:
+ length = data.read_int()
+ data.skip(length * 4)
+ return TypeTreeNode(field_name, type_id)
+
+ # Double array.
+ if type_id == 17:
+ length = data.read_int()
+ data.skip(length * 8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Char array.
+ if type_id == 18:
+ length = data.read_int()
+ data.skip(length * 2)
+ return TypeTreeNode(field_name, type_id)
+
+ # Bool array.
+ if type_id == 19:
+ length = data.read_int()
+ data.skip(length)
+ return TypeTreeNode(field_name, type_id)
+
+ # String array.
+ if type_id == 20:
+ length = data.read_int()
+ for _ in range(length):
+ header = data.read_byte()
+ if header == 9:
+ str_length = data.read_int()
+ data.skip(str_length)
+ elif header == 101:
+ pass
+ else:
+ raise RuntimeError(
+ "Unknown binary type when expected string [type_id=%d]" % header)
+ return TypeTreeNode(field_name, type_id)
+
+ # UUID array.
+ if type_id == 21:
+ length = data.read_int()
+ data.skip(length * 16) # TODO(dmitrievanthony): support NULL values.
+ return TypeTreeNode(field_name, type_id)
+
+ # Date array.
+ if type_id == 22:
+ length = data.read_int()
+ data.skip(length * 8)
+ return TypeTreeNode(field_name, type_id)
+
+ # Wrapped Binary Object.
+ if type_id == 27:
+ length = data.read_int()
+ inner_data = data.read_data(length)
+ data.read_int() # Offset
+ return self._collect_types(field_name, DataBuffer(inner_data))
+
+ # Complex Object.
+ if type_id == 103:
+ data.read_byte() # Object version
+ data.read_short() # Object flags
+ obj_type_id = data.read_int()
+ data.read_int() # Object hash code
+ obj_length = data.read_int()
+ data.read_int() # Object schema id
+ obj_schema_offset = data.read_int()
+
+ obj_type = self._get_type(obj_type_id)
+ children = []
+
+ for obj_field in obj_type.fields:
+ child = self._collect_types(obj_field.field_name, data)
+ children.append(child)
+
+ children_sorted = sorted(children, key=lambda child: child.name)
+ permutation = [children_sorted.index(child) for child in children]
+ children = children_sorted
+
+ data.skip(obj_length - obj_schema_offset)
+
+ return TypeTreeNode(field_name, type_id, children, permutation)
+
+ raise RuntimeError("Unknown binary type [type_id=%d]" % type_id)
+
+ def _get_type(self, type_id):
+ """Queries Apache Ignite information about type by type id."""
+ self.write_int(14) # Message length
+ self.write_short(3002) # Operation code
+ self.write_long(0) # Request ID
+ self.write_int(type_id) # Type ID
+
+ self.read_int() # Result length
+ self.read_long() # Request id
+ status = self.read_int()
+
+ if status != 0:
+ err_msg = self._parse_string()
+ if err_msg is None:
+ raise RuntimeError("Get Binary Type Error [status=%d, message='%s']" %
+ (status, err_msg))
+ else:
+ raise RuntimeError("Get Binary Type Error [status=%d]" % status)
+
+ binary_type_exists = self.read_byte()
+
+ if binary_type_exists == 0:
+ raise RuntimeError("Binary type not found [type_id=%d] " % type_id)
+
+ binary_type_id = self.read_int()
+ binary_type_name = self._parse_string()
+ self._parse_string() # Affinity field name
+
+ fields = []
+ for _ in range(self.read_int()):
+ field_name = self._parse_string()
+ field_type_id = self.read_int()
+ field_id = self.read_int()
+
+ field = BinaryField(field_name, field_type_id, field_id)
+ fields.append(field)
+
+ is_enum = self.read_byte()
+ if is_enum == 1:
+ raise RuntimeError("Enum fields are not supported yet")
+
+ schema_cnt = self.read_int()
+ for _ in range(schema_cnt):
+ self.read_int() # Schema id
+ field_cnt = self.read_int()
+ self.skip(field_cnt * 4)
+
+ return BinaryType(binary_type_id, binary_type_name, fields)
+
+ def _parse_string(self):
+ """Parses string."""
+ header = self.read_byte()
+ if header == 9:
+ length = self.read_int()
+ return self.read_data(length).decode("utf-8")
+ if header == 101:
+ return None
+ raise RuntimeError(
+ "Unknown binary type when expected string [type_id=%d]" % header)
+
+
+class IgniteDataset(dataset_ops.Dataset):
+ """Apache Ignite is a memory-centric distributed database, caching, and
+
+ processing platform for transactional, analytical, and streaming workloads,
+ delivering in-memory speeds at petabyte scale. This contrib package
+ contains an integration between Apache Ignite and TensorFlow. The
+ integration is based on tf.data from TensorFlow side and Binary Client
+ Protocol from Apache Ignite side. It allows to use Apache Ignite as a
+ datasource for neural network training, inference and all other
+ computations supported by TensorFlow. Ignite Dataset is based on Apache
+ Ignite Binary Client Protocol.
+ """
+
+ def __init__(self,
+ cache_name,
+ host="localhost",
+ port=10800,
+ local=False,
+ part=-1,
+ page_size=100,
+ username=None,
+ password=None,
+ certfile=None,
+ keyfile=None,
+ cert_password=None):
+ """Create a IgniteDataset.
+
+ Args:
+ cache_name: Cache name to be used as datasource.
+ host: Apache Ignite Thin Client host to be connected.
+ port: Apache Ignite Thin Client port to be connected.
+ local: Local flag that defines to query only local data.
+ part: Number of partitions to be queried.
+ page_size: Apache Ignite Thin Client page size.
+ username: Apache Ignite Thin Client authentication username.
+ password: Apache Ignite Thin Client authentication password.
+ certfile: File in PEM format containing the certificate as well as any
+ number of CA certificates needed to establish the certificate's
+ authenticity.
+ keyfile: File containing the private key (otherwise the private key will
+ be taken from certfile as well).
+ cert_password: Password to be used if the private key is encrypted and a
+ password is necessary.
+ """
+ super(IgniteDataset, self).__init__()
+
+ with IgniteClient(host, port, username, password, certfile, keyfile,
+ cert_password) as client:
+ client.handshake()
+ self.cache_type = client.get_cache_type(cache_name)
+
+ self.cache_name = ops.convert_to_tensor(
+ cache_name, dtype=dtypes.string, name="cache_name")
+ self.host = ops.convert_to_tensor(host, dtype=dtypes.string, name="host")
+ self.port = ops.convert_to_tensor(port, dtype=dtypes.int32, name="port")
+ self.local = ops.convert_to_tensor(local, dtype=dtypes.bool, name="local")
+ self.part = ops.convert_to_tensor(part, dtype=dtypes.int32, name="part")
+ self.page_size = ops.convert_to_tensor(
+ page_size, dtype=dtypes.int32, name="page_size")
+ self.schema = ops.convert_to_tensor(
+ self.cache_type.to_flat(), dtype=dtypes.int32, name="schema")
+ self.permutation = ops.convert_to_tensor(
+ self.cache_type.to_permutation(),
+ dtype=dtypes.int32,
+ name="permutation")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.ignite_dataset(self.cache_name, self.host, self.port,
+ self.local, self.part, self.page_size,
+ self.schema, self.permutation)
+
+ @property
+ def output_classes(self):
+ return self.cache_type.to_output_classes()
+
+ @property
+ def output_shapes(self):
+ return self.cache_type.to_output_shapes()
+
+ @property
+ def output_types(self):
+ return self.cache_type.to_output_types()
diff --git a/tensorflow/contrib/data/python/ops/contrib_op_loader.py b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py
index 8f495a9dc9..c9af7386cf 100644
--- a/tensorflow/contrib/data/python/ops/contrib_op_loader.py
+++ b/tensorflow/contrib/ignite/python/ops/ignite_op_loader.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Python helper for loading contrib ops and kernels."""
+"""Python helper for loading Ignite ops and kernels."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
diff --git a/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh b/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh
new file mode 100755
index 0000000000..f4607ce8ad
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/bin/start-plain.sh
@@ -0,0 +1,24 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+nohup apache-ignite-fabric/bin/ignite.sh /data/config/ignite-config-plain.xml &
+sleep 5 # Wait Apache Ignite to be started
+
+./apache-ignite-fabric/bin/sqlline.sh \
+-u "jdbc:ignite:thin://127.0.0.1/" \
+--run=/data/sql/init.sql
+
+tail -f nohup.out
diff --git a/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml b/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml
new file mode 100644
index 0000000000..d900174a8a
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/config/ignite-config-plain.xml
@@ -0,0 +1,39 @@
+<?xml version="1.0" encoding="UTF-8"?>
+<!--
+ Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<beans xmlns="http://www.springframework.org/schema/beans"
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
+ xmlns:util="http://www.springframework.org/schema/util"
+ xsi:schemaLocation="http://www.springframework.org/schema/beans
+ http://www.springframework.org/schema/beans/spring-beans.xsd
+ http://www.springframework.org/schema/util
+ http://www.springframework.org/schema/util/spring-util.xsd">
+
+ <bean class="org.apache.ignite.configuration.IgniteConfiguration">
+ <property name="discoverySpi">
+ <bean class="org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi">
+ <property name="ipFinder">
+ <bean class="org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder">
+ <property name="addresses">
+ <list>
+ <value>127.0.0.1</value>
+ </list>
+ </property>
+ </bean>
+ </property>
+ </bean>
+ </property>
+ </bean>
+
+</beans>
diff --git a/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py
new file mode 100644
index 0000000000..1856a4fba8
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/ignite_dataset_test.py
@@ -0,0 +1,118 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+# ==============================================================================
+"""Tests for IgniteDataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.contrib.ignite import IgniteDataset
+from tensorflow.python.client import session
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.platform import test
+
+
+class IgniteDatasetTest(test.TestCase):
+ """The Apache Ignite servers have to setup before the test and tear down
+
+ after the test manually. The docker engine has to be installed.
+
+ To setup Apache Ignite servers:
+ $ bash start_ignite.sh
+
+ To tear down Apache Ignite servers:
+ $ bash stop_ignite.sh
+ """
+
+ def test_ignite_dataset_with_plain_client(self):
+ """Test Ignite Dataset with plain client.
+
+ """
+ self._clear_env()
+ ds = IgniteDataset(cache_name="SQL_PUBLIC_TEST_CACHE", port=42300)
+ self._check_dataset(ds)
+
+ def test_ignite_dataset_with_ssl_client(self):
+ """Test Ignite Dataset with ssl client.
+
+ """
+ self._clear_env()
+ os.environ["IGNITE_DATASET_CERTFILE"] = os.path.dirname(
+ os.path.realpath(__file__)) + "/keystore/client.pem"
+ os.environ["IGNITE_DATASET_CERT_PASSWORD"] = "123456"
+
+ ds = IgniteDataset(
+ cache_name="SQL_PUBLIC_TEST_CACHE",
+ port=42301,
+ certfile=os.environ["IGNITE_DATASET_CERTFILE"],
+ cert_password=os.environ["IGNITE_DATASET_CERT_PASSWORD"])
+ self._check_dataset(ds)
+
+ def test_ignite_dataset_with_ssl_client_and_auth(self):
+ """Test Ignite Dataset with ssl client and authentication.
+
+ """
+ self._clear_env()
+ os.environ["IGNITE_DATASET_USERNAME"] = "ignite"
+ os.environ["IGNITE_DATASET_PASSWORD"] = "ignite"
+ os.environ["IGNITE_DATASET_CERTFILE"] = os.path.dirname(
+ os.path.realpath(__file__)) + "/keystore/client.pem"
+ os.environ["IGNITE_DATASET_CERT_PASSWORD"] = "123456"
+
+ ds = IgniteDataset(
+ cache_name="SQL_PUBLIC_TEST_CACHE",
+ port=42302,
+ certfile=os.environ["IGNITE_DATASET_CERTFILE"],
+ cert_password=os.environ["IGNITE_DATASET_CERT_PASSWORD"],
+ username=os.environ["IGNITE_DATASET_USERNAME"],
+ password=os.environ["IGNITE_DATASET_PASSWORD"])
+ self._check_dataset(ds)
+
+ def _clear_env(self):
+ """Clears environment variables used by Ignite Dataset.
+
+ """
+ if "IGNITE_DATASET_USERNAME" in os.environ:
+ del os.environ["IGNITE_DATASET_USERNAME"]
+ if "IGNITE_DATASET_PASSWORD" in os.environ:
+ del os.environ["IGNITE_DATASET_PASSWORD"]
+ if "IGNITE_DATASET_CERTFILE" in os.environ:
+ del os.environ["IGNITE_DATASET_CERTFILE"]
+ if "IGNITE_DATASET_CERT_PASSWORD" in os.environ:
+ del os.environ["IGNITE_DATASET_CERT_PASSWORD"]
+
+ def _check_dataset(self, dataset):
+ """Checks that dataset provides correct data."""
+ self.assertEqual(dtypes.int64, dataset.output_types["key"])
+ self.assertEqual(dtypes.string, dataset.output_types["val"]["NAME"])
+ self.assertEqual(dtypes.int64, dataset.output_types["val"]["VAL"])
+
+ it = dataset.make_one_shot_iterator()
+ ne = it.get_next()
+
+ with session.Session() as sess:
+ rows = [sess.run(ne), sess.run(ne), sess.run(ne)]
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(ne)
+
+ self.assertEqual({"key": 1, "val": {"NAME": b"TEST1", "VAL": 42}}, rows[0])
+ self.assertEqual({"key": 2, "val": {"NAME": b"TEST2", "VAL": 43}}, rows[1])
+ self.assertEqual({"key": 3, "val": {"NAME": b"TEST3", "VAL": 44}}, rows[2])
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/ignite/python/tests/sql/init.sql b/tensorflow/contrib/ignite/python/tests/sql/init.sql
new file mode 100644
index 0000000000..5a192aef17
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/sql/init.sql
@@ -0,0 +1,20 @@
+-- Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+--
+-- Licensed under the Apache License, Version 2.0 (the "License");
+-- you may not use this file except in compliance with the License.
+-- You may obtain a copy of the License at
+--
+-- http://www.apache.org/licenses/LICENSE-2.0
+--
+-- Unless required by applicable law or agreed to in writing, software
+-- distributed under the License is distributed on an "AS IS" BASIS,
+-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+-- See the License for the specific language governing permissions and
+-- limitations under the License.
+-- ==============================================================================
+
+CREATE TABLE TEST_CACHE (ID LONG PRIMARY KEY, NAME VARCHAR, VAL LONG);
+
+INSERT INTO TEST_CACHE VALUES (1, 'TEST1', 42);
+INSERT INTO TEST_CACHE VALUES (2, 'TEST2', 43);
+INSERT INTO TEST_CACHE VALUES (3, 'TEST3', 44);
diff --git a/tensorflow/contrib/ignite/python/tests/start_ignite.sh b/tensorflow/contrib/ignite/python/tests/start_ignite.sh
new file mode 100755
index 0000000000..a67bd44f2f
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/start_ignite.sh
@@ -0,0 +1,22 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+IGNITE_VERSION=2.6.0
+SCRIPT_PATH="$( cd "$(dirname "$0")" ; pwd -P )"
+
+# Start Apache Ignite with plain client listener.
+docker run -itd --name ignite-plain -p 42300:10800 \
+-v ${SCRIPT_PATH}:/data apacheignite/ignite:${IGNITE_VERSION} /data/bin/start-plain.sh
diff --git a/tensorflow/contrib/ignite/python/tests/stop_ignite.sh b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh
new file mode 100755
index 0000000000..8f03dbd1ed
--- /dev/null
+++ b/tensorflow/contrib/ignite/python/tests/stop_ignite.sh
@@ -0,0 +1,19 @@
+#!/usr/bin/env bash
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
+docker rm -f ignite-plain
+docker rm -f ignite-ssl
+docker rm -f ignite-ssl-auth
diff --git a/tensorflow/contrib/image/kernels/image_ops.cc b/tensorflow/contrib/image/kernels/image_ops.cc
index 370a8caf6a..788bf04b28 100644
--- a/tensorflow/contrib/image/kernels/image_ops.cc
+++ b/tensorflow/contrib/image/kernels/image_ops.cc
@@ -156,6 +156,7 @@ namespace functor {
TF_CALL_uint8(DECLARE_FUNCTOR);
TF_CALL_int32(DECLARE_FUNCTOR);
TF_CALL_int64(DECLARE_FUNCTOR);
+TF_CALL_half(DECLARE_FUNCTOR);
TF_CALL_float(DECLARE_FUNCTOR);
TF_CALL_double(DECLARE_FUNCTOR);
@@ -175,6 +176,7 @@ TF_CALL_double(DECLARE_FUNCTOR);
TF_CALL_uint8(REGISTER);
TF_CALL_int32(REGISTER);
TF_CALL_int64(REGISTER);
+TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
diff --git a/tensorflow/contrib/image/kernels/image_ops.h b/tensorflow/contrib/image/kernels/image_ops.h
index 6b63eed130..7fac774d07 100644
--- a/tensorflow/contrib/image/kernels/image_ops.h
+++ b/tensorflow/contrib/image/kernels/image_ops.h
@@ -71,14 +71,7 @@ class ProjectiveGenerator {
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
projection;
- // TODO(ringwalt): Add a fill value input.
-#if (defined __CUDA_ARCH__) && (CUDART_VERSION < 8000)
- // On CUDA versions previous to 8.0, only __shared__ variables
- // could be declared as static in the device code.
const T fill_value = T(0);
-#else
- static const T fill_value = T(0);
-#endif
switch (interpolation_) {
case INTERPOLATION_NEAREST:
// Switch the order of x and y again for indexing into the image.
diff --git a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc
index 8743a5ff72..36b9a236a6 100644
--- a/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc
+++ b/tensorflow/contrib/image/kernels/image_ops_gpu.cu.cc
@@ -32,6 +32,7 @@ typedef Eigen::GpuDevice GPUDevice;
template class FillProjectiveTransform<GPUDevice, uint8>;
template class FillProjectiveTransform<GPUDevice, int32>;
template class FillProjectiveTransform<GPUDevice, int64>;
+template class FillProjectiveTransform<GPUDevice, Eigen::half>;
template class FillProjectiveTransform<GPUDevice, float>;
template class FillProjectiveTransform<GPUDevice, double>;
diff --git a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
index 376c0751ee..4997c31a7f 100644
--- a/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
+++ b/tensorflow/contrib/image/python/kernel_tests/image_ops_test.py
@@ -272,6 +272,15 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
with self.cached_session():
self.assertAllEqual([[[[1], [0]], [[0], [1]]]], result.eval())
+ def test_transform_data_types(self):
+ for dtype in _DTYPES:
+ image = constant_op.constant([[1, 2], [3, 4]], dtype=dtype)
+ value = image_ops.transform(image, [1] * 8)
+ with self.test_session(use_gpu=True):
+ self.assertAllEqual(
+ value.eval(),
+ np.array([[4, 4], [4, 4]]).astype(dtype.as_numpy_dtype()))
+
class BipartiteMatchTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/lite/BUILD b/tensorflow/contrib/lite/BUILD
index f320b53d94..f3ebe3b245 100644
--- a/tensorflow/contrib/lite/BUILD
+++ b/tensorflow/contrib/lite/BUILD
@@ -26,6 +26,14 @@ config_setting(
},
)
+# Enables inclusion of TensorFlow kernels via the TF Lite Flex delegate.
+# WARNING: This build flag is experimental and subject to change.
+config_setting(
+ name = "with_tflite_flex",
+ define_values = {"with_tflite_flex": "true"},
+ visibility = ["//visibility:public"],
+)
+
cc_library(
name = "schema_fbs_version",
hdrs = ["version.h"],
@@ -157,6 +165,10 @@ cc_library(
"stderr_reporter.h",
],
copts = tflite_copts(),
+ defines = select({
+ ":with_tflite_flex": ["TFLITE_FLEX"],
+ "//conditions:default": [],
+ }),
linkopts = [
] + select({
"//tensorflow:android": [
@@ -180,7 +192,12 @@ cc_library(
"//tensorflow/contrib/lite/nnapi:nnapi_lib",
"//tensorflow/contrib/lite/profiling:profiler",
"//tensorflow/contrib/lite/schema:schema_fbs",
- ],
+ ] + select({
+ ":with_tflite_flex": [
+ "//tensorflow/contrib/lite/delegates/flex:delegate",
+ ],
+ "//conditions:default": [],
+ }),
)
cc_library(
diff --git a/tensorflow/contrib/lite/delegates/flex/BUILD b/tensorflow/contrib/lite/delegates/flex/BUILD
index bf5d91899c..9dd38958e5 100644
--- a/tensorflow/contrib/lite/delegates/flex/BUILD
+++ b/tensorflow/contrib/lite/delegates/flex/BUILD
@@ -20,7 +20,7 @@ cc_library(
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib_lite_no_runtime",
+ "//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:framework",
@@ -60,7 +60,7 @@ cc_library(
"//tensorflow/contrib/lite:util",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib_lite_no_runtime",
+ "//tensorflow/core:android_tensorflow_lib_lite",
],
"//conditions:default": [
"//tensorflow/core:lib",
@@ -178,7 +178,7 @@ cc_library(
"//tensorflow/contrib/lite:kernel_api",
] + select({
"//tensorflow:android": [
- "//tensorflow/core:android_tensorflow_lib_lite_no_runtime",
+ "//tensorflow/core:android_tensorflow_lib",
],
"//conditions:default": [
"//tensorflow/core:lib",
diff --git a/tensorflow/contrib/lite/examples/android/BUILD b/tensorflow/contrib/lite/examples/android/BUILD
index 4d2437e7d3..d180cb4785 100644
--- a/tensorflow/contrib/lite/examples/android/BUILD
+++ b/tensorflow/contrib/lite/examples/android/BUILD
@@ -28,6 +28,7 @@ android_binary(
srcs = glob([
"app/src/main/java/**/*.java",
]),
+ aapt_version = "aapt",
# Package assets from assets dir as well as all model targets.
# Remove undesired models (and corresponding Activities in source)
# to reduce APK size.
diff --git a/tensorflow/contrib/lite/g3doc/performance.md b/tensorflow/contrib/lite/g3doc/performance.md
index 0ae9400068..6b7943caf8 100644
--- a/tensorflow/contrib/lite/g3doc/performance.md
+++ b/tensorflow/contrib/lite/g3doc/performance.md
@@ -7,12 +7,12 @@ Mobile and embedded devices have limited computational resources and it is impor
Some models may be too large to run on embedded devices. Instead of large models it is better to use a slightly less precise but smaller model for embedded devices. Smaller models not only use less disk space and memory but are generally faster and more energy efficient. One example of models optimized for mobile devices are [MobileNets](https://arxiv.org/abs/1704.04861), which are optimized for mobile vision applications. Tensorflow Lite [models page](models.md) lists several other models that have been optimized specifically for mobile and embedded devices.
You can retrain the listed models on your own dataset by using transfer learning. Check out our transfer learning tutorial for
-[image classification] (https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0) and
+[image classification](https://codelabs.developers.google.com/codelabs/tensorflow-for-poets/#0) and
[object detection](https://medium.com/tensorflow/training-and-serving-a-realtime-mobile-object-detector-in-30-minutes-with-cloud-tpus-b78971cf1193).
## Profile your model
-Before starting any optimization, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](../tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time.
+Before starting any optimization, it is a good practice to profile and benchmark your model. Tensorflow Lite [benchmarking tool](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tools/benchmark) has a built-in profiler that shows per operator profiling statistics. This can help in understanding performance bottlenecks and which operators dominate the computation time.
## Profile and optimize operators in the graph
If a particular operator appears frequently in the model and based on profiling you find the operator consuming the most amount of time, you can look into optimizing the operator.
@@ -22,7 +22,7 @@ If a particular operator appears frequently in the model and based on profiling
If your model uses floating point weights or activations then it may be possible to reduce the size of model up to ~4x by using quantization and other model optimizations. Check out our [model optimization toolkit](https://www.tensorflow.org/performance/model_optimization) for details about optimizing your model. Fully quantized models can be remarkably power efficient as well.
## Tweak the number of threads
-Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](../interpreter.h) threads.
+Tensorflow Lite supports multi-threaded kernels for many operators. You can increase the number of threads and speed up execution of operators. Increasing the number of threads will however make your model use more resources and power. For some applications latency may be more important than energy efficiency. You can increase the number of threads by setting the number of [interpreter](https://github.com/tensorflow/tensorflow/blob/1084594657a5d139102ac794f84d1427a710e39a/tensorflow/contrib/lite/interpreter.h#L337) threads.
## Eliminate redundant copies
Tensorflow Lite is optimized to reduce redundant copies. The APIs allow user to [mmap a model file](https://github.com/tensorflow/tensorflow/blob/9982fd6c8831cbd2f58954f79ea71f26660393bc/tensorflow/contrib/lite/model.h#L152) and avoid copies. If your application is not careful, there can be redundant copies when feeding the input to the model and reading output from the model. Make sure to eliminate redundant copies. If you are using higher level APIs like Java API, make sure to carefully check the documentation for performance caveats. For example, the Java API is a lot faster if ByteBuffers are used as [inputs](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/java/src/main/java/org/tensorflow/lite/Interpreter.java#L151).
@@ -31,8 +31,8 @@ Tensorflow Lite is optimized to reduce redundant copies. The APIs allow user to
Platform specific tools like [Android profiler](https://developer.android.com/studio/profile/android-profiler) and [Instruments](https://help.apple.com/instruments/mac/current/) provide a wealth of profiling information that can be used to debug your app. Sometimes the performance bug may be not in the model but in parts of application code that interact with the model. Make sure to familiarize yourself with platform specific profiling tools and best practices for your platform.
## Use hardware accelerators available on the device
-Tensorflow Lite is working on adding support for accelerators like GPU and provides acceleration through [NNAPI](https://developer.android.com/ndk/guides/neuralnetworks/) on Android.
-You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable NNAPI call [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/interpreter.h#L334) on the interpreter instance.
+Tensorflow Lite is working on adding support for accelerators like GPU and provides acceleration through [Neural Networks API](https://developer.android.com/ndk/guides/neuralnetworks/) on Android.
+You can utilize these hardware accelerator backends to improve the speed and efficiency of your model. To enable Neural Networks API call [UseNNAPI](https://github.com/tensorflow/tensorflow/blob/6305a6d83552ba6a472cd72398b60d9241467f1f/tensorflow/contrib/lite/interpreter.h#L334) on the interpreter instance.
## Need more help
The Tensorflow team is happy to help diagnose and address specific performance issues you may be facing. Please file a bug on [github](https://github.com/tensorflow/tensorflow/issues) with details of the issue.
diff --git a/tensorflow/contrib/lite/java/aar_with_jni.bzl b/tensorflow/contrib/lite/java/aar_with_jni.bzl
index db837cf29e..9d2aead266 100644
--- a/tensorflow/contrib/lite/java/aar_with_jni.bzl
+++ b/tensorflow/contrib/lite/java/aar_with_jni.bzl
@@ -3,12 +3,12 @@
load("@build_bazel_rules_android//android:rules.bzl", "android_binary")
def aar_with_jni(name, android_library):
- # Generate dummy AndroidManifest.xml for dummy apk usage
- # (dummy apk is generated by <name>_dummy_app_for_so target below)
- native.genrule(
- name = name + "_binary_manifest_generator",
- outs = [name + "_generated_AndroidManifest.xml"],
- cmd = """
+ # Generate dummy AndroidManifest.xml for dummy apk usage
+ # (dummy apk is generated by <name>_dummy_app_for_so target below)
+ native.genrule(
+ name = name + "_binary_manifest_generator",
+ outs = [name + "_generated_AndroidManifest.xml"],
+ cmd = """
cat > $(OUTS) <<EOF
<manifest
xmlns:android="http://schemas.android.com/apk/res/android"
@@ -17,27 +17,28 @@ cat > $(OUTS) <<EOF
</manifest>
EOF
""",
- )
+ )
- # Generate dummy apk including .so files and later we extract out
- # .so files and throw away the apk.
- android_binary(
- name = name + "_dummy_app_for_so",
- manifest = name + "_generated_AndroidManifest.xml",
- custom_package = "dummy.package.for.so",
- deps = [android_library],
- # In some platforms we don't have an Android SDK/NDK and this target
- # can't be built. We need to prevent the build system from trying to
- # use the target in that case.
- tags = ["manual"],
- )
+ # Generate dummy apk including .so files and later we extract out
+ # .so files and throw away the apk.
+ android_binary(
+ name = name + "_dummy_app_for_so",
+ aapt_version = "aapt",
+ manifest = name + "_generated_AndroidManifest.xml",
+ custom_package = "dummy.package.for.so",
+ deps = [android_library],
+ # In some platforms we don't have an Android SDK/NDK and this target
+ # can't be built. We need to prevent the build system from trying to
+ # use the target in that case.
+ tags = ["manual"],
+ )
- native.genrule(
- name = name,
- srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"],
- outs = [name + ".aar"],
- tags = ["manual"],
- cmd = """
+ native.genrule(
+ name = name,
+ srcs = [android_library + ".aar", name + "_dummy_app_for_so_unsigned.apk"],
+ outs = [name + ".aar"],
+ tags = ["manual"],
+ cmd = """
cp $(location {}.aar) $(location :{}.aar)
chmod +w $(location :{}.aar)
origdir=$$PWD
@@ -46,4 +47,4 @@ unzip $$origdir/$(location :{}_dummy_app_for_so_unsigned.apk) "lib/*"
cp -r lib jni
zip -r $$origdir/$(location :{}.aar) jni/*/*.so
""".format(android_library, name, name, name, name),
- )
+ )
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
index 220d6c2159..5ad738389e 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/BUILD
@@ -7,6 +7,7 @@ licenses(["notice"]) # Apache 2.0
android_binary(
name = "TfLiteCameraDemo",
srcs = glob(["java/**/*.java"]),
+ aapt_version = "aapt",
assets = [
"//tensorflow/contrib/lite/java/demo/app/src/main/assets:labels_mobilenet_quant_v1_224.txt",
"@tflite_mobilenet//:mobilenet_quant_v1_224.tflite",
diff --git a/tensorflow/contrib/lite/java/ovic/BUILD b/tensorflow/contrib/lite/java/ovic/BUILD
index bb0be04ca2..ea9b9ed4b6 100644
--- a/tensorflow/contrib/lite/java/ovic/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/BUILD
@@ -9,6 +9,7 @@ licenses(["notice"]) # Apache 2.0
load("//tensorflow/java:build_defs.bzl", "JAVACOPTS")
+# Build targets for OVIC classification.
java_test(
name = "OvicClassifierTest",
size = "medium",
@@ -45,8 +46,9 @@ android_library(
name = "ovicbenchmarkerlib",
srcs = [
"src/main/java/org/tensorflow/ovic/OvicBenchmarker.java",
+ "src/main/java/org/tensorflow/ovic/OvicClassificationResult.java",
"src/main/java/org/tensorflow/ovic/OvicClassifier.java",
- "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
+ "src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java",
],
manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml",
tags = ["no_oss"],
@@ -60,8 +62,8 @@ android_library(
java_library(
name = "ovicbenchmarkerlib_java",
srcs = [
+ "src/main/java/org/tensorflow/ovic/OvicClassificationResult.java",
"src/main/java/org/tensorflow/ovic/OvicClassifier.java",
- "src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java",
],
javacopts = JAVACOPTS,
tags = ["no_oss"],
@@ -73,3 +75,58 @@ java_library(
"@org_checkerframework_qual",
],
)
+
+# Build targets for OVIC detection.
+java_test(
+ name = "OvicDetectorTest",
+ size = "medium",
+ srcs = ["src/test/java/org/tensorflow/ovic/OvicDetectorTest.java"],
+ data = [
+ "//tensorflow/contrib/lite/java/ovic/src/testdata:coco_labels.txt",
+ "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata",
+ "@tflite_mobilenet_ssd_quant//:detect.tflite",
+ ],
+ javacopts = JAVACOPTS,
+ tags = ["no_oss"],
+ test_class = "org.tensorflow.ovic.OvicDetectorTest",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/contrib/lite/java/ovic:ovicdetectionbenchmarkerlib_java",
+ "@com_google_truth",
+ "@junit",
+ ],
+)
+
+android_library(
+ name = "ovicdetectionbenchmarkerlib",
+ srcs = [
+ "src/main/java/org/tensorflow/ovic/BoundingBox.java",
+ "src/main/java/org/tensorflow/ovic/OvicBenchmarker.java",
+ "src/main/java/org/tensorflow/ovic/OvicDetectionResult.java",
+ "src/main/java/org/tensorflow/ovic/OvicDetector.java",
+ "src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java",
+ ],
+ manifest = "//tensorflow/contrib/lite/java:AndroidManifest.xml",
+ deps = [
+ "//tensorflow/contrib/lite/java:tensorflowlite",
+ "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
+ "@org_checkerframework_qual",
+ ],
+)
+
+java_library(
+ name = "ovicdetectionbenchmarkerlib_java",
+ srcs = [
+ "src/main/java/org/tensorflow/ovic/BoundingBox.java",
+ "src/main/java/org/tensorflow/ovic/OvicDetectionResult.java",
+ "src/main/java/org/tensorflow/ovic/OvicDetector.java",
+ ],
+ javacopts = JAVACOPTS,
+ deps = [
+ "//tensorflow/contrib/lite/java:libtensorflowlite_jni.so",
+ "//tensorflow/contrib/lite/java:tensorflowlite_java",
+ "//tensorflow/contrib/lite/java/src/main/native",
+ "//tensorflow/contrib/lite/java/src/testhelper/java/org/tensorflow/lite:testhelper",
+ "@org_checkerframework_qual",
+ ],
+)
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
index b2e3a9bd7d..f567358ea3 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/BUILD
@@ -8,9 +8,12 @@ android_binary(
srcs = [
"OvicBenchmarkerActivity.java",
],
+ aapt_version = "aapt",
assets = [
- "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata",
+ "//tensorflow/contrib/lite/java/ovic/src/testdata:coco_labels.txt",
"//tensorflow/contrib/lite/java/ovic/src/testdata:labels.txt",
+ "//tensorflow/contrib/lite/java/ovic/src/testdata:ovic_testdata",
+ "@tflite_mobilenet_ssd_quant//:detect.tflite",
],
assets_dir = "",
custom_package = "ovic.demo.app",
@@ -24,6 +27,7 @@ android_binary(
deps = [
"//tensorflow/contrib/lite/java:tensorflowlite",
"//tensorflow/contrib/lite/java/ovic:ovicbenchmarkerlib",
+ "//tensorflow/contrib/lite/java/ovic:ovicdetectionbenchmarkerlib",
"@androidsdk//com.android.support:support-v13-25.2.0",
"@androidsdk//com.android.support:support-v4-25.2.0",
],
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
index 4adf94aeb6..48c29ecebe 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/OvicBenchmarkerActivity.java
@@ -35,19 +35,18 @@ import java.nio.MappedByteBuffer;
import java.nio.channels.FileChannel;
import java.text.DecimalFormat;
import org.tensorflow.ovic.OvicBenchmarker;
-import org.tensorflow.ovic.OvicSingleImageResult;
-
+import org.tensorflow.ovic.OvicClassifierBenchmarker;
+import org.tensorflow.ovic.OvicDetectorBenchmarker;
/** Class that benchmark image classifier models. */
public class OvicBenchmarkerActivity extends Activity {
/** Tag for the {@link Log}. */
private static final String TAG = "OvicBenchmarkerActivity";
- /** Name of the label file stored in Assets. */
- private static final String LABEL_PATH = "labels.txt";
-
- private static final String TEST_IMAGE_PATH = "test_image_224.jpg";
- private static final String MODEL_PATH = "float_model.lite";
+ /** Name of the task-dependent data files stored in Assets. */
+ private static String labelPath = null;
+ private static String testImagePath = null;
+ private static String modelPath = null;
/**
* Each bottom press will launch a benchmarking experiment. The experiment stops when either the
* total native latency reaches WALL_TIME or the number of iterations reaches MAX_ITERATIONS,
@@ -66,8 +65,6 @@ public class OvicBenchmarkerActivity extends Activity {
private MappedByteBuffer model = null;
private InputStream labelInputStream = null;
private OvicBenchmarker benchmarker;
- /** Inference result of each iteration. */
- OvicSingleImageResult iterResult = null;
private TextView textView = null;
// private Button startButton = null;
@@ -83,21 +80,31 @@ public class OvicBenchmarkerActivity extends Activity {
}
private Bitmap loadTestBitmap() throws IOException {
- InputStream imageStream = getAssets().open(TEST_IMAGE_PATH);
+ InputStream imageStream = getAssets().open(testImagePath);
return BitmapFactory.decodeStream(imageStream);
}
- public void initializeTest() throws IOException {
+ public void initializeTest(boolean benchmarkClassification) throws IOException {
Log.i(TAG, "Initializing benchmarker.");
- benchmarker = new OvicBenchmarker(WALL_TIME);
+ if (benchmarkClassification) {
+ benchmarker = new OvicClassifierBenchmarker(WALL_TIME);
+ labelPath = "labels.txt";
+ testImagePath = "test_image_224.jpg";
+ modelPath = "quantized_model.lite";
+ } else { // Benchmarking detection.
+ benchmarker = new OvicDetectorBenchmarker(WALL_TIME);
+ labelPath = "coco_labels.txt";
+ testImagePath = "test_image_224.jpg";
+ modelPath = "detect.tflite";
+ }
AssetManager am = getAssets();
- AssetFileDescriptor fileDescriptor = am.openFd(MODEL_PATH);
+ AssetFileDescriptor fileDescriptor = am.openFd(modelPath);
FileInputStream modelInputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = modelInputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
model = fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
- labelInputStream = am.open(LABEL_PATH);
+ labelInputStream = am.open(labelPath);
}
public Boolean doTestIteration() throws IOException, InterruptedException {
@@ -117,24 +124,44 @@ public class OvicBenchmarkerActivity extends Activity {
Log.i(TAG, "Going to do test iter.");
// Start testing.
Bitmap testImageBitmap = loadTestBitmap();
- iterResult = benchmarker.doTestIteration(testImageBitmap);
- testImageBitmap.recycle();
- if (iterResult == null) {
+ try {
+ if (!benchmarker.processBitmap(testImageBitmap)) {
+ throw new RuntimeException("Failed to run test.");
+ }
+ } catch (Exception e) {
+ e.printStackTrace();
+ throw e;
+ } finally {
+ testImageBitmap.recycle();
+ }
+ String iterResultString = benchmarker.getLastResultString();
+ if (iterResultString == null) {
throw new RuntimeException("Inference failed to produce a result.");
}
- Log.i(TAG, iterResult.toString());
+ Log.i(TAG, iterResultString);
return true;
}
- public void startPressed(View view) throws IOException {
- Log.i(TAG, "Start pressed");
+ public void detectPressed(View view) throws IOException {
+ benchmarkSession(false);
+ }
+ public void classifyPressed(View view) throws IOException {
+ benchmarkSession(true);
+ }
+
+ private void benchmarkSession(boolean benchmarkClassification) throws IOException {
try {
- initializeTest();
+ initializeTest(benchmarkClassification);
} catch (IOException e) {
Log.e(TAG, "Can't initialize benchmarker.", e);
throw e;
}
String displayText = "";
+ if (benchmarkClassification) {
+ displayText = "Classification benchmark: ";
+ } else {
+ displayText = "Detection benchmark: ";
+ }
try {
setProcessorAffinity(BIG_CORE_MASK);
} catch (IOException e) {
@@ -144,7 +171,6 @@ public class OvicBenchmarkerActivity extends Activity {
Log.i(TAG, "Successfully initialized benchmarker.");
int testIter = 0;
Boolean iterSuccess = false;
- double totalLatency = 0.0f;
while (testIter < MAX_ITERATIONS) {
try {
iterSuccess = doTestIteration();
@@ -153,23 +179,22 @@ public class OvicBenchmarkerActivity extends Activity {
throw e;
} catch (InterruptedException e) {
Log.e(TAG, "Interrupted at iteration " + testIter);
+ displayText += e.getMessage() + "\n";
}
if (!iterSuccess) {
break;
}
testIter++;
- totalLatency += (double) iterResult.latency;
}
- ;
Log.i(TAG, "Benchmarking finished");
if (textView != null) {
if (testIter > 0) {
textView.setText(
displayText
- + MODEL_PATH
+ + modelPath
+ ": Average latency="
- + df2.format(totalLatency / testIter)
+ + df2.format(benchmarker.getTotalRunTime() / testIter)
+ "ms after "
+ testIter
+ " runs.");
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml
index e9d83bae54..1bce60ff7d 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/layout/activity_main.xml
@@ -30,14 +30,14 @@
android:layout_height="wrap_content"
android:text="@string/initial_status_msg"
android:id="@+id/textView"
- android:layout_above="@+id/button_start"
+ android:layout_above="@+id/button_clf_start"
android:layout_alignParentTop="true"/>
<Button
android:layout_width="wrap_content"
android:layout_height="wrap_content"
- android:text="@string/start_label"
- android:id="@id/button_start"
+ android:text="@string/start_clf_label"
+ android:id="@id/button_clf_start"
android:layout_alignParentBottom="true"
android:layout_alignParentLeft="true"
android:background="@drawable/start_button_color"
@@ -49,6 +49,25 @@
android:textColor="#ffffff"
android:enabled="true"
style="?android:attr/buttonBarButtonStyle"
- android:onClick="startPressed"/>
+ android:onClick="classifyPressed"/>
+
+ <Button
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:text="@string/start_det_label"
+ android:id="@+id/button_det_start"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentRight="true"
+ android:layout_toRightOf="@id/button_clf_start"
+ android:background="@drawable/start_button_color"
+ android:padding="10dp"
+ android:layout_marginRight="100dp"
+ android:layout_marginLeft="30dp"
+ android:layout_marginTop="10dp"
+ android:foreground="#000000"
+ android:textColor="#ffffff"
+ android:enabled="true"
+ style="?android:attr/buttonBarButtonStyle"
+ android:onClick="detectPressed"/>
</RelativeLayout>
diff --git a/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml b/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml
index d26beb1d27..53525908d3 100644
--- a/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml
+++ b/tensorflow/contrib/lite/java/ovic/demo/app/res/values/strings.xml
@@ -17,6 +17,7 @@
<resources>
<string name="app_name" translatable="false">Benchmarker</string>
- <string name="start_label" translatable="false">Start</string>
+ <string name="start_clf_label" translatable="false">Clf</string>
+ <string name="start_det_label" translatable="false">Det</string>
<string name="initial_status_msg" translatable="false"> Press start to run the benchmarks.</string>
</resources>
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/BoundingBox.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/BoundingBox.java
new file mode 100644
index 0000000000..9bf7d005d2
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/BoundingBox.java
@@ -0,0 +1,68 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+/** Class for holding a detection bounding box with category and confidence. */
+public class BoundingBox {
+ // Upper left point.
+ public float x1;
+ public float y1;
+
+ // Lower right point.
+ public float x2;
+ public float y2;
+
+ // The area of the box
+ public float area;
+
+ // The object category
+ public int category;
+
+ // The confidence of the detection
+ public float score;
+
+ public BoundingBox(float x1, float y1, float x2, float y2, int category, float score) {
+ this.x1 = x1;
+ this.y1 = y1;
+ this.x2 = x2;
+ this.y2 = y2;
+ this.category = category;
+ this.score = score;
+ // -1 stands for area not initialized
+ this.area = -1;
+ }
+
+ // The intersection area of two bounding boxes
+ public float intersect(BoundingBox bbx) {
+ return Math.max(0, Math.min(x2, bbx.x2) - Math.max(x1, bbx.x1))
+ * Math.max(0, Math.min(y2, bbx.y2) - Math.max(y1, bbx.y1));
+ }
+
+ // The union area of two bounding boxes
+ public float union(BoundingBox bbx) {
+ return bbx.getArea() + this.getArea() - this.intersect(bbx);
+ }
+
+ public float getArea() {
+ if (area < 0) {
+ area = (x2 - x1) * (y2 - y1);
+ }
+ return area;
+ }
+
+ public float computeIoU(BoundingBox bbx) {
+ return (float) (this.intersect(bbx) * 1.0 / this.union(bbx));
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
index 4cda258bee..15d9511f50 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicBenchmarker.java
@@ -20,11 +20,10 @@ import android.util.Log;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
-import java.nio.ByteOrder;
import java.nio.MappedByteBuffer;
/**
- * Class that benchmarks image classifier models.
+ * Base class that benchmarks image models.
*
* <p>===================== General workflow =======================
*
@@ -33,37 +32,40 @@ import java.nio.MappedByteBuffer;
* benchmarker.getReadyToTest(labelInputStream, model);
* while (!benchmarker.shouldStop()) {
* Bitmap bitmap = ...
- * benchmarker.doTestIteration(bitmap);
+ * imgId = ...
+ * benchmarker.processBitmap(bitmap, imgId);
* }
* }</pre>
*/
-public class OvicBenchmarker {
+public abstract class OvicBenchmarker {
/** Tag for the {@link Log}. */
private static final String TAG = "OvicBenchmarker";
- /** Evaluation transformation parameters. */
- private static final float CENTRAL_FRACTION = 0.875f;
-
/** Dimensions of inputs. */
- private static final int DIM_BATCH_SIZE = 1;
- private static final int DIM_PIXEL_SIZE = 3;
- private int imgHeight = 224;
- private int imgWidth = 224;
+ protected static final int DIM_BATCH_SIZE = 1;
+ protected static final int DIM_PIXEL_SIZE = 3;
+ protected int imgHeight = 224;
+ protected int imgWidth = 224;
+
+ /** Preprocess parameters (only used when input is float). */
+ protected static final float IMAGE_MEAN = 127.5f;
+ protected static final float IMAGE_STD = 127.5f;
+
+ /** Whether input is float or quantized. */
+ protected Boolean quantizedInput = null;
/* Preallocated buffers for storing image data in. */
- private int[] intValues = null;
+ protected int[] intValues = null;
/** A ByteBuffer to hold image data, to be feed into classifier as inputs. */
- private ByteBuffer imgData = null;
-
- private OvicClassifier classifier;
+ protected ByteBuffer imgData = null;
/** Total runtime in ms. */
- private double totalRuntime = 0.0;
+ protected double totalRuntime = 0.0;
/** Total allowed runtime in ms. */
- private double wallTime = 20000 * 30.0;
-
- private Boolean benchmarkStarted = null;
+ protected double wallTime = 20000 * 30.0;
+ /** Record whether benchmark has started (used to skip the first image). */
+ protected boolean benchmarkStarted = false;
/**
* Initializes an {@link OvicBenchmarker}
@@ -76,6 +78,11 @@ public class OvicBenchmarker {
this.wallTime = wallTime;
}
+ /** Return the cumulative latency of all runs so far. */
+ public double getTotalRunTime() {
+ return totalRuntime;
+ }
+
/** Check whether the benchmarker should stop. */
public Boolean shouldStop() {
if (totalRuntime >= wallTime) {
@@ -90,105 +97,62 @@ public class OvicBenchmarker {
return false;
}
- /** Check whether the benchmarker is ready to start classifying images. */
- public Boolean readyToTest() {
- return (classifier != null);
- }
+ /** Abstract class for checking whether the benchmarker is ready to start processing images */
+ public abstract boolean readyToTest();
/**
- * Getting the benchmarker ready for classifying images.
+ * Abstract class for getting the benchmarker ready.
*
* @param labelInputStream: an {@link InputStream} specifying where the list of labels should be
* read from.
* @param model: a {@link MappedByteBuffer} model to benchmark.
*/
- public void getReadyToTest(InputStream labelInputStream, MappedByteBuffer model) {
- try {
- Log.i(TAG, "Creating classifier.");
- classifier = new OvicClassifier(labelInputStream, model);
- int [] inputDims = classifier.getInputDims();
- imgHeight = inputDims[1];
- imgWidth = inputDims[2];
- // Only accept QUANTIZED_UINT8 input.
- imgData = ByteBuffer.allocateDirect(DIM_BATCH_SIZE * imgHeight * imgWidth * DIM_PIXEL_SIZE);
- imgData.order(ByteOrder.nativeOrder());
- intValues = new int[imgHeight * imgWidth];
- } catch (Exception e) {
- Log.e(TAG, e.getMessage());
- Log.e(TAG, "Failed to initialize ImageNet classifier for the benchmarker.");
- }
- }
-
- /** Return how many classes are predicted per image. */
- public int getNumPredictions() {
- return classifier.getNumPredictions();
- }
+ public abstract void getReadyToTest(InputStream labelInputStream, MappedByteBuffer model);
/**
* Perform test on a single bitmap image.
*
- * @param bitmap: a {@link Bitmap} image to classify.
+ * @param bitmap: a {@link Bitmap} image to process.
+ * @param imageId: an ID uniquely representing the image.
*/
- public OvicSingleImageResult doTestIteration(Bitmap bitmap)
- throws IOException, InterruptedException {
- if (shouldStop() || !readyToTest()) {
- return null;
- }
- OvicSingleImageResult iterResult = null;
- try {
- Log.i(TAG, "Converting bitmap.");
- convertBitmapToInput(bitmap);
- Log.i(TAG, "Classifying image.");
- iterResult = classifier.classifyByteBuffer(imgData);
- } catch (RuntimeException e) {
- Log.e(TAG, e.getMessage());
- Log.e(TAG, "Failed to classify image.");
- }
- if (iterResult == null || iterResult.latency == null) {
- throw new RuntimeException("Classification result or timing is invalid.");
- }
- Log.d(TAG, "Native inference latency: " + iterResult.latency);
- Log.i(TAG, iterResult.toString());
+ public abstract boolean processBitmap(Bitmap bitmap, int imageId)
+ throws IOException, InterruptedException;
- if (!benchmarkStarted) { // Skip the first image to discount warming-up time.
- benchmarkStarted = true;
- } else {
- totalRuntime += (double) iterResult.latency;
- }
- return iterResult;
+ /** Perform test on a single bitmap image without an image ID. */
+ public boolean processBitmap(Bitmap bitmap) throws IOException, InterruptedException {
+ return processBitmap(bitmap, /* imageId = */ 0);
}
+ /** Returns the last inference results as string. */
+ public abstract String getLastResultString();
+
/**
- * Writes Image data into a {@link ByteBuffer}.
- *
- * @param bitmap: a {@link Bitmap} source image.
- */
- private void convertBitmapToInput(Bitmap bitmap) throws RuntimeException {
- if (imgData == null) {
+ * Loads input buffer from intValues into ByteBuffer for the interpreter.
+ * Input buffer must be loaded in intValues and output will be placed in imgData.
+ */
+ protected void loadsInputToByteBuffer() {
+ if (imgData == null || intValues == null || quantizedInput == null) {
throw new RuntimeException("Benchmarker is not yet ready to test.");
}
- imgData.rewind();
- // Perform transformations corresponding to evaluation mode.
- float width = (float) bitmap.getWidth();
- float height = (float) bitmap.getHeight();
- int stWidth = Math.round((width - width * CENTRAL_FRACTION) / 2);
- int stHeight = Math.round((height - height * CENTRAL_FRACTION) / 2);
- int newWidth = Math.round(width - stWidth * 2);
- int newHeight = Math.round(height - stHeight * 2);
- bitmap = Bitmap.createBitmap(bitmap, stWidth, stHeight, newWidth, newHeight);
- bitmap = Bitmap.createScaledBitmap(bitmap, imgWidth, imgHeight, true);
- bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
-
// Convert the image to ByteBuffer.
+ imgData.rewind();
int pixel = 0;
long startTime = SystemClock.uptimeMillis();
for (int i = 0; i < imgHeight; ++i) {
for (int j = 0; j < imgWidth; ++j) {
- final int val = intValues[pixel++];
- imgData.put((byte) ((val >> 16) & 0xFF));
- imgData.put((byte) ((val >> 8) & 0xFF));
- imgData.put((byte) (val & 0xFF));
+ final int pixelValue = intValues[pixel++];
+ if (quantizedInput) {
+ // Quantized model
+ imgData.put((byte) ((pixelValue >> 16) & 0xFF));
+ imgData.put((byte) ((pixelValue >> 8) & 0xFF));
+ imgData.put((byte) (pixelValue & 0xFF));
+ } else {
+ // Float model
+ imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ }
}
}
long endTime = SystemClock.uptimeMillis();
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassificationResult.java
index 4af9a65c2f..5ab804e6ee 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicSingleImageResult.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassificationResult.java
@@ -1,4 +1,4 @@
-/*Copyright 2018 Google LLC
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -17,17 +17,17 @@ package org.tensorflow.ovic;
import java.util.ArrayList;
/** Result class for inference run on a single image. */
-public class OvicSingleImageResult {
+public class OvicClassificationResult {
/** Top K classes and probabilities. */
- public ArrayList<String> topKClasses;
- public ArrayList<Float> topKProbs;
- public ArrayList<Integer> topKIndices;
+ public final ArrayList<String> topKClasses;
+ public final ArrayList<Float> topKProbs;
+ public final ArrayList<Integer> topKIndices;
/** Latency (ms). */
public Long latency;
- OvicSingleImageResult() {
+ OvicClassificationResult() {
topKClasses = new ArrayList<>();
topKProbs = new ArrayList<>();
topKIndices = new ArrayList<>();
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
index fd610b054f..d8a54c1f3b 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifier.java
@@ -31,7 +31,7 @@ import java.util.PriorityQueue;
import org.tensorflow.lite.Interpreter;
import org.tensorflow.lite.TestHelper;
-/** Benchmark ImageNet Classifier with Tensorflow Lite. */
+/** Class for running ImageNet classification with a TfLite model. */
public class OvicClassifier {
/** Tag for the {@link Log}. */
@@ -106,7 +106,7 @@ public class OvicClassifier {
/** Classifies a {@link ByteBuffer} image. */
// @throws RuntimeException if model is uninitialized.
- public OvicSingleImageResult classifyByteBuffer(ByteBuffer imgData) {
+ public OvicClassificationResult classifyByteBuffer(ByteBuffer imgData) {
if (tflite == null) {
throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed.");
}
@@ -122,7 +122,7 @@ public class OvicClassifier {
labelProbArray[0][i] = (inferenceOutputArray[0][i] & 0xff) / 255.0f;
}
}
- OvicSingleImageResult iterResult = computeTopKLabels();
+ OvicClassificationResult iterResult = computeTopKLabels();
iterResult.latency = getLastNativeInferenceLatencyMilliseconds();
return iterResult;
}
@@ -174,7 +174,7 @@ public class OvicClassifier {
}
/** Computes top-K labels. */
- private OvicSingleImageResult computeTopKLabels() {
+ private OvicClassificationResult computeTopKLabels() {
if (labelList == null) {
throw new RuntimeException("Label file has not been loaded.");
}
@@ -184,7 +184,7 @@ public class OvicClassifier {
sortedLabels.poll();
}
}
- OvicSingleImageResult singleImageResult = new OvicSingleImageResult();
+ OvicClassificationResult singleImageResult = new OvicClassificationResult();
if (sortedLabels.size() != RESULTS_TO_SHOW) {
throw new RuntimeException(
"Number of returned labels does not match requirement: "
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java
new file mode 100644
index 0000000000..0cdd0f7bec
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicClassifierBenchmarker.java
@@ -0,0 +1,142 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import android.graphics.Bitmap;
+import android.util.Log;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+
+/** Class that benchmarks image classifier models. */
+public final class OvicClassifierBenchmarker extends OvicBenchmarker {
+ /** Tag for the {@link Log}. */
+ private static final String TAG = "OvicClassifierBenchmarker";
+
+ /** ImageNet preprocessing parameters. */
+ private static final float CENTRAL_FRACTION = 0.875f;
+ private OvicClassifier classifier;
+ private OvicClassificationResult iterResult = null;
+
+ public OvicClassifierBenchmarker(double wallTime) {
+ super(wallTime);
+ }
+
+ /** Test if the classifier is ready for benchmarking. */
+ @Override
+ public boolean readyToTest() {
+ return (classifier != null);
+ }
+
+ /**
+ * Getting the benchmarker ready for classifying images.
+ *
+ * @param labelInputStream: an {@link InputStream} specifying where the list of labels should be
+ * read from.
+ * @param model: a {@link MappedByteBuffer} model to benchmark.
+ */
+ @Override
+ public void getReadyToTest(InputStream labelInputStream, MappedByteBuffer model) {
+ try {
+ Log.i(TAG, "Creating classifier.");
+ classifier = new OvicClassifier(labelInputStream, model);
+ int [] inputDims = classifier.getInputDims();
+ imgHeight = inputDims[1];
+ imgWidth = inputDims[2];
+ quantizedInput = true;
+ // Only accept QUANTIZED_UINT8 input.
+ imgData = ByteBuffer.allocateDirect(DIM_BATCH_SIZE * imgHeight * imgWidth * DIM_PIXEL_SIZE);
+ imgData.order(ByteOrder.nativeOrder());
+ intValues = new int[imgHeight * imgWidth];
+ } catch (Exception e) {
+ Log.e(TAG, e.getMessage());
+ Log.e(TAG, "Failed to initialize ImageNet classifier for the benchmarker.");
+ }
+ }
+
+ /**
+ * Perform classification on a single bitmap image.
+ *
+ * @param bitmap: a {@link Bitmap} image to process.
+ * @param imageId: an ID uniquely representing the image.
+ */
+ @Override
+ public boolean processBitmap(Bitmap bitmap, int imageId)
+ throws IOException, InterruptedException {
+ if (shouldStop() || !readyToTest()) {
+ return false;
+ }
+ try {
+ Log.i(TAG, "Converting bitmap.");
+ convertBitmapToInput(bitmap);
+ Log.i(TAG, "Classifying image: " + imageId);
+ iterResult = classifier.classifyByteBuffer(imgData);
+ } catch (RuntimeException e) {
+ Log.e(TAG, e.getMessage());
+ Log.e(TAG, "Failed to classify image.");
+ }
+ if (iterResult == null || iterResult.latency == null) {
+ throw new RuntimeException("Classification result or timing is invalid.");
+ }
+ Log.d(TAG, "Native inference latency: " + iterResult.latency);
+ Log.i(TAG, iterResult.toString());
+
+ if (!benchmarkStarted) { // Skip the first image to discount warming-up time.
+ benchmarkStarted = true;
+ } else {
+ totalRuntime += ((double) iterResult.latency);
+ }
+ return true;
+ }
+
+ /** Return how many classes are predicted per image. */
+ public int getNumPredictions() {
+ return classifier.getNumPredictions();
+ }
+
+ public OvicClassificationResult getLastClassificationResult() {
+ return iterResult;
+ }
+
+ @Override
+ public String getLastResultString() {
+ if (iterResult == null) {
+ return null;
+ } else {
+ return iterResult.toString();
+ }
+ }
+
+ /**
+ * Preprocess bitmap according to ImageNet protocol then writes result into a {@link ByteBuffer}.
+ *
+ * @param bitmap: a {@link Bitmap} source image.
+ */
+ private void convertBitmapToInput(Bitmap bitmap) {
+ // Perform transformations corresponding to evaluation mode.
+ float width = (float) bitmap.getWidth();
+ float height = (float) bitmap.getHeight();
+ int stWidth = Math.round((width - width * CENTRAL_FRACTION) / 2);
+ int stHeight = Math.round((height - height * CENTRAL_FRACTION) / 2);
+ int newWidth = Math.round(width - stWidth * 2);
+ int newHeight = Math.round(height - stHeight * 2);
+ bitmap = Bitmap.createBitmap(bitmap, stWidth, stHeight, newWidth, newHeight);
+ bitmap = Bitmap.createScaledBitmap(bitmap, imgWidth, imgHeight, true);
+ bitmap.getPixels(intValues, 0, bitmap.getWidth(), 0, 0, bitmap.getWidth(), bitmap.getHeight());
+ loadsInputToByteBuffer();
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectionResult.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectionResult.java
new file mode 100644
index 0000000000..cf2902a5cb
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectionResult.java
@@ -0,0 +1,91 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import java.util.ArrayList;
+
+/** Result class for inference run on a single image. */
+public class OvicDetectionResult {
+
+ // Top K classes and probabilities.
+ public final ArrayList<BoundingBox> detections;
+ // Latency (ms).
+ public Long latency = -1L;
+ // id of the image.
+ public int id = -1;
+ // Number of valid detections (separately maintained, maybe different from detections.size()).
+ public int count = 0;
+
+ // Create OvicDetectionResult object with pre-filled capacity. Note that detections.size() will
+ // be equal to capacity after this call.
+ OvicDetectionResult(int capacity) {
+ detections = new ArrayList<BoundingBox>(capacity);
+ for (int i = 0; i < capacity; i++) {
+ detections.add(new BoundingBox(-1.0f, -1.0f, -1.0f, -1.0f, -1, -1.0f));
+ }
+ }
+
+ public void resetTo(Long latency, int id) {
+ count = 0;
+ this.latency = latency;
+ this.id = id;
+ }
+
+ public void addBox(float x1, float y1, float x2, float y2, int category, float score) {
+ detections.get(count).x1 = x1;
+ detections.get(count).y1 = y1;
+ detections.get(count).x2 = x2;
+ detections.get(count).y2 = y2;
+ detections.get(count).category = category;
+ detections.get(count).score = score;
+ count += 1;
+ }
+
+ public void scaleUp(double scaleFactorWidth, double scaleFactorHeight) {
+ for (BoundingBox box : detections) {
+ box.x1 = (float) (box.x1 * scaleFactorWidth);
+ box.y1 = (float) (box.y1 * scaleFactorHeight);
+ box.x2 = (float) (box.x2 * scaleFactorWidth);
+ box.y2 = (float) (box.y2 * scaleFactorHeight);
+ }
+ }
+
+ @Override
+ public String toString() {
+ String textToShow = latency + "ms";
+ int k = 0;
+ for (BoundingBox box : detections) {
+ textToShow +=
+ "\nPrediction ["
+ + k
+ + "] = Class "
+ + box.category
+ + " ("
+ + box.x1
+ + ", "
+ + box.y1
+ + ", "
+ + box.x2
+ + ", "
+ + box.y2
+ + ") : "
+ + box.score;
+ k++;
+ }
+
+
+ return textToShow;
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetector.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetector.java
new file mode 100644
index 0000000000..56836a79e5
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetector.java
@@ -0,0 +1,184 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.nio.ByteBuffer;
+import java.nio.MappedByteBuffer;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.tensorflow.lite.Interpreter;
+import org.tensorflow.lite.TestHelper;
+
+/** Class for running COCO detection with a TfLite model. */
+public class OvicDetector implements AutoCloseable {
+
+ /** Tag for the {@link Log}. */
+ private static final String TAG = "OvicDetector";
+
+ /** An instance of the driver class to run model inference with Tensorflow Lite. */
+ private Interpreter tflite;
+
+ /** Labels corresponding to the output of the vision model. */
+ private final List<String> labelList;
+
+ /** Define the output format. */
+ private final Boolean inputIsFloat;
+
+ /** Number of detections per image. 10 for demo, 100 for the actual competition. */
+ private static final int NUM_RESULTS = 10;
+
+ /** The output arrays for the mobilenet SSD. */
+ private float[][][] outputLocations;
+ private float[][] outputClasses;
+ private float[][] outputScores;
+ private float[] numDetections;
+ private Map<Integer, Object> outputMap;
+
+ /** Input resolution. */
+ private final int[] inputDims;
+
+ /** Final result. */
+ public OvicDetectionResult result = null;
+
+ OvicDetector(InputStream labelInputStream, MappedByteBuffer model) throws IOException {
+ // Load the label list.
+ labelList = loadLabelList(labelInputStream);
+
+ // Create the TfLite interpreter.
+ tflite = new Interpreter(model, new Interpreter.Options().setNumThreads(1));
+ inputDims = TestHelper.getInputDims(tflite, 0);
+ inputIsFloat = TestHelper.getInputDataType(tflite, 0).equals("float");
+ if (inputDims.length != 4) {
+ throw new RuntimeException("The model's input dimensions must be 4 (BWHC).");
+ }
+ if (inputDims[0] != 1) {
+ throw new RuntimeException(
+ "The model must have a batch size of 1, got " + inputDims[0] + " instead.");
+ }
+ if (inputDims[3] != 3) {
+ throw new RuntimeException(
+ "The model must have three color channels, got " + inputDims[3] + " instead.");
+ }
+ // Check the resolution.
+ int minSide = Math.min(inputDims[1], inputDims[2]);
+ int maxSide = Math.max(inputDims[1], inputDims[2]);
+ if (minSide <= 0 || maxSide > 1000) {
+ throw new RuntimeException("The model's resolution must be between (0, 1000].");
+ }
+
+ // Initialize the input array and result arrays. The input images are stored in a list of
+ // Object. Since this function anaylzed one image per time, there is only 1 item.
+ // The output is fomulated as a map of int -> Object. The output arrays are added to the map.
+ outputLocations = new float[1][NUM_RESULTS][4];
+ outputClasses = new float[1][NUM_RESULTS];
+ outputScores = new float[1][NUM_RESULTS];
+ numDetections = new float[1];
+ outputMap = new HashMap<>();
+ outputMap.put(0, outputLocations);
+ outputMap.put(1, outputClasses);
+ outputMap.put(2, outputScores);
+ outputMap.put(3, numDetections);
+ // Preallocate the result. This will be where inference result is stored after each
+ // detectByteBuffer call.
+ result = new OvicDetectionResult(NUM_RESULTS);
+ }
+
+ public Boolean quantizedInput() {
+ return !inputIsFloat;
+ }
+
+ /** Reads label list from Assets. */
+ private static List<String> loadLabelList(InputStream labelInputStream) throws IOException {
+ List<String> labelList = new ArrayList<>();
+ try (BufferedReader reader =
+ new BufferedReader(new InputStreamReader(labelInputStream, StandardCharsets.UTF_8))) {
+ String line;
+ while ((line = reader.readLine()) != null) {
+ labelList.add(line);
+ }
+ }
+ return labelList;
+ }
+
+ /**
+ * The interface to run the detection. This method currently only support float mobilenet_ssd
+ * model. The quantized models will be added in the future.
+ *
+ * @param imgData The image buffer in ByteBuffer format.
+ * @return boolean indicator of whether detection was a success. If success, the detection results
+ * is available in the result member variable.
+ * See OvicDetectionResult.java for details.
+ */
+ boolean detectByteBuffer(ByteBuffer imgData, int imageId) {
+ if (tflite == null) {
+ throw new RuntimeException(TAG + ": Detector has not been initialized; Failed.");
+ }
+ if (inputIsFloat == null) {
+ throw new RuntimeException(TAG + ": Detector input type has not been resolved.");
+ }
+
+ Object[] inputArray = {imgData};
+ tflite.runForMultipleInputsOutputs(inputArray, outputMap);
+
+ Long latency = getLastNativeInferenceLatencyMilliseconds();
+
+ // Update the results.
+ result.resetTo(latency, imageId);
+ for (int i = 0; i < NUM_RESULTS; i++) {
+ result.addBox(outputLocations[0][i][1] * inputDims[1],
+ outputLocations[0][i][0] * inputDims[1],
+ outputLocations[0][i][3] * inputDims[2],
+ outputLocations[0][i][2] * inputDims[2],
+ Math.round(outputClasses[0][i] + 1 /* Label offset */),
+ outputScores[0][i]);
+ }
+ return true; // Marks that the result is available.
+ }
+
+ /*
+ * Get native inference latency of last image detection run.
+ * @throws RuntimeException if model is uninitialized.
+ * @return The inference latency in millisecond.
+ */
+ public Long getLastNativeInferenceLatencyMilliseconds() {
+ if (tflite == null) {
+ throw new RuntimeException(TAG + ": ImageNet classifier has not been initialized; Failed.");
+ }
+ Long latency = tflite.getLastNativeInferenceDurationNanoseconds();
+ return (latency == null) ? null : (Long) (latency / 1000000);
+ }
+
+ public int[] getInputDims() {
+ return inputDims;
+ }
+
+ public List<String> getLabels() {
+ return labelList;
+ }
+
+ /** Closes tflite to release resources. */
+ @Override
+ public void close() {
+ tflite.close();
+ tflite = null;
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java
new file mode 100644
index 0000000000..1a4e193ff2
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicDetectorBenchmarker.java
@@ -0,0 +1,160 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import android.graphics.Bitmap;
+import android.util.Log;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+
+/**
+ * Class that benchmarks object detection models.
+ */
+public final class OvicDetectorBenchmarker extends OvicBenchmarker {
+ /** Tag for the {@link Log}. */
+ private static final String TAG = "OvicDetectorBenchmarker";
+
+ public double scaleFactorWidth = 1.0f;
+ public double scaleFactorHeight = 1.0f;
+ private Bitmap scaledBitmap = null; // Preallocate bitmap for scaling.
+
+ private OvicDetector detector;
+
+ /**
+ * Initializes an {@link OvicDetectionBenchmarker}
+ *
+ * @param wallTime: a double number specifying the total amount of time to benchmark.
+ */
+ public OvicDetectorBenchmarker(double wallTime) {
+ super(wallTime);
+ }
+
+ /** Check to see if the detector is ready to test. */
+ @Override
+ public boolean readyToTest() {
+ return (detector != null);
+ }
+
+ /**
+ * Getting the benchmarker ready for detecting images.
+ *
+ * @param labelInputStream: an {@link InputStream} specifying where the list of labels should be
+ * read from.
+ * @param model: a {@link MappedByteBuffer} model to benchmark.
+ */
+ @Override
+ public void getReadyToTest(InputStream labelInputStream, MappedByteBuffer model) {
+ try {
+ Log.i(TAG, "Creating detector.");
+ detector = new OvicDetector(labelInputStream, model);
+ quantizedInput = detector.quantizedInput();
+ int[] inputDims = detector.getInputDims();
+ imgHeight = inputDims[1];
+ imgWidth = inputDims[2];
+ if (quantizedInput) {
+ imgData = ByteBuffer.allocateDirect(DIM_BATCH_SIZE * imgHeight * imgWidth * DIM_PIXEL_SIZE);
+ } else {
+ imgData =
+ ByteBuffer.allocateDirect(DIM_BATCH_SIZE * imgHeight * imgWidth * DIM_PIXEL_SIZE * 4);
+ }
+ imgData.order(ByteOrder.nativeOrder());
+ intValues = new int[imgHeight * imgWidth];
+ benchmarkStarted = false;
+ } catch (Exception e) {
+ Log.e(TAG, e.getMessage());
+ Log.e(TAG, "Failed to initialize COCO detector for the benchmarker.", e);
+ }
+ }
+
+ /**
+ * Perform detection on a single ByteBuffer {@link ByteBuffer} image. The image must have the
+ * same dimension that the model expects.
+ *
+ * @param image: a {@link ByteBuffer} image to process.
+ * @param imageId: an ID uniquely representing the image.
+ */
+ public boolean processBuffer(ByteBuffer image, int imageId) {
+ if (!readyToTest()) {
+ return false;
+ }
+ try {
+ if (!detector.detectByteBuffer(image, imageId)) {
+ return false;
+ }
+ } catch (RuntimeException e) {
+ Log.e(TAG, e.getMessage());
+ return false;
+ }
+
+ if (!benchmarkStarted) { // Skip the first image to discount warming-up time.
+ benchmarkStarted = true;
+ } else {
+ totalRuntime += ((double) detector.result.latency);
+ }
+ return true; // Indicating that result is ready.
+ }
+
+ /**
+ * Perform detection on a single bitmap image.
+ *
+ * @param bitmap: a {@link Bitmap} image to process.
+ * @param imageId: an ID uniquely representing the image.
+ */
+ @Override
+ public boolean processBitmap(Bitmap bitmap, int imageId)
+ throws IOException, InterruptedException {
+ if (shouldStop() || !readyToTest()) {
+ return false;
+ }
+ convertBitmapToInput(bitmap); // Scale bitmap if needed, store result in imgData.
+ if (!processBuffer(imgData, imageId)) {
+ return false;
+ }
+ // Scale results back to original image coordinates.
+ detector.result.scaleUp(scaleFactorWidth, scaleFactorHeight);
+ return true; // Indicating that result is ready.
+ }
+
+ public OvicDetectionResult getLastDetectionResult() {
+ return detector.result;
+ }
+
+ @Override
+ public String getLastResultString() {
+ if (detector.result == null) {
+ return null;
+ }
+ return detector.result.toString();
+ }
+
+ /**
+ * Preprocess bitmap image into {@link ByteBuffer} format for the detector.
+ *
+ * @param bitmap: a {@link Bitmap} source image.
+ */
+ private void convertBitmapToInput(Bitmap bitmap) {
+ int originalWidth = bitmap.getWidth();
+ int originalHeight = bitmap.getHeight();
+ scaledBitmap = Bitmap.createScaledBitmap(bitmap, imgWidth, imgHeight, true);
+ scaleFactorWidth = originalWidth * 1.0 / imgWidth;
+ scaleFactorHeight = originalHeight * 1.0 / imgHeight;
+ scaledBitmap.getPixels(intValues, 0, imgWidth, 0, 0, imgWidth, imgHeight);
+ scaledBitmap.recycle();
+ loadsInputToByteBuffer();
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java
index a504ec74a9..baa14baf92 100644
--- a/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java
+++ b/tensorflow/contrib/lite/java/ovic/src/main/java/org/tensorflow/ovic/OvicValidator.java
@@ -51,7 +51,7 @@ public class OvicValidator {
MappedByteBuffer model = loadModelFile(modelFile);
OvicClassifier classifier = new OvicClassifier(labelsInputStream, model);
ByteBuffer imgData = createByteBufferForClassifier(classifier);
- OvicSingleImageResult testResult = classifier.classifyByteBuffer(imgData);
+ OvicClassificationResult testResult = classifier.classifyByteBuffer(imgData);
if (testResult.topKClasses.isEmpty()) {
throw new RuntimeException("Failed to return top K predictions.");
}
diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
index 1587c3c56f..99e874ca78 100644
--- a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
+++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicClassifierTest.java
@@ -1,4 +1,4 @@
-/*Copyright 2018 Google LLC
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
@@ -43,7 +43,7 @@ public final class OvicClassifierTest {
private MappedByteBuffer lowResModel = null;
private ByteBuffer testImage = null;
private ByteBuffer lowResTestImage = null;
- private OvicSingleImageResult testResult = null;
+ private OvicClassificationResult testResult = null;
private static final String LABELS_PATH =
"tensorflow/contrib/lite/java/ovic/src/testdata/labels.txt";
private static final String QUANTIZED_MODEL_PATH =
@@ -147,7 +147,7 @@ public final class OvicClassifierTest {
return imgData;
}
- private static void assertCorrectTopK(OvicSingleImageResult testResult) {
+ private static void assertCorrectTopK(OvicClassificationResult testResult) {
assertThat(testResult.topKClasses.size() > 0).isTrue();
Boolean topKAccurate = false;
// Assert that the correct class is in the top K.
diff --git a/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java
new file mode 100644
index 0000000000..4681e26052
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/test/java/org/tensorflow/ovic/OvicDetectorTest.java
@@ -0,0 +1,149 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ https://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+package org.tensorflow.ovic;
+
+import static com.google.common.truth.Truth.assertThat;
+
+import java.awt.Graphics2D;
+import java.awt.image.BufferedImage;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.MappedByteBuffer;
+import java.nio.channels.FileChannel;
+import javax.imageio.ImageIO;
+import org.junit.Before;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.JUnit4;
+
+/** Unit test for {@link org.tensorflow.ovic.OvicDetector}. */
+@RunWith(JUnit4.class)
+public final class OvicDetectorTest {
+ private OvicDetector detector = null;
+ private InputStream labelsInputStream = null;
+ private MappedByteBuffer model = null;
+ private ByteBuffer testImage = null;
+
+ private static final float IMAGE_MEAN = 128f;
+ private static final float IMAGE_STD = 128f;
+
+ private Boolean quantizedInput = null;
+ private static final String LABELS_PATH =
+ "tensorflow/contrib/lite/java/ovic/src/testdata/coco_labels.txt";
+ private static final String MODEL_PATH =
+ "external/tflite_mobilenet_ssd_quant/detect.tflite";
+ private static final String TEST_IMAGE_PATH =
+ "external/tflite_ovic_testdata/test_image_224.jpg";
+ private static final int GROUNDTRUTH = 1 /* Person */;
+
+ @Before
+ public void setUp() {
+ try {
+ // load models.
+ model = loadModelFile(MODEL_PATH);
+
+ // Load label files;
+ File labelsfile = new File(LABELS_PATH);
+ labelsInputStream = new FileInputStream(labelsfile);
+
+ // Create detector.
+ detector = new OvicDetector(labelsInputStream, model);
+ quantizedInput = detector.quantizedInput();
+
+ // Load test image and convert into byte buffer.
+ File imageFile = new File(TEST_IMAGE_PATH);
+ BufferedImage rawimg = ImageIO.read(imageFile);
+ int[] inputDims = detector.getInputDims();
+ BufferedImage img = new BufferedImage(inputDims[1], inputDims[2], rawimg.getType());
+ Graphics2D g = img.createGraphics();
+ g.drawImage(rawimg, 0, 0, inputDims[1], inputDims[2], null);
+ g.dispose();
+ testImage = toByteBuffer(img);
+ } catch (IOException e) {
+ System.out.println(e.getMessage());
+ }
+
+ System.out.println("Successfully setup");
+ }
+
+ private static MappedByteBuffer loadModelFile(String modelFilePath) throws IOException {
+ File modelfile = new File(modelFilePath);
+ FileInputStream inputStream = new FileInputStream(modelfile);
+ FileChannel fileChannel = inputStream.getChannel();
+ long startOffset = 0L;
+ long declaredLength = fileChannel.size();
+ return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
+ }
+
+ private ByteBuffer toByteBuffer(BufferedImage image) {
+ ByteBuffer imgData;
+ if (quantizedInput) {
+ imgData = ByteBuffer.allocateDirect(image.getHeight() * image.getWidth() * 3);
+ } else {
+ imgData = ByteBuffer.allocateDirect(image.getHeight() * image.getWidth() * 12);
+ }
+ imgData.order(ByteOrder.nativeOrder());
+ for (int y = 0; y < image.getHeight(); y++) {
+ for (int x = 0; x < image.getWidth(); x++) {
+ int pixelValue = image.getRGB(x, y);
+ if (quantizedInput) {
+ // Quantized model
+ imgData.put((byte) ((pixelValue >> 16) & 0xFF));
+ imgData.put((byte) ((pixelValue >> 8) & 0xFF));
+ imgData.put((byte) (pixelValue & 0xFF));
+ } else {
+ // Float model
+ imgData.putFloat((((pixelValue >> 16) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat((((pixelValue >> 8) & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ imgData.putFloat(((pixelValue & 0xFF) - IMAGE_MEAN) / IMAGE_STD);
+ }
+ }
+ }
+ return imgData;
+ }
+
+ @Test
+ public void ovicDetector_detectSuccess() throws Exception {
+ assertThat(detector.detectByteBuffer(testImage, 1)).isTrue();
+ assertThat(detector.result != null).isTrue();
+ }
+
+ @Test
+ public void ovicDetector_simpleBatchTest() throws Exception {
+ final int numRepeats = 5;
+ for (int i = 0; i < numRepeats; i++) {
+ assertThat(detector.detectByteBuffer(testImage, 1)).isTrue();
+ OvicDetectionResult result = detector.result;
+ Boolean detectWithinTop5 = false;
+ for (int j = 0; j < Math.min(5, result.count); j++) {
+ if (result.detections.get(j).category == GROUNDTRUTH) {
+ detectWithinTop5 = true;
+ break;
+ }
+ }
+ if (!detectWithinTop5) {
+ System.out.println("---------------- Image " + i + " ---------------------");
+ System.out.println("Expect category " + GROUNDTRUTH);
+ System.out.println("Detection results: ");
+ System.out.println(result.toString());
+ }
+ assertThat(detectWithinTop5).isTrue();
+ }
+ }
+}
diff --git a/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD b/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD
index 1021ea30dd..051aa2204e 100644
--- a/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD
+++ b/tensorflow/contrib/lite/java/ovic/src/testdata/BUILD
@@ -14,6 +14,9 @@ filegroup(
)
exports_files(
- ["labels.txt"],
+ [
+ "labels.txt",
+ "coco_labels.txt",
+ ],
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/contrib/lite/java/ovic/src/testdata/coco_labels.txt b/tensorflow/contrib/lite/java/ovic/src/testdata/coco_labels.txt
new file mode 100644
index 0000000000..d91f535b1a
--- /dev/null
+++ b/tensorflow/contrib/lite/java/ovic/src/testdata/coco_labels.txt
@@ -0,0 +1,91 @@
+person
+bicycle
+car
+motorcycle
+airplane
+bus
+train
+truck
+boat
+traffic light
+fire hydrant
+empty
+stop sign
+parking meter
+bench
+bird
+cat
+dog
+horse
+sheep
+cow
+elephant
+bear
+zebra
+giraffe
+empty
+backpack
+umbrella
+empty
+empty
+handbag
+tie
+suitcase
+frisbee
+skis
+snowboard
+sports ball
+kite
+baseball bat
+baseball glove
+skateboard
+surfboard
+tennis racket
+bottle
+empty
+wine glasses
+cup
+fork
+knife
+spoon
+bowl
+banana
+apple
+sandwich
+orange
+broccoli
+carrot
+hot dog
+pizza
+donut
+cake
+chair
+couch
+potted plant
+bed
+empty
+dining table
+empty
+empty
+toilet
+empty
+tv
+laptop
+mouse
+remote
+keyboard
+cell phone
+microwave
+oven
+toaster
+sink
+refrigerator
+empty
+book
+clock
+vase
+scissors
+teddy bear
+hair drier
+toothbrush
+empty
diff --git a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
index f18a2ca07a..2e5033dab1 100644
--- a/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
+++ b/tensorflow/contrib/lite/models/smartreply/demo/app/src/main/BUILD
@@ -20,6 +20,7 @@ filegroup(
android_binary(
name = "SmartReplyDemo",
srcs = glob(["java/**/*.java"]),
+ aapt_version = "aapt",
assets = [":assets"],
assets_dir = "",
custom_package = "com.example.android.smartreply",
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
index 613a1530f7..1bf42d7551 100644
--- a/tensorflow/contrib/lite/python/convert.py
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -155,7 +155,8 @@ def build_toco_convert_protos(input_tensors,
post_training_quantize=False,
dump_graphviz_dir=None,
dump_graphviz_video=False,
- converter_mode=ConverterMode.DEFAULT):
+ converter_mode=ConverterMode.DEFAULT,
+ allow_nonexistent_arrays=False):
"""Builds protocol buffers describing a conversion of a model using TOCO.
Typically this is to convert from TensorFlow GraphDef to TFLite, in which
@@ -212,6 +213,8 @@ def build_toco_convert_protos(input_tensors,
every graph transformation. (default False)
converter_mode: Experimental flag, subject to change. ConverterMode
indicating which converter to use. (default ConverterMode.DEFAULT)
+ allow_nonexistent_arrays: Allow specifying array names that don't exist
+ or are unused in the final graph. (default False)
Returns:
model_flags, toco_flags: two protocol buffers describing the conversion
@@ -261,6 +264,9 @@ def build_toco_convert_protos(input_tensors,
for output_tensor in output_tensors:
model.output_arrays.append(tensor_name(output_tensor))
+
+ model.allow_nonexistent_arrays = allow_nonexistent_arrays
+
return model, toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
index 1bc366f555..fb299c31b7 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/quantize.cc
@@ -97,15 +97,6 @@ const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
// to allow easily trying out quantization even if the graph
// lacks some minmax information.
if (array.buffer != nullptr) {
- LOG(WARNING)
- << "Constant array " << array_name
- << " lacks MinMax information. To make up for that, we will now compute"
- << " the MinMax from actual array elements. That will result in"
- << " quantization parameters that probably do not match whichever "
- "arithmetic"
- << " was used during training, and thus will probably be a cause of "
- "poor"
- << " inference accuracy.";
CHECK(array.buffer->type == ArrayDataType::kFloat);
const auto& data = array.GetBuffer<ArrayDataType::kFloat>().data;
// We always want [min, max] to contain 0.
@@ -120,6 +111,27 @@ const MinMax& GetOrComputeMinMax(Model* model, const string& array_name) {
// to not be equal.
max = 1.f;
}
+ // No need to warn about accuracy if all array values are equal to either
+ // min or max:
+ // in that case, quantization is exact, and such arrays are not learned
+ // weights arrays for which fake-quantization would make sense, rather
+ // they tend to be hardcoded arrays of zeros or ones used in some graphs.
+ bool is_quantization_trivially_exact = true;
+ for (auto val : data) {
+ is_quantization_trivially_exact &= (val == min || val == max);
+ }
+ if (!is_quantization_trivially_exact) {
+ LOG(WARNING)
+ << "Constant array " << array_name
+ << " lacks MinMax information. To make up for that, we will now "
+ "compute"
+ << " the MinMax from actual array elements. That will result in"
+ << " quantization parameters that probably do not match whichever "
+ "arithmetic"
+ << " was used during training, and thus will probably be a cause of "
+ "poor"
+ << " inference accuracy.";
+ }
auto& minmax = array.GetOrCreateMinMax();
minmax.min = min;
minmax.max = max;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
index 5b41c49bfa..eaa9d3bcda 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/read_array_minmax_and_narrow_range_from_fake_quant.cc
@@ -71,8 +71,10 @@ bool ReadArrayMinmaxAndNarrowRangeFromFakeQuant::Run(Model* model,
CHECK(fq_op->minmax);
CHECK_EQ(1, fq_op->inputs.size());
- return ApplyAttrsToArray(this, model, *fq_op, fq_op->inputs[0]) ||
- ApplyAttrsToArray(this, model, *fq_op, fq_op->outputs[0]);
+ bool changed = false;
+ changed |= ApplyAttrsToArray(this, model, *fq_op, fq_op->inputs[0]);
+ changed |= ApplyAttrsToArray(this, model, *fq_op, fq_op->outputs[0]);
+ return changed;
}
} // namespace toco
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
index 4bb1217828..b2b2ea151b 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
@@ -60,6 +60,10 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
const auto& output_array_name = mul_op->outputs[0];
auto& output_array = model->GetArray(output_array_name);
+ if (!IsDiscardableArray(*model, output_array_name)) {
+ return false;
+ }
+
if (output_array.data_type == ArrayDataType::kNone) {
// Yield until the output type has been set by PropagateArrayDataTypes
return false;
@@ -139,14 +143,8 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
}
// Erase input arrays to the multiply if no longer used
- if (IsDiscardableArray(*model, mul_op->inputs[0]) &&
- CountOpsWithInput(*model, mul_op->inputs[0]) == 1) {
- model->EraseArray(mul_op->inputs[0]);
- }
- if (IsDiscardableArray(*model, mul_op->inputs[1]) &&
- CountOpsWithInput(*model, mul_op->inputs[1]) == 1) {
- model->EraseArray(mul_op->inputs[1]);
- }
+ DeleteArrayIfUsedOnce(mul_op->inputs[0], model);
+ DeleteArrayIfUsedOnce(mul_op->inputs[1], model);
// Erase the multiply operator.
model->operators.erase(mul_it);
diff --git a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
index d34da63e43..b6a401aaf2 100644
--- a/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
+++ b/tensorflow/contrib/lite/toco/model_cmdline_flags.cc
@@ -394,12 +394,18 @@ void ReadModelFlagsFromCommandLineFlags(
}
}
- model_flags->set_allow_nonascii_arrays(
- parsed_model_flags.allow_nonascii_arrays.value());
- model_flags->set_allow_nonexistent_arrays(
- parsed_model_flags.allow_nonexistent_arrays.value());
- model_flags->set_change_concat_input_ranges(
- parsed_model_flags.change_concat_input_ranges.value());
+ if (!model_flags->has_allow_nonascii_arrays()) {
+ model_flags->set_allow_nonascii_arrays(
+ parsed_model_flags.allow_nonascii_arrays.value());
+ }
+ if (!model_flags->has_allow_nonexistent_arrays()) {
+ model_flags->set_allow_nonexistent_arrays(
+ parsed_model_flags.allow_nonexistent_arrays.value());
+ }
+ if (!model_flags->has_change_concat_input_ranges()) {
+ model_flags->set_change_concat_input_ranges(
+ parsed_model_flags.change_concat_input_ranges.value());
+ }
if (parsed_model_flags.arrays_extra_info_file.specified()) {
string arrays_extra_info_file_contents;
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index 4a1ae35cb5..e3f27e9e2a 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -843,24 +843,43 @@ void CheckNonAsciiIOArrays(const ModelFlags& model_flags) {
}
void CheckNonExistentIOArrays(const Model& model) {
+ // "non-existent" is interpreted in the stronger sense of
+ // "not actually produced/consumed by an op".
+ // Rationale: we have to artificially fix up TensorFlow graphs by creating
+ // any array that it refers to, so just checking that arrays exist isn't
+ // sufficient. The real invariant here is whether arrays are produced/consumed
+ // by something.
if (model.flags.allow_nonexistent_arrays()) {
return;
}
+ static constexpr char general_comment[] =
+ "Is it a typo? To silence this message, pass this flag: "
+ "allow_nonexistent_arrays";
for (const auto& input_array : model.flags.input_arrays()) {
- CHECK(model.HasArray(input_array.name()))
- << "Input array not found: " << input_array.name();
+ QCHECK(GetOpWithInput(model, input_array.name()))
+ << "Specified input array \"" << input_array.name()
+ << "\" is not consumed by any op in this graph. " << general_comment;
}
for (const string& output_array : model.flags.output_arrays()) {
- CHECK(model.HasArray(output_array))
- << "Output array not found: " << output_array;
+ QCHECK(GetOpWithOutput(model, output_array))
+ << "Specified output array \"" << output_array
+ << "\" is not produced by any op in this graph. " << general_comment;
}
for (const auto& rnn_state : model.flags.rnn_states()) {
if (!rnn_state.discardable()) {
- CHECK(model.HasArray(rnn_state.state_array()));
- CHECK(model.HasArray(rnn_state.back_edge_source_array()));
+ // Check that all RNN states are consumed
+ QCHECK(GetOpWithInput(model, rnn_state.state_array()))
+ << "Specified RNN state \"" << rnn_state.state_array()
+ << "\" is not consumed by any op in this graph. " << general_comment;
+ // Check that all RNN back-edge source arrays are produced
+ QCHECK(GetOpWithOutput(model, rnn_state.back_edge_source_array()))
+ << "Specified RNN back-edge source array \""
+ << rnn_state.back_edge_source_array()
+ << "\" is not produced by any op in this graph. " << general_comment;
}
}
}
+
} // namespace
void CheckNoMissingArray(const Model& model) {
@@ -1597,6 +1616,7 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
input_array.GetOrCreateMinMax() = input_minmax;
}
}
+
// Creation of the RNN state arrays
for (const auto& rnn_state : model->flags.rnn_states()) {
CreateOrCheckRnnStateArray(rnn_state.state_array(), rnn_state.size(),
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index 89b538d1ba..9e9345e875 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -23,8 +23,8 @@ import numpy as np
import six
from tensorflow.contrib import lookup
-from tensorflow.contrib.data.python.ops import counter
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import counter
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/makefile/Makefile b/tensorflow/contrib/makefile/Makefile
index d962a5e12d..36125c198e 100644
--- a/tensorflow/contrib/makefile/Makefile
+++ b/tensorflow/contrib/makefile/Makefile
@@ -133,7 +133,8 @@ $(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*benchmark*.cc) \
$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*benchmark*.cc) \
$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*benchmark*.cc) \
$(wildcard tensorflow/contrib/makefile/downloads/absl/absl/*/*/*/*/*benchmark*.cc) \
-tensorflow/contrib/makefile/downloads/absl/absl/synchronization/internal/mutex_nonprod.cc
+tensorflow/contrib/makefile/downloads/absl/absl/synchronization/internal/mutex_nonprod.cc \
+tensorflow/contrib/makefile/downloads/absl/absl/hash/internal/print_hash_of.cc
ABSL_CC_SRCS := $(filter-out $(ABSL_CC_EXCLUDE_SRCS), $(ABSL_CC_ALL_SRCS))
diff --git a/tensorflow/contrib/model_pruning/README.md b/tensorflow/contrib/model_pruning/README.md
index 15d95896d9..b313024e28 100644
--- a/tensorflow/contrib/model_pruning/README.md
+++ b/tensorflow/contrib/model_pruning/README.md
@@ -62,6 +62,7 @@ The pruning library allows for specification of the following hyper parameters:
| sparsity_function_begin_step | integer | 0 | The global step at this which the gradual sparsity function begins to take effect |
| sparsity_function_end_step | integer | 100 | The global step used as the end point for the gradual sparsity function |
| sparsity_function_exponent | float | 3.0 | exponent = 1 is linearly varying sparsity between initial and final. exponent > 1 varies more slowly towards the end than the beginning |
+| use_tpu | bool | False | Training using TPUs? |
The sparsity $$s_t$$ at global step $$t$$ is given by:
diff --git a/tensorflow/contrib/opt/python/training/shampoo_test.py b/tensorflow/contrib/opt/python/training/shampoo_test.py
index 05bcf2cfa3..a2fd8fbd87 100644
--- a/tensorflow/contrib/opt/python/training/shampoo_test.py
+++ b/tensorflow/contrib/opt/python/training/shampoo_test.py
@@ -54,9 +54,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -105,9 +105,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size[0], size[1])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -164,9 +164,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size[0], size[1], size[2])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -254,9 +254,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -310,9 +310,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(size[0], size[1])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -383,9 +383,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np_2 = np.random.rand(sample_size_2, size[1])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = ops.IndexedSlices(
constant_op.constant(grad_np, dtype=dtypes.float32),
@@ -463,9 +463,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
grad_np = np.random.rand(sample_size, size[1], size[2])
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = ops.IndexedSlices(
constant_op.constant(grad_np, dtype=dtypes.float32),
@@ -533,9 +533,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
gbar_weight = 0.1
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = constant_op.constant(grad_np, dtype=dtypes.float32)
grad_2 = constant_op.constant(grad_np_2, dtype=dtypes.float32)
@@ -628,9 +628,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g3 = np.zeros_like(mat_g3_a)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = array_ops.placeholder(dtypes.float32, shape=size)
@@ -705,9 +705,9 @@ class ShampooTest(test.TestCase, parameterized.TestCase):
mat_g3 = np.zeros_like(mat_g3_a)
with self.cached_session() as sess:
- global_step = variables.Variable(
+ global_step = variables.VariableV1(
0, dtype=dtypes.int64, use_resource=use_resource_var)
- var = variables.Variable(
+ var = variables.VariableV1(
init_var_np, dtype=dtypes.float32, use_resource=use_resource_var)
grad = array_ops.placeholder(dtypes.float32, shape=size)
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index 6af59dcfbf..53e27c08c4 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -30,7 +30,6 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gradients
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.training import distribute as distribute_lib
@@ -965,8 +964,7 @@ class OptimizerV2(optimizer_v1.Optimizer):
# Use the processors to update the variables.
update_ops = []
for grad, var in grads_and_vars:
- update_ops.extend(distribution.unwrap(distribution.update(
- var, update, grad)))
+ update_ops.extend(distribution.update(var, update, grad, grouped=False))
# Give the child class a chance to do something after applying
# gradients
@@ -978,26 +976,24 @@ class OptimizerV2(optimizer_v1.Optimizer):
update_ops = control_flow_ops.group(update_ops)
with ops.control_dependencies([update_ops]):
- finish_updates = distribution.update_non_slot(non_slot_devices, finish)
- if finish_updates is None:
- finish_updates = update_ops
+ finish_updates = distribution.update_non_slot(
+ non_slot_devices, finish, grouped=False)
+ # We said grouped=False, which means finish_updates is always a list.
+ # It will be [None] when finish() returns None.
+ if finish_updates == [None]:
+ finish_updates = [update_ops]
# Update `global_step` (if any).
if global_step is None:
apply_updates = distribution.group(finish_updates, name=name)
else:
- with ops.control_dependencies(distribution.unwrap(finish_updates)):
-
- def update_global_step(global_step):
- if isinstance(global_step, resource_variable_ops.ResourceVariable):
- return global_step.assign_add(
- ops.convert_to_tensor(1, dtype=global_step.dtype),
- read_value=False)
- else:
- return state_ops.assign_add(global_step, 1)
-
- apply_updates = distribution.group(
- distribution.update(global_step, update_global_step), name=name)
+ with ops.control_dependencies(finish_updates):
+
+ def update_global_step(global_step, name):
+ return global_step.assign_add(1, read_value=False, name=name)
+
+ apply_updates = distribution.update(
+ global_step, update_global_step, name)
# Add the training op to the TRAIN_OP graph collection in graph mode.
if not eager_execution:
diff --git a/tensorflow/contrib/quantize/BUILD b/tensorflow/contrib/quantize/BUILD
index 23e3a25d71..94a2d9672d 100644
--- a/tensorflow/contrib/quantize/BUILD
+++ b/tensorflow/contrib/quantize/BUILD
@@ -138,7 +138,6 @@ py_library(
srcs = ["python/quant_ops.py"],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
diff --git a/tensorflow/contrib/quantize/python/quant_ops.py b/tensorflow/contrib/quantize/python/quant_ops.py
index 27069444a4..d9dc7fa62e 100644
--- a/tensorflow/contrib/quantize/python/quant_ops.py
+++ b/tensorflow/contrib/quantize/python/quant_ops.py
@@ -18,8 +18,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.framework.python.ops import add_arg_scope
-from tensorflow.contrib.framework.python.ops import model_variable
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
@@ -29,7 +27,6 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.training import moving_averages
-@add_arg_scope
def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None):
"""Adds a fake quantize layer with fixed quantization interval.
@@ -46,7 +43,21 @@ def FixedQuantize(inputs, init_min=-6.0, init_max=6.0, scope=None):
inputs, min=init_min, max=init_max)
-@add_arg_scope
+def _ModelVariable(name,
+ shape=None,
+ initializer=None,
+ collections=None,
+ trainable=None):
+ collections = list(collections or [])
+ collections += [ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.MODEL_VARIABLES]
+ return variable_scope.get_variable(
+ name,
+ shape=shape,
+ initializer=initializer,
+ collections=collections,
+ trainable=trainable)
+
+
def LastValueQuantize(inputs,
per_channel=False,
init_min=-6.0,
@@ -93,13 +104,13 @@ def LastValueQuantize(inputs,
else:
min_max_shape = []
- min_var = model_variable(
+ min_var = _ModelVariable(
'min',
shape=min_max_shape,
initializer=init_ops.constant_initializer(init_min),
collections=[vars_collection],
trainable=False)
- max_var = model_variable(
+ max_var = _ModelVariable(
'max',
shape=min_max_shape,
initializer=init_ops.constant_initializer(init_max),
@@ -153,7 +164,6 @@ def LastValueQuantize(inputs,
narrow_range=narrow_range)
-@add_arg_scope
def MovingAvgQuantize(inputs,
per_channel=False,
init_min=-6.0,
@@ -202,13 +212,13 @@ def MovingAvgQuantize(inputs,
else:
min_max_shape = []
- min_var = model_variable(
+ min_var = _ModelVariable(
'min',
shape=min_max_shape,
initializer=init_ops.constant_initializer(init_min),
collections=[vars_collection],
trainable=False)
- max_var = model_variable(
+ max_var = _ModelVariable(
'max',
shape=min_max_shape,
initializer=init_ops.constant_initializer(init_max),
diff --git a/tensorflow/contrib/stateless/BUILD b/tensorflow/contrib/stateless/BUILD
index dcbef2881d..a217397c1a 100644
--- a/tensorflow/contrib/stateless/BUILD
+++ b/tensorflow/contrib/stateless/BUILD
@@ -9,19 +9,13 @@ exports_files(["LICENSE"])
load("//tensorflow:tensorflow.bzl", "cuda_py_test")
load("//tensorflow:tensorflow.bzl", "tf_gen_op_wrapper_py")
-tf_gen_op_wrapper_py(
- name = "stateless_random_ops",
- out = "gen_stateless_random_ops.py", # cmake chokes without this
- deps = ["//tensorflow/core:stateless_random_ops_op_lib"],
-)
-
py_library(
name = "stateless",
srcs = ["__init__.py"],
srcs_version = "PY2AND3",
deps = [
- ":stateless_random_ops",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python:stateless_random_ops_gen",
"//tensorflow/python:util",
],
)
diff --git a/tensorflow/contrib/stateless/__init__.py b/tensorflow/contrib/stateless/__init__.py
index 0cca40f071..fe23fe0dd8 100644
--- a/tensorflow/contrib/stateless/__init__.py
+++ b/tensorflow/contrib/stateless/__init__.py
@@ -32,10 +32,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import ops
+
# pylint: disable=wildcard-import
-from tensorflow.contrib.stateless.gen_stateless_random_ops import *
+from tensorflow.python.ops.gen_stateless_random_ops import *
-from tensorflow.python.framework import ops
from tensorflow.python.util.all_util import remove_undocumented
ops.NotDifferentiable("StatelessMultinomial")
diff --git a/tensorflow/contrib/timeseries/python/timeseries/head_test.py b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
index 647455ae42..04d17bc123 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/head_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/head_test.py
@@ -104,7 +104,7 @@ class EvaluationMetricsTests(test.TestCase):
"ticker":
array_ops.reshape(
math_ops.cast(
- variables.Variable(
+ variables.VariableV1(
name="ticker",
initial_value=0,
dtype=dtypes.int64,
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index e9aa037634..10ed1c2891 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -133,6 +133,12 @@ tf_custom_op_library(
tf_gen_op_wrapper_py(
name = "tpu_ops",
+ hidden = [
+ "SendTPUEmbeddingGradients",
+ "EnqueueTPUEmbeddingIntegerBatch",
+ "EnqueueTPUEmbeddingSparseBatch",
+ "EnqueueTPUEmbeddingSparseTensorBatch",
+ ],
deps = [
":cross_replica_ops_op_lib",
":heartbeat_ops_op_lib",
diff --git a/tensorflow/contrib/tpu/__init__.py b/tensorflow/contrib/tpu/__init__.py
index 766466968a..6ce6b779a2 100644
--- a/tensorflow/contrib/tpu/__init__.py
+++ b/tensorflow/contrib/tpu/__init__.py
@@ -55,7 +55,9 @@
@@TPUDistributionStrategy
@@keras_to_tpu_model
+
@@AsyncCheckpointSaverHook
+@@TPUInMemoryEvalHook
"""
from __future__ import absolute_import
@@ -65,6 +67,7 @@ from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.contrib.tpu.python import profiler
from tensorflow.contrib.tpu.python.ops.tpu_ops import *
+from tensorflow.contrib.tpu.python.tpu.async_checkpoint import *
from tensorflow.contrib.tpu.python.tpu.bfloat16 import *
from tensorflow.contrib.tpu.python.tpu.device_assignment import *
from tensorflow.contrib.tpu.python.tpu.keras_support import tpu_model as keras_to_tpu_model
diff --git a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
index 1bd1a31e11..0ef29bdf73 100644
--- a/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
+++ b/tensorflow/contrib/tpu/ops/tpu_embedding_ops.cc
@@ -46,7 +46,7 @@ namespace tensorflow {
// 5. TPUEmbeddingActivations, when used with appropriate Python libraries,
// enables the automatic differentiation of models that use embeddings.
// 6. TPUEmbeddingSendGradients takes a list of Tensors (of the same shapes
-// as those returned by TPUEmbeddingReceivActivations) containing gradients
+// as those returned by TPUEmbeddingReceiveActivations) containing gradients
// to use in updating the embedding tables.
// 7. Before saving a checkpoint, use the TPUEmbeddingRetrieve Op to update
// the Graph's embedding table Variables from the updated tables in the
@@ -147,7 +147,7 @@ parameters that are loaded from a checkpoint before a training loop is
executed.
%s
table_name: Name of this table; must match a name in the
- EmbeddingLayerConfiguration proto (overrides table_id).
+ TPUEmbeddingConfiguration proto (overrides table_id).
num_shards: Number of shards into which the embedding tables are divided.
shard_id: Identifier of shard for this operation.
table_id: Index of this table in the EmbeddingLayerConfiguration proto
@@ -283,7 +283,7 @@ the correct embedding table configuration. For example, this op is
used to retrieve updated parameters before saving a checkpoint.
%s
table_name: Name of this table; must match a name in the
- EmbeddingLayerConfiguration proto (overrides table_id).
+ TPUEmbeddingConfiguration proto (overrides table_id).
num_shards: Number of shards into which the embedding tables are divided.
shard_id: Identifier of shard for this operation.
table_id: Index of this table in the EmbeddingLayerConfiguration proto
@@ -335,7 +335,6 @@ void RegisterPerTableLoadAndRetrieveOps() {
tpu::GradientAccumulationSupport grad_accum_support;
TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) {
- // TODO(gkurian): Condition this on being used internally within Google.
OpRegistry::Global()->Register(
[alg](OpRegistrationData* op_reg_data) -> Status {
return RegisterPerTableLoadOpsForAlgorithmBody(alg, true,
@@ -353,7 +352,6 @@ void RegisterPerTableLoadAndRetrieveOps() {
tpu::GradientAccumulationSupport grad_accum_support;
TF_CHECK_OK(GetGradientAccumulationSupport(alg, &grad_accum_support));
if (grad_accum_support == tpu::GradientAccumulationSupport::kSupported) {
- // TODO(gkurian): Condition this on being used internally within Google.
OpRegistry::Global()->Register(
[alg](OpRegistrationData* op_reg_data) -> Status {
return RegisterPerTableRetrieveOpsForAlgorithmBody(alg, true,
@@ -366,7 +364,7 @@ void RegisterPerTableLoadAndRetrieveOps() {
} // namespace
REGISTER_OP("RecvTPUEmbeddingActivations")
- .Output("outputs: num_outputs * float")
+ .Output("outputs: num_outputs * float32")
.Attr("num_outputs: int >= 1")
.Attr("config: string")
.SetIsStateful()
@@ -395,11 +393,11 @@ REGISTER_OP("RecvTPUEmbeddingActivations")
An op that receives embedding activations on the TPU.
The TPU system performs the embedding lookups and aggregations specified by
-the arguments to TPUEmbeddingEnqueueSparseBatch. The results of these
-aggregations are visible to the Tensorflow Graph as the outputs of a
-TPUEmbeddingDequeueActivations Op. This op returns a list containing one
-Tensor of activations per table specified in the model. There can be at most
-one ReceieveActivations op in the TPU graph.
+the arguments to TPUEmbeddingEnqueue(Integer/Sparse/SparseTensor)Batch. The
+results of these aggregations are visible to the Tensorflow Graph as the
+outputs of a RecvTPUEmbeddingActivations op. This op returns a list containing
+one Tensor of activations per table specified in the model. There can be at
+most one RecvTPUEmbeddingActivations op in the TPU graph.
outputs: A TensorList of embedding activations containing one Tensor per
embedding table in the model.
@@ -437,10 +435,25 @@ lookup_id: Identifier of the set of embedding indices which produced these
REGISTER_OP("SendTPUEmbeddingGradients")
.Input("inputs: N * float32")
+ .Input("learning_rates: NN * float32")
.Attr("N: int >= 1")
+ .Attr("NN: int >= 0 = 0")
.Attr("config: string")
.SetIsStateful()
- .SetShapeFn(shape_inference::UnknownShape)
+ .SetShapeFn([](shape_inference::InferenceContext* c) -> Status {
+ int nn;
+ TF_RETURN_IF_ERROR(c->GetAttr("NN", &nn));
+ std::vector<shape_inference::ShapeHandle> learning_rates;
+ TF_RETURN_IF_ERROR(c->input("learning_rates", &learning_rates));
+ for (int i = 0; i < nn; ++i) {
+ // Verify that each learning_rates element is scalar
+ shape_inference::ShapeHandle learning_rates_shape;
+ TF_RETURN_IF_ERROR(
+ c->WithRank(learning_rates[i], 0, &learning_rates_shape));
+ }
+
+ return Status::OK();
+ })
.Doc(R"doc(
An op that performs gradient updates of embedding tables.
@@ -451,12 +464,18 @@ from these gradients via the optimizer specified in the configuration given
to tpu.initialize_system.
inputs: A TensorList of gradients with which to update embedding tables.
+ It contains one tensor per embedding table in the model.
+learning_rates: A list of float32 scalars, one for each embedding table,
+ containing the learning rates for each table when dynamic learning rate is
+ enabled through the OptimizationParameters in TPUEmbeddingConfiguration.
+ When the learning rate is constant, the list should be empty.
config: Serialized TPUEmbeddingConfiguration proto.
)doc");
REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch")
.Input("batch: N * int32")
- .Attr("N: int")
+ .Input("mode_override: string")
+ .Attr("N: int >= 1")
.Attr("device_ordinal: int = -1")
.SetIsStateful()
.SetShapeFn(shape_inference::UnknownShape)
@@ -464,17 +483,21 @@ REGISTER_OP("EnqueueTPUEmbeddingIntegerBatch")
An op that enqueues a list of input batch tensors to TPUEmbedding.
batch: A list of 1D tensors, one for each embedding table, containing the
-batch inputs represented as integers.
-device_ordinal: The TPU device to use. This should be -1 when the Op
-is running on a TPU device, and >= 0 when the Op is running on the CPU
-device.
+ indices into the tables.
+mode_override: A string input that overrides the mode specified in the
+ TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference',
+ 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set
+ in TPUEmbeddingConfiguration is used, otherwise mode_override is used.
+device_ordinal: The TPU device to use. Should be >= 0 and less than the number
+ of TPU cores in the task on which the node is placed.
)doc");
REGISTER_OP("EnqueueTPUEmbeddingSparseBatch")
.Input("sample_indices: N * int32")
.Input("embedding_indices: N * int32")
.Input("aggregation_weights: N * float32")
- .Attr("N: int")
+ .Input("mode_override: string")
+ .Attr("N: int >= 1")
.Attr("device_ordinal: int = -1")
.Attr("combiners: list(string) = []")
.SetIsStateful()
@@ -497,32 +520,41 @@ An op that enqueues TPUEmbedding input indices from a SparseTensor.
This Op eases the porting of code that uses embedding_lookup_sparse(),
although some Python preprocessing of the SparseTensor arguments to
embedding_lookup_sparse() is required to produce the arguments to this Op,
-since only a single EnqueueTPUEmbedding Op is allowed per training step.
+since only a single EnqueueTPUEmbeddingSparseBatch Op is allowed per training
+step.
The tensors at corresponding positions in the three input lists
must have the same shape, i.e. rank 1 with dim_size() equal to the total
number of lookups into the table described by the corresponding table_id.
-sample_indices: A list of Rank 1 Tensors specifying the training example and
+sample_indices: A list of rank 1 Tensors specifying the training example and
feature to which the corresponding embedding_indices and aggregation_weights
values belong. sample_indices[i] must equal b * nf + f, where nf is the
number of features from the corresponding table, f is in [0, nf), and
- b is in [0, training batch size).
-embedding_indices: A list of Rank 1 Tensors, indices into the embedding tables.
-aggregation_weights: A list of Rank 1 Tensors containing per sample -- i.e. per
+ b is in [0, batch size).
+embedding_indices: A list of rank 1 Tensors, indices into the embedding tables.
+aggregation_weights: A list of rank 1 Tensors containing per sample -- i.e. per
(training example, feature) -- aggregation weights.
-device_ordinal: The TPU device to use. This should be -1 when the Op
-is running on a TPU device, and >= 0 when the Op is running on the CPU
-device.
-combiners: A list of string scalars whose values are 'mean', 'sum', or 'sqrtn'
-to specify how to normalize the embedding activations after weighted summation.
+mode_override: A string input that overrides the mode specified in the
+ TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference',
+ 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set
+ in TPUEmbeddingConfiguration is used, otherwise mode_override is used.
+device_ordinal: The TPU device to use. Should be >= 0 and less than the number
+ of TPU cores in the task on which the node is placed.
+combiners: A list of string scalars, one for each embedding table that specify
+ how to normalize the embedding activations after weighted summation.
+ Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have
+ the sum of the weights be 0 for 'mean' or the sum of the squared weights be
+ 0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for
+ all tables.
)doc");
REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
.Input("sample_indices: N * int32")
.Input("embedding_indices: N * int32")
.Input("aggregation_weights: N * float32")
- .Attr("N: int")
+ .Input("mode_override: string")
+ .Attr("N: int >= 1")
.Attr("device_ordinal: int = -1")
.Attr("combiners: list(string) = []")
.Attr("table_ids: list(int)")
@@ -532,25 +564,39 @@ REGISTER_OP("EnqueueTPUEmbeddingSparseTensorBatch")
This Op eases the porting of code that uses tf.nn.embedding_lookup_sparse().
sample_indices[i], embedding_indices[i] and aggregation_weights[i] correspond
-to ith feature. table_ids[i] indicates which embedding table to look up ith
+to the ith feature. table_ids[i] indicates which embedding table to look up ith
feature.
-sample_indices: A list of Rank 1 Tensors, corresponds to sp_ids.indices[:,0] in
-embedding_lookup_sparse().
-embedding_indices: A list of Rank 1 Tensors, corresponds to sp_ids.values
- in embedding_lookup_sparse().
-aggregation_weights: A list of Rank 1 Tensors, corresponds to sp_weights.values
- in embedding_lookup_sparse().
-device_ordinal: The TPU device to use. This should be -1 when the Op
-is running on a TPU device, and >= 0 when the Op is running on the CPU
-device.
-combiners: A list of strings, one for each embedding table, specifying the
-reduction operation. Currently, 'sum', 'mean' and 'sqrtn' are supported. It is
-invalid to have the sum of the weights be 0 for 'mean' or the sum of the squared
-weights be 0 for 'sqrtn'. If combiners isn't passed, the default is to
-use 'sum' for all tables.
-table_ids: A list of int. table_ids[i] indicates which embedding table to look
-up ith feature.
+The tensors at corresponding positions in the three input lists (sample_indices,
+embedding_indices and aggregation_weights) must have the same shape, i.e. rank 1
+with dim_size() equal to the total number of lookups into the table described by
+the corresponding feature.
+
+sample_indices: A list of rank 1 Tensors specifying the training example to
+ which the corresponding embedding_indices and aggregation_weights values
+ belong. It corresponds to sp_ids.indices[:,0] in embedding_lookup_sparse().
+embedding_indices: A list of rank 1 Tensors, indices into the embedding tables.
+ It corresponds to sp_ids.values in embedding_lookup_sparse().
+aggregation_weights: A list of rank 1 Tensors containing per training example
+ aggregation weights. It corresponds to sp_weights.values in
+ embedding_lookup_sparse().
+mode_override: A string input that overrides the mode specified in the
+ TPUEmbeddingConfiguration. Supported values are {'unspecified', 'inference',
+ 'training', 'backward_pass_only'}. When set to 'unspecified', the mode set
+ in TPUEmbeddingConfiguration is used, otherwise mode_override is used.
+device_ordinal: The TPU device to use. Should be >= 0 and less than the number
+ of TPU cores in the task on which the node is placed.
+combiners: A list of string scalars, one for each embedding table that specify
+ how to normalize the embedding activations after weighted summation.
+ Supported combiners are 'mean', 'sum', or 'sqrtn'. It is invalid to have
+ the sum of the weights be 0 for 'mean' or the sum of the squared weights be
+ 0 for 'sqrtn'. If combiners isn't passed, the default is to use 'sum' for
+ all tables.
+table_ids: A list of integers specifying the identifier of the embedding table
+ (offset of TableDescriptor in the TPUEmbeddingConfiguration) to lookup the
+ corresponding input. The ith input is looked up using table_ids[i]. The size
+ of the table_ids list must be equal to that of sample_indices,
+ embedding_indices and aggregation_weights.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
index b498599962..1c5ea2d997 100644
--- a/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
+++ b/tensorflow/contrib/tpu/profiler/capture_tpu_profile.cc
@@ -156,8 +156,7 @@ bool NewSession(const string& service_addr,
channel_args));
NewProfileSessionResponse new_session_response;
TF_QCHECK_OK(FromGrpcStatus(
- stub->NewSession(&context, new_session_request, &new_session_response)))
- << new_session_response.error_message();
+ stub->NewSession(&context, new_session_request, &new_session_response)));
std::cout << "Profile session succeed for host(s):"
<< str_util::Join(hostnames, ",") << std::endl;
@@ -238,7 +237,8 @@ void StartMonitoring(const tensorflow::string& service_addr, int duration_ms,
MonitorResponse response;
TF_QCHECK_OK(FromGrpcStatus(stub->Monitor(&context, request, &response)));
- std::cout << "Xprof Monitoring Results (Sample " << query + 1 << "):\n\n"
+ std::cout << "Cloud TPU Monitoring Results (Sample " << query + 1
+ << "):\n\n"
<< response.data() << std::flush;
}
}
diff --git a/tensorflow/contrib/tpu/profiler/op_profile.proto b/tensorflow/contrib/tpu/profiler/op_profile.proto
index b25d06dda8..292108f949 100644
--- a/tensorflow/contrib/tpu/profiler/op_profile.proto
+++ b/tensorflow/contrib/tpu/profiler/op_profile.proto
@@ -66,8 +66,8 @@ message Metrics {
// - it does not reveal the peak core FLOPS of the hardware
double flops = 2;
- // The VMEM bandwidth used to load operands from HBM, as a fraction of
- // thereotical VMEM bandwidth on the specific hardware.
+ // The memory bandwidth used to load operands, as a fraction of
+ // thereotical memory bandwidth on the specific hardware.
double memory_bandwidth = 3;
double raw_time = 11; // Elapsed core-time in picoseconds.
diff --git a/tensorflow/contrib/tpu/profiler/pip_package/setup.py b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
index 2415c46718..f27ae38e04 100644
--- a/tensorflow/contrib/tpu/profiler/pip_package/setup.py
+++ b/tensorflow/contrib/tpu/profiler/pip_package/setup.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from setuptools import setup
-_VERSION = '1.11.0'
+_VERSION = '1.12.0'
CONSOLE_SCRIPTS = [
'capture_tpu_profile=cloud_tpu_profiler.main:run_main',
diff --git a/tensorflow/contrib/tpu/profiler/version.h b/tensorflow/contrib/tpu/profiler/version.h
index 90d34b5ef1..4b6d1b2b07 100644
--- a/tensorflow/contrib/tpu/profiler/version.h
+++ b/tensorflow/contrib/tpu/profiler/version.h
@@ -16,6 +16,6 @@ limitations under the License.
#ifndef TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
#define TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
-#define TPU_PROFILER_VERSION "1.11.0"
+#define TPU_PROFILER_VERSION "1.12.0"
#endif // TENSORFLOW_CONTRIB_TPU_PROFILER_VERSION_H_
diff --git a/tensorflow/contrib/tpu/proto/optimization_parameters.proto b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
index fc1320501b..8529b48c15 100644
--- a/tensorflow/contrib/tpu/proto/optimization_parameters.proto
+++ b/tensorflow/contrib/tpu/proto/optimization_parameters.proto
@@ -22,13 +22,22 @@ message LearningRate {
}
}
+// Each optimizer's parameter proto has a link to its documentation and CPU
+// implementation (if available) for user reference.
+
+// https://www.tensorflow.org/api_docs/python/tf/train/AdagradOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L151
message AdagradParameters {
float initial_accumulator = 1;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/GradientDescentOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L423
message StochasticGradientDescentParameters {
}
+// https://www.tensorflow.org/api_docs/python/tf/train/FtrlOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L192
message FtrlParameters {
float l1 = 1;
float l2 = 2;
@@ -41,21 +50,42 @@ message FtrlParameters {
// learning rate feature instead, setting the learning rate to:
// user learning_rate * sqrt(1 - beta2^t) / (1 - beta1^t)
// Here, t is the current timestep.
+//
+// https://www.tensorflow.org/api_docs/python/tf/train/AdamOptimizer
// https://github.com/tensorflow/tensorflow/blob/ab51450c817674c8ff08a7ae4f8ac50cdc4bed8b/tensorflow/python/training/adam.py#L54
+//
+// Note that the code by default implements the lazy version of Adam
+// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/LazyAdamOptimizer)
+// unless the use_non_lazy_adam parameter is set, in which case it implements
+// the normal version of Adam that updates all parameters in the embedding
+// table, even for entries that are not used in the current minibatch
+// (https://www.tensorflow.org/api_docs/python/tf/contrib/opt/AdamOptimizer). If
+// use_non_lazy_adam is enabled, use_gradient_accumulation is also required in
+// order to get correct results; a warning will be printed otherwise (which may
+// change to an error in the future). If use_max_with_epsilon is set, the Adam
+// variable update formula will be changed from m / (sqrt(v) + epsilon) to
+// m / max(sqrt(v), abs(epsilon)); this option improves the performance of TPU
+// training and is not expected to harm model quality.
message AdamParameters {
float beta1 = 3;
float beta2 = 4;
float epsilon = 5;
float initial_m = 6;
float initial_v = 7;
+ bool use_non_lazy_adam = 8;
+ bool use_max_with_epsilon = 9;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/MomentumOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L271
message MomentumParameters {
float momentum = 1;
bool use_nesterov = 2;
float initial_accum = 3;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L356
message RmsPropParameters {
float rho = 1;
float momentum = 2;
@@ -64,6 +94,8 @@ message RmsPropParameters {
float initial_mom = 5;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L372
message CenteredRmsPropParameters {
float rho = 1;
float momentum = 2;
@@ -73,6 +105,7 @@ message CenteredRmsPropParameters {
float initial_mg = 6;
}
+// Variant of algorithm in http://proceedings.mlr.press/v44/shamir15.pdf
message MdlAdagradLightParameters {
float l2 = 1;
float lr_power = 2;
@@ -91,6 +124,8 @@ message MdlAdagradLightParameters {
float initial_benefit = 15;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L68
message AdadeltaParameters {
float rho = 1;
float epsilon = 2;
@@ -98,6 +133,8 @@ message AdadeltaParameters {
float initial_update = 4;
}
+// https://www.tensorflow.org/api_docs/python/tf/train/RMSPropOptimizer
+// https://github.com/tensorflow/tensorflow/blob/c19e29306ce1777456b2dbb3a14f511edf7883a8/tensorflow/core/kernels/training_ops.cc#L164
message ProximalAdagradParameters {
float l1 = 1;
float l2 = 2;
diff --git a/tensorflow/contrib/tpu/python/ops/tpu_ops.py b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
index a1aee69691..968adccf2b 100644
--- a/tensorflow/contrib/tpu/python/ops/tpu_ops.py
+++ b/tensorflow/contrib/tpu/python/ops/tpu_ops.py
@@ -200,6 +200,181 @@ if platform.system() != "Windows":
return gen_tpu_ops.infeed_dequeue_tuple(dtypes, shapes, name=name)
# pylint: enable=redefined-outer-name
+ # pylint: disable=protected-access
+ def send_tpu_embedding_gradients(inputs,
+ config,
+ learning_rates=None,
+ name=None):
+ """A placeholder op for feeding per-sample gradients to the embedding layer.
+
+ Args:
+ inputs: A TensorList of gradients with which to update embedding tables.
+ Contains one tensor per embedding table in the model.
+ config: Serialized TPUEmbeddingConfiguration proto.
+ learning_rates: A TensorList of float32 scalars, one for each embedding
+ table, containing the learning rates for each table when dynamic
+ learning rate is enabled through the OptimizationParameters in
+ TPUEmbeddingConfiguration. When the learning rate is constant, the list
+ should be empty (optional).
+ name: A name for the operation (optional).
+
+ Returns:
+ A SendTPUEmbeddingGradients operation.
+ """
+ if learning_rates is None:
+ learning_rates = []
+ return gen_tpu_ops._send_tpu_embedding_gradients(
+ inputs=inputs, learning_rates=learning_rates, config=config, name=name)
+
+
+ send_tpu_embedding_gradients.__doc__ = (
+ gen_tpu_ops._send_tpu_embedding_gradients.__doc__)
+
+ # pylint: disable=protected-access
+ def enqueue_tpu_embedding_integer_batch(batch,
+ device_ordinal,
+ mode_override=None,
+ name=None):
+ """A placeholder op for enqueueing embedding IDs to the TPU.
+
+ Args:
+ batch: A list of 1D tensors, one for each embedding table, containing the
+ indices into the tables.
+ device_ordinal: The TPU device to use. Should be >= 0 and less than the
+ number of TPU cores in the task on which the node is placed.
+ mode_override: A string input that overrides the mode specified in the
+ TPUEmbeddingConfiguration. Supported values are {'unspecified',
+ 'inference', 'training', 'backward_pass_only'}. When set to
+ 'unspecified', the mode set in TPUEmbeddingConfiguration is used,
+ otherwise mode_override is used (optional).
+ name: A name for the operation (optional).
+
+ Returns:
+ An EnqueueTPUEmbeddingIntegerBatch operation.
+ """
+ if mode_override is None:
+ mode_override = "unspecified"
+ return gen_tpu_ops._enqueue_tpu_embedding_integer_batch(
+ batch=batch,
+ device_ordinal=device_ordinal,
+ mode_override=mode_override,
+ name=name)
+
+ enqueue_tpu_embedding_integer_batch.__doc__ = (
+ gen_tpu_ops._enqueue_tpu_embedding_integer_batch.__doc__)
+
+ # pylint: disable=protected-access
+ def enqueue_tpu_embedding_sparse_batch(sample_indices,
+ embedding_indices,
+ aggregation_weights,
+ device_ordinal,
+ combiners=None,
+ mode_override=None,
+ name=None):
+ """A placeholder op for enqueueing embedding IDs to the TPU.
+
+ Args:
+ sample_indices: A list of rank 1 Tensors specifying the training example
+ and feature to which the corresponding embedding_indices and
+ aggregation_weights values belong. sample_indices[i] must equal b * nf +
+ f, where nf is the number of features from the corresponding table, f is
+ in [0, nf), and b is in [0, batch size).
+ embedding_indices: A list of rank 1 Tensors, indices into the embedding
+ tables.
+ aggregation_weights: A list of rank 1 Tensors containing per sample --
+ i.e. per (training example, feature) -- aggregation weights.
+ device_ordinal: The TPU device to use. Should be >= 0 and less than the
+ number of TPU cores in the task on which the node is placed.
+ combiners: A list of string scalars, one for each embedding table that
+ specify how to normalize the embedding activations after weighted
+ summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
+ invalid to have the sum of the weights be 0 for 'mean' or the sum of the
+ squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
+ is to use 'sum' for all tables (optional).
+ mode_override: A string input that overrides the mode specified in the
+ TPUEmbeddingConfiguration. Supported values are {'unspecified',
+ 'inference', 'training', 'backward_pass_only'}. When set to
+ 'unspecified', the mode set in TPUEmbeddingConfiguration is used,
+ otherwise mode_override is used (optional).
+ name: A name for the operation (optional).
+
+ Returns:
+ An EnqueueTPUEmbeddingSparseBatch operation.
+ """
+ if mode_override is None:
+ mode_override = "unspecified"
+ return gen_tpu_ops._enqueue_tpu_embedding_sparse_batch(
+ sample_indices=sample_indices,
+ embedding_indices=embedding_indices,
+ aggregation_weights=aggregation_weights,
+ device_ordinal=device_ordinal,
+ combiners=combiners,
+ mode_override=mode_override,
+ name=name)
+
+ enqueue_tpu_embedding_sparse_batch.__doc__ = (
+ gen_tpu_ops._enqueue_tpu_embedding_sparse_batch.__doc__)
+
+ # pylint: disable=protected-access
+ def enqueue_tpu_embedding_sparse_tensor_batch(sample_indices,
+ embedding_indices,
+ aggregation_weights,
+ table_ids,
+ device_ordinal,
+ combiners=None,
+ mode_override=None,
+ name=None):
+ """A placeholder op for enqueueing embedding IDs to the TPU.
+
+ Args:
+ sample_indices: A list of rank 1 Tensors specifying the training example
+ to which the corresponding embedding_indices and aggregation_weights
+ values
+ belong. It corresponds to sp_ids.indices[:,0] in
+ embedding_lookup_sparse().
+ embedding_indices: A list of rank 1 Tensors, indices into the embedding
+ tables. It corresponds to sp_ids.values in embedding_lookup_sparse().
+ aggregation_weights: A list of rank 1 Tensors containing per training
+ example aggregation weights. It corresponds to sp_weights.values in
+ embedding_lookup_sparse().
+ table_ids: A list of integers specifying the identifier of the embedding
+ table (offset of TableDescriptor in the TPUEmbeddingConfiguration) to
+ lookup the corresponding input. The ith input is looked up using
+ table_ids[i]. The size of the table_ids list must be equal to that of
+ sample_indices, embedding_indices and aggregation_weights.
+ device_ordinal: The TPU device to use. Should be >= 0 and less than the
+ number of TPU cores in the task on which the node is placed.
+ combiners: A list of string scalars, one for each embedding table that
+ specify how to normalize the embedding activations after weighted
+ summation. Supported combiners are 'mean', 'sum', or 'sqrtn'. It is
+ invalid to have the sum of the weights be 0 for 'mean' or the sum of the
+ squared weights be 0 for 'sqrtn'. If combiners isn't passed, the default
+ is to use 'sum' for all tables (optional).
+ mode_override: A string input that overrides the mode specified in the
+ TPUEmbeddingConfiguration. Supported values are {'unspecified',
+ 'inference', 'training', 'backward_pass_only'}. When set to
+ 'unspecified', the mode set in TPUEmbeddingConfiguration is used,
+ otherwise mode_override is used (optional).
+ name: A name for the operation (optional).
+
+ Returns:
+ An EnqueueTPUEmbeddingSparseTensorBatch operation.
+ """
+ if mode_override is None:
+ mode_override = "unspecified"
+ return gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch(
+ sample_indices=sample_indices,
+ embedding_indices=embedding_indices,
+ aggregation_weights=aggregation_weights,
+ table_ids=table_ids,
+ device_ordinal=device_ordinal,
+ combiners=combiners,
+ mode_override=mode_override,
+ name=name)
+
+ enqueue_tpu_embedding_sparse_tensor_batch.__doc__ = (
+ gen_tpu_ops._enqueue_tpu_embedding_sparse_tensor_batch.__doc__)
+
else:
# We have already built the appropriate libraries into the binary via CMake
# if we have built contrib, so we don't need this
diff --git a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
index e06a720e82..20b7ba0997 100644
--- a/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
+++ b/tensorflow/contrib/tpu/python/tpu/async_checkpoint.py
@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ======================================
-
"""Hook for asynchronous checkpointing.
This hook dispatches checkpoint writing operations in a separate thread to
@@ -28,18 +27,16 @@ import threading
import time
from tensorflow.core.util.event_pb2 import SessionLog
-
from tensorflow.python.framework import meta_graph
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import basic_session_run_hooks
-from tensorflow.python.training import session_run_hook
from tensorflow.python.training import training_util
from tensorflow.python.training.session_run_hook import SessionRunArgs
from tensorflow.python.training.summary_io import SummaryWriterCache
-class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
+class AsyncCheckpointSaverHook(basic_session_run_hooks.CheckpointSaverHook):
"""Saves checkpoints every N steps or seconds."""
def __init__(self,
@@ -67,7 +64,7 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
ValueError: One of `save_steps` or `save_secs` should be set.
ValueError: At most one of `saver` or `scaffold` should be set.
"""
- logging.info("Create CheckpointSaverHook.")
+ logging.info("Create AsyncCheckpointSaverHook.")
if saver is not None and scaffold is not None:
raise ValueError("You cannot provide both saver and scaffold.")
self._saver = saver
@@ -144,6 +141,10 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
def _save(self, session, step, asynchronous=True):
"""Saves the latest checkpoint, returns should_stop."""
+ # Skip saving on step 0
+ if step == 0:
+ return
+
def _save_fn():
"""Run the saver process."""
logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
@@ -162,7 +163,6 @@ class AsyncCheckpointSaverHook(session_run_hook.SessionRunHook):
end_time - start_time)
logging.info("Checkpoint finished for %d into %s.", step, self._save_path)
- logging.info("Saving checkpoints for %d into %s.", step, self._save_path)
for l in self._listeners:
l.before_save(session, step)
diff --git a/tensorflow/contrib/tpu/python/tpu/datasets.py b/tensorflow/contrib/tpu/python/tpu/datasets.py
index d879170b68..c694e9c1bc 100644
--- a/tensorflow/contrib/tpu/python/tpu/datasets.py
+++ b/tensorflow/contrib/tpu/python/tpu/datasets.py
@@ -18,8 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
index 956d0142a3..a3a7fd8bb0 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_support.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -46,6 +46,7 @@ from __future__ import print_function
import abc
import collections
+import contextlib
import re
import sys
import time
@@ -94,21 +95,56 @@ from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
+# TODO(b/114775106): temporary shim to optionally initialize the TPU
+# This increases the odds our session is initialized, but shouldn't be needed.
+def _maybe_initialize_tpu(session):
+ """Initialize the TPU if it has not already been initialized."""
+ try:
+
+ def test_op():
+ return constant_op.constant(1) + constant_op.constant(1)
+
+ session.run(tpu.rewrite(test_op))
+ except errors.FailedPreconditionError as _:
+ session.run(tpu.initialize_system())
+
+
+@contextlib.contextmanager
+def _tpu_session_context():
+ """Initialize the TPU and cleans cache entries for bad sessions."""
+ try:
+ _maybe_initialize_tpu(K.get_session())
+ yield
+ except (errors.FailedPreconditionError, errors.AbortedError) as e:
+ K.clear_session()
+ raise Exception("""
+An error occurred connecting or initializing your TPU.
+
+The session has been reset. re-run keras_to_tpu_model to create a new session.
+""" + e)
+
+
def setup_tpu_session(cluster_resolver):
"""Construct or return a `tf.Session` connected to the given cluster."""
master = cluster_resolver.master()
# Use the existing session if we're already connected to this TPU
- if (K.get_session()._target == master and
- getattr(K.get_session(), '_tpu_initialized', None)):
- return
+ # N.B K.get_session() is a non-trivial operation, and may fail if the remote
+ # session has been reset.
+ try:
+ default_session = K.get_session()
+ if (default_session._target == master and
+ getattr(default_session, '_tpu_initialized', None)):
+ return
+ except errors.AbortedError as _:
+ # We lost the remote session and need to re-initialize.
+ logging.warning('Lost remote session: creating a new session.')
cluster_spec = cluster_resolver.cluster_spec()
config = config_pb2.ConfigProto(isolate_session_state=True)
if cluster_spec:
config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())
- logging.info('Initialize')
tpu_session = tf_session.Session(target=master, config=config)
tpu_session.run(tpu.initialize_system())
tpu_session._tpu_initialized = True
@@ -959,7 +995,16 @@ class TPUFunction(object):
# Compute our outfeed depending on the execution mode
if is_training:
- self._cloned_model._make_train_function()
+ if not isinstance(self._cloned_optimizer, keras_optimizers.TFOptimizer):
+ # For Keras optimizer, we try to place the variable weights on the TPU
+ # device. Keras creates optimizer variables (e.g. momentum values for
+ # the Momentum optimizer) when _make_train_function is invoked.
+ with keras_tpu_variables.replicated_variable_for_optimizer(
+ self._tpu_assignment.num_towers):
+ self._cloned_model._make_train_function()
+ else:
+ self._cloned_model._make_train_function()
+
self._outfeed_spec = [
tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
for tensor in self._cloned_model.train_function.outputs
@@ -1382,97 +1427,74 @@ class KerasTPUModel(models.Model):
raise EnvironmentError('KerasTPUModel currently does not support eager '
'mode.')
- assert not self._numpy_to_infeed_manager_list # Ensure empty.
-
- infeed_managers = [] # Managers to clean up at the end of the fit call.
- if isinstance(x, dataset_ops.Dataset):
- # TODO(b/111413240): Support taking a tf.data.Dataset directly.
- raise ValueError(
- 'Taking a Dataset directly is not yet supported. Please '
- 'wrap your dataset construction code in a function and '
- 'pass that to fit instead. For examples, see: '
- 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
- '/keras')
- if callable(x):
- with ops.device('/job:%s/device:CPU:0' %
- self._tpu_assignment.worker_name):
- dataset = x()
- if steps_per_epoch is None:
- raise ValueError('When using tf.data as input to a model, you '
- 'should specify the steps_per_epoch argument.')
- if y is not None:
- raise ValueError('When using tf.data as input to a model, y must be '
- 'None')
- infeed_manager = TPUDatasetInfeedManager(
- dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
+ with _tpu_session_context():
+ assert not self._numpy_to_infeed_manager_list # Ensure empty.
+
+ infeed_managers = [] # Managers to clean up at the end of the fit call.
+ if isinstance(x, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
+ raise ValueError(
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(x):
+ with ops.device(
+ '/job:%s/device:CPU:0' % self._tpu_assignment.worker_name):
+ dataset = x()
+ if steps_per_epoch is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps_per_epoch argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must '
+ 'be None')
+ infeed_manager = TPUDatasetInfeedManager(
+ dataset, self._tpu_assignment, model_fn_lib.ModeKeys.TRAIN)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
+
+ if isinstance(validation_data, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
+ raise ValueError(
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(validation_data):
+ dataset = validation_data()
+ if validation_steps is None:
+ raise ValueError('When using tf.data as validation for a model, you '
+ 'should specify the validation_steps argument.')
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ model_fn_lib.ModeKeys.EVAL)
# Use dummy numpy inputs for the rest of Keras' shape checking. We
# intercept them when building the model.
- x = infeed_manager.dummy_x
- y = infeed_manager.dummy_y
- infeed_managers.append((x, infeed_manager))
+ val_x = infeed_manager.dummy_x
+ val_y = infeed_manager.dummy_y
+ infeed_managers.append((val_x, infeed_manager))
+ validation_data = (val_x, val_y)
- if isinstance(validation_data, dataset_ops.Dataset):
- # TODO(b/111413240): Support taking a tf.data.Dataset directly.
- raise ValueError(
- 'Taking a Dataset directly is not yet supported. Please '
- 'wrap your dataset construction code in a function and '
- 'pass that to fit instead. For examples, see: '
- 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
- '/keras')
- if callable(validation_data):
- dataset = validation_data()
- if validation_steps is None:
- raise ValueError('When using tf.data as validation for a model, you '
- 'should specify the validation_steps argument.')
- infeed_manager = TPUDatasetInfeedManager(
- dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- val_x = infeed_manager.dummy_x
- val_y = infeed_manager.dummy_y
- infeed_managers.append((val_x, infeed_manager))
- validation_data = (val_x, val_y)
-
- self._numpy_to_infeed_manager_list = infeed_managers
- try:
- if not kwargs.get('_pipeline', True):
- logging.info('Running non-pipelined training loop (`_pipeline=%s`).',
- kwargs['_pipeline'])
- kwargs.pop('_pipeline')
- return super(KerasTPUModel, self).fit(
- x,
- y,
- batch_size,
- epochs,
- verbose,
- callbacks,
- validation_split,
- validation_data,
- shuffle,
- class_weight,
- sample_weight,
- initial_epoch,
- steps_per_epoch,
- validation_steps,
- **kwargs)
- return self._pipeline_fit(
- x,
- y,
- batch_size,
- epochs,
- verbose,
- callbacks,
- validation_split,
- validation_data,
- shuffle,
- class_weight,
- sample_weight,
- initial_epoch,
- steps_per_epoch,
- validation_steps,
- **kwargs)
- finally:
- self._numpy_to_infeed_manager_list = []
+ self._numpy_to_infeed_manager_list = infeed_managers
+ try:
+ if not kwargs.get('_pipeline', True):
+ logging.info('Running non-pipelined training loop (`_pipeline=%s`).',
+ kwargs['_pipeline'])
+ kwargs.pop('_pipeline')
+ return super(KerasTPUModel, self).fit(
+ x, y, batch_size, epochs, verbose, callbacks, validation_split,
+ validation_data, shuffle, class_weight, sample_weight,
+ initial_epoch, steps_per_epoch, validation_steps, **kwargs)
+ return self._pipeline_fit(x, y, batch_size, epochs, verbose, callbacks,
+ validation_split, validation_data, shuffle,
+ class_weight, sample_weight, initial_epoch,
+ steps_per_epoch, validation_steps, **kwargs)
+ finally:
+ self._numpy_to_infeed_manager_list = []
def evaluate(self,
x=None,
@@ -1483,37 +1505,38 @@ class KerasTPUModel(models.Model):
steps=None):
assert not self._numpy_to_infeed_manager_list # Ensure empty.
- infeed_managers = [] # Managers to clean up at the end of the fit call.
- if isinstance(x, dataset_ops.Dataset):
- # TODO(b/111413240): Support taking a tf.data.Dataset directly.
- raise ValueError(
- 'Taking a Dataset directly is not yet supported. Please '
- 'wrap your dataset construction code in a function and '
- 'pass that to fit instead. For examples, see: '
- 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
- '/keras')
- if callable(x):
- dataset = x()
- if steps is None:
- raise ValueError('When using tf.data as input to a model, you '
- 'should specify the steps argument.')
- if y is not None:
- raise ValueError('When using tf.data as input to a model, y must be '
- 'None')
- infeed_manager = TPUDatasetInfeedManager(
- dataset, self._tpu_assignment, model_fn_lib.ModeKeys.EVAL)
- # Use dummy numpy inputs for the rest of Keras' shape checking. We
- # intercept them when building the model.
- x = infeed_manager.dummy_x
- y = infeed_manager.dummy_y
- infeed_managers.append((x, infeed_manager))
-
- self._numpy_to_infeed_manager_list = infeed_managers
- try:
- return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose,
- sample_weight, steps)
- finally:
- self._numpy_to_infeed_manager_list = []
+ with _tpu_session_context():
+ infeed_managers = [] # Managers to clean up at the end of the fit call.
+ if isinstance(x, dataset_ops.Dataset):
+ # TODO(b/111413240): Support taking a tf.data.Dataset directly.
+ raise ValueError(
+ 'Taking a Dataset directly is not yet supported. Please '
+ 'wrap your dataset construction code in a function and '
+ 'pass that to fit instead. For examples, see: '
+ 'https://github.com/tensorflow/tpu/tree/master/models/experimental'
+ '/keras')
+ if callable(x):
+ dataset = x()
+ if steps is None:
+ raise ValueError('When using tf.data as input to a model, you '
+ 'should specify the steps argument.')
+ if y is not None:
+ raise ValueError('When using tf.data as input to a model, y must be '
+ 'None')
+ infeed_manager = TPUDatasetInfeedManager(dataset, self._tpu_assignment,
+ model_fn_lib.ModeKeys.EVAL)
+ # Use dummy numpy inputs for the rest of Keras' shape checking. We
+ # intercept them when building the model.
+ x = infeed_manager.dummy_x
+ y = infeed_manager.dummy_y
+ infeed_managers.append((x, infeed_manager))
+
+ self._numpy_to_infeed_manager_list = infeed_managers
+ try:
+ return super(KerasTPUModel, self).evaluate(x, y, batch_size, verbose,
+ sample_weight, steps)
+ finally:
+ self._numpy_to_infeed_manager_list = []
def _pipeline_fit(self, x, y, batch_size, epochs, verbose, callbacks,
validation_split, validation_data, shuffle, class_weight,
@@ -1901,6 +1924,24 @@ class KerasTPUModel(models.Model):
return val_x, val_y, val_sample_weights
+ def predict(self,
+ x,
+ batch_size=None,
+ verbose=0,
+ steps=None,
+ max_queue_size=10,
+ workers=1,
+ use_multiprocessing=False):
+ with _tpu_session_context():
+ return super(KerasTPUModel, self).predict(
+ x,
+ batch_size=batch_size,
+ verbose=verbose,
+ steps=steps,
+ max_queue_size=max_queue_size,
+ workers=workers,
+ use_multiprocessing=use_multiprocessing)
+
@property
def optimizer(self):
if self._tpu_model:
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
index 170977d8ab..004b1012e5 100644
--- a/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
+++ b/tensorflow/contrib/tpu/python/tpu/keras_tpu_variables.py
@@ -25,10 +25,15 @@ from __future__ import print_function
import contextlib
+import numpy as np
+
from tensorflow.python.client import session as session_lib
+from tensorflow.python.framework import dtypes as dtypes_module
from tensorflow.python.framework import ops
+from tensorflow.python.keras import backend
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_resource_variable_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
@@ -73,7 +78,7 @@ class ReplicatedVariable(object):
if tpu_context is None:
return self._primary_var.handle
- return tpu_context.get_replicated_var_handle(self)
+ return tpu_context.get_replicated_var_handle(self._name, self._vars)
@contextlib.contextmanager
def _assign_dependencies(self):
@@ -285,3 +290,51 @@ def replicated_scope(num_replicas):
return variable_scope.variable_scope(
"", custom_getter=_replicated_variable_getter)
+
+
+@contextlib.contextmanager
+def replicated_variable_for_optimizer(num_replicas):
+ """Context manager for optimizer weights. Overrides K.variable."""
+ if num_replicas == 1:
+ yield
+ return
+
+ try:
+ old_v = backend.variable
+
+ def opt_variable(value, dtype=None, name=None, constraint=None):
+ """Instantiates a variable and returns it."""
+ if dtype is None:
+ dtype = backend.floatx()
+
+ variables = []
+ for i in range(num_replicas):
+ # Keras holds the variables in optimizer class instance , so the name
+ # does not matter here. ResourceVariable constructor will find a unique
+ # name (including name=None) for each replica.
+ with ops.device("device:TPU:{}".format(i)):
+ v = resource_variable_ops.ResourceVariable(
+ value,
+ dtype=dtypes_module.as_dtype(dtype),
+ name=name,
+ constraint=constraint)
+ variables.append(v)
+ name = "replicate_{}_{}".format("variable" if name is None else name,
+ ops.uid())
+ v = ReplicatedVariable(name, variables)
+
+ # pylint: disable=protected-access
+
+ if isinstance(value, np.ndarray):
+ v._keras_shape = value.shape
+ elif hasattr(value, "shape"):
+ v._keras_shape = backend.int_shape(value)
+ v._uses_learning_phase = False
+ backend.track_variable(v)
+ return v
+
+ backend.variable = opt_variable
+ yield
+
+ finally:
+ backend.variable = old_v
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index 883e08bf47..11aaa1c66a 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -155,19 +155,20 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
self._pivot = pivot
self._replicated_vars = {}
- def get_replicated_var_handle(self, var):
+ def get_replicated_var_handle(self, name, vars_):
"""Returns a variable handle for replicated TPU variable 'var'.
This is a method used by an experimental replicated variable implementation
and is not intended as a public API.
Args:
- var: The replicated TPU variable.
+ name: The common name of the variable.
+ vars_: The replicated TPU variables.
Returns:
The handle of the TPU replicated input node.
"""
- handle = self._replicated_vars.get(var)
+ handle = self._replicated_vars.get(name)
if handle is not None:
return handle
@@ -183,10 +184,10 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
saved_context = graph._get_control_flow_context()
graph._set_control_flow_context(self.outer_context)
handle = tpu_ops.tpu_replicated_input(
- [v.handle for v in var._vars], name=var.name + "/handle")
+ [v.handle for v in vars_], name=name + "/handle")
graph._set_control_flow_context(saved_context)
# pylint: enable=protected-access
- self._replicated_vars[var] = handle
+ self._replicated_vars[name] = handle
return handle
def report_unsupported_operations(self):
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_context.py b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
index 7cfb6c38fa..da6bdf67d6 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_context.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_context.py
@@ -154,6 +154,20 @@ class TPUContext(object):
# as far as model is replicated to all cores in the system.
return self._internal_ctx.device_for_replica(replica_id)
+ @property
+ def tpu_host_placement_function(self):
+ """Returns the TPU host place function.
+
+ The place function takes host_id as the input and returns the TF device
+ for the correspoding host.
+ """
+
+ def _placement_function(host_id):
+ """Return the host device given host_id."""
+ return self._internal_ctx.tpu_host_placement_function(host_id=host_id)
+
+ return _placement_function
+
class _InternalTPUContext(object):
"""A context holds immutable states of TPU computation.
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index 23c54511ca..3aa5b6efa1 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -231,7 +231,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
`metric_fn` runs on CPU to generate metrics and `tensors` represents the
`Tensor`s transferred from TPU system to CPU host and passed to `metric_fn`.
To be precise, TPU evaluation expects a slightly different signature from the
- @{tf.estimator.Estimator}. While `EstimatorSpec.eval_metric_ops` expects a
+ `tf.estimator.Estimator`. While `EstimatorSpec.eval_metric_ops` expects a
dict, `TPUEstimatorSpec.eval_metrics` is a tuple of `metric_fn` and `tensors`.
The `tensors` could be a list of `Tensor`s or dict of names to `Tensor`s. The
`tensors` usually specify the model logits, which are transferred back from
@@ -254,7 +254,7 @@ class TPUEstimatorSpec(model_fn_lib._TPUEstimatorSpec): # pylint: disable=prote
sending tensors from TPU to CPU. To reduce the overhead, try reducing the
size of the tensors. The `tensors` are concatenated along their major (batch)
dimension, and so must be >= rank 1. The `host_call` is useful for writing
- summaries with @{tf.contrib.summary.create_file_writer}.
+ summaries with `tf.contrib.summary.create_file_writer`.
"""
def __new__(cls,
@@ -404,12 +404,17 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
self._feed_error = None
self._finished = False
+ self._should_initialize_tpu = True
def begin(self):
logging.info('TPU job name %s', self._master_job)
self._iterations_per_loop_var = _create_or_get_iterations_per_loop()
- self._init_ops = [tpu.initialize_system(job=self._master_job)]
- self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
+ if self._should_initialize_tpu:
+ self._init_ops = [tpu.initialize_system(job=self._master_job)]
+ self._finalize_ops = [tpu.shutdown_system(job=self._master_job)]
+ else:
+ self._init_ops = []
+ self._finalize_ops = []
summary_writer_init_ops = contrib_summary.summary_writer_initializer_op()
self._init_ops.extend(summary_writer_init_ops)
@@ -421,10 +426,10 @@ class TPUInfeedOutfeedSessionHook(session_run_hook.SessionRunHook):
def _run_infeed(self, queue_ctx, session):
logging.info('Starting infeed thread controller.')
if self._initial_infeed_sleep_secs:
- logging.info('%s thread sleeping for %d seconds.', self._name,
+ logging.info('Infeed thread sleeping for %d seconds.',
self._initial_infeed_sleep_secs)
time.sleep(self._initial_infeed_sleep_secs)
- logging.info('%s thread starting after sleep', self._name)
+ logging.info('Infeed thread starting after sleep')
with self._rendezvous.catch_errors(source='infeed', session=session):
if self._run_infeed_loop_on_coordinator:
diff --git a/tensorflow/contrib/tpu/tpu_estimator.md b/tensorflow/contrib/tpu/tpu_estimator.md
index 639e708169..b6514e19dc 100644
--- a/tensorflow/contrib/tpu/tpu_estimator.md
+++ b/tensorflow/contrib/tpu/tpu_estimator.md
@@ -87,7 +87,7 @@ handle training:
label = tf.cast(features["label"], tf.int32)
return image, label
- dataset = tf.contrib.data.TFRecordDataset(
+ dataset = tf.data.TFRecordDataset(
filename, buffer_size=FLAGS.dataset_reader_buffer_size)
dataset = dataset.map(parser).cache().repeat().batch(batch_size)
images, labels = dataset.make_one_shot_iterator().get_next()
diff --git a/tensorflow/contrib/training/BUILD b/tensorflow/contrib/training/BUILD
index ddf8365d61..00295f57f6 100644
--- a/tensorflow/contrib/training/BUILD
+++ b/tensorflow/contrib/training/BUILD
@@ -295,7 +295,6 @@ py_test(
tags = ["notsan"],
deps = [
":training_py",
- "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:gradients",
@@ -305,6 +304,7 @@ py_test(
"//tensorflow/python:training",
"//tensorflow/python:variables",
"//tensorflow/python/data",
+ "//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base",
"//third_party/py/numpy",
],
)
@@ -313,6 +313,5 @@ tf_proto_library(
name = "protos_all",
srcs = glob(["**/*.proto"]),
cc_api_version = 2,
- java_api_version = 2,
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
index d9b0511a98..c1657fec7b 100644
--- a/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
+++ b/tensorflow/contrib/training/python/training/tensor_queue_dataset_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.contrib.training.python.training import tensor_queue_dataset as tqd
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index ca247dc56b..0aae29d10c 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -239,7 +239,6 @@ tf_proto_library(
srcs = [],
cc_api_version = 2,
default_header = True,
- java_api_version = 2,
js_api_version = 2,
protodeps = [
":protos_all_proto",
@@ -272,6 +271,12 @@ proto_library(
visibility = ["//visibility:public"],
)
+java_proto_library(
+ name = "example_java_proto",
+ visibility = ["//visibility:public"],
+ deps = [":example_protos"],
+)
+
closure_proto_library(
name = "example_protos_closure",
visibility = ["//visibility:public"],
@@ -1039,6 +1044,7 @@ tf_gen_op_libs(
"dataset_ops",
"decode_proto_ops",
"encode_proto_ops",
+ "experimental_dataset_ops",
"function_ops",
"functional_ops",
"image_ops",
@@ -1169,6 +1175,7 @@ cc_library(
":dataset_ops_op_lib",
":decode_proto_ops_op_lib",
":encode_proto_ops_op_lib",
+ ":experimental_dataset_ops_op_lib",
":function_ops_op_lib",
":functional_ops_op_lib",
":image_ops_op_lib",
@@ -2383,7 +2390,6 @@ tf_proto_library(
srcs = ERROR_CODES_PROTO_SRCS,
cc_api_version = 2,
default_header = True,
- java_api_version = 2,
js_api_version = 2,
provide_cc_alias = True,
)
@@ -2404,7 +2410,6 @@ tf_proto_library(
srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS,
cc_api_version = 2,
default_header = True,
- java_api_version = 2,
js_api_version = 2,
protodeps = [
":error_codes_proto",
@@ -2484,6 +2489,8 @@ FRAMEWORK_INTERNAL_PUBLIC_HEADERS = [
"framework/op_segment.h",
"framework/rendezvous.h", # only needed for tests
"framework/resource_var.h",
+ "framework/run_handler.h",
+ "framework/run_handler_util.h",
"framework/tensor_reference.h",
"framework/tracking_allocator.h", # only needed for tests
"framework/unique_tensor_references.h",
@@ -2970,6 +2977,7 @@ tf_cuda_library(
":core_cpu_internal",
":device_tracer",
":framework",
+ ":framework_internal",
":graph",
":lib",
":lib_internal",
@@ -4117,6 +4125,19 @@ tf_cc_test(
],
)
+tf_cc_test(
+ name = "framework_run_handler_util_test",
+ size = "small",
+ srcs = ["framework/run_handler_util_test.cc"],
+ linkstatic = tf_kernel_tests_linkstatic(),
+ deps = [
+ ":framework_internal",
+ ":lib",
+ ":test",
+ ":test_main",
+ ],
+)
+
tf_cuda_cc_test(
name = "common_runtime_direct_session_test",
size = "small",
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt
new file mode 100644
index 0000000000..fa8fc96bb2
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalAssertNextDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalAssertNextDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt
new file mode 100644
index 0000000000..5fd88e7a0c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalCSVDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalCSVDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt
new file mode 100644
index 0000000000..ac1f9719fe
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalDirectedInterleaveDataset.pbtxt
@@ -0,0 +1,21 @@
+op {
+ graph_op_name: "ExperimentalDirectedInterleaveDataset"
+ in_arg {
+ name: "selector_input_dataset"
+ description: <<END
+A dataset of scalar `DT_INT64` elements that determines which of the
+`N` data inputs should produce the next output element.
+END
+ }
+ in_arg {
+ name: "data_input_datasets"
+ description: <<END
+`N` datasets with the same type that will be interleaved according to
+the values of `selector_input_dataset`.
+END
+ }
+ summary: <<END
+A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt
new file mode 100644
index 0000000000..66511eff60
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResource.pbtxt
@@ -0,0 +1,58 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResource"
+ in_arg {
+ name: "string_arg"
+ description: <<END
+String argument to the function call.
+END
+ }
+ in_arg {
+ name: "target_device"
+ description: <<END
+Target device to execute the function on.
+END
+ }
+ out_arg {
+ name: "resource"
+ description: <<END
+Handle to the resource created.
+END
+ }
+ attr {
+ name: "shared_name"
+ description: <<END
+If non-empty, this resource will be shared under the given name across
+multiple sessions.
+END
+ }
+ attr {
+ name: "container"
+ description: <<END
+If non-empty, this resource is placed in the given container.
+Otherwise, a default container is used.
+END
+ }
+ attr {
+ name: "f"
+ description: <<END
+Function to be executed.
+END
+ }
+ attr {
+ name: "buffer_size"
+ description: <<END
+Size of the buffer.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ summary: <<END
+Creates a resource that fills up a buffer by making function calls.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt
new file mode 100644
index 0000000000..bf4b66b22b
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceGetNext.pbtxt
@@ -0,0 +1,25 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResourceGetNext"
+ in_arg {
+ name: "function_buffer_resource"
+ description: <<END
+The FunctionBufferingResource handle.
+END
+ }
+ out_arg {
+ name: "output"
+ description: <<END
+A list of return values.
+END
+ }
+ attr {
+ name: "output_types"
+ description: <<END
+The type list for the return values.
+END
+ }
+ summary: <<END
+Gets the next element from a FunctionBufferingResource.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt
new file mode 100644
index 0000000000..729718ddb3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalFunctionBufferingResourceReset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalFunctionBufferingResourceReset"
+ in_arg {
+ name: "function_buffer_resource"
+ description: <<END
+The FunctionBufferingResource handle.
+END
+ }
+ summary: <<END
+Resets the FunctionBufferingResource.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt
new file mode 100644
index 0000000000..fe266c111f
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIdentityIndexedDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIdentityIndexedDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt
new file mode 100644
index 0000000000..d42546516d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIgnoreErrorsDataset.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalIgnoreErrorsDataset"
+ summary: <<END
+Creates a dataset that contains the elements of `input_dataset` ignoring errors.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt
new file mode 100644
index 0000000000..e285f87e10
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetGet.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIndexedDatasetGet"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt
new file mode 100644
index 0000000000..60c32473b5
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIndexedDatasetMaterialize.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalIndexedDatasetMaterialize"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt
new file mode 100644
index 0000000000..b72b229e9a
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalIteratorGetDevice.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalIteratorGetDevice"
+ summary: <<END
+Returns the name of the device on which `resource` has been placed.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt
new file mode 100644
index 0000000000..b38b23a51d
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalLMDBDataset.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalLMDBDataset"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt
new file mode 100644
index 0000000000..9676b9d284
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalMaterializedIndexDatasetHandle.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "ExperimentalMaterializedIndexDatasetHandle"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt
new file mode 100644
index 0000000000..d73b5bfda3
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolDataset.pbtxt
@@ -0,0 +1,13 @@
+op {
+ graph_op_name: "ExperimentalThreadPoolDataset"
+ in_arg {
+ name: "thread_pool"
+ description: <<END
+A resource produced by the ThreadPoolHandle op.
+END
+ }
+ summary: <<END
+Creates a dataset that uses a custom thread pool to compute `input_dataset`.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt
new file mode 100644
index 0000000000..48bf93406c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalThreadPoolHandle.pbtxt
@@ -0,0 +1,35 @@
+op {
+ graph_op_name: "ExperimentalThreadPoolHandle"
+ out_arg {
+ name: "handle"
+ description: <<END
+A resource that can be consumed by one or more ExperimentalThreadPoolDataset
+ops.
+END
+ }
+ attr {
+ name: "num_threads"
+ description: <<END
+The number of threads in the thread pool.
+END
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ description: <<END
+The maximum degree of parallelism to use within operations that execute on this
+threadpool.
+END
+ }
+ attr {
+ name: "display_name"
+ description: <<END
+A human-readable name for the threads that may be visible in some
+visualizations.
+threadpool.
+END
+ }
+ summary: <<END
+Creates a dataset that uses a custom thread pool to compute `input_dataset`.
+END
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt
new file mode 100644
index 0000000000..68ed797a0c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ExperimentalUniqueDataset.pbtxt
@@ -0,0 +1,8 @@
+op {
+ graph_op_name: "ExperimentalUniqueDataset"
+ summary: <<END
+Creates a dataset that contains the unique elements of `input_dataset`.
+END
+ visibility: HIDDEN
+}
+
diff --git a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
index 40d7d371ca..7142a0e3f2 100644
--- a/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Igamma.pbtxt
@@ -9,7 +9,7 @@ The lower regularized incomplete Gamma function is defined as:
where
-\\(gamma(a, x) = int_{0}^{x} t^{a-1} exp(-t) dt\\)
+\\(gamma(a, x) = \\int_{0}^{x} t^{a-1} exp(-t) dt\\)
is the lower incomplete Gamma function.
diff --git a/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
index b17806b338..5020844204 100644
--- a/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
+++ b/tensorflow/core/api_def/python_api/api_def_RegexReplace.pbtxt
@@ -1,10 +1,4 @@
op {
graph_op_name: "RegexReplace"
- endpoint {
- name: "strings.regex_replace"
- }
- endpoint {
- name: "regex_replace"
- deprecated: true
- }
+ visibility: HIDDEN
}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt
new file mode 100644
index 0000000000..d3c70190dd
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessMultinomial.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessMultinomial"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt
new file mode 100644
index 0000000000..e294325fb8
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessRandomNormal.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessRandomNormal"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt
new file mode 100644
index 0000000000..95d414c54a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessRandomUniform.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessRandomUniform"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt b/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt
new file mode 100644
index 0000000000..c72bdda94a
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_StatelessTruncatedNormal.pbtxt
@@ -0,0 +1,4 @@
+op {
+ graph_op_name: "StatelessTruncatedNormal"
+ visibility: HIDDEN
+}
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 841181f8c3..458e133b68 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -40,6 +40,7 @@ limitations under the License.
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/log_memory.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/run_handler.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/algorithm.h"
@@ -244,6 +245,21 @@ void DirectSession::SchedClosure(thread::ThreadPool* pool,
#endif // __ANDROID__
}
+static RunHandlerPool* GetOrCreateRunHandlerPool(
+ const SessionOptions& options) {
+ static RunHandlerPool* pool =
+ new RunHandlerPool(NumInterOpThreadsFromSessionOptions(options));
+ return pool;
+}
+
+bool DirectSession::ShouldUseRunHandlerPool() const {
+ if (options_.config.session_inter_op_thread_pool_size() > 0 ||
+ options_.config.use_per_session_threads()) {
+ return false;
+ }
+ return true;
+}
+
DirectSession::DirectSession(const SessionOptions& options,
const DeviceMgr* device_mgr,
DirectSessionFactory* const factory)
@@ -582,16 +598,37 @@ Status DirectSession::RunInternal(int64 step_id, const RunOptions& run_options,
}
}
- Executor::Args::Runner default_runner = [this,
- pool](Executor::Args::Closure c) {
- SchedClosure(pool, std::move(c));
- };
+ std::unique_ptr<RunHandler> handler;
+ if (ShouldUseRunHandlerPool() &&
+ run_options.experimental().use_run_handler_pool()) {
+ // Non-null only when a global inter-op pool is used.
+ VLOG(1) << "Using RunHandler to scheduler inter-op closures.";
+ handler = GetOrCreateRunHandlerPool(options_)->Get();
+ }
+ auto* handler_ptr = handler.get();
+
+ Executor::Args::Runner default_runner = nullptr;
+
+ if (pool == nullptr) {
+ default_runner = [](Executor::Args::Closure c) { c(); };
+ } else if (handler_ptr != nullptr) {
+ default_runner = [handler_ptr](Executor::Args::Closure c) {
+ handler_ptr->ScheduleInterOpClosure(std::move(c));
+ };
+ } else {
+ default_runner = [this, pool](Executor::Args::Closure c) {
+ SchedClosure(pool, std::move(c));
+ };
+ }
+
for (const auto& item : executors_and_keys->items) {
- // TODO(zhengxq): support partial run.
- // TODO(zhengxq): if the device picks its own threadpool, we need to assign
+ // TODO(azaks): support partial run.
+ // TODO(azaks): if the device picks its own threadpool, we need to assign
// less threads to the main compute pool by default.
thread::ThreadPool* device_thread_pool =
item.device->tensorflow_device_thread_pool();
+ // TODO(crk): Investigate usage of RunHandlerPool when using device specific
+ // thread pool(s).
if (!device_thread_pool) {
args.runner = default_runner;
} else {
diff --git a/tensorflow/core/common_runtime/direct_session.h b/tensorflow/core/common_runtime/direct_session.h
index 4a6a921ea7..3a168bbe3f 100644
--- a/tensorflow/core/common_runtime/direct_session.h
+++ b/tensorflow/core/common_runtime/direct_session.h
@@ -247,6 +247,9 @@ class DirectSession : public Session {
ExecutorsAndKeys* executors_and_keys,
RunMetadata* run_metadata);
+ // Returns whether inter-op execution uses a global pool.
+ bool ShouldUseRunHandlerPool() const;
+
::tensorflow::Status ExtendLocked(const GraphDef& graph)
EXCLUSIVE_LOCKS_REQUIRED(graph_state_lock_);
diff --git a/tensorflow/core/common_runtime/direct_session_test.cc b/tensorflow/core/common_runtime/direct_session_test.cc
index 65e816c202..e3e431f800 100644
--- a/tensorflow/core/common_runtime/direct_session_test.cc
+++ b/tensorflow/core/common_runtime/direct_session_test.cc
@@ -625,6 +625,34 @@ TEST_F(DirectSessionMinusAXTest, RunSimpleNetworkWithOpts_Callable) {
EXPECT_EQ(run_metadata.step_stats().dev_stats_size(), 2);
}
+TEST_F(DirectSessionMinusAXTest, UseRunHandlerPool) {
+ Initialize({3, 2, -1, 0});
+ auto session = CreateSession();
+ ASSERT_TRUE(session != nullptr);
+ TF_ASSERT_OK(session->Create(def_));
+ std::vector<std::pair<string, Tensor>> inputs;
+
+ // Request two targets: one fetch output and one non-fetched output.
+ std::vector<string> output_names = {y_ + ":0"};
+ std::vector<string> target_nodes = {y_neg_};
+ std::vector<Tensor> outputs;
+
+ // Prepares RunOptions and RunMetadata
+ RunOptions run_options;
+ run_options.mutable_experimental()->set_use_run_handler_pool(true);
+
+ Status s = session->Run(run_options, inputs, output_names, target_nodes,
+ &outputs, nullptr);
+ TF_ASSERT_OK(s);
+
+ ASSERT_EQ(1, outputs.size());
+ // The first output should be initialized and have the correct
+ // output.
+ auto mat = outputs[0].matrix<float>();
+ ASSERT_TRUE(outputs[0].IsInitialized());
+ EXPECT_FLOAT_EQ(5.0, mat(0, 0));
+}
+
TEST(DirectSessionTest, KeepsStateAcrossRunsOfSession) {
GraphDef def;
Graph g(OpRegistry::Global());
diff --git a/tensorflow/core/common_runtime/eager/BUILD b/tensorflow/core/common_runtime/eager/BUILD
index be5f3bae3a..7b74c67c85 100644
--- a/tensorflow/core/common_runtime/eager/BUILD
+++ b/tensorflow/core/common_runtime/eager/BUILD
@@ -147,10 +147,11 @@ tf_cuda_library(
"kernel_and_device.h",
],
visibility = ["//tensorflow:internal"],
- deps = select({
+ deps = [
+ "@farmhash_archive//:farmhash",
+ ] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
- "//util/hash:farmhash_fingerprint",
],
"//conditions:default": [
"//tensorflow/core:core_cpu_lib",
@@ -219,13 +220,13 @@ tf_cuda_library(
visibility = ["//tensorflow:internal"],
deps = [
":kernel_and_device",
+ "@farmhash_archive//:farmhash",
# Only the TF_AttrType enum is required, so pull in just the C headers.
# TODO(b/113535673): Break this dependency and avoid the C header completely.
"//tensorflow/c:c_api_headers",
] + select({
"//tensorflow:android": [
"//tensorflow/core:android_tensorflow_lib_lite",
- "//util/hash:farmhash_fingerprint",
],
"//conditions:default": [
"//tensorflow/core:core_cpu",
diff --git a/tensorflow/core/common_runtime/lower_if_op.cc b/tensorflow/core/common_runtime/lower_if_op.cc
index dfce7c23e7..a02084f223 100644
--- a/tensorflow/core/common_runtime/lower_if_op.cc
+++ b/tensorflow/core/common_runtime/lower_if_op.cc
@@ -38,11 +38,12 @@ class CondBuilder {
public:
enum Branch { kElseBranch = 0, kThenBranch = 1 };
- // Create a CondBuilder to create the lowering of If op. that has then and
+ // Create a CondBuilder to create the lowered form of `if_op` with then and
// else functions named `then_fn_name` and `else_fn_name` respectively in the
- // given graph.
+ // `graph`. The functions should be available in `flib`.
CondBuilder(Node* if_op, const string& then_fn_name,
- const string& else_fn_name, Graph* graph);
+ const string& else_fn_name, const FunctionLibraryDefinition& flib,
+ Graph* graph);
// Constructs the basic conditional control flow using switch and merge nodes.
Status CreatePivotNodes();
@@ -89,6 +90,7 @@ class CondBuilder {
Node* then_call_node_;
Node* else_call_node_;
Graph* graph_;
+ const FunctionLibraryDefinition& flib_;
string name_;
NodeBuilder then_call_builder_;
@@ -96,9 +98,11 @@ class CondBuilder {
};
CondBuilder::CondBuilder(Node* if_op, const string& then_fn_name,
- const string& else_fn_name, Graph* graph)
+ const string& else_fn_name,
+ const FunctionLibraryDefinition& flib, Graph* graph)
: if_op_(if_op),
graph_(graph),
+ flib_(flib),
name_(if_op->name()),
then_call_builder_(NewName("then"), then_fn_name, graph->op_registry()),
else_call_builder_(NewName("else"), else_fn_name, graph->op_registry()) {
@@ -193,15 +197,15 @@ Status CondBuilder::AddOutputs() {
return Status::OK();
}
-Status InlineCallInGraph(Node* n, Graph* g) {
- const auto& lib = g->flib_def();
- const FunctionDef* fdef = lib.Find(n->type_string());
+Status InlineCallInGraph(Node* n, const FunctionLibraryDefinition& flib,
+ Graph* g) {
+ const FunctionDef* fdef = flib.Find(n->type_string());
CHECK(fdef != nullptr);
FunctionBody* fbody;
TF_RETURN_IF_ERROR(
- FunctionDefToBodyHelper(*fdef, n->attrs(), &lib,
- [&lib](const string& op, const OpDef** sig) {
- return lib.LookUpOpDef(op, sig);
+ FunctionDefToBodyHelper(*fdef, n->attrs(), &flib,
+ [&flib](const string& op, const OpDef** sig) {
+ return flib.LookUpOpDef(op, sig);
},
&fbody));
// TODO(jpienaar): Improve this interface to make the need to delete it
@@ -219,8 +223,8 @@ Status CondBuilder::BuildLoweredIfOutput() {
}
Status CondBuilder::InlineCallNodes() {
- TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, graph_));
- TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, graph_));
+ TF_RETURN_IF_ERROR(InlineCallInGraph(then_call_node_, flib_, graph_));
+ TF_RETURN_IF_ERROR(InlineCallInGraph(else_call_node_, flib_, graph_));
return Status::OK();
}
@@ -240,6 +244,12 @@ Status LowerIfOpPass::Run(const GraphOptimizationPassOptions& options) {
return errors::Internal("Lowering If op requires a graph to be available.");
}
+ FunctionLibraryDefinition* flib = options.flib_def;
+ if (flib == nullptr) {
+ return errors::Internal(
+ "Lowering If op requires a FunctionLibraryDefinition to be available.");
+ }
+
// Match all the nodes that need to be rewritten.
gtl::InlinedVector<Node*, 2> matches;
for (Node* n : g->op_nodes()) {
@@ -251,12 +261,14 @@ Status LowerIfOpPass::Run(const GraphOptimizationPassOptions& options) {
}
}
for (Node* n : matches) {
- TF_RETURN_IF_ERROR(RewriteNode(n, g));
+ TF_RETURN_IF_ERROR(RewriteNode(n, *flib, g));
}
return Status::OK();
}
-Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) {
+Status LowerIfOpPass::RewriteNode(Node* n,
+ const FunctionLibraryDefinition& flib,
+ Graph* g) {
const AttrValue* then_attr = n->attrs().Find("then_branch");
if (then_attr == nullptr) {
return errors::InvalidArgument("Then branch function missing");
@@ -266,7 +278,8 @@ Status LowerIfOpPass::RewriteNode(Node* n, Graph* g) {
return errors::InvalidArgument("Else branch function missing");
}
- CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), g);
+ CondBuilder cb(n, then_attr->func().name(), else_attr->func().name(), flib,
+ g);
TF_RETURN_IF_ERROR(cb.CreatePivotNodes());
TF_RETURN_IF_ERROR(cb.AddInputs());
TF_RETURN_IF_ERROR(cb.AddOutputs());
diff --git a/tensorflow/core/common_runtime/lower_if_op.h b/tensorflow/core/common_runtime/lower_if_op.h
index a9ef39ae5c..5ab1123e3f 100644
--- a/tensorflow/core/common_runtime/lower_if_op.h
+++ b/tensorflow/core/common_runtime/lower_if_op.h
@@ -29,8 +29,9 @@ class LowerIfOpPass : public GraphOptimizationPass {
Status Run(const GraphOptimizationPassOptions& options) override;
private:
- // Rewrite the given If node `n` in graph `g` to use the switch-merge form.
- Status RewriteNode(Node* n, Graph* g);
+ // Rewrite the given If node `n` in graph `g` to use the switch-merge
+ // form. `flib` should contain the branch functions referenced by `n`.
+ Status RewriteNode(Node* n, const FunctionLibraryDefinition& flib, Graph* g);
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/lower_if_op_test.cc b/tensorflow/core/common_runtime/lower_if_op_test.cc
index 319a617b32..044a355d06 100644
--- a/tensorflow/core/common_runtime/lower_if_op_test.cc
+++ b/tensorflow/core/common_runtime/lower_if_op_test.cc
@@ -36,9 +36,7 @@ namespace tensorflow {
namespace {
Status Rewrite(std::unique_ptr<Graph>* graph) {
- FunctionDefLibrary flib;
- FunctionLibraryDefinition flib_def((*graph)->op_registry(), flib);
-
+ FunctionLibraryDefinition flib_def((*graph)->flib_def());
GraphOptimizationPassOptions opt_options;
opt_options.graph = graph;
opt_options.flib_def = &flib_def;
diff --git a/tensorflow/core/framework/dataset.h b/tensorflow/core/framework/dataset.h
index 697e0604bf..8c1151cb56 100644
--- a/tensorflow/core/framework/dataset.h
+++ b/tensorflow/core/framework/dataset.h
@@ -657,15 +657,15 @@ class DatasetBaseIterator : public IteratorBase {
// When performance modeling is enabled, this method adds a tunable parameter
// to the model node corresponding to this iterator.
//
- // The performance modeling logic may use `value` to set the value of the
+ // The performance modeling logic may use `state` to set the value of the
// tunable parameter at any point during the lifetime of this iterator. When
- // it does, it notifies `cond_var`.
+ // it does, it acquires `state->mu` and notifies `state->cond_var`.
void AddTunableParameter(IteratorContext* ctx, const string& name,
- std::atomic<int64>* value, int64 min, int64 max,
- condition_variable* cond_var) {
+ std::shared_ptr<model::SharedState> state, int64 min,
+ int64 max) {
if (ctx->model()) {
- ctx->model()->AddTunableParameter(prefix(), name, value, min, max,
- cond_var);
+ ctx->model()->AddTunableParameter(prefix(), name, std::move(state), min,
+ max);
}
}
diff --git a/tensorflow/core/framework/function.cc b/tensorflow/core/framework/function.cc
index a17959a448..20f957190b 100644
--- a/tensorflow/core/framework/function.cc
+++ b/tensorflow/core/framework/function.cc
@@ -1101,6 +1101,14 @@ Status FunctionLibraryDefinition::ReplaceFunction(const string& func,
return Status::OK();
}
+Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) {
+ mutex_lock l(mu_);
+ bool added;
+ TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name()));
+ TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added));
+ return Status::OK();
+}
+
Status FunctionLibraryDefinition::RemoveFunction(const string& func) {
const auto& i = function_defs_.find(func);
if (i == function_defs_.end()) {
diff --git a/tensorflow/core/framework/function.h b/tensorflow/core/framework/function.h
index e01eb7503d..4d6d68e214 100644
--- a/tensorflow/core/framework/function.h
+++ b/tensorflow/core/framework/function.h
@@ -331,6 +331,11 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
// a non-OK status if "func" was not found in the library, OK otherwise.
Status ReplaceFunction(const string& func, const FunctionDef& fdef);
+ // Replaces the gradient corresponding to `grad.function_name()`. Returns
+ // a non-OK status if "grad.function_name()" was not found in the library, OK
+ // otherwise.
+ Status ReplaceGradient(const GradientDef& grad);
+
// Adds the functions and gradients in 'other' to this function library.
// Duplicate functions and gradients are ignored.
// This operation is atomic.
diff --git a/tensorflow/core/framework/model.cc b/tensorflow/core/framework/model.cc
index b0330ec990..bfdb3a6658 100644
--- a/tensorflow/core/framework/model.cc
+++ b/tensorflow/core/framework/model.cc
@@ -296,12 +296,12 @@ void Model::AddProcessingTime(const string& name, int64 delta) {
void Model::AddTunableParameter(const string& node_name,
const string& parameter_name,
- std::atomic<int64>* value, int64 min, int64 max,
- condition_variable* cond_var) {
+ std::shared_ptr<SharedState> state, int64 min,
+ int64 max) {
tf_shared_lock l(mu_);
auto node = *gtl::FindOrNull(lookup_table_, node_name);
DCHECK(node);
- node->add_tunable_param(parameter_name, value, min, max, cond_var);
+ node->add_tunable_param(parameter_name, std::move(state), min, max);
}
// The optimization algorithm starts by setting all tunable parallelism
@@ -311,54 +311,55 @@ void Model::AddTunableParameter(const string& node_name,
// is less than or equal to the processing time needed to produce an element
// divided by CPU budget.
void Model::Optimize(int64 cpu_budget) {
- tf_shared_lock lock(mu_);
std::vector<std::shared_ptr<Model::Node::Tunable>> tunables;
- const int64 processing_time = ProcessingTime();
- tunables = CollectTunables();
- for (auto tunable : tunables) {
- tunable->value = 1;
- }
- while (true) {
- const int64 output_time = OutputTime();
- bool all_tunables = true;
- for (auto& tunable : tunables) {
- if (tunable->value < tunable->max) {
- all_tunables = false;
+ {
+ tf_shared_lock lock(mu_);
+ const int64 processing_time = ProcessingTime();
+ tunables = CollectTunables();
+ for (auto tunable : tunables) {
+ tunable->value = 1;
+ }
+ while (true) {
+ const int64 output_time = OutputTime();
+ bool all_tunables = true;
+ for (auto& tunable : tunables) {
+ if (tunable->value < tunable->max) {
+ all_tunables = false;
+ break;
+ }
+ }
+ if (output_time < processing_time / cpu_budget || all_tunables) {
break;
}
- }
- if (output_time < processing_time / cpu_budget || all_tunables) {
- break;
- }
- int64 best_delta = -1;
- Model::Node::Tunable* best_tunable = nullptr;
- for (auto& tunable : tunables) {
- if (tunable->value == tunable->max) {
- continue;
+ int64 best_delta = -1;
+ Model::Node::Tunable* best_tunable = nullptr;
+ for (auto& tunable : tunables) {
+ if (tunable->value == tunable->max) {
+ continue;
+ }
+ tunable->value++;
+ int64 delta = output_time - OutputTime();
+ if (delta > best_delta) {
+ best_delta = delta;
+ best_tunable = tunable.get();
+ }
+ tunable->value--;
}
- tunable->value++;
- int64 delta = output_time - OutputTime();
- if (delta > best_delta) {
- best_delta = delta;
- best_tunable = tunable.get();
+ if (!best_tunable) {
+ // NOTE: This can happen because we are performing the optimization
+ // while the model data is changing. If this becomes an issue, we should
+ // look into performing the optimization using a model snapshot.
+ break;
}
- tunable->value--;
+ best_tunable->value++;
}
- if (!best_tunable) {
- // NOTE: This can happen because we are performing the optimization
- // while the model data is changing. If this becomes an issue, we should
- // look into performing the optimization using a model snapshot.
- break;
- }
- best_tunable->value++;
}
VLOG(2) << "Number of knobs: " << tunables.size();
for (auto& tunable : tunables) {
VLOG(2) << "Setting tunable parameter: " << tunable->value;
- tunable->value_ptr->store(tunable->value);
- if (tunable->cond_var) {
- tunable->cond_var->notify_all();
- }
+ mutex_lock l(*tunable->state->mu);
+ tunable->state->value = tunable->value;
+ tunable->state->cond_var->notify_all();
}
}
diff --git a/tensorflow/core/framework/model.h b/tensorflow/core/framework/model.h
index 26402f5cd3..eae0fa70e8 100644
--- a/tensorflow/core/framework/model.h
+++ b/tensorflow/core/framework/model.h
@@ -33,6 +33,19 @@ namespace tensorflow {
namespace data {
namespace model {
+// Represents thread-safe state that can be shared between an input pipeline and
+// the performance model.
+struct SharedState {
+ public:
+ explicit SharedState(int64 value, std::shared_ptr<mutex> mu,
+ std::shared_ptr<condition_variable> cond_var)
+ : value(value), mu(std::move(mu)), cond_var(std::move(cond_var)) {}
+
+ std::shared_ptr<mutex> mu;
+ std::shared_ptr<condition_variable> cond_var;
+ int64 value;
+};
+
// Abstract representation of a TensorFlow input pipeline that can be used
// for collecting runtime information and optimizing performance. It collects
// runtime information about execution of the input pipeline that is used to
@@ -62,8 +75,8 @@ class Model {
// Adds a tunable parameter for the given node.
void AddTunableParameter(const string& node_name,
const string& parameter_name,
- std::atomic<int64>* value, int64 min, int64 max,
- condition_variable* cond_var) LOCKS_EXCLUDED(mu_);
+ std::shared_ptr<SharedState> value, int64 min,
+ int64 max) LOCKS_EXCLUDED(mu_);
// Runs optimization.
void Optimize(int64 cpu_budget) LOCKS_EXCLUDED(mu_);
@@ -109,13 +122,8 @@ class Model {
public:
// Represents a tunable parameter.
struct Tunable {
- Tunable(std::atomic<int64>* value, int64 min, int64 max,
- condition_variable* cond_var)
- : value(*value),
- min(min),
- max(max),
- value_ptr(value),
- cond_var(cond_var) {}
+ Tunable(std::shared_ptr<SharedState> state, int64 min, int64 max)
+ : value(state->value), min(min), max(max), state(std::move(state)) {}
// Identifies the model value of the parameter. This can be different from
// the actual value (e.g. during optimization search).
@@ -127,12 +135,8 @@ class Model {
// Identifies the maximum value of the parameter.
int64 max;
- // Points to the actual value of the parameter. Not owned.
- std::atomic<int64>* value_ptr;
-
- // If non-null, this condition variable is notified when the model updates
- // the actual value of the parameter (via `value_ptr`). Not owned.
- condition_variable* cond_var;
+ // Shared state of the parameter.
+ std::shared_ptr<SharedState> state;
};
Node(int64 id, const string& name, std::shared_ptr<Node> output)
@@ -158,12 +162,12 @@ class Model {
}
// Adds a tunable parameter.
- void add_tunable_param(const string& name, std::atomic<int64>* value,
- int64 min, int64 max, condition_variable* cond_var)
- LOCKS_EXCLUDED(mu_) {
+ void add_tunable_param(const string& name,
+ std::shared_ptr<SharedState> state, int64 min,
+ int64 max) LOCKS_EXCLUDED(mu_) {
mutex_lock l(mu_);
tunable_params_[name] =
- std::make_shared<Tunable>(value, min, max, cond_var);
+ std::make_shared<Tunable>(std::move(state), min, max);
}
// Returns the unique node ID.
diff --git a/tensorflow/core/framework/node_def_util.h b/tensorflow/core/framework/node_def_util.h
index 187bfa2c88..0ff67554eb 100644
--- a/tensorflow/core/framework/node_def_util.h
+++ b/tensorflow/core/framework/node_def_util.h
@@ -17,7 +17,6 @@ limitations under the License.
#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
#include <string>
-#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/attr_value_util.h"
diff --git a/tensorflow/core/framework/op.h b/tensorflow/core/framework/op.h
index 25f8de8dcc..81ed5f95f0 100644
--- a/tensorflow/core/framework/op.h
+++ b/tensorflow/core/framework/op.h
@@ -209,16 +209,16 @@ template <>
class OpDefBuilderWrapper<true> {
public:
OpDefBuilderWrapper(const char name[]) : builder_(name) {}
- OpDefBuilderWrapper<true>& Attr(StringPiece spec) {
- builder_.Attr(spec);
+ OpDefBuilderWrapper<true>& Attr(string spec) {
+ builder_.Attr(std::move(spec));
return *this;
}
- OpDefBuilderWrapper<true>& Input(StringPiece spec) {
- builder_.Input(spec);
+ OpDefBuilderWrapper<true>& Input(string spec) {
+ builder_.Input(std::move(spec));
return *this;
}
- OpDefBuilderWrapper<true>& Output(StringPiece spec) {
- builder_.Output(spec);
+ OpDefBuilderWrapper<true>& Output(string spec) {
+ builder_.Output(std::move(spec));
return *this;
}
OpDefBuilderWrapper<true>& SetIsCommutative() {
@@ -237,12 +237,12 @@ class OpDefBuilderWrapper<true> {
builder_.SetAllowsUninitializedInput();
return *this;
}
- OpDefBuilderWrapper<true>& Deprecated(int version, StringPiece explanation) {
- builder_.Deprecated(version, explanation);
+ OpDefBuilderWrapper<true>& Deprecated(int version, string explanation) {
+ builder_.Deprecated(version, std::move(explanation));
return *this;
}
- OpDefBuilderWrapper<true>& Doc(StringPiece text) {
- builder_.Doc(text);
+ OpDefBuilderWrapper<true>& Doc(string text) {
+ builder_.Doc(std::move(text));
return *this;
}
OpDefBuilderWrapper<true>& SetShapeFn(
diff --git a/tensorflow/core/framework/op_def_builder.cc b/tensorflow/core/framework/op_def_builder.cc
index 34a7a43d38..8a9bb63182 100644
--- a/tensorflow/core/framework/op_def_builder.cc
+++ b/tensorflow/core/framework/op_def_builder.cc
@@ -526,32 +526,32 @@ void FinalizeDoc(const string& text, OpDef* op_def,
} // namespace
-OpDefBuilder::OpDefBuilder(StringPiece op_name) {
- op_def()->set_name(string(op_name)); // NOLINT
+OpDefBuilder::OpDefBuilder(string op_name) {
+ op_def()->set_name(std::move(op_name));
}
-OpDefBuilder& OpDefBuilder::Attr(StringPiece spec) {
- attrs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Attr(string spec) {
+ attrs_.push_back(std::move(spec));
return *this;
}
-OpDefBuilder& OpDefBuilder::Input(StringPiece spec) {
- inputs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Input(string spec) {
+ inputs_.push_back(std::move(spec));
return *this;
}
-OpDefBuilder& OpDefBuilder::Output(StringPiece spec) {
- outputs_.emplace_back(spec.data(), spec.size());
+OpDefBuilder& OpDefBuilder::Output(string spec) {
+ outputs_.push_back(std::move(spec));
return *this;
}
#ifndef TF_LEAN_BINARY
-OpDefBuilder& OpDefBuilder::Doc(StringPiece text) {
+OpDefBuilder& OpDefBuilder::Doc(string text) {
if (!doc_.empty()) {
errors_.push_back(
strings::StrCat("Extra call to Doc() for Op ", op_def()->name()));
} else {
- doc_.assign(text.data(), text.size());
+ doc_ = std::move(text);
}
return *this;
}
@@ -577,14 +577,14 @@ OpDefBuilder& OpDefBuilder::SetAllowsUninitializedInput() {
return *this;
}
-OpDefBuilder& OpDefBuilder::Deprecated(int version, StringPiece explanation) {
+OpDefBuilder& OpDefBuilder::Deprecated(int version, string explanation) {
if (op_def()->has_deprecation()) {
errors_.push_back(
strings::StrCat("Deprecated called twice for Op ", op_def()->name()));
} else {
OpDeprecation* deprecation = op_def()->mutable_deprecation();
deprecation->set_version(version);
- deprecation->set_explanation(string(explanation));
+ deprecation->set_explanation(std::move(explanation));
}
return *this;
}
diff --git a/tensorflow/core/framework/op_def_builder.h b/tensorflow/core/framework/op_def_builder.h
index 0b39d6e848..8077b20598 100644
--- a/tensorflow/core/framework/op_def_builder.h
+++ b/tensorflow/core/framework/op_def_builder.h
@@ -51,7 +51,7 @@ struct OpRegistrationData {
class OpDefBuilder {
public:
// Constructs an OpDef with just the name field set.
- explicit OpDefBuilder(StringPiece op_name);
+ explicit OpDefBuilder(string op_name);
// Adds an attr to this OpDefBuilder (and returns *this). The spec has
// format "<name>:<type>" or "<name>:<type>=<default>"
@@ -84,7 +84,7 @@ class OpDefBuilder {
// * Ability to restrict the type of the tensor like the existing
// restrictions for type attrs.
// Perhaps by linking the type of the tensor to a type attr?
- OpDefBuilder& Attr(StringPiece spec);
+ OpDefBuilder& Attr(string spec);
// Adds an input or output to this OpDefBuilder (and returns *this).
// The spec has form "<name>:<type-expr>" or "<name>:Ref(<type-expr>)"
@@ -101,8 +101,8 @@ class OpDefBuilder {
// in the spec?
// TODO(josh11b): SparseInput() and SparseOutput() matching the Python
// handling?
- OpDefBuilder& Input(StringPiece spec);
- OpDefBuilder& Output(StringPiece spec);
+ OpDefBuilder& Input(string spec);
+ OpDefBuilder& Output(string spec);
// Turns on the indicated boolean flag in this OpDefBuilder (and
// returns *this).
@@ -112,7 +112,7 @@ class OpDefBuilder {
OpDefBuilder& SetAllowsUninitializedInput();
// Deprecate the op at a certain GraphDef version.
- OpDefBuilder& Deprecated(int version, StringPiece explanation);
+ OpDefBuilder& Deprecated(int version, string explanation);
// Adds docs to this OpDefBuilder (and returns *this).
// Docs have the format:
@@ -128,9 +128,9 @@ class OpDefBuilder {
// to suppress the automatically-generated type documentation in
// generated output.
#ifndef TF_LEAN_BINARY
- OpDefBuilder& Doc(StringPiece text);
+ OpDefBuilder& Doc(string text);
#else
- OpDefBuilder& Doc(StringPiece text) { return *this; }
+ OpDefBuilder& Doc(string text) { return *this; }
#endif
// Sets the shape function to be used for shape inference.
diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h
index abb6635984..4a531648d9 100644
--- a/tensorflow/core/framework/resource_mgr.h
+++ b/tensorflow/core/framework/resource_mgr.h
@@ -248,10 +248,16 @@ Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
ResourceHandle* handle);
// Create a resource pointed by a given resource handle.
+//
+// If successful, the caller transfers the ownership of one ref on `resource` to
+// `ctx->resource_mgr()`.
template <typename T>
Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value);
// Looks up a resource pointed by a given resource handle.
+//
+// If the lookup is successful, the caller takes the ownership of one ref on
+// `*value`, and must call its `Unref()` method when it has finished using it.
template <typename T>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p, T** value);
@@ -262,6 +268,11 @@ Status LookupResources(
std::vector<std::unique_ptr<T, core::RefCountDeleter>>* values);
// Looks up or creates a resource.
+//
+// If successful, the caller takes the ownership of one ref on `*value`, and
+// must call its `Unref()` method when it has finished using it. If the
+// `creator` is invoked, its reference on the created resource is transferred
+// to `ctx->resource_mgr()`.
template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value, std::function<Status(T**)> creator);
diff --git a/tensorflow/core/framework/run_handler.cc b/tensorflow/core/framework/run_handler.cc
new file mode 100644
index 0000000000..0c4007eafc
--- /dev/null
+++ b/tensorflow/core/framework/run_handler.cc
@@ -0,0 +1,249 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#define EIGEN_USE_THREADS
+
+#include "tensorflow/core/framework/run_handler.h"
+
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/framework/run_handler_util.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/util/ptr_util.h"
+
+namespace tensorflow {
+
+// Contains the concrete implementation of the RunHandler.
+// Externally visible RunHandler class simply forwards the work to this one.
+class RunHandler::Impl {
+ public:
+ explicit Impl(RunHandlerPool::Impl* pool_impl) : pool_impl_(pool_impl) {
+ Reset();
+ }
+
+ ~Impl() {}
+
+ void set_inter_op_scheduling_range(std::uint_fast32_t start,
+ std::uint_fast32_t limit) {
+ inter_op_scheduling_range_.store(EncodePartition(start, limit),
+ std::memory_order_release);
+ }
+
+ std::uint_fast32_t inter_op_scheduling_range() const {
+ return inter_op_scheduling_range_.load(std::memory_order_acquire);
+ }
+
+ // Stores now time (in microseconds) since unix epoch when the handler is
+ // requested via RunHandlerPool::Get().
+ uint64 start_time_us() const { return start_time_us_; }
+
+ void ScheduleInterOpClosure(std::function<void()> fn);
+
+ void Reset();
+
+ RunHandlerPool::Impl* pool_impl() { return pool_impl_; }
+
+ private:
+ // Encoding/decoding logic for storing [start, limit) into a single
+ // uint_fast32_t int. We assume that pool_num_threads < (1 << 16).
+ const int kMaxPartitionBits = 16;
+ const int kMaxThreads = 1 << kMaxPartitionBits;
+
+ std::uint_fast32_t EncodePartition(std::uint_fast32_t start,
+ std::uint_fast32_t limit) {
+ return (start << kMaxPartitionBits) | limit;
+ }
+
+ void DecodePartition(std::uint_fast32_t val, std::uint_fast32_t* start,
+ std::uint_fast32_t* limit) {
+ *limit = val & (kMaxThreads - 1);
+ val >>= kMaxPartitionBits;
+ *start = val;
+ }
+
+ std::atomic_uint_fast32_t inter_op_scheduling_range_;
+ RunHandlerPool::Impl* pool_impl_; // NOT OWNED.
+ uint64 start_time_us_;
+};
+
+// Contains shared state across all run handlers present in the pool. Also
+// responsible for pool management decisions.
+// This class is thread safe.
+class RunHandlerPool::Impl {
+ public:
+ explicit Impl(int num_inter_op_threads)
+ : max_handlers_(128),
+ inter_op_thread_pool_(new thread::ThreadPool(
+ Env::Default(), ThreadOptions(), "inter_op", num_inter_op_threads)),
+ iterations_(0) {
+ VLOG(1) << "Creating a RunHandlerPool with max handlers: " << max_handlers_;
+ for (int i = 0; i < max_handlers_; ++i) {
+ handlers_.emplace_back(new RunHandler::Impl(this));
+ free_handlers_.push_back(handlers_.back().get());
+ }
+ }
+
+ ~Impl() {
+ // Sanity check that all handlers have been returned back to the pool before
+ // destruction.
+ DCHECK_EQ(handlers_.size(), max_handlers_);
+ DCHECK_EQ(free_handlers_.size(), handlers_.size());
+ DCHECK_EQ(sorted_active_handlers_.size(), 0);
+ }
+
+ thread::ThreadPool* inter_op_thread_pool() const {
+ return inter_op_thread_pool_.get();
+ }
+
+ std::unique_ptr<RunHandler> Get() LOCKS_EXCLUDED(mu_) {
+ mutex_lock l(mu_);
+ while (free_handlers_.empty()) {
+ one_handler_free_.wait(l);
+ }
+ // Remove the last entry from free_handlers_ and add to the end of
+ // sorted_active_handlers_.
+ auto* handler_impl = free_handlers_.back();
+ handler_impl->Reset();
+ // Sortedness isn't violated if we simply add at the end of the list, since
+ // handlers are expected to be obtained in increasing order of time.
+ sorted_active_handlers_.push_back(handler_impl);
+ DCHECK_LE(sorted_active_handlers_.size(), max_handlers_);
+ free_handlers_.pop_back();
+
+ RecomputePoolStatsLocked();
+ return WrapUnique<RunHandler>(new RunHandler(handler_impl));
+ }
+
+ void ReleaseHandler(RunHandler::Impl* handler) LOCKS_EXCLUDED(mu_) {
+ {
+ mutex_lock l(mu_);
+ DCHECK_GT(sorted_active_handlers_.size(), 0);
+
+ uint64 now = tensorflow::Env::Default()->NowMicros();
+ double elapsed = (now - handler->start_time_us()) / 1000.0;
+ time_hist_.Add(elapsed);
+
+ // Erase from and update sorted_active_handlers_. Add it to the end of
+ // free_handlers_.
+ auto iter = std::find(sorted_active_handlers_.begin(),
+ sorted_active_handlers_.end(), handler);
+ DCHECK(iter != sorted_active_handlers_.end())
+ << "Unexpected handler: " << handler
+ << " is being requested for release";
+
+ // Remove this handler from this list and add it to the list of free
+ // handlers.
+ sorted_active_handlers_.erase(iter);
+ free_handlers_.push_back(handler);
+ DCHECK_LE(free_handlers_.size(), max_handlers_);
+
+ RecomputePoolStatsLocked();
+ }
+ one_handler_free_.notify_one();
+ }
+
+ private:
+ void RecomputePoolStatsLocked() EXCLUSIVE_LOCKS_REQUIRED(mu_);
+
+ // Maximum number of handlers pre-created during pool construction time. The
+ // number has been chosen expecting each handler might at least want 1
+ // inter-op thread for execution (during compute intensive workloads like
+ // inference).
+ const int max_handlers_;
+
+ // Thread safe part.
+ const std::unique_ptr<thread::ThreadPool> inter_op_thread_pool_;
+
+ // Thread compatible part used only by lock under RunHandlerPool.
+ // Handlers are sorted by start time.
+ std::vector<RunHandler::Impl*> sorted_active_handlers_ GUARDED_BY(mu_);
+ std::vector<RunHandler::Impl*> free_handlers_ GUARDED_BY(mu_);
+ std::vector<std::unique_ptr<RunHandler::Impl>> handlers_ GUARDED_BY(mu_);
+ // Histogram of elapsed runtime of every handler (in ms).
+ histogram::Histogram time_hist_ GUARDED_BY(mu_);
+ std::vector<std::uint_fast32_t> inter_op_start_ GUARDED_BY(mu_);
+ std::vector<std::uint_fast32_t> inter_op_limit_ GUARDED_BY(mu_);
+ int64 iterations_ GUARDED_BY(mu_);
+ condition_variable one_handler_free_;
+ mutex mu_;
+};
+
+void RunHandlerPool::Impl::RecomputePoolStatsLocked() {
+ int num_active_requests = sorted_active_handlers_.size();
+ if (num_active_requests == 0) return;
+
+ int num_threads = inter_op_thread_pool_->NumThreads();
+
+ inter_op_start_.resize(num_active_requests);
+ inter_op_limit_.resize(num_active_requests);
+
+ const int kMinThreadsPerRequest = 3;
+ ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
+ kMinThreadsPerRequest, &inter_op_start_,
+ &inter_op_limit_);
+
+ for (int i = 0; i < num_active_requests; ++i) {
+ sorted_active_handlers_[i]->set_inter_op_scheduling_range(
+ inter_op_start_[i], inter_op_limit_[i]);
+ }
+
+ if (iterations_++ % 5000 == 0 && VLOG_IS_ON(1)) {
+ VLOG(1) << "Printing time histogram: " << time_hist_.ToString();
+ VLOG(1) << "Active session runs: " << num_active_requests;
+ uint64 now = tensorflow::Env::Default()->NowMicros();
+ string ranges_str = "";
+ string times_str = "";
+ for (int i = 0; i < num_active_requests; ++i) {
+ if (i > 0) {
+ times_str += " ";
+ ranges_str += " ";
+ }
+
+ times_str += strings::StrCat(
+ (now - sorted_active_handlers_[i]->start_time_us()) / 1000.0, " ms.");
+ ranges_str += strings::StrCat("[", inter_op_start_[i], ", ",
+ inter_op_limit_[i], ")");
+ }
+ VLOG(1) << "Elapsed times are: " << times_str;
+ VLOG(1) << "Ranges are: " << ranges_str;
+ }
+}
+
+void RunHandler::Impl::ScheduleInterOpClosure(std::function<void()> fn) {
+ std::uint_fast32_t start = 0, limit = 0;
+ DecodePartition(inter_op_scheduling_range(), &start, &limit);
+ pool_impl_->inter_op_thread_pool()->Schedule(std::move(fn));
+}
+
+void RunHandler::Impl::Reset() {
+ set_inter_op_scheduling_range(
+ 0, pool_impl_->inter_op_thread_pool()->NumThreads());
+ start_time_us_ = tensorflow::Env::Default()->NowMicros();
+}
+
+RunHandlerPool::RunHandlerPool(int num_inter_op_threads)
+ : impl_(new Impl(num_inter_op_threads)) {}
+
+RunHandlerPool::~RunHandlerPool() {}
+
+std::unique_ptr<RunHandler> RunHandlerPool::Get() { return impl_->Get(); }
+
+RunHandler::RunHandler(Impl* impl) : impl_(impl) {}
+
+void RunHandler::ScheduleInterOpClosure(std::function<void()> fn) {
+ impl_->ScheduleInterOpClosure(std::move(fn));
+}
+
+RunHandler::~RunHandler() { impl_->pool_impl()->ReleaseHandler(impl_); }
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/run_handler.h b/tensorflow/core/framework/run_handler.h
new file mode 100644
index 0000000000..72fa6301b4
--- /dev/null
+++ b/tensorflow/core/framework/run_handler.h
@@ -0,0 +1,95 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
+
+#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/histogram/histogram.h"
+#include "tensorflow/core/platform/mutex.h"
+#include "tensorflow/core/platform/thread_annotations.h"
+#include "tensorflow/core/protobuf/config.pb.h"
+
+namespace tensorflow {
+
+class RunHandler;
+
+// RunHandlerPool is a fixed size pool of pre-allocated RunHandlers
+// that can be used for tracking inter-op work for a given Session::Run().
+// RunHandler(s) in the pool are initially 'inactive'. A RunHandler becomes
+// 'active' when its unique_ptr is returned by Get() and is being used by a
+// client. It becomes 'inactive' once more when its unique_ptr gets destroyed.
+//
+// Expected usage:
+//
+// * Create a single RunHandlerPool (say run_handler_pool_).
+//
+// * When a Session::Run() is invoked, obtain a handler by:
+// auto handler = run_handler_pool_->Get();
+//
+// * Use handler for scheduling all inter-op work by:
+// handler->ScheduleInterOpClosure(closure);
+//
+// This class is thread safe.
+class RunHandlerPool {
+ public:
+ explicit RunHandlerPool(int num_inter_op_threads);
+ ~RunHandlerPool();
+
+ // Returns an inactive RunHandler from the pool.
+ //
+ // RunHandlers in RunHandlerPool are initially 'inactive'.
+ // A RunHandler becomes 'active' when its unique_ptr its returned by Get()
+ // and is being used by a client. It becomes 'inactive' once more when the
+ // unique_ptr is destroyed.
+ //
+ // Will block unless there is an inactive handler.
+ std::unique_ptr<RunHandler> Get();
+
+ private:
+ class Impl;
+ friend class RunHandler;
+
+ std::unique_ptr<Impl> impl_;
+};
+
+// RunHandler can be used to schedule inter-op closures to run on a global pool
+// shared across all Session::Run(s).
+//
+// It can only be created via RunHandlerPool::Get().
+//
+// This class can be used instead of directly scheduling closures on a global
+// pool since it maintains a global view across all sessions and optimizes pool
+// scheduling to improve (median and tail) latency.
+//
+// This class is thread safe.
+class RunHandler {
+ public:
+ void ScheduleInterOpClosure(std::function<void()> fn);
+
+ ~RunHandler();
+
+ private:
+ class Impl;
+ friend class RunHandlerPool::Impl;
+
+ explicit RunHandler(Impl* impl);
+
+ Impl* impl_; // NOT OWNED.
+};
+
+} // end namespace tensorflow.
+
+#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_H_
diff --git a/tensorflow/core/framework/run_handler_util.cc b/tensorflow/core/framework/run_handler_util.cc
new file mode 100644
index 0000000000..3087998c69
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util.cc
@@ -0,0 +1,57 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/run_handler_util.h"
+
+#include <algorithm>
+#include <cmath>
+#include "tensorflow/core/platform/logging.h"
+
+namespace tensorflow {
+
+void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
+ int min_threads_per_request,
+ std::vector<std::uint_fast32_t>* start_vec,
+ std::vector<std::uint_fast32_t>* end_vec) {
+ // Each request is expected to have weight W[i] = num_active_requests - i.
+ // Therefore, total_weight = sum of all request weights.
+ float total_weight = 0.5f * num_active_requests * (num_active_requests + 1);
+ float demand_factor = static_cast<float>(num_threads) / total_weight;
+ float last_cumulative_weight = 0.0;
+ min_threads_per_request = std::max(1, min_threads_per_request);
+ for (int i = 0; i != num_active_requests; i++) {
+ float cumulative_weight =
+ static_cast<float>(i + 1) *
+ (num_active_requests - static_cast<float>(i) * 0.5f);
+ float weight = cumulative_weight - last_cumulative_weight;
+ // Quantize thread_demand by rounding up, and also satisfying
+ // `min_threads_per_request` constraint.
+ // Note: We subtract a small epsilon (0.00001) to prevent ceil(..) from
+ // rounding weights like 4.0 to 5.
+ int demand =
+ std::max(min_threads_per_request,
+ static_cast<int>(ceil(weight * demand_factor - 0.00001f)));
+ // For the quantized range [start, end); compute the floor of real start,
+ // and expand downwards from there with length `demand` and adjust for
+ // boundary conditions.
+ int start = last_cumulative_weight * demand_factor;
+ int end = std::min(num_threads, start + demand);
+ start = std::max(0, std::min(start, end - demand));
+ start_vec->at(i) = start;
+ end_vec->at(i) = end;
+ last_cumulative_weight = cumulative_weight;
+ }
+}
+} // namespace tensorflow
diff --git a/tensorflow/core/framework/run_handler_util.h b/tensorflow/core/framework/run_handler_util.h
new file mode 100644
index 0000000000..c0c36aeccb
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util.h
@@ -0,0 +1,43 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
+
+#include <cstdint>
+#include <vector>
+
+namespace tensorflow {
+
+// Assign thread ranges to requests.
+// Requests are numbered 0...num_active_requests-1, and
+// threads are numbered 0...num_threads-1.
+// On return, the range start_vec->at(i)...end_vec->at(i)-1
+// indicates the subrange of the threads available to request i.
+// The ranges given to different requests may overlap.
+// Lower numbered requests will tend to be assigned more threads.
+// Thus, a client might associate older requests with lower
+// array indices so they receive access to more threads.
+// However, the routine ensures that each request is given access
+// to at least min(min_threads_per_request, num_threads) threads.
+// Every thread will be assigned to at least one request range,
+// assuming there is at least one request.
+void ComputeInterOpSchedulingRanges(int num_active_requests, int num_threads,
+ int min_threads_per_request,
+ std::vector<std::uint_fast32_t>* start_vec,
+ std::vector<std::uint_fast32_t>* end_vec);
+
+} // end namespace tensorflow
+#endif // TENSORFLOW_CORE_FRAMEWORK_RUN_HANDLER_UTIL_H_
diff --git a/tensorflow/core/framework/run_handler_util_test.cc b/tensorflow/core/framework/run_handler_util_test.cc
new file mode 100644
index 0000000000..a1928c132b
--- /dev/null
+++ b/tensorflow/core/framework/run_handler_util_test.cc
@@ -0,0 +1,93 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/framework/run_handler_util.h"
+
+#include <vector>
+#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/core/platform/test.h"
+namespace tensorflow {
+namespace {
+
+void VerifyFunction(int num_active_requests, int num_threads,
+ int min_threads_per_request, bool print_stats = false) {
+ if (print_stats) {
+ LOG(INFO) << "Test case# num_active_requests: " << num_active_requests
+ << " num_threads: " << num_threads
+ << " min_threads: " << min_threads_per_request;
+ }
+ std::vector<std::uint_fast32_t> start(num_active_requests);
+ std::vector<std::uint_fast32_t> end(num_active_requests);
+
+ ComputeInterOpSchedulingRanges(num_active_requests, num_threads,
+ min_threads_per_request, &start, &end);
+ string range_str = "";
+ for (int i = 0; i < num_active_requests; ++i) {
+ if (i > 0) range_str += " ";
+ range_str += strings::StrCat("[", start[i], ", ", end[i], ")");
+
+ ASSERT_GE(start[i], 0) << range_str;
+ ASSERT_LE(end[i], num_threads) << range_str;
+ if (i > 0) {
+ // Due to linearly decreasing demand, #threads(i - 1) >= #threads(i)
+ ASSERT_GE(end[i - 1] - start[i - 1], end[i] - start[i]) << range_str;
+ // No missing threads.
+ ASSERT_GE(end[i - 1], start[i]) << range_str;
+ }
+ // Each interval is at least of size 'min_threads_per_request'.
+ ASSERT_GE((end[i] - start[i]), min_threads_per_request) << range_str;
+ // Verify that assigned (quantized) threads is not overly estimated
+ // from real demand, when the demand is high (>=
+ // min_threads_per_request).
+ float entry_weight = num_active_requests - i;
+ float total_weight = 0.5f * num_active_requests * (num_active_requests + 1);
+ float thread_demand = (entry_weight * num_threads) / total_weight;
+ if (thread_demand > min_threads_per_request) {
+ // We expect some over-estimation of threads due to quantization,
+ // but we hope it's not more than 1 extra thread.
+ ASSERT_NEAR(end[i] - start[i], thread_demand, 1.0)
+ << "Ranges: " << range_str << " thread_demand: " << thread_demand
+ << " i: " << i;
+ }
+ }
+ ASSERT_EQ(end[num_active_requests - 1], num_threads);
+ ASSERT_EQ(start[0], 0);
+ if (print_stats) {
+ LOG(INFO) << "Assigned ranges: " << range_str;
+ }
+}
+
+TEST(RunHandlerUtilTest, TestComputeInterOpSchedulingRanges) {
+ const int kMinThreadsPerRequestBound = 12;
+ const int kMaxActiveRequests = 128;
+ const int kMaxThreads = 128;
+
+ for (int min_threads_per_request = 1;
+ min_threads_per_request <= kMinThreadsPerRequestBound;
+ ++min_threads_per_request) {
+ for (int num_active_requests = 1; num_active_requests <= kMaxActiveRequests;
+ ++num_active_requests) {
+ for (int num_threads = min_threads_per_request;
+ num_threads <= kMaxThreads; ++num_threads) {
+ VerifyFunction(num_active_requests, num_threads,
+ min_threads_per_request);
+ }
+ }
+ }
+}
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/graph/graph.cc b/tensorflow/core/graph/graph.cc
index 1630ab7a15..4c0cd14ff1 100644
--- a/tensorflow/core/graph/graph.cc
+++ b/tensorflow/core/graph/graph.cc
@@ -643,7 +643,7 @@ Status Graph::IsValidNode(const Node* node) const {
Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
TF_RETURN_IF_ERROR(IsValidNode(node));
- if (idx >= node->num_outputs()) {
+ if (idx >= node->num_outputs() || idx < 0) {
return errors::OutOfRange("Node '", node->name(), "' (type: '",
node->op_def().name(),
"', num of outputs: ", node->num_outputs(),
@@ -654,7 +654,7 @@ Status Graph::IsValidOutputTensor(const Node* node, int idx) const {
Status Graph::IsValidInputTensor(const Node* node, int idx) const {
TF_RETURN_IF_ERROR(IsValidNode(node));
- if (idx >= node->num_inputs()) {
+ if (idx >= node->num_inputs() || idx < 0) {
return errors::OutOfRange("Node '", node->name(), "' (type: '",
node->op_def().name(),
"', num of inputs: ", node->num_inputs(),
diff --git a/tensorflow/core/graph/graph.h b/tensorflow/core/graph/graph.h
index 52e9f23a76..72cef07072 100644
--- a/tensorflow/core/graph/graph.h
+++ b/tensorflow/core/graph/graph.h
@@ -590,12 +590,12 @@ class Graph {
// Returns OK if `node` is non-null and belongs to this graph
Status IsValidNode(const Node* node) const;
- // Returns OK if IsValidNode(`node`) and `idx` is less than
- // node->num_outputs()
+ // Returns OK if IsValidNode(`node`) and `idx` is a valid output. Does not
+ // accept control outputs.
Status IsValidOutputTensor(const Node* node, int idx) const;
- // Returns OK if IsValidNode(`node`) and `idx` is less than
- // node->num_inputs()
+ // Returns OK if IsValidNode(`node`) and `idx` a valid input. Does not accept
+ // control inputs.
Status IsValidInputTensor(const Node* node, int idx) const;
// Create and return a new WhileContext owned by this graph. This is called
diff --git a/tensorflow/core/graph/node_builder.cc b/tensorflow/core/graph/node_builder.cc
index a446e0d136..d92874909f 100644
--- a/tensorflow/core/graph/node_builder.cc
+++ b/tensorflow/core/graph/node_builder.cc
@@ -99,6 +99,11 @@ NodeBuilder& NodeBuilder::Device(StringPiece device_spec) {
return *this;
}
+NodeBuilder& NodeBuilder::AssignedDevice(StringPiece device) {
+ assigned_device_ = string(device);
+ return *this;
+}
+
Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
// In case of error, set *created_node to nullptr.
if (created_node != nullptr) *created_node = nullptr;
@@ -115,6 +120,8 @@ Status NodeBuilder::Finalize(Graph* graph, Node** created_node) const {
Node* node = graph->AddNode(node_def, &status);
if (!status.ok()) return status;
+ node->set_assigned_device_name(assigned_device_);
+
for (size_t i = 0; i < inputs_.size(); ++i) {
if (inputs_[i].node != nullptr) { // Skip back edges.
graph->AddEdge(inputs_[i].node, inputs_[i].index, node, i);
diff --git a/tensorflow/core/graph/node_builder.h b/tensorflow/core/graph/node_builder.h
index 4727ee7b56..d576985a23 100644
--- a/tensorflow/core/graph/node_builder.h
+++ b/tensorflow/core/graph/node_builder.h
@@ -100,6 +100,9 @@ class NodeBuilder {
// "assigned device" in the Node).
NodeBuilder& Device(StringPiece device_spec);
+ // Sets the device name in the "assigned device" field in tensorflow::Node.
+ NodeBuilder& AssignedDevice(StringPiece device);
+
// Set the value of an attr. attr_name must match the name of one of
// attrs defined by the Op, and value must have the corresponding type
// (see SetAttrValue() in ../framework/attr_value_util.h for legal
@@ -141,6 +144,7 @@ class NodeBuilder {
std::vector<NodeOut> inputs_;
std::vector<Node*> control_inputs_;
std::vector<string> errors_;
+ string assigned_device_;
};
// IMPLEMENTATION -------------------------------------------------------------
diff --git a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
index c94ee2f227..0ec95dd684 100644
--- a/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
+++ b/tensorflow/core/grappler/costs/graph_properties_testdata/function_functional_while.pbtxt
@@ -88,6 +88,13 @@ library {
}
}
}
+ attr {
+ key: "output_shapes"
+ value {
+ list {
+ }
+ }
+ }
}
ret {
key: "while"
diff --git a/tensorflow/core/grappler/grappler_item.cc b/tensorflow/core/grappler/grappler_item.cc
index bbc0fedd22..2c490f3966 100644
--- a/tensorflow/core/grappler/grappler_item.cc
+++ b/tensorflow/core/grappler/grappler_item.cc
@@ -38,6 +38,7 @@ GrapplerItem::GrapplerItem(const GrapplerItem& other, GraphDef* graph_def) {
restore_op = other.restore_op;
save_restore_loc_tensor = other.save_restore_loc_tensor;
queue_runners = other.queue_runners;
+ allowed_optimizations = other.allowed_optimizations;
graph.Swap(graph_def);
}
diff --git a/tensorflow/core/grappler/grappler_item.h b/tensorflow/core/grappler/grappler_item.h
index 939e5fa046..a0748abfe6 100644
--- a/tensorflow/core/grappler/grappler_item.h
+++ b/tensorflow/core/grappler/grappler_item.h
@@ -77,6 +77,15 @@ struct GrapplerItem {
// Return a set of node names that must be preserved. This includes feed and
// fetch nodes, keep_ops, init_ops.
std::unordered_set<string> NodesToPreserve() const;
+
+ // Restrict types of optimizations that are allowed for this GrapplerItem.
+ struct AllowedOptimizations {
+ // Is it allowed to add nodes to the graph that do not have registered
+ // gradient function.
+ bool non_differentiable_rewrites = true;
+ };
+
+ AllowedOptimizations allowed_optimizations;
};
// Return the transitive fanin of a set of terminal nodes.
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 3521669b63..9f0d9dbf28 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -425,6 +425,10 @@ bool IsSwitch(const NodeDef& node) {
return op == "Switch" || op == "RefSwitch";
}
+bool IsSymbolicGradient(const NodeDef& node) {
+ return node.op() == "SymbolicGradient";
+}
+
bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 25ab6b65ac..7f86a5f295 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -149,6 +149,7 @@ bool IsStridedSliceGrad(const NodeDef& node);
bool IsSub(const NodeDef& node);
bool IsSum(const NodeDef& node);
bool IsSwitch(const NodeDef& node);
+bool IsSymbolicGradient(const NodeDef& node);
bool IsTanhGrad(const NodeDef& node);
bool IsTile(const NodeDef& node);
bool IsTranspose(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 960d1addb3..c708f84948 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -525,6 +525,7 @@ cc_library(
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/utils:colocation",
@@ -541,6 +542,7 @@ tf_cuda_cc_test(
":custom_graph_optimizer_registry",
":meta_optimizer",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 3388ee8035..7d5014ee0a 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -3249,6 +3249,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
optimized_graph_ = &optimized_item.graph;
node_map_.reset(new NodeMap(optimized_graph_));
+ // Disable restricted graph rewrites.
+ options_.unary_ops_composition &=
+ item.allowed_optimizations.non_differentiable_rewrites;
+
if (options_.dedup_computations) {
DedupComputations();
}
diff --git a/tensorflow/core/grappler/optimizers/data/BUILD b/tensorflow/core/grappler/optimizers/data/BUILD
index 81c1bddf67..5a3abbb545 100644
--- a/tensorflow/core/grappler/optimizers/data/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/BUILD
@@ -124,10 +124,10 @@ cc_library(
],
visibility = ["//visibility:public"],
deps = [
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core/grappler:mutable_graph_view",
- "//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
] + tf_protos_all(),
)
@@ -523,6 +523,7 @@ cc_library(
":function_utils",
":graph_utils",
"@com_google_absl//absl/strings",
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
@@ -538,6 +539,7 @@ tf_cc_test(
srcs = ["vectorization_utils_test.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":graph_utils",
":function_utils",
":vectorization_utils",
"//tensorflow/core:framework",
@@ -547,7 +549,10 @@ tf_cc_test(
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer",
"//tensorflow/core/grappler/optimizers:custom_graph_optimizer_registry",
+ # For ops we need registered
+ "//tensorflow/core/kernels/data:dataset_ops",
"//tensorflow/core/kernels:cast_op",
+ "//tensorflow/core/kernels:logging_ops",
"//tensorflow/tools/graph_transforms:transform_utils",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/graph_utils.h b/tensorflow/core/grappler/optimizers/data/graph_utils.h
index 5dd7819100..3af34f6904 100644
--- a/tensorflow/core/grappler/optimizers/data/graph_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/graph_utils.h
@@ -116,8 +116,8 @@ std::vector<int> FindAllGraphNodesWithOp(const string& op,
// is unique across the graph.
void SetUniqueGraphNodeName(StringPiece prefix, GraphDef* graph, NodeDef* node);
-// Sets the node name using the `prefix` name as a prefix while guaranteeing the
-// name is unique across the graph.
+// Sets the function name using the `prefix` name as a prefix while guaranteeing
+// the name is unique across the function library.
void SetUniqueGraphFunctionName(StringPiece prefix, FunctionDefLibrary* library,
FunctionDef* function);
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
index 32ab912619..9328a7ca99 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization.cc
@@ -86,21 +86,19 @@ FunctionDef* AddVectorizedFunction(const NodeDef& map_node,
// efficient vectorization with VectorizeMapDefun.
FunctionDef* vectorized_func =
CreateMapDefunWrapper(map_node, orig_func, library);
- NodeDef* map_defun_node = vectorized_func->mutable_node_def()->Mutable(0);
- DCHECK_EQ(map_defun_node->op(), "MapDefun");
-
- // Create a copy of the original function so that we can mutate it, and
- // attach that to the map defun node.
- FunctionDef* map_defun_fn = library->add_function();
- *map_defun_fn = orig_func;
- graph_utils::SetUniqueGraphFunctionName(orig_func.signature().name(), library,
- map_defun_fn);
- (*map_defun_node->mutable_attr())["f"].mutable_func()->set_name(
- map_defun_fn->signature().name());
-
- vectorization_utils::VectorizeMapDefun(vectorized_func, map_defun_fn,
- map_defun_node);
- return vectorized_func;
+ const NodeDef& map_defun_node = vectorized_func->node_def(0);
+ DCHECK_EQ(map_defun_node.op(), "MapDefun");
+
+ // TODO(b/116285210): Unreferenced functions should get cleaned up later
+ FunctionDef* result;
+ Status s = vectorization_utils::VectorizeMapDefun(
+ *vectorized_func, map_defun_node, library, &result);
+
+ if (!s.ok()) {
+ LOG(ERROR) << "VectorizeMapDefun failed: " << s;
+ return vectorized_func;
+ }
+ return result;
}
bool IsOutputShapesFullyDefined(const NodeDef& node) {
diff --git a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
index ed1bd6bc97..f4faf41549 100644
--- a/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/map_vectorization_test.cc
@@ -30,72 +30,51 @@ namespace {
using test::function::GDef;
using test::function::NDef;
-void MakeTensorShapeProtoHelper(const gtl::ArraySlice<int> dims,
- TensorShapeProto* t) {
- for (size_t i = 0; i < dims.size(); ++i) {
- auto* d = t->add_dim();
- d->set_size(dims[i]);
- }
-}
-
-AttrValue MakeShapeListAttr(
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& shapes) {
- AttrValue shapes_attr;
- for (size_t i = 0; i < shapes.size(); ++i) {
- MakeTensorShapeProtoHelper(shapes[i],
- shapes_attr.mutable_list()->add_shape());
- }
-
- return shapes_attr;
-}
-
-NodeDef MakeMapNodeHelper(
- StringPiece name, StringPiece input_node_name, StringPiece function_name,
- StringPiece map_op_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
+NodeDef MakeMapNodeHelper(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name, StringPiece map_op_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
return test::function::NDef(
name, map_op_name, {string(input_node_name)},
{{"f", FunctionDefHelper::FunctionRef(string(function_name))},
{"Targuments", {}},
- {"output_shapes", MakeShapeListAttr(output_shapes)},
+ {"output_shapes", output_shapes},
{"output_types", output_types}});
}
-NodeDef MakeMapNode(
- StringPiece name, StringPiece input_node_name, StringPiece function_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
+NodeDef MakeMapNode(StringPiece name, StringPiece input_node_name,
+ StringPiece function_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
return MakeMapNodeHelper(name, input_node_name, function_name, "MapDataset",
output_shapes, output_types);
}
-NodeDef MakeBatchNode(
- StringPiece name, StringPiece input_node_name,
- StringPiece input_batch_size_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
- return NDef(name, "BatchDataset",
- {string(input_node_name), string(input_batch_size_name)},
- {{"output_types", output_types},
- {"output_shapes", MakeShapeListAttr(output_shapes)}});
+NodeDef MakeBatchNode(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDataset",
+ {string(input_node_name), string(input_batch_size_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
}
-NodeDef MakeBatchV2Node(
- StringPiece name, StringPiece input_node_name,
- StringPiece input_batch_size_name, StringPiece input_drop_remainder_name,
- const gtl::ArraySlice<const gtl::ArraySlice<int>>& output_shapes,
- const gtl::ArraySlice<DataType>& output_types) {
- return NDef(name, "BatchDatasetV2",
- {string(input_node_name), string(input_batch_size_name),
- string(input_drop_remainder_name)},
- {{"output_types", output_types},
- {"output_shapes", MakeShapeListAttr(output_shapes)}});
+NodeDef MakeBatchV2Node(StringPiece name, StringPiece input_node_name,
+ StringPiece input_batch_size_name,
+ StringPiece input_drop_remainder_name,
+ gtl::ArraySlice<PartialTensorShape> output_shapes,
+ gtl::ArraySlice<DataType> output_types) {
+ return NDef(
+ name, "BatchDatasetV2",
+ {string(input_node_name), string(input_batch_size_name),
+ string(input_drop_remainder_name)},
+ {{"output_types", output_types}, {"output_shapes", output_shapes}});
}
-NodeDef MakeRangeNode(StringPiece name, const gtl::ArraySlice<string>& inputs) {
+NodeDef MakeRangeNode(StringPiece name, gtl::ArraySlice<string> inputs) {
return NDef(name, "RangeDataset", inputs,
- {{"output_shapes", MakeShapeListAttr({{}})},
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})},
{"output_types", gtl::ArraySlice<DataType>({DT_INT64})}});
}
@@ -184,7 +163,7 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
item.graph = GDef(
{NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
NDef("input", "InputDataset", {},
- {{"output_shapes", MakeShapeListAttr({{}})}}),
+ {{"output_shapes", gtl::ArraySlice<TensorShape>({{}})}}),
MakeMapNode("map", "input", "XTimesTwo", {{}}, {DT_INT32}),
MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
// FunctionLib
@@ -196,6 +175,37 @@ TEST(MapVectorizationTest, VectorizeWithUndefinedOutputTypes) {
TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
}
+TEST(MapVectorizationTest, VectorizeWithFullyDefinedFunction) {
+ GrapplerItem item;
+ item.graph = GDef(
+ {NDef("start", "Const", {}, {{"value", 0}, {"dtype", DT_INT32}}),
+ NDef("stop", "Const", {}, {{"value", 10}, {"dtype", DT_INT32}}),
+ NDef("step", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ NDef("batch_size", "Const", {}, {{"value", 1}, {"dtype", DT_INT32}}),
+ MakeRangeNode("range", {"start", "stop", "step"}),
+ MakeMapNode("map", "range", "Func", {{}}, {DT_INT32}),
+ MakeBatchNode("batch", "map", "batch_size", {{-1}}, {DT_INT32})},
+ // FunctionLib
+ {FunctionDefHelper::Create(
+ "Func", {"x: int64", "y: int64"}, {"res: int64", "res2: int64"}, {},
+ {{{"o"}, "Mul", {"x", "x"}, {{"T", DT_INT64}}}},
+ {{"res", "o:z"}, {"res2", "o:z"}})});
+ MapVectorization optimizer;
+ GraphDef output;
+ TF_ASSERT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("MapDataset", output).size(),
+ 1);
+ EXPECT_EQ(graph_utils::FindAllGraphNodesWithOp("BatchDataset", output).size(),
+ 1);
+ const NodeDef& map_node =
+ output.node(graph_utils::FindGraphNodeWithOp("MapDataset", output));
+ const NodeDef& batch_node =
+ output.node(graph_utils::FindGraphNodeWithOp("BatchDataset", output));
+ EXPECT_EQ(map_node.input(0), batch_node.name());
+ EXPECT_EQ(batch_node.input(0), "range");
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
index 1462cb234d..37aa24b947 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/BUILD
@@ -9,13 +9,14 @@ load("//tensorflow/core:platform/default/build_config.bzl", "tf_protos_all")
VECTORIZER_DEPS = [
":vectorizer_registry",
- "//tensorflow/core/grappler/optimizers/data:function_utils",
+ "//tensorflow/core/grappler/optimizers/data:graph_utils",
] + tf_protos_all()
cc_library(
name = "vectorizer",
hdrs = ["vectorizer.h"],
deps = [
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
] + tf_protos_all(),
)
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
index c1739737a0..3af6bab409 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/cast_vectorizer.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
namespace tensorflow {
@@ -23,26 +23,21 @@ namespace vectorization_utils {
class CastVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
- if (inputs.size() != 1) {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) override {
+ Status s;
+ if (node.num_inputs() != 1) {
return errors::Internal("Cast op should only have one input.");
}
- // Add new Cast node
- NodeDef* new_cast_node = outer_scope->add_node_def();
- *new_cast_node = node;
- new_cast_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", node.name()), outer_scope,
- new_cast_node);
- new_cast_node->set_input(0, inputs[0]);
-
- // Add the output mapping to conversion map
- (*conversion_map)[strings::StrCat(node.name(), ":y:0")] =
- strings::StrCat(new_cast_node->name(), ":y:0");
+ // Add new Cast node with the same op and attrs as the original node
+ auto new_cast_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
+ // Add input and output mappings
+ input_ports->push_back({new_cast_node, 0});
+ output_ports->push_back({new_cast_node, 0});
return Status::OK();
}
};
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
index 776d3179c5..74ce520ce1 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/unpack_vectorizer.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/node_def.pb.h"
-#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
namespace tensorflow {
@@ -23,31 +23,29 @@ namespace vectorization_utils {
class UnpackVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
- if (inputs.size() != 1) {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) override {
+ Status s;
+ if (node.num_inputs() != 1) {
return errors::Internal("Unpack op should only have one input.");
}
- // Add new Unpack node
- NodeDef* new_unpack_node = outer_scope->add_node_def();
- *new_unpack_node = node;
- new_unpack_node->clear_name();
- function_utils::SetUniqueFunctionNodeName(
- strings::StrCat("vectorized/", node.name()), outer_scope,
- new_unpack_node);
+ // Add new Unpack node with the same op and attrs as the original node
+ auto new_unpack_node = outer_scope->AddNode(node.def(), &s);
+ TF_RETURN_IF_ERROR(s);
// Increment "axis" attr by 1:
- (*new_unpack_node->mutable_attr())["axis"].set_i(
- node.attr().at("axis").i() + 1);
- new_unpack_node->set_input(0, inputs[0]);
+ int new_axis = node.def().attr().at("axis").i() + 1;
+ new_unpack_node->AddAttr("axis", new_axis);
- // Add the output mappings to conversion map
- int num = new_unpack_node->attr().at("num").i();
+ // Add the input mappings
+ input_ports->push_back({new_unpack_node, 0});
+
+ // Add the output mappings
+ int num = node.def().attr().at("num").i();
for (int i = 0; i < num; ++i) {
- (*conversion_map)[strings::StrCat(node.name(), ":output:", i)] =
- strings::StrCat(new_unpack_node->name(), ":output:", i);
+ output_ports->push_back({new_unpack_node, i});
}
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
index d341dbba7d..56eb88c95e 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer.h
@@ -17,30 +17,33 @@ limitations under the License.
#define TENSORFLOW_CORE_GRAPPLER_OPTIMIZERS_DATA_VECTORIZATION_VECTORIZER_H_
#include "tensorflow/core/framework/function.pb.h"
+#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
+// Describes a tensor with its operation Node and output position
+typedef std::pair<Node*, int> Port;
+
// Interface for vectorization of TensorFlow operations. See `CastVectorizer`
// for an example.
class Vectorizer {
public:
virtual ~Vectorizer() {}
- // Vectorizes an operation, `node`, by adding operation(s) to `outer_scope`
+ // Vectorizes an operation, `node`, by adding Node(s) to `outer_scope`
// that produce the same vector output(s) as executing `node`'s op
- // on elements of the vector inputs, and adding mappings to `conversion_map`
- // from old output tensor names to new (vectorized) output tensor names.
- // The new node(s) collectively have the same number of inputs and outputs as
- // the node being converted, and use the tensor names in `inputs` as their
- // inputs.
- virtual Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) = 0;
+ // on elements of the vector inputs. The new Node(s) collectively have the
+ // same number of input and output ports as the node being converted.
+ // Adds mappings for the new nodes' input and output ports to `inputs` and
+ // `outputs` respectively, where the i'th Port in inputs/outputs
+ // corresponds to the i'th input/output port of the node to be converted.
+ virtual Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* input_ports,
+ std::vector<Port>* output_ports) = 0;
};
} // namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
index 86e303564b..663ceba027 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry_test.cc
@@ -24,9 +24,9 @@ namespace vectorization_utils {
class TestVectorizer : public Vectorizer {
public:
- Status Vectorize(const NodeDef& node, gtl::ArraySlice<string> inputs,
- FunctionDef* outer_scope,
- std::map<string, string>* conversion_map) override {
+ Status Vectorize(const Node& node, Graph* outer_scope,
+ std::vector<Port>* inputs,
+ std::vector<Port>* outputs) override {
return Status::OK();
}
};
@@ -39,10 +39,12 @@ TEST(TestVectorizer, TestTestVectorizer) {
auto vectorizer = VectorizerRegistry::Global()->Get("test_op");
EXPECT_NE(vectorizer, nullptr);
- FunctionDef function;
- NodeDef node;
- std::map<string, string> conversion_map;
- EXPECT_TRUE(vectorizer->Vectorize(node, {}, &function, &conversion_map).ok());
+ Graph g(OpRegistry::Global());
+ NodeDef node_def;
+ Status s;
+ Node* node = g.AddNode(node_def, &s);
+ std::vector<Port> inputs, outputs;
+ EXPECT_TRUE(vectorizer->Vectorize(*node, &g, &inputs, &outputs).ok());
}
} // namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
index cb56b65985..cea667f668 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.cc
@@ -14,13 +14,17 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/data/vectorization_utils.h"
+#include <memory>
#include "tensorflow/core/grappler/optimizers/data/vectorization/vectorizer_registry.h"
#include "absl/strings/str_join.h"
+#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/op_def.pb.h"
#include "tensorflow/core/framework/types.h"
@@ -36,255 +40,346 @@ namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
-using function_utils::FunctionDefTensorDesc;
-
namespace {
-void AddMapDefunOutput(FunctionDef* map_defun_fn, NodeDef* map_defun_node,
- const string& output_retval, const DataType t) {
- // Set to unknown shape
- TensorShapeProto tensor_shape_proto;
- PartialTensorShape().AsProto(&tensor_shape_proto);
+// Describes a tensor with its operation Node and output position
+typedef std::pair<Node*, int> TensorDesc;
- function_utils::AddFunctionOutputWithUniqueName(
- "vectorized_out", output_retval, map_defun_fn, t);
+const char* const kRetValOp = "_Retval";
- *(*map_defun_node->mutable_attr())["output_shapes"]
- .mutable_list()
- ->add_shape() = tensor_shape_proto;
- (*map_defun_node->mutable_attr())["output_types"].mutable_list()->add_type(t);
+void ReplaceEdgeSources(const TensorDesc& old_src, const TensorDesc& new_src,
+ Graph* graph) {
+ // NOTE: We need two for loops here because we can't mutate the set of output
+ // edges as we iterate over them.
+ std::vector<const Edge*> edges_to_replace;
+ for (auto edge : old_src.first->out_edges()) {
+ if (edge->src_output() == old_src.second) {
+ edges_to_replace.push_back(edge);
+ }
+ }
+ for (auto edge : edges_to_replace) {
+ graph->AddEdge(new_src.first, new_src.second, edge->dst(),
+ edge->dst_input());
+ graph->RemoveEdge(edge);
+ }
}
-void RemoveMapDefunOutput(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node, int output_position) {
- DCHECK_LT(output_position, map_defun_fn->signature().output_arg_size())
- << "Trying to remove output that doesn't exist. Output number: "
- << output_position;
+Status AddMapDefunOutput(FunctionBody* map_defun_fn, Node* map_defun_node,
+ const TensorDesc& output) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DataType type = output.first->output_type(output.second);
+ int index = map_defun_fn->ret_nodes.size();
- int num_later_outputs =
- map_defun_fn->signature().output_arg_size() - output_position - 1;
+ NodeDef ret_node_def;
+ ret_node_def.set_name("map_out");
+ ret_node_def.set_op(kRetValOp);
+ AddNodeAttr("T", type, &ret_node_def);
+ AddNodeAttr("index", index, &ret_node_def);
- // Remove from map_defun_fn's ret dict and output args
- map_defun_fn->mutable_ret()->erase(
- map_defun_fn->signature().output_arg(output_position).name());
- map_defun_fn->mutable_signature()->mutable_output_arg()->DeleteSubrange(
- output_position, 1);
+ Status s;
+ Node* ret_node = map_defun_fn->graph->AddNode(ret_node_def, &s);
+ TF_RETURN_IF_ERROR(s);
- // Renumber outputs that come after
- for (int i = 0; i < num_later_outputs; ++i) {
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i + 1),
- strings::StrCat(map_defun_node->name(),
- ":output:", output_position + i),
- outer_scope);
- }
- map_defun_node->mutable_attr()
- ->at("output_shapes")
- .mutable_list()
- ->mutable_shape()
- ->DeleteSubrange(output_position, 1);
- map_defun_node->mutable_attr()
- ->at("output_types")
- .mutable_list()
- ->mutable_type()
- ->ExtractSubrange(output_position, 1, nullptr);
+ map_defun_fn->graph->AddEdge(output.first, output.second, ret_node, 0);
+ map_defun_fn->ret_nodes.push_back(ret_node);
+ map_defun_fn->ret_types.push_back(type);
+
+ return s;
}
-int FindOutputToConvert(const FunctionDef& function,
- const std::set<string>& unconvertible,
- FunctionDefTensorDesc* f) {
- for (int i = function.signature().output_arg_size() - 1; i >= 0; --i) {
- const string& ret_key = function.signature().output_arg(i).name();
- *f = FunctionDefTensorDesc(function.ret().at(ret_key));
+void RemoveMapDefunOutput(int output_position, Graph* outer_scope,
+ FunctionBody* map_defun_fn, Node* map_defun_node) {
+ // Note that we don't update MapDefun attrs as we go, only when we are done
+ DCHECK_LT(output_position, map_defun_fn->ret_nodes.size())
+ << "Trying to remove output that doesn't exist. Output number: "
+ << output_position;
+
+ int num_later_outputs = map_defun_fn->ret_nodes.size() - output_position - 1;
- if (unconvertible.find(f->node_name) == unconvertible.end()) {
- return i;
- }
+ // Modify map_defun_fn's signature and remove the output node from its graph
+ map_defun_fn->graph->RemoveNode(map_defun_fn->ret_nodes[output_position]);
+ map_defun_fn->ret_nodes.erase(map_defun_fn->ret_nodes.begin() +
+ output_position);
+ map_defun_fn->ret_types.erase(map_defun_fn->ret_types.begin() +
+ output_position);
+
+ // Renumber the nodes and edges that come after
+ for (int i = 0; i < num_later_outputs; ++i) {
+ ReplaceEdgeSources({map_defun_node, output_position + i + 1},
+ {map_defun_node, output_position + i}, outer_scope);
+ // Each ret node has an "index" attr that has to be updated
+ map_defun_fn->ret_nodes[output_position + i]->AddAttr("index",
+ output_position + i);
}
- return -1;
}
// Helper class that vectorizes the body of a MapDefun node, adding new
// operations to the graph that collectively compute the same value as what
// running the MapDefun function on slices of the input would produce.
-// Each instance of the class encapsulates all the data necessary to vectorize a
-// MapDefun op in place.
+// This class transforms the input FunctionDefs into their corresponding
+// Graph objects and works on the graphs directly, then converts them back
+// to FunctionDefs when GetResult is called.
class Vectorization {
public:
- Vectorization(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node)
- : outer_scope_(outer_scope),
- map_defun_fn_(map_defun_fn),
- map_defun_node_(map_defun_node) {}
+ explicit Vectorization(FunctionDefLibrary* lib)
+ : lib_(lib), lib_def_(OpRegistry::Global(), *lib) {}
- // Repeatedly tries to convert outputs of map_defun_fn_ into new nodes in
- // the outer_scope_, until there are no convertible outputs remaining.
- // This method is idempotent.
- void Vectorize();
+ // Adds the vectorized function and new map_defun_fn to lib, and points
+ // vectorized_function to the former. Returns an error status if
+ // the conversion between FunctionDef -> Graph -> FunctionDef failed anywhere
+ // along the way.
+ Status Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDef** result);
private:
- // Vectorizes the map defun function's output at output_position
- Status ConvertOutput(int output_position, const FunctionDefTensorDesc& desc);
- // Given a descriptor of the original output tensor, gets a string
- // corresponding to the converted output tensor.
- Status ConvertOutputHelper(const FunctionDefTensorDesc& output_desc,
- string* converted);
- Status AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc);
+ // Converts FunctionDefs to Graphs.
+ Status Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node);
+
+ // Converts Graphs back to FunctionDefs and adds them to `lib_`.
+ Status GetResult(FunctionDef** vectorized_function);
+
+ // Repeatedly tries to convert outputs of `map_defun_fn_` into new nodes in
+ // `outer_scope_`, until there are no convertible outputs remaining.
+ void VectorizeHelper();
+
+ // Vectorizes map_defun_fn's output at output_position.
+ Status ConvertOutput(int output_position);
// Adds mappings from node's outputs tensors to converted output tensors,
// creating the necessary new node(s). Generally, the steps to convert an op
// are:
- // 1) Promote the inputs of the op inputs to outputs of the map_defun_fn_,
- // and modify map_defun_node_ attrs accordingly
- // 2) Create new node(s) in outer_scope_ that act on batched input tensors.
+ // 1) Create new node(s) in `outer_scope_` that act on batched input tensors.
// These operations collectively compute the same value as what running
// the original operation on slices of the input tensors would produce.
// For example, a Cast op in MapDefun translates to a Cast op in
- // outer_scope_, since the vectorized version of Cast is itself.
- // 3) Set inputs of new node(s) to the corresponding converted inputs (that
- // are now outputs of map_defun_node_)
- // 4) For each output of the old node, add the mapping of output strings to
- // the conversion map (eg "Cast:y:0" -> "Vectorize/Cast:y:0")
- Status AddConversionMappingFromOp(const NodeDef& node,
- const FunctionDefTensorDesc& output_desc);
-
- // Maps a tensor name to the name of the corresponding vectorized tensor. For
- // example, "Cast:y:0" -> "Vectorize/Cast:y:0"
- std::map<string, string> conversion_map_;
- // Unconvertible node names
- std::set<string> unconvertible_;
-
- FunctionDef* outer_scope_;
- FunctionDef* map_defun_fn_;
- NodeDef* map_defun_node_;
+ // `outer_scope_`, since the vectorized version of Cast is itself.
+ // 2) Promote the inputs of the op inputs to outputs of the
+ // `map_defun_node_` and `map_defun_fn_`.
+ // 3) Add edges between the promoted inputs (that are now outputs of
+ // `map_defun_node`) and the inputs ports of the new node(s).
+ // 4) For each output of the old node, add the mapping of output tensors to
+ // the conversion map.
+ Status AddConversionMapping(Node* op_node);
+
+ // Maps a tensor to the corresponding vectorized tensor. For example,
+ // {"Cast" Node*, 0} -> {"Vectorize/Cast" Node*, 0}
+ std::map<TensorDesc, TensorDesc> conversion_map_;
+
+ // Unconvertible ret nodes
+ std::set<Node*> unconvertible_;
+
+ FunctionDefLibrary* lib_; // Not owned
+ FunctionLibraryDefinition lib_def_;
+ // Note that FunctionBody has a pointer to a Graph object that corresponds
+ // to the function's subgraph, with additional kArgOp and kRetValOp nodes
+ // that denote that function arguments and return values. These nodes have the
+ // attrs "T" for the type, and "index" for the argument / retval index
+ // respectively. FunctionBody also keeps track of arg/ret_nodes and
+ // arg/ret_types, that should be ordered according to argument/output indices.
+ std::unique_ptr<Graph> outer_scope_;
+ std::unique_ptr<FunctionBody> map_defun_fn_;
+ Node* map_defun_node_ = nullptr; // Owned by `outer_scope`
+ Status status_;
};
-Status Vectorization::AddConversionMappingFromOp(
- const NodeDef& node, const FunctionDefTensorDesc& output_desc) {
- for (const string& input_name : node.input()) {
- if (IsControlInput(input_name)) {
+Status Vectorization::AddConversionMapping(Node* op_node) {
+ for (auto edge : op_node->in_edges()) {
+ if (edge->IsControlEdge()) {
return errors::InvalidArgument(
"Vectorizing outputs with control inputs is currently not "
"supported.");
}
}
- // TODO(rachelim): Have some mechanism for registering converters and some
- // uniform, simpler way to represent them.
-
- DataTypeVector types;
- const OpDef* op_def = nullptr;
- TF_RETURN_IF_ERROR(OpRegistry::Global()->LookUpOpDef(node.op(), &op_def));
- TF_RETURN_IF_ERROR(InputTypesForNode(node, *op_def, &types));
-
- std::vector<string> promoted_inputs;
- promoted_inputs.reserve(node.input_size());
- for (int i = 0; i < node.input_size(); ++i) {
- promoted_inputs.push_back(strings::StrCat(
- map_defun_node_->name(),
- ":output:", map_defun_fn_->signature().output_arg_size() + i));
- }
-
- auto vectorizer = VectorizerRegistry::Global()->Get(node.op());
+ auto vectorizer = VectorizerRegistry::Global()->Get(op_node->type_string());
if (vectorizer == nullptr) {
return errors::Unimplemented("No vectorizer registered for op: ",
- node.op());
+ op_node->type_string());
+ }
+ std::vector<Port> input_ports, output_ports;
+ input_ports.reserve(op_node->num_inputs());
+ output_ports.reserve(op_node->num_outputs());
+ TF_RETURN_IF_ERROR(vectorizer->Vectorize(*op_node, outer_scope_.get(),
+ &input_ports, &output_ports));
+
+ std::vector<const Edge*> input_edges;
+ TF_RETURN_IF_ERROR(op_node->input_edges(&input_edges));
+
+ if (op_node->num_outputs() != output_ports.size() ||
+ op_node->num_inputs() != input_ports.size() ||
+ input_edges.size() != input_ports.size()) {
+ return errors::Internal("Vectorizer inputs/outputs don't match.");
}
- TF_RETURN_IF_ERROR(vectorizer->Vectorize(node, promoted_inputs, outer_scope_,
- &conversion_map_));
+ // Promote the inputs of the op to MapDefun outputs and connect the edges
+ // accordingly.
+ for (size_t i = 0; i < op_node->num_inputs(); ++i) {
+ auto edge = input_edges[i];
+ TF_RETURN_IF_ERROR(AddMapDefunOutput(map_defun_fn_.get(), map_defun_node_,
+ {edge->src(), edge->src_output()}));
+ outer_scope_->AddEdge(map_defun_node_, map_defun_fn_->ret_nodes.size() - 1,
+ input_ports[i].first, input_ports[i].second);
+ }
- // If we get here, the conversion was successful, so we promote the inputs
- // of the ops to MapDefun outputs.
- for (int i = 0; i < types.size(); ++i) {
- AddMapDefunOutput(map_defun_fn_, map_defun_node_, node.input(i), types[i]);
+ // Add output mappings.
+ for (size_t i = 0; i < op_node->num_outputs(); ++i) {
+ conversion_map_.insert({{op_node, i}, std::move(output_ports[i])});
}
return Status::OK();
}
-Status Vectorization::AddConversionMappingFromInput(
- const FunctionDefTensorDesc& output_desc) {
- int input_index = function_utils::FindFunctionInputWithName(
- output_desc.node_name, *map_defun_fn_);
- if (input_index == -1) {
- return errors::Internal("Cannot convert non-existent input.");
+Status Vectorization::ConvertOutput(int output_position) {
+ // ret_edge->src() is the actual op that generated the retval, and
+ // ret_edge->dst() is the retval node whose op is "_Retval"
+ const Edge* ret_edge;
+ TF_RETURN_IF_ERROR(
+ map_defun_fn_->ret_nodes[output_position]->input_edge(0, &ret_edge));
+
+ TensorDesc output({ret_edge->src(), ret_edge->src_output()});
+ TensorDesc converted_output;
+ if (auto found = gtl::FindOrNull(conversion_map_, output)) {
+ // It's possible the output already has a mapping, if it comes from a node
+ // that has already been converted.
+ converted_output = *found;
+ } else {
+ TF_RETURN_IF_ERROR(AddConversionMapping(output.first));
+ converted_output = conversion_map_.at(output);
}
- conversion_map_[output_desc.full_str] = map_defun_node_->input(input_index);
+ ReplaceEdgeSources({map_defun_node_, output_position}, converted_output,
+ outer_scope_.get());
+ RemoveMapDefunOutput(output_position, outer_scope_.get(), map_defun_fn_.get(),
+ map_defun_node_);
+
return Status::OK();
}
-Status Vectorization::ConvertOutputHelper(
- const FunctionDefTensorDesc& output_desc, string* converted) {
- // It's possible the output already has a mapping, if it comes from a node
- // that has already been converted.
- if (auto found = gtl::FindOrNull(conversion_map_, output_desc.full_str)) {
- *converted = *found;
- return Status::OK();
+Status Vectorization::Vectorize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node,
+ FunctionDef** result) {
+ TF_RETURN_IF_ERROR(Initialize(outer_scope, map_defun_node));
+ VectorizeHelper();
+ return GetResult(result);
+}
+
+void Vectorization::VectorizeHelper() {
+ while (true) {
+ int output_position = graph_utils::GetFirstElementIndexWithPredicate(
+ [this](Node* n) {
+ return this->unconvertible_.find(n) == this->unconvertible_.end();
+ },
+ map_defun_fn_->ret_nodes);
+
+ // No outputs left to convert
+ if (output_position == -1) break;
+
+ Status s = ConvertOutput(output_position);
+ if (!s.ok()) {
+ Node* output_node = map_defun_fn_->ret_nodes.at(output_position);
+ VLOG(2) << "Could not convert the output at node: "
+ << output_node->DebugString() << "\nError: " << s;
+ unconvertible_.insert(output_node);
+ }
}
- int index = function_utils::FindFunctionNodeWithName(output_desc.node_name,
- *map_defun_fn_);
- if (index == -1) { // The output comes from an input
- TF_RETURN_IF_ERROR(AddConversionMappingFromInput(output_desc));
+ // If we've converted all the outputs of the MapDefun function, we no longer
+ // need the MapDefun node and can delete it.
+ if (map_defun_fn_->ret_nodes.empty()) {
+ outer_scope_->RemoveNode(map_defun_node_);
} else {
- TF_RETURN_IF_ERROR(AddConversionMappingFromOp(
- map_defun_fn_->node_def(index), output_desc));
+ // Update MapDefun node attrs accordingly
+ DCHECK_EQ(map_defun_fn_->ret_types.size(), map_defun_fn_->ret_nodes.size());
+ map_defun_node_->AddAttr(
+ "output_shapes",
+ std::vector<PartialTensorShape>(map_defun_fn_->ret_types.size()));
+ map_defun_node_->AddAttr("output_types", map_defun_fn_->ret_types);
}
- *converted = conversion_map_.at(output_desc.full_str);
- return Status::OK();
}
+Status Vectorization::Initialize(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node) {
+ // Convert outer_scope and map_defun_fn to FunctionBodys so we can
+ // work on Graphs directly.
+ const FunctionDef* map_defun_fn =
+ lib_def_.Find(map_defun_node.attr().at("f").func().name());
+
+ if (map_defun_fn == nullptr) {
+ return errors::NotFound("Could not find function with name ",
+ map_defun_node.attr().at("f").func().name(),
+ " in function library.");
+ }
-Status Vectorization::ConvertOutput(int output_position,
- const FunctionDefTensorDesc& output_desc) {
- string converted_output_name;
- TF_RETURN_IF_ERROR(ConvertOutputHelper(output_desc, &converted_output_name));
+ auto get_func_sig = [this](const string& op, const OpDef** sig) {
+ return this->lib_def_.LookUpOpDef(op, sig);
+ };
+
+ FunctionBody* outer_fn;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(outer_scope, {}, &lib_def_,
+ get_func_sig, &outer_fn));
+ // We don't need outer_fn, just the graph
+ outer_scope_.reset(outer_fn->graph);
+ outer_fn->graph = nullptr;
+ delete outer_fn;
+
+ FunctionBody* tmp;
+ TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*map_defun_fn, {}, &lib_def_,
+ get_func_sig, &tmp));
+ map_defun_fn_.reset(tmp);
+
+ // Find the MapDefun node in outer_scope_
+ int node_id = graph_utils::GetFirstElementIndexWithPredicate(
+ [&map_defun_node](Node* n) { return n->name() == map_defun_node.name(); },
+ outer_scope_->nodes());
+ if (node_id == -1) {
+ return errors::NotFound("Could not find node with name ",
+ map_defun_node.name(), " in outer_scope.");
+ }
+ map_defun_node_ = outer_scope_->FindNodeId(node_id);
+
+ // Add mappings from map_defun_fn_ arg nodes to map_defun_node_ input nodes to
+ // the conversion map
+ for (auto arg_node : map_defun_fn_->arg_nodes) {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(map_defun_node_->input_node(
+ arg_node->attrs().Find("index")->i(), &input_node));
- // Remove the old output and make everything that referenced it point
- // to the new string
- function_utils::ReplaceReferences(
- strings::StrCat(map_defun_node_->name(), ":output:", output_position),
- converted_output_name, outer_scope_);
- RemoveMapDefunOutput(outer_scope_, map_defun_fn_, map_defun_node_,
- output_position);
+ conversion_map_.insert({{arg_node, 0}, {input_node, 0}});
+ }
return Status::OK();
}
-void Vectorization::Vectorize() {
- while (true) {
- FunctionDefTensorDesc desc;
- int output_position =
- FindOutputToConvert(*map_defun_fn_, unconvertible_, &desc);
- if (output_position == -1) break;
+Status Vectorization::GetResult(FunctionDef** vectorized_function) {
+ TF_RETURN_IF_ERROR(status_);
- if (!ConvertOutput(output_position, desc).ok()) {
- unconvertible_.insert(desc.node_name);
- }
- }
+ if (!map_defun_fn_->ret_nodes.empty()) {
+ FunctionDef* map_defun_fn = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("map_defun_fn", lib_, map_defun_fn);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *map_defun_fn_->graph, map_defun_fn->signature().name(), map_defun_fn));
- // If we've converted all the outputs of the MapDefun function, we no longer
- // need the MapDefun node and can delete it.
- if (map_defun_fn_->signature().output_arg_size() == 0) {
- outer_scope_->mutable_node_def()->DeleteSubrange(
- function_utils::FindFunctionNodeWithName(map_defun_node_->name(),
- *outer_scope_),
- 1);
+ AttrValue func_attr;
+ func_attr.mutable_func()->set_name(map_defun_fn->signature().name());
+ map_defun_node_->AddAttr("f", func_attr);
}
- if (!unconvertible_.empty()) {
- VLOG(2) << "The following nodes could not be converted: ["
- << absl::StrJoin(unconvertible_, ", ") << "].";
- }
+ *vectorized_function = lib_->add_function();
+ graph_utils::SetUniqueGraphFunctionName("vectorized_fn", lib_,
+ *vectorized_function);
+ TF_RETURN_IF_ERROR(GraphToFunctionDef(
+ *outer_scope_, (*vectorized_function)->signature().name(),
+ *vectorized_function));
+ return Status::OK();
}
+
} // namespace
-void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node) {
- Vectorization(outer_scope, map_defun_fn, map_defun_node).Vectorize();
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result) {
+ *result = nullptr;
+ return Vectorization(lib).Vectorize(outer_scope, map_defun_node, result);
}
} // end namespace vectorization_utils
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
index bb405faa77..bd7d390900 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils.h
@@ -24,22 +24,28 @@ namespace tensorflow {
namespace grappler {
namespace vectorization_utils {
-// Given a function, `map_defun_fn`, that is mapped across some input vector
-// elements via a MapDefun operation, `VectorizeMapDefun` attempts to
-// vectorize the MapDefun by "lifting" operations from the `map_defun_fn` to the
-// `outer_scope`; that is, replacing `map_defun_fn` operations with new
-// `outer_scope` operations that produce the same vector output(s) as executing
-// the `map_defun_fn` operations on elements of vector input(s) would. If all
-// `map_defun_fn` operations are successfully lifted, `map_defun_node` is
-// eliminated from `outer_scope` altogether. However, if some operations cannot
-// be lifted, and this vectorization only succeeds partially, `map_defun_node`
-// remains to be used for operations that were not lifted.
+// Given a MapDefun node (`map_defun_node`) in a FunctionDef (`outer_scope`)
+// that maps a function in lib across some input vector elements,
+// `VectorizeMapDefun` attempts to create a vectorized version of `outer_scope`
+// by "lifting" operations from the MapDefun function to the new function
+// (`result`); that is, replacing operations in the MapDefun function with
+// operations that produce the same vector output(s) as executing the original
+// operations on elements of vector input(s) would. If all operations in the
+// MapDefun function are successfully lifted, `result` has no MapDefun node
+// altogether. However, if some operations cannot be lifted, and this
+// vectorization only succeeds partially, a MapDefun node remains in `result` to
+// be used for operations that were not lifted, and the modified MapDefun
+// function is added to `lib`. The newly vectorized function `result` is also
+// added to `lib`.
+//
+// Returns Status::OK() if the vectorization is completely or partially
+// successful. Otherwise, returns an error, and sets `result` to nullptr.
//
// Example:
// If the input to the `VectorizeMapDefun` function is a MapDefun
// whose `map_defun_fn` performs the Cast operation, the vectorization will
// eliminate the MapDefun. This is because the Cast operation supports
-// any tensor shape and can thus be lifted to the `outer_scope`.
+// any tensor shape and can thus be lifted to `result`.
//
// Before:
//
@@ -68,7 +74,7 @@ namespace vectorization_utils {
//
// After:
//
-// outer_scope +------+
+// result +------+
// +---------------+ Arg0 +---------+
// | +---+--+ |
// | | |
@@ -80,8 +86,9 @@ namespace vectorization_utils {
// +---------------+ Ret0 +---------+
// +------+
//
-void VectorizeMapDefun(FunctionDef* outer_scope, FunctionDef* map_defun_fn,
- NodeDef* map_defun_node);
+Status VectorizeMapDefun(const FunctionDef& outer_scope,
+ const NodeDef& map_defun_node, FunctionDefLibrary* lib,
+ FunctionDef** result);
} // end namespace vectorization_utils
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
index e129fa9237..1ff62217dd 100644
--- a/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
+++ b/tensorflow/core/grappler/optimizers/data/vectorization_utils_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/grappler/optimizers/data/function_utils.h"
+#include "tensorflow/core/grappler/optimizers/data/graph_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
@@ -60,6 +61,11 @@ NodeDef* AddMapDefunNode(const string& name, const std::vector<string>& inputs,
return node;
}
+string GetRetval(const FunctionDef& function_def, int index) {
+ return function_def.ret().at(
+ function_def.signature().output_arg(index).name());
+}
+
// TODO(rachelim): Use FunctionDefHelper::Create instead
FunctionDef CreateFunction(
StringPiece name, const std::vector<std::pair<string, DataType>>& inputs,
@@ -85,7 +91,6 @@ FunctionDef CreateFunction(
return func;
}
-TEST(FunctionDefInputDescTest, ConstructedCorrectly) {}
// Before:
//
@@ -133,10 +138,15 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- EXPECT_EQ(outer.ret().at("mapdefun"), "ret0");
- EXPECT_EQ(outer.ret().at("mapdefun_0"), "ret1");
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ EXPECT_EQ(GetRetval(*vectorized, 0), "ret0");
+ EXPECT_EQ(GetRetval(*vectorized, 1), "ret1");
}
// Before:
@@ -149,12 +159,12 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// | +-----------+ Arg0 +---+ Arg1 +----+ |
// | | +---+--+ +---+--+ | |
// | | | | | |
-// | | +------+ | +---v--+ | |
-// | | |Const | | | Op0 | | |
-// | | +---v--+ | +---+--+ | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
// | | | | | | |
// | | | +---v--+ +---v--+ | |
-// | | +---| XOp1 | | XOp2 | | |
+// | | +---| XOp1 | | Cast | | |
// | | +---+--+ +---+--+ | |
// | | | | | |
// | | MapDefun +---v--+ +---v--+ | |
@@ -165,23 +175,50 @@ TEST(VectorizeMapDefunTest, VectorizeDefunNoOps) {
// +---------------+ Ret0 +---+ Ret1 +--------+
// +------+ +------+
//
-// where XOp1 and XOp2 are not convertible.
+// where XOp1 is not convertible.
//
// After:
//
-// No change because the ops are not convertible.
+//
+// +------+ +------+
+// +---------------+ Arg0 +---+ Arg1 +--------+
+// | +---+--+ +---+--+ |
+// | | | |
+// | +---v--+ | |
+// | +-----------+ Arg0 +-+ | |
+// | | +---+--+ | | |
+// | | | | | |
+// | | +------+ | | | |
+// | | |Const | | | | |
+// | | +---v--+ | | | |
+// | | | | | | |
+// | | | +---v--+ | +---v--+ |
+// | | +---| XOp1 | | | Cast | |
+// | | +---+--+ | +---+--+ |
+// | | | | | |
+// | | MapDefun +---v--+ | | |
+// | +-----------+ Ret0 +-+ | |
+// | +---+--+ | |
+// | | | |
+// | +---v--+ +---v--+ |
+// +---------------+ Ret0 +---+ Ret1 +--------+
+// +------+ +------+
//
TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}, {"arg1", DT_INT32}},
{{"ret0", DT_INT32}, {"ret1", DT_INT32}},
- {{"ret0", "XOp1:output:0"}, {"ret1", "XOp2:output:0"}});
+ {{"ret0", "MatMul:product:0"}, {"ret1", "Cast:y:0"}});
+ // TODO(rachelim): If we ever write a converter for MatMul, we have to
+ // change this test.
NodeDef* x_op1 =
- function_utils::AddNode("XOp1", "XOp1", {"const", "arg0"}, {}, &inner);
+ function_utils::AddNode("MatMul", "MatMul", {"arg0", "arg0"}, {}, &inner);
CHECK_NOTNULL(x_op1);
+ graph_transforms::SetNodeAttr("T", DT_INT32, x_op1);
- NodeDef* x_op2 = function_utils::AddNode("XOp2", "XOp2", {"op1"}, {}, &inner);
- CHECK_NOTNULL(x_op2);
+ NodeDef* cast_node =
+ AddCastNode("Cast", {"arg1"}, DT_INT32, DT_INT32, false, &inner);
+ CHECK_NOTNULL(cast_node);
FunctionDef outer = CreateFunction(
"outer_function", {{"x", DT_INT32}, {"y", DT_INT32}},
@@ -193,12 +230,22 @@ TEST(VectorizeMapDefunTest, VectorizeDefunUnconvertible) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
- // They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
- EXPECT_TRUE(FunctionDefsEqual(inner_copy, inner));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+
+ auto map_defun_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("MapDefun", *vectorized));
+ // The Cast node should be converted just fine.
+ EXPECT_EQ(GetRetval(*vectorized, 1), "Cast:y:0");
+
+ // The inner function should only have one retval.
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+ EXPECT_EQ(map_defun_fn->signature().output_arg_size(), 1);
}
// Before:
@@ -257,14 +304,19 @@ TEST(VectorizeMapDefunTest, VectorizeDefunSimpleCast) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -330,16 +382,21 @@ TEST(VectorizeMapDefunTest, VectorizeDefunCastUsedTwice) {
{{}, {}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(cast_node.name(), ":y:0"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -411,21 +468,26 @@ TEST(VectorizeMapDefunTest, VectorizeDefunOpWithMultipleOutputs) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), "x");
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 1);
+ EXPECT_EQ(vectorized->node_def_size(), 1);
}
// Before:
@@ -486,7 +548,7 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{"ret1", "MyUnstack:output:1"},
{"ret2", "MyUnstack:output:2"}});
NodeDef* cast_op =
- AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT64, false, &inner);
+ AddCastNode("Cast", {"arg0"}, DT_INT32, DT_INT32, false, &inner);
CHECK_NOTNULL(cast_op);
NodeDef* unstack_op =
AddUnstackNode("MyUnstack", {"Cast:y:0"}, DT_INT32, 0, 3, &inner);
@@ -505,25 +567,30 @@ TEST(VectorizeMapDefunTest, VectorizeDefunChainedConvertibleOps) {
{{1}, {1}, {1}}, inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- VectorizeMapDefun(&outer, &inner, map_defun);
- EXPECT_TRUE(!function_utils::ContainsFunctionNodeWithOp("MapDefun", outer));
- const NodeDef& cast_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Cast", outer));
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
+ EXPECT_TRUE(
+ !function_utils::ContainsFunctionNodeWithOp("MapDefun", *vectorized));
+ const NodeDef& cast_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *vectorized));
EXPECT_EQ(cast_node.input(0), "x");
- const NodeDef& unpack_node =
- outer.node_def(function_utils::FindFunctionNodeWithOp("Unpack", outer));
+ const NodeDef& unpack_node = vectorized->node_def(
+ function_utils::FindFunctionNodeWithOp("Unpack", *vectorized));
EXPECT_EQ(unpack_node.input(0), strings::StrCat(cast_node.name(), ":y:0"));
EXPECT_EQ(unpack_node.attr().at("axis").i(), 1);
EXPECT_EQ(unpack_node.attr().at("T").type(), DT_INT32);
EXPECT_EQ(unpack_node.attr().at("num").i(), 3);
- EXPECT_EQ(outer.ret().at("mapdefun"),
+ EXPECT_EQ(GetRetval(*vectorized, 0),
strings::StrCat(unpack_node.name(), ":output:0"));
- EXPECT_EQ(outer.ret().at("mapdefun_0"),
+ EXPECT_EQ(GetRetval(*vectorized, 1),
strings::StrCat(unpack_node.name(), ":output:1"));
- EXPECT_EQ(outer.ret().at("mapdefun_1"),
+ EXPECT_EQ(GetRetval(*vectorized, 2),
strings::StrCat(unpack_node.name(), ":output:2"));
- EXPECT_EQ(outer.node_def_size(), 2);
+ EXPECT_EQ(vectorized->node_def_size(), 2);
}
// Before:
@@ -561,9 +628,11 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
FunctionDef inner =
CreateFunction("inner_function", {{"arg0", DT_INT32}},
{{"ret0", DT_INT64}}, {{"ret0", "Cast:y:0"}});
- // The attrs aren't relevant
- NodeDef* print_op =
- function_utils::AddNode("Print", "Print", {"arg0", "arg0"}, {}, &inner);
+ NodeDef* print_op = function_utils::AddNode(
+ "Print", "Print", {"arg0", "arg0"}, {/*attrs*/}, &inner);
+ graph_transforms::SetNodeAttr("T", DT_INT32, print_op);
+ graph_transforms::SetNodeAttr("U", gtl::ArraySlice<DataType>({DT_INT32}),
+ print_op);
CHECK_NOTNULL(print_op);
NodeDef* cast_op = AddCastNode("Cast", {"arg0", "^Print"}, DT_INT32, DT_INT64,
false, &inner);
@@ -578,11 +647,27 @@ TEST(VectorizeMapDefunTest, VectorizeDefunWithControlInputs) {
inner.signature().name(), &outer);
CHECK_NOTNULL(map_defun);
- FunctionDef outer_copy(outer);
- FunctionDef inner_copy(inner);
- VectorizeMapDefun(&outer, &inner, map_defun);
+ FunctionDefLibrary lib;
+ *lib.add_function() = outer;
+ *lib.add_function() = inner;
+ FunctionDef* vectorized;
+ EXPECT_TRUE(VectorizeMapDefun(outer, *map_defun, &lib, &vectorized).ok());
// They should be unchanged
- EXPECT_TRUE(FunctionDefsEqual(outer_copy, outer));
+ // We check this somewhat manually as the names of nodes may have changed
+ EXPECT_EQ(vectorized->node_def_size(), 1);
+ const NodeDef& map_defun_node = vectorized->node_def(0);
+ EXPECT_EQ(map_defun_node.op(), "MapDefun");
+ FunctionLibraryDefinition lib_def(OpRegistry::Global(), lib);
+ const FunctionDef* map_defun_fn =
+ lib_def.Find(map_defun_node.attr().at("f").func().name());
+
+ const NodeDef& print_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Print", *map_defun_fn));
+ const NodeDef& cast_node = map_defun_fn->node_def(
+ function_utils::FindFunctionNodeWithOp("Cast", *map_defun_fn));
+ string control_input = strings::StrCat("^", print_node.name());
+ EXPECT_TRUE(cast_node.input(0) == control_input ||
+ cast_node.input(1) == control_input);
}
// TODO(rachelim): More test cases when we get around to implementing them:
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index e18a5f21d2..c3d70a1fdf 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -37,6 +37,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
@@ -115,6 +116,9 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
Status MetaOptimizer::InitializeOptimizers(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
+ if (cfg_.disable_meta_optimizer()) {
+ return Status::OK();
+ }
if (!cfg_.disable_model_pruning()) {
optimizers->push_back(MakeUnique<ModelPruner>());
}
@@ -135,7 +139,7 @@ Status MetaOptimizer::InitializeOptimizers(
if (cfg_.remapping() != RewriterConfig::OFF) {
optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
}
- if (cfg_.pin_to_host_optimization() != RewriterConfig::OFF) {
+ if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
optimizers->push_back(MakeUnique<PinToHostOptimizer>());
}
if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
@@ -410,6 +414,15 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
FunctionLibraryDefinition flib(OpRegistry::Global(),
optimized_graph->library());
+ // Find functions for which we might need to compute a gradient at runtime.
+ gtl::FlatSet<string> differentiable_functions;
+ for (const NodeDef& node : optimized_graph->node()) {
+ if (IsSymbolicGradient(node)) {
+ const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
+ if (f_attr) differentiable_functions.insert(f_attr->func().name());
+ }
+ }
+
// Optimize each function only once.
std::unordered_set<string> optimized_funcs;
bool optimize_function_library = true;
@@ -425,6 +438,8 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
// Skip parametrized functions (function type or body is defined only at
// function call time by caller node attributes).
+ // They should be specialized to their instantiation type parameters by
+ // the function optimizer, before we can optimize function body.
if (IsParametrized(func)) continue;
VLOG(3) << "Optimize function: function=" << func_name;
@@ -439,6 +454,13 @@ Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
func, flib, item.graph.versions().producer(), &func_item));
+ // If we need to compute the gradient of optimized function at runtime, we
+ // can't perform non-differentiable rewrites.
+ if (differentiable_functions.find(func_name) !=
+ differentiable_functions.end()) {
+ func_item.allowed_optimizations.non_differentiable_rewrites = false;
+ }
+
// Optimize function body graph.
GraphDef optimized_func_graph;
TF_RETURN_IF_ERROR(
@@ -489,6 +511,9 @@ void MetaOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
}
bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
+ if (cfg.disable_meta_optimizer()) {
+ return false;
+ }
return !cfg.disable_model_pruning() ||
cfg.layout_optimizer() != RewriterConfig::OFF ||
cfg.function_optimization() != RewriterConfig::OFF ||
@@ -502,7 +527,7 @@ bool MetaOptimizerEnabled(const RewriterConfig& cfg) {
cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
cfg.debug_stripper() == RewriterConfig::ON ||
cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
- cfg.pin_to_host_optimization() != RewriterConfig::OFF ||
+ cfg.pin_to_host_optimization() == RewriterConfig::ON ||
!cfg.optimizers().empty() || !cfg.custom_optimizers().empty();
}
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index c477c4d4b1..3f3f43382f 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -25,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@@ -82,6 +83,48 @@ class TestOptimizerWithParams : public TestOptimizer {
REGISTER_GRAPH_OPTIMIZER(TestOptimizerWithParams);
+// Record various properties of the GrapplerItems passed for optimization.
+class GrapplerItemPropertiesAccumulator : public CustomGraphOptimizer {
+ public:
+ static void SetAllowedOptimizations(
+ gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ allowed_optimizations) {
+ allowed_optimizations_ = allowed_optimizations;
+ }
+ static void ResetAllowedOptimizations() { allowed_optimizations_ = nullptr; }
+
+ GrapplerItemPropertiesAccumulator() {}
+ string name() const override {
+ return "grappler_item_properties_accumulator";
+ }
+
+ Status Init(
+ const tensorflow::RewriterConfig_CustomGraphOptimizer* config) override {
+ return Status::OK();
+ }
+
+ Status Optimize(Cluster* cluster, const GrapplerItem& item,
+ GraphDef* optimized_graph) override {
+ *optimized_graph = item.graph;
+ if (allowed_optimizations_) {
+ allowed_optimizations_->insert({item.id, item.allowed_optimizations});
+ }
+ return Status::OK();
+ }
+
+ void Feedback(Cluster* cluster, const GrapplerItem& item,
+ const GraphDef& optimized_graph, double result) override {}
+
+ private:
+ static gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ allowed_optimizations_;
+};
+
+gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>*
+ GrapplerItemPropertiesAccumulator::allowed_optimizations_;
+
+REGISTER_GRAPH_OPTIMIZER(GrapplerItemPropertiesAccumulator);
+
class MetaOptimizerTest : public GrapplerTest {};
TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
@@ -335,6 +378,89 @@ TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
test::ExpectTensorEqual<int>(tensors_expected[1], tensors[1]);
}
+TEST_F(MetaOptimizerTest, OptimizeFunctionLibraryWithRestrictions) {
+ using test::function::NDef;
+ using FDH = FunctionDefHelper;
+
+ // We will record what type of optimizations meta optimizer allows for each
+ // GrapplerItem (main graph and graphs for each function).
+ gtl::FlatMap<string, GrapplerItem::AllowedOptimizations>
+ allowed_optimizations;
+ GrapplerItemPropertiesAccumulator::SetAllowedOptimizations(
+ &allowed_optimizations);
+
+ // Just record properties of optimized Grappler items.
+ RewriterConfig rewriter_config;
+ rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
+ rewriter_config.add_optimizers("GrapplerItemPropertiesAccumulator");
+ rewriter_config.set_min_graph_nodes(-1);
+
+ MetaOptimizer optimizer(nullptr, rewriter_config);
+
+ // Define simple function library with two identical mul functions.
+ FunctionDef mul_func_1 = FunctionDefHelper::Create(
+ "MyMul1", {"x:float", "y:float"}, {"z:float"}, {},
+ {{{"mul"}, "Mul", {"x", "y"}, {}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "mul:z:0"}});
+
+ FunctionDef mul_func_2 = FunctionDefHelper::Create(
+ "MyMul2", {"x:float", "y:float"}, {"z:float"}, {},
+ {{{"mul"}, "Mul", {"x", "y"}, {}}},
+ /* Mapping between function returns and function node outputs. */
+ {{"z", "mul:z:0"}});
+
+ // Tensorflow graph:
+ //
+ // x0 = tf.Placeholder(tf.float);
+ // x1 = tf.Placeholder(tf.float);
+ // dy = tf.Placeholder(tf.float);
+ //
+ // mul_1 = MyMul1(x0, x1);
+ // mul_2 = MyMul2(x0, x1);
+ // dx = SymbolicGradient({x0, x1, dy}, f=MyMul2)
+ GrapplerItem item;
+ item.id = "main";
+ item.graph = test::function::GDef(
+ {NDef("x0", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("x1", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ NDef("dy", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
+ // Calls into function library
+ NDef("mul_1", "MyMul1", {"x0", "x1"}, {}, kDevice),
+ NDef("mul_2", "MyMul2", {"x0", "x1"}, {}, kDevice),
+ // Symbolic gradient of a MyMul2
+ NDef("dx", "SymbolicGradient", {"x0", "x1", "dy"},
+ {{"f", FDH::FunctionRef("MyMul2", {})},
+ {"Tin", DataTypeSlice{DT_FLOAT}},
+ {"Tout", DataTypeSlice{DT_FLOAT, DT_FLOAT}}},
+ kDevice)},
+ // FunctionLib
+ {mul_func_1, mul_func_2});
+ item.fetch = {"mul_1", "mul_2", "dx"};
+
+ GraphDef output;
+ TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
+
+ // Our custom optimizer must be called for the main graph and for the two
+ // functions.
+ ASSERT_EQ(allowed_optimizations.size(), 3);
+
+ auto allowed_optimizations_main =
+ gtl::FindOrNull(allowed_optimizations, "main");
+ ASSERT_NE(allowed_optimizations_main, nullptr);
+ EXPECT_TRUE(allowed_optimizations_main->non_differentiable_rewrites);
+
+ auto allowed_optimizations_my_mul_1 =
+ gtl::FindOrNull(allowed_optimizations, "MyMul1");
+ ASSERT_NE(allowed_optimizations_my_mul_1, nullptr);
+ EXPECT_TRUE(allowed_optimizations_my_mul_1->non_differentiable_rewrites);
+
+ auto allowed_optimizations_my_mul_2 =
+ gtl::FindOrNull(allowed_optimizations, "MyMul2");
+ ASSERT_NE(allowed_optimizations_my_mul_2, nullptr);
+ EXPECT_FALSE(allowed_optimizations_my_mul_2->non_differentiable_rewrites);
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
index 2190d38937..89eb76046e 100644
--- a/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/pin_to_host_optimizer.cc
@@ -169,7 +169,13 @@ bool IsTPUGraphDef(const GraphDef& def) {
}
// All the nodes that should be blacklisted and not swapped.
-bool IsBlacklisted(const NodeDef& node) { return IsCollective(node); }
+bool IsBlacklisted(const NodeDef& node) {
+ return
+ // Collective ops should not be swapped.
+ IsCollective(node) ||
+ // NoOp breaks perf regression tests (probably due to group dependencies).
+ IsNoOp(node);
+}
} // end namespace internal
Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
diff --git a/tensorflow/core/grappler/utils/functions.cc b/tensorflow/core/grappler/utils/functions.cc
index a428aea7f5..6861fb423c 100644
--- a/tensorflow/core/grappler/utils/functions.cc
+++ b/tensorflow/core/grappler/utils/functions.cc
@@ -41,7 +41,8 @@ Status RegisterFunctionBodyOutputs(const OpRegistrationData& registration,
tensorflow::NameRangeMap outputs_range_map;
TF_RETURN_IF_ERROR(tensorflow::NameRangesForNode(
node, registration.op_def, nullptr, &outputs_range_map));
- connectivity->RegisterFunctionBodyOutputs(node.name(), outputs_range_map);
+ connectivity->RegisterFunctionBodyOutputs(node.name(),
+ std::move(outputs_range_map));
return Status::OK();
}
@@ -75,20 +76,22 @@ Status ResolveFunctionBodyNodeAttrPlaceholders(
} // namespace
void GrapplerFunctionConnectivity::RegisterInputArgExpansion(
- const InputArgExpansion& input_arg_expansion) {
- const auto& input_name = input_arg_expansion.input_name;
+ InputArgExpansion input_arg_expansion) {
+ string input_name = input_arg_expansion.input_name;
const auto& placeholders = input_arg_expansion.placeholders;
- input_arg_expansions_.emplace(input_name, input_arg_expansion);
+
for (int i = 0; i < placeholders.size(); ++i) {
const string& placeholder = input_arg_expansion.placeholders[i];
- input_arg_placeholders_.emplace(
- placeholder, InputArgPlaceholder{input_name, /*position=*/i});
+ input_arg_placeholders_.insert(
+ {placeholder, InputArgPlaceholder{input_name, /*position=*/i}});
}
+ input_arg_expansions_.insert(
+ {std::move(input_name), std::move(input_arg_expansion)});
}
void GrapplerFunctionConnectivity::RegisterFunctionBodyOutputs(
- const string& node_name, const tensorflow::NameRangeMap& outputs) {
- function_body_outputs_[node_name] = outputs;
+ const string& node_name, tensorflow::NameRangeMap&& outputs) {
+ function_body_outputs_[node_name] = std::move(outputs);
}
Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
@@ -174,11 +177,12 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
const auto& output_range = output->second;
if (position == -1) {
+ graph_def_inputs->reserve(graph_def_inputs->size() +
+ output_range.second - output_range.first);
// If position is not defined expand node output range
for (int i = output_range.first; i < output_range.second; ++i) {
- i == 0 ? graph_def_inputs->push_back(node_name)
- : graph_def_inputs->push_back(
- strings::StrCat(node_name, ":", i));
+ graph_def_inputs->push_back(
+ i == 0 ? node_name : strings::StrCat(node_name, ":", i));
}
} else {
if (position > (output_range.second - output_range.first)) {
@@ -187,9 +191,8 @@ Status GrapplerFunctionConnectivity::ExpandFunctionDefInput(
" position: ", position, " (out of range)");
}
int pos = output_range.first + position;
- pos == 0 ? graph_def_inputs->push_back(node_name)
- : graph_def_inputs->push_back(
- strings::StrCat(node_name, ":", pos));
+ graph_def_inputs->push_back(
+ pos == 0 ? node_name : strings::StrCat(node_name, ":", pos));
}
return Status::OK();
@@ -211,8 +214,8 @@ Status GrapplerFunctionConnectivity::ExpandNodeInputs(
}
function_body_node->clear_input();
- for (const string& expanded_input : expanded_inputs)
- function_body_node->add_input(expanded_input);
+ for (string& expanded_input : expanded_inputs)
+ function_body_node->add_input(std::move(expanded_input));
return Status::OK();
}
@@ -323,7 +326,7 @@ GrapplerFunctionItem::GrapplerFunctionItem(
// Fill the feed nodes with input placeholders.
for (const InputArgExpansion& input_arg : input_arg_expansions_) {
for (const string& placeholder : input_arg.placeholders) {
- feed.emplace_back(placeholder, Tensor());
+ feed.push_back({placeholder, Tensor()});
input_arg_placeholders_.insert(placeholder);
}
}
@@ -460,7 +463,7 @@ Status InstantiationBodyParameters(
auto it = func_instantiation_attr.find(placeholder);
if (it != func_instantiation_attr.end()) {
- body_parameters->emplace(placeholder, it->second);
+ body_parameters->insert({placeholder, it->second});
} else {
return errors::InvalidArgument("Can't resolve placeholder: ",
placeholder);
@@ -498,10 +501,6 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
// GraphDef input format (name[:position])
GrapplerFunctionConnectivity connectivity;
- std::vector<InputArgExpansion> inputs;
- std::vector<OutputArgExpansion> outputs;
- std::vector<string> keep_nodes;
-
// Function body shares the library with the graph that instantiated it.
GraphDef function_body;
*function_body.mutable_library() = flib.ToProto();
@@ -518,6 +517,9 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
}
}
+ std::vector<InputArgExpansion> inputs;
+ inputs.reserve(signature.input_arg_size());
+
// For each input argument create a placeholder in function body.
for (const OpDef::ArgDef& input : signature.input_arg()) {
if (!input.type_list_attr().empty() || !input.number_attr().empty()) {
@@ -542,9 +544,10 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
/*is_ref*/ input.is_ref(),
/*placeholders=*/{input.name()}};
connectivity.RegisterInputArgExpansion(input_expansion);
- inputs.push_back(input_expansion);
+ inputs.push_back(std::move(input_expansion));
}
+ std::vector<string> keep_nodes;
// Add all function nodes to the function body
for (const NodeDef& func_def_node : func.node_def()) {
NodeDef* new_node = function_body.add_node();
@@ -572,6 +575,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
TF_RETURN_IF_ERROR(connectivity.ExpandNodeInputs(&node));
}
+ std::vector<OutputArgExpansion> outputs;
+ outputs.reserve(signature.output_arg_size());
// Add function outputs
for (const OpDef::ArgDef& out : signature.output_arg()) {
std::vector<string> output_tensors;
@@ -589,8 +594,8 @@ Status MakeGrapplerFunctionItem(const FunctionDef& func,
OutputArgExpansion output{/*output_name=*/out.name(),
/*data_type=*/output_data_type,
/*is_ref=*/out.is_ref(),
- /*output_tensors=*/output_tensors};
- outputs.push_back(output);
+ /*output_tensors=*/std::move(output_tensors)};
+ outputs.push_back(std::move(output));
}
bool is_stateful = signature.is_stateful();
diff --git a/tensorflow/core/grappler/utils/functions.h b/tensorflow/core/grappler/utils/functions.h
index 733caf325f..ef944ced09 100644
--- a/tensorflow/core/grappler/utils/functions.h
+++ b/tensorflow/core/grappler/utils/functions.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <memory>
#include <string>
+#include <unordered_map>
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/function.pb.h"
@@ -70,9 +71,9 @@ struct OutputArgExpansion {
// and fold it back when doing backward conversion.
class GrapplerFunctionConnectivity {
public:
- void RegisterInputArgExpansion(const InputArgExpansion& input_arg_expansion);
+ void RegisterInputArgExpansion(InputArgExpansion input_arg_expansion);
void RegisterFunctionBodyOutputs(const string& node_name,
- const tensorflow::NameRangeMap& outputs);
+ tensorflow::NameRangeMap&& outputs);
// Expand input encoded in FunctionDef format (name[:output][:position]) into
// multiple inputs in GraphDef format (name[:position]).
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 0b8e9ec527..9439ab332c 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -1197,8 +1197,10 @@ tf_cc_test(
tf_cc_test(
name = "example_parsing_ops_test",
- size = "large",
+ size = "medium",
srcs = ["example_parsing_ops_test.cc"],
+ shard_count = 4,
+ tags = ["optonly"],
deps = [
":example_parsing_ops",
":ops_testutil",
@@ -4049,11 +4051,6 @@ cc_library(
)
SPARSE_DEPS = [
- ":bounds_check",
- ":cwise_op",
- ":fill_functor",
- ":scatter_functor",
- "//third_party/eigen3",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:sparse_ops_op_lib",
@@ -4086,7 +4083,9 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_cross_op",
prefix = "sparse_cross_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4098,13 +4097,19 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_dense_binary_op_shared",
prefix = "sparse_dense_binary_op_shared",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":cwise_op",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_sparse_binary_op_shared",
prefix = "sparse_sparse_binary_op_shared",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":cwise_op",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4136,7 +4141,9 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_softmax",
prefix = "sparse_softmax",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
@@ -4148,25 +4155,37 @@ tf_kernel_library(
tf_kernel_library(
name = "sparse_tensor_dense_add_op",
prefix = "sparse_tensor_dense_add_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":scatter_functor",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_tensor_dense_matmul_op",
prefix = "sparse_tensor_dense_matmul_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":bounds_check",
+ ":fill_functor",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_to_dense_op",
prefix = "sparse_to_dense_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
name = "sparse_xent_op",
prefix = "sparse_xent_op",
- deps = SPARSE_DEPS,
+ deps = SPARSE_DEPS + [
+ ":bounds_check",
+ "//third_party/eigen3",
+ ],
)
tf_kernel_library(
diff --git a/tensorflow/core/kernels/batching_util/BUILD b/tensorflow/core/kernels/batching_util/BUILD
index 039b0db144..0d53240330 100644
--- a/tensorflow/core/kernels/batching_util/BUILD
+++ b/tensorflow/core/kernels/batching_util/BUILD
@@ -12,11 +12,6 @@ cc_library(
name = "periodic_function_dynamic",
srcs = ["periodic_function.cc"],
hdrs = ["periodic_function.h"],
- visibility = [
- "//learning/serving:__subpackages__",
- "//tensorflow:internal",
- "//tensorflow_serving:__subpackages__",
- ],
deps = [
"//tensorflow/core:framework_headers_lib",
"//tensorflow/core:protos_all_cc",
@@ -25,11 +20,6 @@ cc_library(
cc_library(
name = "periodic_function",
- visibility = [
- "//learning/serving:__subpackages__",
- "//tensorflow:internal",
- "//tensorflow_serving:__subpackages__",
- ],
deps = [
":periodic_function_dynamic",
"//tensorflow/core:lib",
@@ -198,11 +188,6 @@ cc_library(
testonly = 1,
srcs = ["fake_clock_env.cc"],
hdrs = ["fake_clock_env.h"],
- visibility = [
- "//learning/serving:__subpackages__",
- "//tensorflow:internal",
- "//tensorflow_serving:__subpackages__",
- ],
deps = [
"//tensorflow/core:lib",
"//tensorflow/core:tensorflow",
diff --git a/tensorflow/core/kernels/collective_ops.cc b/tensorflow/core/kernels/collective_ops.cc
index fa959b5a0e..82e2913b64 100644
--- a/tensorflow/core/kernels/collective_ops.cc
+++ b/tensorflow/core/kernels/collective_ops.cc
@@ -132,7 +132,6 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
"Failed to get CollectiveExecutor from OpKernelContext for Op ",
col_params_.name),
done);
- col_params_.instance.shape = c->input(0).shape();
// Allocate output on the first pass through this function. This must be
// done immediately, while we're still in the executor thread. Otherwise
// the memory is not guaranteed to be unused by any concurrently executing
@@ -144,6 +143,7 @@ class CollectiveReduceOpKernel : public CollectiveOpKernel {
c->forward_input_or_allocate_output(
{0}, 0, c->input(0).shape(), &output),
done);
+ col_params_.instance.shape = c->input(0).shape();
}
if (!CanProceedWithCompute(c, col_exec, done)) return;
auto actual_done = [c, col_exec, done](const Status& s) {
diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD
index 87efdff789..6333853cdf 100644
--- a/tensorflow/core/kernels/data/BUILD
+++ b/tensorflow/core/kernels/data/BUILD
@@ -765,6 +765,7 @@ tf_kernel_library(
":window_dataset_op",
":writer_ops",
":zip_dataset_op",
+ "//tensorflow/core/kernels/data/experimental:dataset_kernels",
],
)
diff --git a/tensorflow/contrib/data/kernels/BUILD b/tensorflow/core/kernels/data/experimental/BUILD
index ec6cb37193..43406db3ed 100644
--- a/tensorflow/contrib/data/kernels/BUILD
+++ b/tensorflow/core/kernels/data/experimental/BUILD
@@ -1,22 +1,26 @@
# Description:
-# Contains kernels for datasets and iterators.
+# Contains experimental kernels for datasets and iterators.
package(default_visibility = ["//tensorflow:internal"])
licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_kernel_library",
+)
+
cc_library(
name = "indexed_dataset_headers",
hdrs = ["indexed_dataset.h"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:framework",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
)
-cc_library(
+tf_kernel_library(
name = "indexed_dataset",
srcs = [
"identity_indexed_dataset.cc",
@@ -24,103 +28,102 @@ cc_library(
],
deps = [
":indexed_dataset_headers",
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "prefetching_kernels",
srcs = ["prefetching_kernels.cc"],
deps = [
- "//tensorflow/core:core_cpu_headers_lib",
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "directed_interleave_dataset_op",
srcs = ["directed_interleave_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "csv_dataset_op",
srcs = ["csv_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "ignore_errors_dataset_op",
srcs = ["ignore_errors_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "lmdb_dataset_op",
srcs = ["lmdb_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
"@lmdb",
- "@protobuf_archive//:protobuf_headers",
],
)
-cc_library(
+tf_kernel_library(
name = "threadpool_dataset_op",
srcs = ["threadpool_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "unique_dataset_op",
srcs = ["unique_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "assert_next_dataset_op",
srcs = ["assert_next_dataset_op.cc"],
deps = [
- "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:experimental_dataset_ops_op_lib",
+ "//tensorflow/core:framework",
"//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
- alwayslink = 1,
)
-cc_library(
+tf_kernel_library(
name = "dataset_kernels",
deps = [
":assert_next_dataset_op",
@@ -132,8 +135,5 @@ cc_library(
":prefetching_kernels",
":threadpool_dataset_op",
":unique_dataset_op",
- "//tensorflow/core:framework_headers_lib",
- "//third_party/eigen3",
- "@protobuf_archive//:protobuf_headers",
],
)
diff --git a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
index c19a609780..3511cca0f5 100644
--- a/tensorflow/contrib/data/kernels/assert_next_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/assert_next_dataset_op.cc
@@ -147,8 +147,9 @@ class AssertNextDatasetOp : public UnaryDatasetOpKernel {
std::vector<PartialTensorShape> output_shapes_;
};
-REGISTER_KERNEL_BUILDER(Name("AssertNextDataset").Device(DEVICE_CPU),
- AssertNextDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalAssertNextDataset").Device(DEVICE_CPU),
+ AssertNextDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/csv_dataset_op.cc b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
index 21ec50fb6b..7451ca4cb1 100644
--- a/tensorflow/contrib/data/kernels/csv_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/csv_dataset_op.cc
@@ -852,7 +852,8 @@ class CSVDatasetOp : public DatasetOpKernel {
}; // class CSVDatasetOp
// Register the kernel implementation for CSVDataset.
-REGISTER_KERNEL_BUILDER(Name("CSVDataset").Device(DEVICE_CPU), CSVDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalCSVDataset").Device(DEVICE_CPU),
+ CSVDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
index a5321620bf..c47a9099c4 100644
--- a/tensorflow/contrib/data/kernels/directed_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/directed_interleave_dataset_op.cc
@@ -272,8 +272,9 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("DirectedInterleaveDataset").Device(DEVICE_CPU),
- DirectedInterleaveDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalDirectedInterleaveDataset").Device(DEVICE_CPU),
+ DirectedInterleaveDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
index c3cb45dbf7..2141f118ca 100644
--- a/tensorflow/contrib/data/kernels/identity_indexed_dataset.cc
+++ b/tensorflow/core/kernels/data/experimental/identity_indexed_dataset.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
#include "tensorflow/core/lib/core/errors.h"
namespace tensorflow {
@@ -147,8 +147,9 @@ class IdentityIndexedDatasetOp : public IndexedDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("IdentityIndexedDataset").Device(DEVICE_CPU),
- IdentityIndexedDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIdentityIndexedDataset").Device(DEVICE_CPU),
+ IdentityIndexedDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
index beec344534..b34377c642 100644
--- a/tensorflow/contrib/data/kernels/ignore_errors_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/ignore_errors_dataset_op.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
-#include "tensorflow/core/lib/random/random.h"
namespace tensorflow {
namespace data {
@@ -133,8 +132,9 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("IgnoreErrorsDataset").Device(DEVICE_CPU),
- IgnoreErrorsDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIgnoreErrorsDataset").Device(DEVICE_CPU),
+ IgnoreErrorsDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.cc b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
index ced8ab0d60..75ea462f40 100644
--- a/tensorflow/contrib/data/kernels/indexed_dataset.cc
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/contrib/data/kernels/indexed_dataset.h"
+#include "tensorflow/core/kernels/data/experimental/indexed_dataset.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/framework/tensor_shape.h"
@@ -361,12 +361,14 @@ class IndexedDatasetGet : public OpKernel {
};
REGISTER_KERNEL_BUILDER(
- Name("MaterializedIndexDatasetHandle").Device(DEVICE_CPU),
+ Name("ExperimentalMaterializedIndexDatasetHandle").Device(DEVICE_CPU),
MaterializedHandleOp);
-REGISTER_KERNEL_BUILDER(Name("IndexedDatasetMaterialize").Device(DEVICE_CPU),
- MaterializeDatasetOp);
-REGISTER_KERNEL_BUILDER(Name("IndexedDatasetGet").Device(DEVICE_CPU),
- IndexedDatasetGet);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetMaterialize").Device(DEVICE_CPU),
+ MaterializeDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIndexedDatasetGet").Device(DEVICE_CPU),
+ IndexedDatasetGet);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/indexed_dataset.h b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
index 7aa2d3fdbc..27a8360cbc 100644
--- a/tensorflow/contrib/data/kernels/indexed_dataset.h
+++ b/tensorflow/core/kernels/data/experimental/indexed_dataset.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
-#define TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
+#ifndef TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
+#define TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -116,4 +116,4 @@ Status StoreIndexedDatasetInVariantTensor(IndexedDataset* dataset,
} // namespace data
} // namespace tensorflow
-#endif // TENSORFLOW_CONTRIB_DATA_KERNELS_INDEXED_DATASET_H_
+#endif // TENSORFLOW_CORE_KERNELS_DATA_EXPERIMENTAL_INDEXED_DATASET_H_
diff --git a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
index d233c1f8ec..8a88d32f0c 100644
--- a/tensorflow/contrib/data/kernels/lmdb_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/lmdb_dataset_op.cc
@@ -210,7 +210,8 @@ class LMDBDatasetOp : public DatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("LMDBDataset").Device(DEVICE_CPU), LMDBDatasetOp);
+REGISTER_KERNEL_BUILDER(Name("ExperimentalLMDBDataset").Device(DEVICE_CPU),
+ LMDBDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/prefetching_kernels.cc b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
index 96f1dd0059..2c6179d9f5 100644
--- a/tensorflow/contrib/data/kernels/prefetching_kernels.cc
+++ b/tensorflow/core/kernels/data/experimental/prefetching_kernels.cc
@@ -338,20 +338,20 @@ class FunctionBufferResourceHandleOp : public OpKernel {
DataTypeVector output_types_;
};
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
.Device(DEVICE_CPU)
.HostMemory("resource")
.HostMemory("string_arg")
.HostMemory("target_device"),
FunctionBufferResourceHandleOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
.Device(DEVICE_GPU)
.HostMemory("resource")
.HostMemory("string_arg")
.HostMemory("target_device"),
FunctionBufferResourceHandleOp);
#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResource")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResource")
.Device(DEVICE_SYCL)
.HostMemory("resource")
.HostMemory("string_arg")
@@ -403,16 +403,16 @@ class FunctionBufferingResourceGetNextOp : public AsyncOpKernel {
}
};
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
.Device(DEVICE_CPU)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceGetNextOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
.Device(DEVICE_GPU)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceGetNextOp);
#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceGetNext")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceGetNext")
.Device(DEVICE_SYCL)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceGetNextOp);
@@ -440,16 +440,16 @@ class FunctionBufferingResourceResetOp : public OpKernel {
}
};
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
.Device(DEVICE_CPU)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceResetOp);
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
.Device(DEVICE_GPU)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceResetOp);
#if TENSORFLOW_USE_SYCL
-REGISTER_KERNEL_BUILDER(Name("FunctionBufferingResourceReset")
+REGISTER_KERNEL_BUILDER(Name("ExperimentalFunctionBufferingResourceReset")
.Device(DEVICE_SYCL)
.HostMemory("function_buffer_resource"),
FunctionBufferingResourceResetOp);
@@ -473,8 +473,9 @@ class IteratorGetDeviceOp : public OpKernel {
}
};
-REGISTER_KERNEL_BUILDER(Name("IteratorGetDevice").Device(DEVICE_CPU),
- IteratorGetDeviceOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalIteratorGetDevice").Device(DEVICE_CPU),
+ IteratorGetDeviceOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
index 30fa97a636..c80493d3a1 100644
--- a/tensorflow/contrib/data/kernels/threadpool_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/threadpool_dataset_op.cc
@@ -209,10 +209,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("ThreadPoolHandle").Device(DEVICE_CPU),
+REGISTER_KERNEL_BUILDER(Name("ExperimentalThreadPoolHandle").Device(DEVICE_CPU),
ThreadPoolHandleOp);
-REGISTER_KERNEL_BUILDER(Name("ThreadPoolDataset").Device(DEVICE_CPU),
- ThreadPoolDatasetOp);
+REGISTER_KERNEL_BUILDER(
+ Name("ExperimentalThreadPoolDataset").Device(DEVICE_CPU),
+ ThreadPoolDatasetOp);
} // namespace
} // namespace data
diff --git a/tensorflow/contrib/data/kernels/unique_dataset_op.cc b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
index 57fc5697a4..cd612e0eb2 100644
--- a/tensorflow/contrib/data/kernels/unique_dataset_op.cc
+++ b/tensorflow/core/kernels/data/experimental/unique_dataset_op.cc
@@ -199,8 +199,9 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
HANDLE_TYPE(DT_INT64);
HANDLE_TYPE(DT_STRING);
default:
- LOG(FATAL) << "UniqueDataset unhandled data type: "
- << DataTypeString(lhs.dtype());
+ DCHECK(false) << "UniqueDataset unhandled data type: "
+ << DataTypeString(lhs.dtype());
+ return false;
}
}
};
@@ -215,7 +216,7 @@ class UniqueDatasetOp : public UnaryDatasetOpKernel {
};
};
-REGISTER_KERNEL_BUILDER(Name("UniqueDataset").Device(DEVICE_CPU),
+REGISTER_KERNEL_BUILDER(Name("ExperimentalUniqueDataset").Device(DEVICE_CPU),
UniqueDatasetOp);
} // namespace
diff --git a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
index 2bbf4af664..bf08970560 100644
--- a/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/data/map_and_batch_dataset_op.cc
@@ -37,6 +37,8 @@ namespace {
// See documentation in ../ops/dataset_ops.cc for a high-level
// description of the following op.
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit MapAndBatchDatasetOp(OpKernelConstruction* ctx)
@@ -185,29 +187,31 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- num_parallel_calls_(params.dataset->num_parallel_calls_) {}
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ params.dataset->num_parallel_calls_, mu_, cond_var_)) {}
~Iterator() override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
- cond_var_.notify_all();
+ cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
AddConstantParameter(ctx, "batch_size", dataset()->batch_size_);
- if (num_parallel_calls_ == kAutoTune) {
- num_parallel_calls_ = 1;
- AddTunableParameter(ctx, "parallelism",
- &num_parallel_calls_ /* value */, 1 /* min */,
- port::NumSchedulableCPUs() /* max */, &cond_var_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ port::NumSchedulableCPUs());
} else {
- AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
}
TF_RETURN_IF_ERROR(
dataset()->input_->MakeIterator(ctx, prefix(), &input_impl_));
@@ -219,27 +223,27 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
bool* end_of_sequence) override {
std::shared_ptr<BatchResult> result;
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (batch_results_.empty() ||
batch_results_.front()->num_calls > 0) {
RecordStop(ctx);
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx);
}
std::swap(result, batch_results_.front());
batch_results_.pop_front();
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
return ProcessResult(ctx, result, out_tensors, end_of_sequence);
}
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
@@ -255,7 +259,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
TF_RETURN_IF_ERROR(
reader->ReadScalar(full_name("call_counter"), &call_counter_));
@@ -296,7 +300,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
void Callback(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<BatchResult>& result,
const std::shared_ptr<std::vector<Tensor>>& return_values,
- int64 offset, const Status& status) LOCKS_EXCLUDED(mu_) {
+ int64 offset, const Status& status) LOCKS_EXCLUDED(*mu_) {
result->UpdateStatus(status);
if (status.ok()) {
EnsureOutputAllocated(ctx, result, return_values);
@@ -332,16 +336,16 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void CallCompleted(const std::shared_ptr<BatchResult>& result)
- LOCKS_EXCLUDED(mu_) {
- mutex_lock l(mu_);
+ LOCKS_EXCLUDED(*mu_) {
+ mutex_lock l(*mu_);
num_calls_--;
result->num_calls--;
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
void CallFunction(std::shared_ptr<IteratorContext> ctx,
const std::shared_ptr<BatchResult>& result,
- int64 offset) LOCKS_EXCLUDED(mu_) {
+ int64 offset) LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
bool end_of_input;
@@ -398,7 +402,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
@@ -474,14 +478,14 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
std::vector<std::pair<std::shared_ptr<BatchResult>, int64>> new_calls;
RecordStart(ctx.get());
auto stop_cleanup =
gtl::MakeCleanup([this, &ctx]() { RecordStop(ctx.get()); });
- new_calls.reserve(num_parallel_calls_);
- auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
- int64 num_parallel_calls = num_parallel_calls_;
+ new_calls.reserve(num_parallel_calls_->value);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_->value;
int64 max_batch_results =
(num_parallel_calls + dataset()->batch_size_ - 1) /
dataset()->batch_size_;
@@ -492,10 +496,10 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
};
while (true) {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
while (!cancelled_ && busy()) {
RecordStop(ctx.get());
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx.get());
}
@@ -522,7 +526,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status ReadBatchResult(IteratorContext* ctx, IteratorStateReader* reader,
- size_t index) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ size_t index) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
batch_results_.emplace_back(new BatchResult(dataset()->batch_size_));
std::shared_ptr<BatchResult> result = batch_results_.back();
string prefix = strings::StrCat("batch_results_", index);
@@ -567,7 +571,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status ReadStatus(IteratorStateReader* reader, const string& prefix,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name(strings::StrCat(prefix, "_code")), &code_int));
@@ -585,7 +589,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status WriteBatchResult(IteratorStateWriter* writer, size_t index)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
std::shared_ptr<BatchResult> result = batch_results_[index];
string prefix = strings::StrCat("batch_results_", index);
mutex_lock l(result->mu);
@@ -626,7 +630,7 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
}
Status WriteStatus(IteratorStateWriter* writer, const string& prefix,
- const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const Status& status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(full_name(strings::StrCat(prefix, "_code")),
static_cast<int64>(status.code())));
@@ -640,24 +644,24 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// Used for coordination between the main thread, the runner thread, and
// the callback threads.
- mutex mu_;
+ const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread, the runner thread, and
// the callback threads. In particular, the runner thread should only
- // schedule new calls when the number of in-flight calls is less than the
- // user specified level of parallelism and there are slots available in
- // the `batch_results_` buffer.
- condition_variable cond_var_;
+ // schedule new calls when the number of in-flight calls is less than
+ // `num_parallel_calls_->value` and there are slots available in the
+ // `batch_results_` buffer.
+ const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
- std::atomic<int64> num_parallel_calls_;
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Counts the number of outstanding calls for this batch.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
// Counts the total number of calls.
- int64 call_counter_ GUARDED_BY(mu_) = 0;
+ int64 call_counter_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the (intermediate) batch results.
- std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
- bool cancelled_ GUARDED_BY(mu_) = false;
+ std::deque<std::shared_ptr<BatchResult>> batch_results_ GUARDED_BY(*mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ bool cancelled_ GUARDED_BY(*mu_) = false;
};
const DatasetBase* const input_;
diff --git a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
index 5f143967d9..d909b9e9d3 100644
--- a/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
+++ b/tensorflow/core/kernels/data/multi_device_iterator_ops.cc
@@ -134,19 +134,17 @@ class MultiDeviceIterator : public ResourceBase {
void Reset() LOCKS_EXCLUDED(mu_) {
{
mutex_lock l(mu_);
- if (background_thread_finished_) {
- return;
- }
-
- cancelled_ = true;
- // Wake up the background thread.
- for (int i = 0; i < size_; ++i) {
- buffer_[i].cond_var.notify_all();
- }
+ if (!background_thread_finished_) {
+ cancelled_ = true;
+ // Wake up the background thread.
+ for (int i = 0; i < size_; ++i) {
+ buffer_[i].cond_var.notify_all();
+ }
- // Make sure background thread has finished first.
- while (!background_thread_finished_) {
- shutdown_cond_var_.wait(l);
+ // Make sure background thread has finished first.
+ while (!background_thread_finished_) {
+ shutdown_cond_var_.wait(l);
+ }
}
}
RunPendingCallbacks();
@@ -182,7 +180,7 @@ class MultiDeviceIterator : public ResourceBase {
buffer_[shard_num].cond_var.notify_all();
}
} else {
- if (background_thread_finished_) {
+ if (end_of_iterator_) {
produced_output = true;
elem.end_of_sequence = true;
} else {
@@ -219,8 +217,12 @@ class MultiDeviceIterator : public ResourceBase {
while (!buffer_[i].callbacks.empty()) {
if (buffer_[i].data.empty()) {
HostBufferElement elem;
- elem.status =
- errors::Cancelled("Cancelled and buffer not filled.");
+ if (end_of_iterator_) {
+ elem.end_of_sequence = true;
+ } else {
+ elem.status =
+ errors::Cancelled("Cancelled and buffer not filled.");
+ }
cancellation_elements.push_back(std::move(elem));
} else {
cancellation_elements.push_back(
@@ -293,6 +295,7 @@ class MultiDeviceIterator : public ResourceBase {
{
mutex_lock l(mu_);
background_thread_finished_ = true;
+ end_of_iterator_ = true;
shutdown_cond_var_.notify_all();
}
RunPendingCallbacks();
@@ -312,6 +315,7 @@ class MultiDeviceIterator : public ResourceBase {
std::unique_ptr<Thread> background_thread_ GUARDED_BY(mu_);
bool background_thread_finished_ GUARDED_BY(mu_) = false;
bool background_thread_started_ GUARDED_BY(mu_) = false;
+ bool end_of_iterator_ GUARDED_BY(mu_) = false;
bool cancelled_ GUARDED_BY(mu_) = false;
condition_variable shutdown_cond_var_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/kernels/data/optimize_dataset_op.cc b/tensorflow/core/kernels/data/optimize_dataset_op.cc
index d5b725eac9..1cb7caa738 100644
--- a/tensorflow/core/kernels/data/optimize_dataset_op.cc
+++ b/tensorflow/core/kernels/data/optimize_dataset_op.cc
@@ -154,12 +154,8 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
: DatasetIterator<Dataset>(params) {}
Status Initialize(IteratorContext* ctx) override {
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = *(ctx->runner());
- params.stats_aggregator_getter = ctx->stats_aggregator_getter();
+ IteratorContext::Params params = ctx->params();
params.lib = dataset()->lib_;
- params.allocator_getter = ctx->allocator_getter();
return dataset()->optimized_input_->MakeIterator(
IteratorContext(params), prefix(), &input_impl_);
}
@@ -167,14 +163,10 @@ class OptimizeDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
- IteratorContext::Params params;
- params.env = ctx->env();
- params.runner = *(ctx->runner());
- params.stats_aggregator_getter = ctx->stats_aggregator_getter();
+ IteratorContext::Params params = ctx->params();
params.lib = dataset()->lib_;
- params.allocator_getter = ctx->allocator_getter();
- IteratorContext iter_ctx(params);
- return input_impl_->GetNext(&iter_ctx, out_tensors, end_of_sequence);
+ return input_impl_->GetNext(IteratorContext(params), out_tensors,
+ end_of_sequence);
}
protected:
diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
index 2e6e0465f7..6b6b3d6ab9 100644
--- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc
@@ -1084,6 +1084,9 @@ REGISTER_KERNEL_BUILDER(Name("ParallelInterleaveDataset").Device(DEVICE_CPU),
// The above design choices were made with automated optimizations in mind,
// isolating the degree of parallelism as the single tunable knob of this
// implementation.
+//
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit ParallelInterleaveDatasetV2Op(OpKernelConstruction* ctx)
@@ -1214,7 +1217,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
- num_parallel_calls_(params.dataset->num_parallel_calls_),
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ params.dataset->num_parallel_calls_, mu_, cond_var_)),
args_list_(params.dataset->cycle_length_),
current_elements_(params.dataset->cycle_length_),
element_in_use_(params.dataset->cycle_length_, false),
@@ -1224,25 +1230,24 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
false /* low_latency_hint */)) {}
~Iterator() override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
- cond_var_.notify_all();
+ cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
- mutex_lock l(mu_);
- if (num_parallel_calls_ == kAutoTune) {
- num_parallel_calls_ = 1;
- AddTunableParameter(ctx, "parallelism",
- &num_parallel_calls_ /* value */, 1 /* min */,
- dataset()->cycle_length_ /* max */, &cond_var_);
+ mutex_lock l(*mu_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ dataset()->cycle_length_);
} else {
- AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
}
AddConstantParameter(ctx, "cycle_length", dataset()->cycle_length_);
TF_RETURN_IF_ERROR(
@@ -1256,12 +1261,12 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
std::shared_ptr<InvocationResult> result;
do {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty() &&
(!end_of_input_ || num_open_ > 0)) {
RecordStop(ctx);
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx);
}
if (!invocation_results_.empty()) {
@@ -1271,7 +1276,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
*end_of_sequence = true;
return Status::OK();
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
RecordStop(ctx);
result->notification.WaitForNotification();
@@ -1287,10 +1292,10 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
@@ -1328,7 +1333,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
@@ -1381,7 +1386,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
};
void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> new_ctx(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
@@ -1398,7 +1403,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
void FetchOutputs(
const std::shared_ptr<IteratorContext>& ctx, int64 cycle_index,
const std::vector<std::shared_ptr<InvocationResult>>& results)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
bool end_of_input = false;
@@ -1421,14 +1426,14 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
if (end_of_input) {
current_elements_[cycle_index].reset();
}
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
element_in_use_[cycle_index] = false;
num_calls_--;
if (end_of_input) {
args_list_[cycle_index].clear();
num_open_--;
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
// Method responsible for 1) creating iterators out of input elements, 2)
@@ -1439,20 +1444,20 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
void RunnerThread(const std::shared_ptr<IteratorContext>& ctx) {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
- auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
return element_in_use_[cycle_index_] ||
- num_calls_ >= num_parallel_calls_ ||
+ num_calls_ >= num_parallel_calls_->value ||
invocation_results_.size() >=
dataset()->cycle_length_ * dataset()->block_length_;
};
while (true) {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait until this thread is cancelled, the end of input has been
// reached, or the cycle element at the `cycle_index_` position is
// not in use and there is space in the `invocation_results_` queue.
while (!cancelled_ && (!end_of_input_ || num_open_ > 0) && busy()) {
RecordStop(ctx.get());
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx.get());
}
@@ -1506,13 +1511,13 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
cycle_index_ = (cycle_index_ + 1) % dataset()->cycle_length_;
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
const Status& status)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
@@ -1523,7 +1528,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
@@ -1550,7 +1555,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
}
Status WriteCurrentElements(IteratorStateWriter* writer)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (current_elements_[idx]) {
TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx]));
@@ -1569,7 +1574,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
Status ReadCurrentElements(IteratorContext* ctx,
IteratorStateReader* reader)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
for (int idx = 0; idx < current_elements_.size(); idx++) {
if (reader->Contains(
full_name(strings::StrCat("args_size[", idx, "]")))) {
@@ -1597,7 +1602,7 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// Used for coordination between the main thread, the runner thread, and
// the worker threads.
- mutex mu_;
+ const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread, the runner thread, and
// the worker threads. In particular, the runner thread should only
@@ -1605,45 +1610,45 @@ class ParallelInterleaveDatasetV2Op : public UnaryDatasetOpKernel {
// user specified level of parallelism, there are slots available in the
// `invocation_results_` buffer, the current cycle element is not in use,
// and there are elements left to be fetched.
- condition_variable cond_var_;
+ const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
- std::atomic<int64> num_parallel_calls_;
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Iterator for input elements.
- std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(*mu_);
// Identifies current cycle element.
int64 cycle_index_ = 0;
// Arguments for creating an iterator for cycle elements.
- std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(mu_);
+ std::vector<std::vector<Tensor>> args_list_ GUARDED_BY(*mu_);
// Iterators for the current cycle elements. Concurrent access is
// protected by `element_in_use_`.
std::vector<std::unique_ptr<IteratorBase>> current_elements_;
// Identifies cycle elements that are in use by worker threads.
- std::vector<bool> element_in_use_ GUARDED_BY(mu_);
+ std::vector<bool> element_in_use_ GUARDED_BY(*mu_);
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
- GUARDED_BY(mu_);
+ GUARDED_BY(*mu_);
// Identifies whether end of input has been reached.
- bool end_of_input_ GUARDED_BY(mu_) = false;
+ bool end_of_input_ GUARDED_BY(*mu_) = false;
// Identifies the number of open iterators.
- int64 num_open_ GUARDED_BY(mu_) = 0;
+ int64 num_open_ GUARDED_BY(*mu_) = 0;
// Identifies the number of outstanding calls.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<thread::ThreadPool> thread_pool_;
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
// Identifies whether background activity should be cancelled.
- bool cancelled_ GUARDED_BY(mu_) = false;
+ bool cancelled_ GUARDED_BY(*mu_) = false;
};
const DatasetBase* const input_;
diff --git a/tensorflow/core/kernels/data/parallel_map_iterator.cc b/tensorflow/core/kernels/data/parallel_map_iterator.cc
index ee20249bfe..13bd4b6036 100644
--- a/tensorflow/core/kernels/data/parallel_map_iterator.cc
+++ b/tensorflow/core/kernels/data/parallel_map_iterator.cc
@@ -27,6 +27,8 @@ namespace tensorflow {
namespace data {
namespace {
+// TODO(b/116852688): Make coordination between the performance model and this
+// transformation more robust.
class ParallelMapIterator : public DatasetBaseIterator {
public:
explicit ParallelMapIterator(
@@ -38,30 +40,32 @@ class ParallelMapIterator : public DatasetBaseIterator {
input_dataset_(input_dataset),
init_func_(std::move(init_func)),
map_func_(std::move(map_func)),
- num_parallel_calls_(num_parallel_calls) {}
+ mu_(std::make_shared<mutex>()),
+ cond_var_(std::make_shared<condition_variable>()),
+ num_parallel_calls_(std::make_shared<model::SharedState>(
+ num_parallel_calls, mu_, cond_var_)) {}
~ParallelMapIterator() override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Cancel the runner thread.
cancelled_ = true;
- cond_var_.notify_all();
+ cond_var_->notify_all();
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
}
Status Initialize(IteratorContext* ctx) override {
- mutex_lock l(mu_);
- if (num_parallel_calls_ == kAutoTune) {
- num_parallel_calls_ = 1;
+ mutex_lock l(*mu_);
+ if (num_parallel_calls_->value == kAutoTune) {
+ num_parallel_calls_->value = 1;
// TODO(jsimsa): Surface the number of threads used by `ctx->runner()` and
// use it here for the maximum.
- AddTunableParameter(ctx, "parallelism", &num_parallel_calls_ /* value */,
- 1 /* min */, port::NumSchedulableCPUs() /* max */,
- &cond_var_);
+ AddTunableParameter(ctx, "parallelism", num_parallel_calls_, 1,
+ port::NumSchedulableCPUs());
} else {
- AddConstantParameter(ctx, "parallelism", num_parallel_calls_);
+ AddConstantParameter(ctx, "parallelism", num_parallel_calls_->value);
}
TF_RETURN_IF_ERROR(
input_dataset_->MakeIterator(ctx, prefix(), &input_impl_));
@@ -75,16 +79,16 @@ class ParallelMapIterator : public DatasetBaseIterator {
bool* end_of_sequence) override {
std::shared_ptr<InvocationResult> result;
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
EnsureRunnerThreadStarted(ctx);
while (invocation_results_.empty()) {
RecordStop(ctx);
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx);
}
std::swap(result, invocation_results_.front());
invocation_results_.pop_front();
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
RecordStop(ctx);
result->notification.WaitForNotification();
@@ -94,28 +98,27 @@ class ParallelMapIterator : public DatasetBaseIterator {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {
- cond_var_.wait(l);
+ cond_var_->wait(l);
}
CHECK_EQ(num_calls_, 0);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("invocation_results.size"),
invocation_results_.size()));
for (size_t i = 0; i < invocation_results_.size(); i++) {
- std::shared_ptr<InvocationResult> result = invocation_results_[i];
- TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result->status));
+ const auto& result = *(invocation_results_[i]);
+ TF_RETURN_IF_ERROR(WriteStatusLocked(writer, i, result.status));
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(strings::StrCat("invocation_results[", i, "].size")),
- result->return_values.size()));
- for (size_t j = 0; j < result->return_values.size(); j++) {
- TF_RETURN_IF_ERROR(
- writer->WriteTensor(full_name(strings::StrCat(
- "invocation_results[", i, "][", j, "]")),
- result->return_values[j]));
+ result.return_values.size()));
+ for (size_t j = 0; j < result.return_values.size(); j++) {
+ TF_RETURN_IF_ERROR(writer->WriteTensor(
+ full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
+ result.return_values[j]));
}
- if (result->end_of_input) {
+ if (result.end_of_input) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")),
@@ -127,15 +130,15 @@ class ParallelMapIterator : public DatasetBaseIterator {
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
int64 invocation_results_size;
TF_RETURN_IF_ERROR(reader->ReadScalar(
full_name("invocation_results.size"), &invocation_results_size));
for (size_t i = 0; i < invocation_results_size; i++) {
- std::shared_ptr<InvocationResult> result(new InvocationResult());
- invocation_results_.push_back(result);
- TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result->status));
+ invocation_results_.push_back(std::make_shared<InvocationResult>());
+ auto& result = *invocation_results_.back();
+ TF_RETURN_IF_ERROR(ReadStatusLocked(reader, i, &result.status));
size_t num_return_values;
{
int64 size;
@@ -151,17 +154,16 @@ class ParallelMapIterator : public DatasetBaseIterator {
": ", size, " is not a valid value of type size_t."));
}
}
- result->return_values.reserve(num_return_values);
+ result.return_values.reserve(num_return_values);
for (size_t j = 0; j < num_return_values; j++) {
- result->return_values.emplace_back();
- TF_RETURN_IF_ERROR(
- reader->ReadTensor(full_name(strings::StrCat(
- "invocation_results[", i, "][", j, "]")),
- &result->return_values.back()));
+ result.return_values.emplace_back();
+ TF_RETURN_IF_ERROR(reader->ReadTensor(
+ full_name(strings::StrCat("invocation_results[", i, "][", j, "]")),
+ &result.return_values.back()));
}
- result->end_of_input = reader->Contains(full_name(
+ result.end_of_input = reader->Contains(full_name(
strings::StrCat("invocation_results[", i, "].end_of_input")));
- result->notification.Notify();
+ result.notification.Notify();
}
return Status::OK();
}
@@ -175,7 +177,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
};
void EnsureRunnerThreadStarted(IteratorContext* ctx)
- EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
if (!runner_thread_) {
std::shared_ptr<IteratorContext> ctx_copy(new IteratorContext(*ctx));
runner_thread_.reset(ctx->env()->StartThread(
@@ -185,18 +187,18 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
void CallCompleted(const std::shared_ptr<InvocationResult>& result)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
num_calls_--;
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
result->notification.Notify();
}
void CallFunction(const std::shared_ptr<IteratorContext>& ctx,
const std::shared_ptr<InvocationResult>& result)
- LOCKS_EXCLUDED(mu_) {
+ LOCKS_EXCLUDED(*mu_) {
// Get the next input element.
std::vector<Tensor> input_element;
result->status =
@@ -239,29 +241,29 @@ class ParallelMapIterator : public DatasetBaseIterator {
RecordStart(ctx.get());
auto cleanup = gtl::MakeCleanup([this, ctx] { RecordStop(ctx.get()); });
std::vector<std::shared_ptr<InvocationResult>> new_calls;
- new_calls.reserve(num_parallel_calls_);
- auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(mu_) -> bool {
- int64 num_parallel_calls = num_parallel_calls_;
+ new_calls.reserve(num_parallel_calls_->value);
+ auto busy = [this]() EXCLUSIVE_LOCKS_REQUIRED(*mu_) -> bool {
+ int64 num_parallel_calls = num_parallel_calls_->value;
return num_calls_ >= num_parallel_calls ||
invocation_results_.size() >= num_parallel_calls;
};
while (true) {
{
- mutex_lock l(mu_);
+ mutex_lock l(*mu_);
while (!cancelled_ && busy()) {
RecordStop(ctx.get());
- cond_var_.wait(l);
+ cond_var_->wait(l);
RecordStart(ctx.get());
}
if (cancelled_) {
return;
}
while (!busy()) {
- invocation_results_.emplace_back(new InvocationResult());
+ invocation_results_.push_back(std::make_shared<InvocationResult>());
new_calls.push_back(invocation_results_.back());
num_calls_++;
}
- cond_var_.notify_all();
+ cond_var_->notify_all();
}
for (const auto& call : new_calls) {
CallFunction(ctx, call);
@@ -271,7 +273,8 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status WriteStatusLocked(IteratorStateWriter* writer, size_t index,
- const Status& status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ const Status& status)
+ EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
TF_RETURN_IF_ERROR(
writer->WriteScalar(CodeKey(index), static_cast<int64>(status.code())));
if (!status.ok()) {
@@ -282,7 +285,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status ReadStatusLocked(IteratorStateReader* reader, size_t index,
- Status* status) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ Status* status) EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
int64 code_int;
TF_RETURN_IF_ERROR(reader->ReadScalar(CodeKey(index), &code_int));
error::Code code = static_cast<error::Code>(code_int);
@@ -312,23 +315,23 @@ class ParallelMapIterator : public DatasetBaseIterator {
const std::function<Status(IteratorContext*)> init_func_;
const ParallelMapIteratorFunction map_func_;
// Used for coordination between the main thread and the runner thread.
- mutex mu_;
+ const std::shared_ptr<mutex> mu_;
// Used for coordination between the main thread and the runner thread. In
// particular, the runner thread should only schedule new calls when the
// number of in-flight calls is less than the user specified level of
// parallelism and there are slots available in the `invocation_results_`
// buffer.
- condition_variable cond_var_;
+ const std::shared_ptr<condition_variable> cond_var_;
// Identifies the maximum number of parallel calls.
- std::atomic<int64> num_parallel_calls_;
+ const std::shared_ptr<model::SharedState> num_parallel_calls_;
// Counts the number of outstanding calls.
- int64 num_calls_ GUARDED_BY(mu_) = 0;
+ int64 num_calls_ GUARDED_BY(*mu_) = 0;
std::unique_ptr<IteratorBase> input_impl_;
// Buffer for storing the invocation results.
std::deque<std::shared_ptr<InvocationResult>> invocation_results_
- GUARDED_BY(mu_);
- std::unique_ptr<Thread> runner_thread_ GUARDED_BY(mu_);
- bool cancelled_ GUARDED_BY(mu_) = false;
+ GUARDED_BY(*mu_);
+ std::unique_ptr<Thread> runner_thread_ GUARDED_BY(*mu_);
+ bool cancelled_ GUARDED_BY(*mu_) = false;
};
} // namespace
diff --git a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
index 7e528a71be..c8abfb9eb5 100644
--- a/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
+++ b/tensorflow/core/kernels/data/stats_aggregator_dataset_op.cc
@@ -118,16 +118,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
protected:
Status SaveInternal(IteratorStateWriter* writer) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
- return Status::OK();
+ return errors::Unimplemented(dataset()->DebugString(),
+ " does not support checkpointing");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
- mutex_lock l(mu_);
- TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
- return Status::OK();
+ return errors::Unimplemented(dataset()->DebugString(),
+ " does not support checkpointing");
}
private:
diff --git a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
index 52157ed5fb..f406ad2ab5 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_filter_ops.cc
@@ -853,7 +853,7 @@ class MklConvCustomBackpropFilterOp
// MKL DNN allocates large buffers when a conv gradient filter primtive is
// created. So we don't cache conv backward primitives when the env
- // variable TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is set to true.
+ // variable TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is set to true.
bool do_not_cache = MklPrimitiveFactory<T>::IsPrimitiveMemOptEnabled();
conv_bwd_filter = MklConvBwdFilterPrimitiveFactory<T>::Get(
convBwdFilterDims, do_not_cache);
diff --git a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
index c38c9cc27c..a501ce2c93 100644
--- a/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_grad_input_ops.cc
@@ -713,7 +713,7 @@ class MklConvCustomBackpropInputOp : public MklConvBackpropCommonOp<Device, T> {
TFPaddingToMklDnnPadding(this->padding_));
// We don't cache those primitves if the env variable
- // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true and if primitve descriptor
+ // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true and if primitve descriptor
// includes potentialy large buffers. MKL DNN allocates buffers
// in the following cases
// 1. Legacy CPU without AVX512/AVX2, or
diff --git a/tensorflow/core/kernels/mkl_conv_ops.cc b/tensorflow/core/kernels/mkl_conv_ops.cc
index 184e0cb003..b332edad0a 100644
--- a/tensorflow/core/kernels/mkl_conv_ops.cc
+++ b/tensorflow/core/kernels/mkl_conv_ops.cc
@@ -901,7 +901,7 @@ class MklConvOp : public OpKernel {
// In some cases, primitve descriptor includes potentialy large buffers,
// we don't cache those primitves if the env variable
- // TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE is true. MKL DNN allocates buffers
+ // TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE is true. MKL DNN allocates buffers
// in the following cases
// 1. Legacy CPU without AVX512/AVX2, or
// 2. 1x1 convolution with stride != 1
diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc
index 427044ca67..23d76986bf 100644
--- a/tensorflow/core/kernels/resource_variable_ops.cc
+++ b/tensorflow/core/kernels/resource_variable_ops.cc
@@ -172,17 +172,21 @@ REGISTER_KERNEL_BUILDER(
.Device(DEVICE_GPU) \
.HostMemory("resource") \
.TypeConstraint<type>("dtype"), \
- ResourceHandleOp<Var>) \
- REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp") \
- .Device(DEVICE_GPU) \
- .HostMemory("resources") \
- .TypeConstraint<type>("dtypes"), \
- ResourceHandlesOp<Var>)
-
+ ResourceHandleOp<Var>)
TF_CALL_GPU_ALL_TYPES(REGISTER_GPU_KERNELS);
TF_CALL_int64(REGISTER_GPU_KERNELS);
TF_CALL_variant(REGISTER_GPU_KERNELS);
#undef REGISTER_GPU_KERNELS
+
+REGISTER_KERNEL_BUILDER(Name("_VarHandlesOp")
+ .Device(DEVICE_GPU)
+ .HostMemory("resources")
+ .TypeConstraint("dtypes",
+ {DT_INT64, DT_COMPLEX64,
+ DT_COMPLEX128, DT_HALF, DT_FLOAT,
+ DT_DOUBLE, DT_BOOL, DT_VARIANT}),
+ ResourceHandlesOp<Var>);
+
#endif // GOOGLE_CUDA
template <typename T>
diff --git a/tensorflow/core/kernels/slice_op.cc b/tensorflow/core/kernels/slice_op.cc
index 97f77e45b6..a006c69297 100644
--- a/tensorflow/core/kernels/slice_op.cc
+++ b/tensorflow/core/kernels/slice_op.cc
@@ -228,191 +228,6 @@ class SliceOp : public OpKernel {
}
};
-#ifdef INTEL_MKL
-template <typename Device, typename T>
-class MklSliceOp : public OpKernel {
- public:
- explicit MklSliceOp(OpKernelConstruction* context) : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- TensorShape output_shape;
- gtl::InlinedVector<int64, 4> begin;
- gtl::InlinedVector<int64, 4> size;
- Tensor* result = nullptr;
- bool done = false;
- SharedSliceCommonCases<T>(context, &output_shape, &begin, &size, &result,
- &done);
- if (!context->status().ok() || done == true) return;
-
- const Tensor& input = context->input(0);
- const int input_dims = input.dims();
-
- if (output_shape.num_elements() > 0) {
- if (std::is_same<Device, CPUDevice>::value && input_dims == 2 &&
- DataTypeCanUseMemcpy(DataTypeToEnum<T>::v())) {
- auto input = context->input(0).tensor<T, 2>();
- auto output = result->tensor<T, 2>();
- // TODO(agarwal): Consider multi-threading this loop for cases where
- // size[0] is very large.
- for (int i = 0; i < size[0]; ++i) {
- const int64 row = begin[0] + i;
- if (i + 1 < size[0]) {
- port::prefetch<port::PREFETCH_HINT_T0>(&output(i + 1, 0));
- port::prefetch<port::PREFETCH_HINT_T0>(&input(row + 1, begin[1]));
- }
- memcpy(&output(i, 0), &input(row, begin[1]), size[1] * sizeof(T));
- }
- return;
- }
-#define HANDLE_DIM(NDIM) \
- if (input_dims == NDIM) { \
- HandleCase<NDIM>(context, begin, size, result); \
- return; \
- }
-
- HANDLE_DIM(1);
- HANDLE_DIM(2);
- HANDLE_DIM(3);
- HANDLE_DIM(4);
- HANDLE_DIM(5);
- HANDLE_DIM(6);
- HANDLE_DIM(7);
-
-#undef HANDLE_DIM
-
- OP_REQUIRES(
- context, false,
- errors::Unimplemented("SliceOp : Unhandled input dimensions"));
- }
- }
-
- private:
- // Helper function for DoesSliceShapeDifferInOnly1D. Checks if the following
- // criteria matches for slice_dim: if indices for slice are 0 in all dims
- // except slice_dim and if sizes of all the dimensions of the slice are same
- // as the sizes of all the dimensions of the input except slice_dim, then
- // returns True. Otherwise, returns False.
- bool DoesSliceShapeDifferInOnly1DHelper(const TensorShape& input_shape,
- const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size,
- int slice_dim) {
- for (int dim = 0; dim < 4; dim++) {
- if (dim != slice_dim &&
- (begin[dim] != 0 || size[dim] != input_shape.dim_size(dim))) {
- return false;
- }
- }
- return true;
- }
-
- // Is 'input' tensor being sliced over a single dimension out of 4?
- //
- // This check is applicable in the context of Slice of a 4-D tensor in
- // NHWC or NCHW format over channel dimension.
- //
- // If indices for slice are 0 in all dims except one dimension and if sizes of
- // all dimensions of slice are same as sizes of all dimensions of inputs
- // except that dimension, then we are slicing over a single dimension.
- //
- // Returns True if Slicing over a single dimension, and sets slice_dim
- // to the number of the dimension that satisfies criteria.
- bool DoesSliceShapeDifferInOnly1D(const TensorShape& input_shape,
- const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size,
- int* slice_dim) {
- for (int dim = 0; dim < 4; dim++) {
- if (DoesSliceShapeDifferInOnly1DHelper(input_shape, begin, size, dim)) {
- *slice_dim = dim;
- return true;
- }
- }
- return false;
- }
-
- template <int NDIM>
- void HandleCase(OpKernelContext* context, const gtl::ArraySlice<int64>& begin,
- const gtl::ArraySlice<int64>& size, Tensor* result) {
- int slice_dim = -1;
- TensorShape in_shape = context->input(0).shape();
- // Special case for handling 4-D tensor slice when shape of the slice
- // differs from the input tensor in only 1 out of 4 dimensions.
- // This case arises in the context of Slice of 4-D tensor in NHWC or NCHW
- // format over channel dimension.
- if (NDIM == 4 &&
- DoesSliceShapeDifferInOnly1D(in_shape, begin, size, &slice_dim)) {
- size_t in_strides[4] = {
- (size_t)in_shape.dim_size(1) * in_shape.dim_size(2) *
- in_shape.dim_size(3),
- (size_t)in_shape.dim_size(2) * in_shape.dim_size(3),
- (size_t)in_shape.dim_size(3), (size_t)1};
-
- size_t out_strides[4] = {(size_t)size[1] * size[2] * size[3],
- (size_t)size[2] * size[3], (size_t)size[3],
- (size_t)1};
-
- T* in_buf = const_cast<T*>(
- const_cast<const T*>(context->input(0).flat<T>().data()));
- T* op_buf = result->flat<T>().data();
-
- if (slice_dim == 1) {
- /* data format = NCHW */
-
-#pragma omp parallel for
- for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
- T* ip = in_buf + (d0 * in_strides[0]);
- T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
-#pragma omp parallel for
- for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
- T* ip1 = ip + (d1 * in_strides[1]);
- T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
- // For NCHW, H and W will be contiguous. So we can copy
- // both with one memcpy.
- memcpy(static_cast<void*>(op1), static_cast<void*>(ip1),
- sizeof(T) * in_strides[1]);
- }
- }
- return;
- } else if (slice_dim == 3) {
- /* data_format = NHWC */
-
-#pragma omp parallel for
- for (ssize_t d0 = begin[0]; d0 < begin[0] + size[0]; d0++) {
- T* ip = in_buf + (d0 * in_strides[0]);
- T* op = op_buf + ((d0 - begin[0]) * out_strides[0]);
-#pragma omp parallel for
- for (ssize_t d1 = begin[1]; d1 < begin[1] + size[1]; d1++) {
- T* ip1 = ip + (d1 * in_strides[1]);
- T* op1 = op + ((d1 - begin[1]) * out_strides[1]);
-#pragma omp parallel for
- for (ssize_t d2 = begin[2]; d2 < begin[2] + size[2]; d2++) {
- T* ip2 = ip1 + (d2 * in_strides[2]);
- T* ip3 = ip2 + begin[3];
- T* op2 = op1 + ((d2 - begin[2]) * out_strides[2]);
- T* op3 = op2;
- memcpy(static_cast<void*>(op3), static_cast<void*>(ip3),
- sizeof(T) * size[3]);
- }
- }
- }
- return;
- }
- // slice_dim is not 1 or 3, then we fallback to Eigen implementation.
- }
-
- Eigen::DSizes<Eigen::DenseIndex, NDIM> indices;
- Eigen::DSizes<Eigen::DenseIndex, NDIM> sizes;
- for (int i = 0; i < NDIM; ++i) {
- indices[i] = begin[i];
- sizes[i] = size[i];
- }
-
- functor::Slice<Device, T, NDIM>()(
- context->eigen_device<Device>(), result->tensor<T, NDIM>(),
- context->input(0).tensor<T, NDIM>(), indices, sizes);
- }
-};
-#endif // INTEL_MKL
-
// Forward declarations of the functor specializations for declared in the
// sharded source files.
namespace functor {
@@ -440,15 +255,6 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N);
#undef DECLARE_CPU_SPEC
} // namespace functor
-#if defined(INTEL_MKL) && defined(ENABLE_MKL)
-#define REGISTER_SLICE(type) \
- REGISTER_KERNEL_BUILDER(Name("Slice") \
- .Device(DEVICE_CPU) \
- .TypeConstraint<type>("T") \
- .HostMemory("begin") \
- .HostMemory("size"), \
- MklSliceOp<CPUDevice, type>)
-#else
#define REGISTER_SLICE(type) \
REGISTER_KERNEL_BUILDER(Name("Slice") \
.Device(DEVICE_CPU) \
@@ -456,7 +262,6 @@ TF_CALL_ALL_TYPES(DECLARE_FOR_N);
.HostMemory("begin") \
.HostMemory("size"), \
SliceOp<CPUDevice, type>)
-#endif // INTEL_MKL && ENABLE_MKL
TF_CALL_POD_STRING_TYPES(REGISTER_SLICE);
TF_CALL_QUANTIZED_TYPES(REGISTER_SLICE);
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index f0575de4d9..3e8a4c5b72 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -149,7 +149,7 @@ class StridedSliceOp : public OpKernel {
// NDIM and T
if (is_simple_slice && std::is_same<Device, CPUDevice>::value &&
input_dims == 2 && processing_shape.dims() == 2 &&
- final_shape.dims() == 2) {
+ final_shape.dims() == 2 && new_axis_mask == 0) {
MemCpyFunctor<T> functor;
if (functor.Copy(input, begin, end, result)) {
return;
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 32ce31cf23..e46cbc863d 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -21532,6 +21532,421 @@ op {
}
}
op {
+ name: "ExperimentalAssertNextDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "transformations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalCSVDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "compression_type"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "buffer_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "header"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "field_delim"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "use_quote_delim"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "na_value"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "select_cols"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "record_defaults"
+ type_list_attr: "output_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalDirectedInterleaveDataset"
+ input_arg {
+ name: "selector_input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "data_input_datasets"
+ type: DT_VARIANT
+ number_attr: "N"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalFunctionBufferingResource"
+ input_arg {
+ name: "string_arg"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "target_device"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "buffer_size"
+ type: "int"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceGetNext"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceReset"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIdentityIndexedDataset"
+ input_arg {
+ name: "size"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalIndexedDatasetGet"
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "index"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIndexedDatasetMaterialize"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIteratorGetDevice"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "device"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalLMDBDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalMaterializedIndexDatasetHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "thread_pool"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "num_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ type: "int"
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "display_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalUniqueDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Expm1"
input_arg {
name: "x"
@@ -76383,6 +76798,39 @@ op {
is_stateful: true
}
op {
+ name: "While"
+ input_arg {
+ name: "input"
+ type_list_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "cond"
+ type: "func"
+ }
+ attr {
+ name: "body"
+ type: "func"
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ }
+ is_stateful: true
+}
+op {
name: "WholeFileReader"
output_arg {
name: "reader_handle"
diff --git a/tensorflow/contrib/data/ops/dataset_ops.cc b/tensorflow/core/ops/experimental_dataset_ops.cc
index d1a771f005..f6bd5dce26 100644
--- a/tensorflow/contrib/data/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/experimental_dataset_ops.cc
@@ -17,24 +17,16 @@ limitations under the License.
namespace tensorflow {
-REGISTER_OP("DirectedInterleaveDataset")
+REGISTER_OP("ExperimentalDirectedInterleaveDataset")
.Input("selector_input_dataset: variant")
.Input("data_input_datasets: N * variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
.Attr("N: int >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-A substitute for `InterleaveDataset` on a fixed list of `N` datasets.
-
-selector_input_dataset: A dataset of scalar `DT_INT64` elements that determines
- which of the `N` data inputs should produce the next output element.
-data_input_datasets: `N` datasets with the same type that will be interleaved
- according to the values of `selector_input_dataset`.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("CSVDataset")
+REGISTER_OP("ExperimentalCSVDataset")
.Input("filenames: string")
.Input("compression_type: string")
.Input("buffer_size: int64")
@@ -76,35 +68,26 @@ REGISTER_OP("CSVDataset")
return shape_inference::ScalarShape(c);
});
-REGISTER_OP("IgnoreErrorsDataset")
+REGISTER_OP("ExperimentalIgnoreErrorsDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that contains the elements of `input_dataset` ignoring errors.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("UniqueDataset")
+REGISTER_OP("ExperimentalUniqueDataset")
.Input("input_dataset: variant")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that contains the unique elements of `input_dataset`.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("IteratorGetDevice")
+REGISTER_OP("ExperimentalIteratorGetDevice")
.Input("resource: resource")
.Output("device: string")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Returns the name of the device on which `resource` has been placed.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("FunctionBufferingResource")
+REGISTER_OP("ExperimentalFunctionBufferingResource")
.Input("string_arg: string")
.Input("target_device: string")
.Output("resource: resource")
@@ -113,77 +96,36 @@ REGISTER_OP("FunctionBufferingResource")
.Attr("f: func")
.Attr("buffer_size: int")
.Attr("output_types: list(type)")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Creates a resource that fills up a buffer by making function calls.
-
-string_arg: String argument to the function call.
-target_device: Target device to execute the function on.
-resource: Handle to the resource created.
-f: Function to be executed.
-buffer_size: Size of the buffer.
-container: If non-empty, this resource is placed in the given container.
- Otherwise, a default container is used.
-shared_name: If non-empty, this resource will be shared under the given name
- across multiple sessions.
-output_types: The type list for the return values.
-)doc");
-
-REGISTER_OP("FunctionBufferingResourceGetNext")
+ .SetShapeFn(shape_inference::UnknownShape);
+
+REGISTER_OP("ExperimentalFunctionBufferingResourceGetNext")
.Input("function_buffer_resource: resource")
.Attr("output_types: list(type)")
.Output("output: output_types")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Gets the next element from a FunctionBufferingResource.
+ .SetShapeFn(shape_inference::UnknownShape);
-function_buffer_resource: The FunctionBufferingResource handle.
-output: A list of return values.
-output_types: The type list for the return values.
-)doc");
-
-REGISTER_OP("FunctionBufferingResourceReset")
+REGISTER_OP("ExperimentalFunctionBufferingResourceReset")
.Input("function_buffer_resource: resource")
- .SetShapeFn(shape_inference::UnknownShape)
- .Doc(R"doc(
-Resets the FunctionBufferingResource.
-
-function_buffer_resource: The FunctionBufferingResource handle.
-)doc");
+ .SetShapeFn(shape_inference::UnknownShape);
-REGISTER_OP("ThreadPoolDataset")
+REGISTER_OP("ExperimentalThreadPoolDataset")
.Input("input_dataset: variant")
.Input("thread_pool: resource")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
.Attr("output_shapes: list(shape) >= 1")
- .SetShapeFn(shape_inference::ScalarShape)
- .Doc(R"doc(
-Creates a dataset that uses a custom thread pool to compute `input_dataset`.
-
-handle: A resource produced by the ThreadPoolHandle op.
-)doc");
+ .SetShapeFn(shape_inference::ScalarShape);
-REGISTER_OP("ThreadPoolHandle")
+REGISTER_OP("ExperimentalThreadPoolHandle")
.Output("handle: resource")
.SetShapeFn(shape_inference::ScalarShape)
.Attr("num_threads: int")
.Attr("max_intra_op_parallelism: int = 1")
.Attr("display_name: string")
.Attr("container: string = ''")
- .Attr("shared_name: string = ''")
- .Doc(R"doc(
-Creates a custom thread pool with the given number of threads.
-
-handle: A resource that can be consumed by one or more ThreadPoolDataset ops.
-num_threads: The number of threads in the thread pool.
-max_intra_op_parallelism: The maximum degree of parallelism to use within
- operations that execute on this threadpool.
-display_name: A human-readable name for the threads that may be visible in
- some visualizations.
-)doc");
-
-REGISTER_OP("AssertNextDataset")
+ .Attr("shared_name: string = ''");
+
+REGISTER_OP("ExperimentalAssertNextDataset")
.Input("input_dataset: variant")
.Input("transformations: string")
.Output("handle: variant")
@@ -196,7 +138,7 @@ REGISTER_OP("AssertNextDataset")
return shape_inference::ScalarShape(c);
});
-REGISTER_OP("LMDBDataset")
+REGISTER_OP("ExperimentalLMDBDataset")
.Input("filenames: string")
.Output("handle: variant")
.Attr("output_types: list(type) >= 1")
@@ -205,4 +147,61 @@ REGISTER_OP("LMDBDataset")
// stateful to inhibit constant folding.
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("ExperimentalIdentityIndexedDataset")
+ .Input("size: uint64")
+ .Output("handle: variant")
+ .SetIsStateful()
+ .SetShapeFn(
+ shape_inference::ScalarShape); // TODO(saeta): check input shapes.
+
+///////////////////////////////////////////////////////////////////////////////
+// IndexedDataset Internals
+///////////////////////////////////////////////////////////////////////////////
+
+// Creates the handle.
+REGISTER_OP("ExperimentalMaterializedIndexDatasetHandle")
+ .Output("handle: resource")
+ .Attr("container: string")
+ .Attr("shared_name: string")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(shape_inference::ScalarShape);
+
+// Actually materialize the materialize handle.
+REGISTER_OP("ExperimentalIndexedDatasetMaterialize")
+ .Input("dataset: variant")
+ .Input("materialized: resource")
+ .SetShapeFn(shape_inference::NoOutputs);
+
+namespace {
+
+Status GetShapeFn(shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle unused;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as `output_types` (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ return Status::OK();
+}
+
+} // namespace
+
+REGISTER_OP("ExperimentalIndexedDatasetGet")
+ .Input("materialized: resource")
+ .Input("index: uint64")
+ .Output("components: output_types")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetShapeFn(GetShapeFn);
+
} // namespace tensorflow
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index bda4a75c5d..fed3fa22ed 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -150,10 +150,29 @@ REGISTER_OP("While")
.Attr("T: list(type) >= 0")
.Attr("cond: func")
.Attr("body: func")
+ .Attr("output_shapes: list(shape) = []")
.SetIsStateful()
.SetShapeFn([](shape_inference::InferenceContext* c) {
- for (int i = 0; i < c->num_outputs(); ++i) {
- c->set_output(i, c->input(i));
+ std::vector<PartialTensorShape> output_shapes;
+ TF_RETURN_IF_ERROR(c->GetAttr("output_shapes", &output_shapes));
+ // If `output_shapes` attr is set use that as the shapes of the outputs
+ // else use the input shapes.
+ if (!output_shapes.empty()) {
+ if (output_shapes.size() != c->num_outputs()) {
+ return errors::InvalidArgument(
+ "`output_shapes` must be the same length as num outputs (",
+ output_shapes.size(), " vs. ", c->num_outputs());
+ }
+ for (size_t i = 0; i < output_shapes.size(); ++i) {
+ shape_inference::ShapeHandle output_shape_handle;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(
+ output_shapes[i], &output_shape_handle));
+ c->set_output(static_cast<int>(i), output_shape_handle);
+ }
+ } else {
+ for (int i = 0; i < c->num_outputs(); ++i) {
+ c->set_output(i, c->input(i));
+ }
}
return Status::OK();
});
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index 02a7f8d717..0e9f939ab4 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -10039,6 +10039,421 @@ op {
}
}
op {
+ name: "ExperimentalAssertNextDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "transformations"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalCSVDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "compression_type"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "buffer_size"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "header"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "field_delim"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "use_quote_delim"
+ type: DT_BOOL
+ }
+ input_arg {
+ name: "na_value"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "select_cols"
+ type: DT_INT64
+ }
+ input_arg {
+ name: "record_defaults"
+ type_list_attr: "output_types"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_STRING
+ }
+ }
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalDirectedInterleaveDataset"
+ input_arg {
+ name: "selector_input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "data_input_datasets"
+ type: DT_VARIANT
+ number_attr: "N"
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "N"
+ type: "int"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalFunctionBufferingResource"
+ input_arg {
+ name: "string_arg"
+ type: DT_STRING
+ }
+ input_arg {
+ name: "target_device"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+ attr {
+ name: "buffer_size"
+ type: "int"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceGetNext"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalFunctionBufferingResourceReset"
+ input_arg {
+ name: "function_buffer_resource"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIdentityIndexedDataset"
+ input_arg {
+ name: "size"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIgnoreErrorsDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
+ name: "ExperimentalIndexedDatasetGet"
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "index"
+ type: DT_UINT64
+ }
+ output_arg {
+ name: "components"
+ type_list_attr: "output_types"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIndexedDatasetMaterialize"
+ input_arg {
+ name: "dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "materialized"
+ type: DT_RESOURCE
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalIteratorGetDevice"
+ input_arg {
+ name: "resource"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "device"
+ type: DT_STRING
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalLMDBDataset"
+ input_arg {
+ name: "filenames"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalMaterializedIndexDatasetHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "container"
+ type: "string"
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ input_arg {
+ name: "thread_pool"
+ type: DT_RESOURCE
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalThreadPoolHandle"
+ output_arg {
+ name: "handle"
+ type: DT_RESOURCE
+ }
+ attr {
+ name: "num_threads"
+ type: "int"
+ }
+ attr {
+ name: "max_intra_op_parallelism"
+ type: "int"
+ default_value {
+ i: 1
+ }
+ }
+ attr {
+ name: "display_name"
+ type: "string"
+ }
+ attr {
+ name: "container"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ attr {
+ name: "shared_name"
+ type: "string"
+ default_value {
+ s: ""
+ }
+ }
+ is_stateful: true
+}
+op {
+ name: "ExperimentalUniqueDataset"
+ input_arg {
+ name: "input_dataset"
+ type: DT_VARIANT
+ }
+ output_arg {
+ name: "handle"
+ type: DT_VARIANT
+ }
+ attr {
+ name: "output_types"
+ type: "list(type)"
+ has_minimum: true
+ minimum: 1
+ }
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ has_minimum: true
+ minimum: 1
+ }
+}
+op {
name: "Expm1"
input_arg {
name: "x"
@@ -36520,6 +36935,14 @@ op {
name: "body"
type: "func"
}
+ attr {
+ name: "output_shapes"
+ type: "list(shape)"
+ default_value {
+ list {
+ }
+ }
+ }
is_stateful: true
}
op {
diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc b/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc
index f41b83ac34..affb68ebbb 100644
--- a/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc
+++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.cc
@@ -17,7 +17,6 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/platform/cloud/curl_http_request.h"
-#include "tensorflow/core/platform/cloud/retrying_utils.h"
namespace tensorflow {
@@ -25,21 +24,14 @@ namespace {
// The URL to retrieve metadata when running in Google Compute Engine.
constexpr char kGceMetadataBaseUrl[] = "http://metadata/computeMetadata/v1/";
-// The default initial delay between retries with exponential backoff.
-constexpr int kInitialRetryDelayUsec = 500000; // 0.5 sec
} // namespace
ComputeEngineMetadataClient::ComputeEngineMetadataClient(
- std::shared_ptr<HttpRequest::Factory> http_request_factory)
- : ComputeEngineMetadataClient(std::move(http_request_factory),
- kInitialRetryDelayUsec) {}
-
-ComputeEngineMetadataClient::ComputeEngineMetadataClient(
std::shared_ptr<HttpRequest::Factory> http_request_factory,
- int64 initial_retry_delay_usec)
+ const RetryConfig& config)
: http_request_factory_(std::move(http_request_factory)),
- initial_retry_delay_usec_(initial_retry_delay_usec) {}
+ retry_config_(config) {}
Status ComputeEngineMetadataClient::GetMetadata(
const string& path, std::vector<char>* response_buffer) {
@@ -52,8 +44,7 @@ Status ComputeEngineMetadataClient::GetMetadata(
return Status::OK();
};
- return RetryingUtils::CallWithRetries(get_metadata_from_gce,
- initial_retry_delay_usec_);
+ return RetryingUtils::CallWithRetries(get_metadata_from_gce, retry_config_);
}
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
index 534ccf30b2..7f060327da 100644
--- a/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
+++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/cloud/http_request.h"
+#include "tensorflow/core/platform/cloud/retrying_utils.h"
namespace tensorflow {
@@ -31,10 +32,11 @@ namespace tensorflow {
class ComputeEngineMetadataClient {
public:
explicit ComputeEngineMetadataClient(
- std::shared_ptr<HttpRequest::Factory> http_request_factory);
- ComputeEngineMetadataClient(
std::shared_ptr<HttpRequest::Factory> http_request_factory,
- int64 initial_retry_delay_usec);
+ const RetryConfig& config = RetryConfig(
+ 10000, /* init_delay_time_us = 1 ms */
+ 1000000 /* max_delay_time_us = 1 s */
+ ));
virtual ~ComputeEngineMetadataClient() {}
/// \brief Get the metadata value for a given attribute of the metadata
@@ -54,7 +56,7 @@ class ComputeEngineMetadataClient {
private:
std::shared_ptr<HttpRequest::Factory> http_request_factory_;
- const int64 initial_retry_delay_usec_;
+ const RetryConfig retry_config_;
TF_DISALLOW_COPY_AND_ASSIGN(ComputeEngineMetadataClient);
};
diff --git a/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc b/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc
index 4c41ccaa0e..e891b4a5e9 100644
--- a/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc
+++ b/tensorflow/core/platform/cloud/compute_engine_metadata_client_test.cc
@@ -30,7 +30,8 @@ TEST(ComputeEngineMetadataClientTest, GetMetadata) {
std::shared_ptr<HttpRequest::Factory> http_factory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- ComputeEngineMetadataClient client(http_factory, 0);
+ ComputeEngineMetadataClient client(http_factory,
+ RetryConfig(0 /* init_delay_time_us */));
std::vector<char> result;
TF_EXPECT_OK(
@@ -56,7 +57,8 @@ TEST(ComputeEngineMetadataClientTest, RetryOnFailure) {
std::shared_ptr<HttpRequest::Factory> http_factory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- ComputeEngineMetadataClient client(http_factory, 0);
+ ComputeEngineMetadataClient client(http_factory,
+ RetryConfig(0 /* init_delay_time_us */));
std::vector<char> result;
TF_EXPECT_OK(
diff --git a/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc b/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc
index f7477eca23..476e4f9c1f 100644
--- a/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc
+++ b/tensorflow/core/platform/cloud/compute_engine_zone_provider_test.cc
@@ -34,8 +34,8 @@ TEST_F(ComputeEngineZoneProviderTest, GetZone) {
auto httpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadata_client =
- std::make_shared<ComputeEngineMetadataClient>(httpRequestFactory, 0);
+ auto metadata_client = std::make_shared<ComputeEngineMetadataClient>(
+ httpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
ComputeEngineZoneProvider provider(metadata_client);
@@ -55,8 +55,8 @@ TEST_F(ComputeEngineZoneProviderTest, InvalidZoneString) {
auto httpRequestFactory = std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadata_client =
- std::make_shared<ComputeEngineMetadataClient>(httpRequestFactory, 0);
+ auto metadata_client = std::make_shared<ComputeEngineMetadataClient>(
+ httpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
ComputeEngineZoneProvider provider(metadata_client);
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 83ea8539ed..c61b68aeeb 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -333,14 +333,14 @@ class GcsWritableFile : public WritableFile {
GcsFileSystem* filesystem,
GcsFileSystem::TimeoutConfig* timeouts,
std::function<void()> file_cache_erase,
- int64 initial_retry_delay_usec)
+ RetryConfig retry_config)
: bucket_(bucket),
object_(object),
filesystem_(filesystem),
timeouts_(timeouts),
file_cache_erase_(std::move(file_cache_erase)),
sync_needed_(true),
- initial_retry_delay_usec_(initial_retry_delay_usec) {
+ retry_config_(retry_config) {
// TODO: to make it safer, outfile_ should be constructed from an FD
if (GetTmpFilename(&tmp_content_filename_).ok()) {
outfile_.open(tmp_content_filename_,
@@ -357,14 +357,14 @@ class GcsWritableFile : public WritableFile {
GcsFileSystem* filesystem, const string& tmp_content_filename,
GcsFileSystem::TimeoutConfig* timeouts,
std::function<void()> file_cache_erase,
- int64 initial_retry_delay_usec)
+ RetryConfig retry_config)
: bucket_(bucket),
object_(object),
filesystem_(filesystem),
timeouts_(timeouts),
file_cache_erase_(std::move(file_cache_erase)),
sync_needed_(true),
- initial_retry_delay_usec_(initial_retry_delay_usec) {
+ retry_config_(retry_config) {
tmp_content_filename_ = tmp_content_filename;
outfile_.open(tmp_content_filename_,
std::ofstream::binary | std::ofstream::app);
@@ -441,7 +441,7 @@ class GcsWritableFile : public WritableFile {
first_attempt = false;
return UploadToSession(session_uri, already_uploaded);
},
- initial_retry_delay_usec_);
+ retry_config_);
if (upload_status.code() == errors::Code::NOT_FOUND) {
// GCS docs recommend retrying the whole upload. We're relying on the
// RetryingFileSystem to retry the Sync() call.
@@ -586,7 +586,7 @@ class GcsWritableFile : public WritableFile {
GcsFileSystem::TimeoutConfig* timeouts_;
std::function<void()> file_cache_erase_;
bool sync_needed_; // whether there is buffered data that needs to be synced
- int64 initial_retry_delay_usec_;
+ RetryConfig retry_config_;
};
class GcsReadOnlyMemoryRegion : public ReadOnlyMemoryRegion {
@@ -791,7 +791,7 @@ GcsFileSystem::GcsFileSystem(
std::unique_ptr<ZoneProvider> zone_provider, size_t block_size,
size_t max_bytes, uint64 max_staleness, uint64 stat_cache_max_age,
size_t stat_cache_max_entries, uint64 matching_paths_cache_max_age,
- size_t matching_paths_cache_max_entries, int64 initial_retry_delay_usec,
+ size_t matching_paths_cache_max_entries, RetryConfig retry_config,
TimeoutConfig timeouts, const std::unordered_set<string>& allowed_locations,
std::pair<const string, const string>* additional_header)
: auth_provider_(std::move(auth_provider)),
@@ -806,7 +806,7 @@ GcsFileSystem::GcsFileSystem(
kCacheNeverExpire, kBucketLocationCacheMaxEntries)),
allowed_locations_(allowed_locations),
timeouts_(timeouts),
- initial_retry_delay_usec_(initial_retry_delay_usec),
+ retry_config_(retry_config),
additional_header_(additional_header) {}
Status GcsFileSystem::NewRandomAccessFile(
@@ -941,7 +941,7 @@ Status GcsFileSystem::NewWritableFile(const string& fname,
TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
result->reset(new GcsWritableFile(bucket, object, this, &timeouts_,
[this, fname]() { ClearFileCaches(fname); },
- initial_retry_delay_usec_));
+ retry_config_));
return Status::OK();
}
@@ -981,7 +981,7 @@ Status GcsFileSystem::NewAppendableFile(const string& fname,
TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
result->reset(new GcsWritableFile(
bucket, object, this, old_content_filename, &timeouts_,
- [this, fname]() { ClearFileCaches(fname); }, initial_retry_delay_usec_));
+ [this, fname]() { ClearFileCaches(fname); }, retry_config_));
return Status::OK();
}
@@ -1534,7 +1534,7 @@ Status GcsFileSystem::RenameObject(const string& src, const string& target) {
// on the server side, we can't just retry the whole RenameFile operation
// because the source object is already gone.
return RetryingUtils::DeleteWithRetries(
- [this, &src]() { return DeleteFile(src); }, initial_retry_delay_usec_);
+ [this, &src]() { return DeleteFile(src); }, retry_config_);
}
Status GcsFileSystem::IsDirectory(const string& fname) {
@@ -1590,8 +1590,7 @@ Status GcsFileSystem::DeleteRecursively(const string& dirname,
// and therefore RetryingFileSystem won't pay attention to the failures,
// we need to make sure these failures are properly retried.
const auto& delete_file_status = RetryingUtils::DeleteWithRetries(
- [this, &full_path]() { return DeleteFile(full_path); },
- initial_retry_delay_usec_);
+ [this, &full_path]() { return DeleteFile(full_path); }, retry_config_);
if (!delete_file_status.ok()) {
if (IsDirectory(full_path).ok()) {
// The object is a directory marker.
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index 71db707687..d0840a3046 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -93,7 +93,7 @@ class GcsFileSystem : public FileSystem {
uint64 stat_cache_max_age, size_t stat_cache_max_entries,
uint64 matching_paths_cache_max_age,
size_t matching_paths_cache_max_entries,
- int64 initial_retry_delay_usec, TimeoutConfig timeouts,
+ RetryConfig retry_config, TimeoutConfig timeouts,
const std::unordered_set<string>& allowed_locations,
std::pair<const string, const string>* additional_header);
@@ -332,7 +332,7 @@ class GcsFileSystem : public FileSystem {
GcsStatsInterface* stats_ = nullptr; // Not owned.
/// The initial delay for exponential backoffs when retrying failed calls.
- const int64 initial_retry_delay_usec_ = 1000000L;
+ RetryConfig retry_config_;
// Additional header material to be transmitted with all GCS requests
std::unique_ptr<std::pair<const string, const string>> additional_header_;
@@ -344,7 +344,8 @@ class GcsFileSystem : public FileSystem {
class RetryingGcsFileSystem : public RetryingFileSystem<GcsFileSystem> {
public:
RetryingGcsFileSystem()
- : RetryingFileSystem(std::unique_ptr<GcsFileSystem>(new GcsFileSystem)) {}
+ : RetryingFileSystem(std::unique_ptr<GcsFileSystem>(new GcsFileSystem),
+ RetryConfig(100000 /* init_delay_time_us */)) {}
};
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index 14376ad339..702802b185 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -24,6 +24,8 @@ namespace tensorflow {
namespace {
static GcsFileSystem::TimeoutConfig kTestTimeoutConfig(5, 1, 10, 20, 30);
+static RetryConfig kTestRetryConfig(0 /* init_delay_time_us */);
+
// Default (empty) constraint config
static std::unordered_set<string>* kAllowedLocationsDefault =
new std::unordered_set<string>();
@@ -62,16 +64,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache) {
"Range: 6-11\n"
"Timeouts: 5 1 20\n",
"6789")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -108,9 +110,9 @@ TEST(GcsFileSystemTest,
0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- *kAllowedLocationsAuto, nullptr /* gcs additional header */);
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsAuto,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -150,9 +152,9 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithLocationConstraintCaching) {
0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- *kAllowedLocationsAuto, nullptr /* gcs additional header */);
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsAuto,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
@@ -191,9 +193,9 @@ TEST(GcsFileSystemTest,
0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
0 /* stat cache max age */, 0 /* stat cache max entries */,
0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */,
- 0 /* initial retry delay */, kTestTimeoutConfig,
- *kAllowedLocationsAuto, nullptr /* gcs additional header */);
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsAuto,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
EXPECT_EQ(tensorflow::errors::FailedPrecondition(
@@ -216,16 +218,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_NoBlockCache_DifferentN) {
"Range: 3-12\n"
"Timeouts: 5 1 20\n",
"3456789")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
TF_EXPECT_OK(fs.NewRandomAccessFile("gs://bucket/random_access.txt", &file));
@@ -283,7 +285,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */,
18 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -372,7 +374,7 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_Flush) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */,
18 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -414,17 +416,17 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_WithBlockCache_MaxStaleness) {
"Range: 8-15\n"
"Timeouts: 5 1 20\n",
"89abcdef")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */,
- 16 /* max bytes */, 3600 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 8 /* block size */, 16 /* max bytes */,
+ 3600 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
char scratch[100];
StringPiece result;
// There should only be two HTTP requests issued to GCS even though we iterate
@@ -492,7 +494,7 @@ TEST(GcsFileSystemTest,
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 9 /* block size */,
18 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -513,17 +515,17 @@ TEST(GcsFileSystemTest,
TEST(GcsFileSystemTest, NewRandomAccessFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
- 0 /* read ahead bytes */, 0 /* max bytes */, 0 /* max staleness */,
- 0 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* read ahead bytes */, 0 /* max bytes */,
+ 0 /* max staleness */, 0 /* stat cache max age */,
+ 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<RandomAccessFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -547,16 +549,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_InconsistentRead) {
"012")});
// Set stat_cache_max_age to 1000s so that StatCache could work.
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 1e3 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 1e3 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Stat the file first so that the file stats are cached.
FileStatistics stat;
@@ -621,7 +623,7 @@ TEST(GcsFileSystemTest, NewWritableFile) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */,
8 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -703,16 +705,16 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceeds) {
"Timeouts: 5 1 30\n"
"Put body: t2\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -773,17 +775,17 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadSucceedsOnGetStatus) {
"Range: 0-7\n"
"Timeouts: 5 1 20\n",
"01234567")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 8 /* block size */,
- 8 /* max bytes */, 3600 /* max staleness */,
- 3600 /* stat cache max age */, 0 /* stat cache max entries */,
- 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 8 /* block size */, 8 /* max bytes */,
+ 3600 /* max staleness */, 3600 /* stat cache max age */,
+ 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Pull the file's first block into the cache. This will trigger the first
// HTTP request to GCS.
std::unique_ptr<RandomAccessFile> rfile;
@@ -867,9 +869,9 @@ TEST(GcsFileSystemTest, NewWritableFile_ResumeUploadAllAttemptsFail) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 2 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ 0 /* matching paths cache max entries */,
+ RetryConfig(2 /* .init_delay_time_us */), kTestTimeoutConfig,
+ *kAllowedLocationsDefault, nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -918,16 +920,16 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) {
"Timeouts: 5 1 30\n"
"Put body: content1,content2\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
TF_EXPECT_OK(fs.NewWritableFile("gs://bucket/path/writeable.txt", &file));
@@ -948,16 +950,16 @@ TEST(GcsFileSystemTest, NewWritableFile_UploadReturns410) {
TEST(GcsFileSystemTest, NewWritableFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1013,7 +1015,7 @@ TEST(GcsFileSystemTest, NewAppendableFile) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 32 /* block size */,
32 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1041,16 +1043,16 @@ TEST(GcsFileSystemTest, NewAppendableFile) {
TEST(GcsFileSystemTest, NewAppendableFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<WritableFile> file;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1075,16 +1077,16 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) {
"Range: 0-",
content.size() - 1, "\n", "Timeouts: 5 1 20\n"),
content)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<ReadOnlyMemoryRegion> region;
TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile(
@@ -1096,16 +1098,16 @@ TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile) {
TEST(GcsFileSystemTest, NewReadOnlyMemoryRegionFromFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<ReadOnlyMemoryRegion> region;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1120,16 +1122,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsObject) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket/path/file1.txt"));
}
@@ -1150,16 +1152,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsFolder) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/subfolder/\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket/path/subfolder"));
}
@@ -1176,16 +1178,16 @@ TEST(GcsFileSystemTest, FileExists_YesAsBucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"size\": \"100\"}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.FileExists("gs://bucket1"));
TF_EXPECT_OK(fs.FileExists("gs://bucket1/"));
@@ -1206,16 +1208,16 @@ TEST(GcsFileSystemTest, FileExists_NotAsObjectOrFolder) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"items\": []}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::NOT_FOUND,
fs.FileExists("gs://bucket/path/file1.txt").code());
@@ -1233,16 +1235,16 @@ TEST(GcsFileSystemTest, FileExists_NotAsBucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
fs.FileExists("gs://bucket2/").code());
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1279,7 +1281,7 @@ TEST(GcsFileSystemTest, FileExists_StatCache) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1306,7 +1308,7 @@ TEST(GcsFileSystemTest, FileExists_DirectoryMark) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1322,16 +1324,16 @@ TEST(GcsFileSystemTest, GetChildren_NoItems) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1350,16 +1352,16 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/file3.txt\" }],"
"\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1379,16 +1381,16 @@ TEST(GcsFileSystemTest, GetChildren_SelfDirectoryMarker) {
" { \"name\": \"path/\" },"
" { \"name\": \"path/file3.txt\" }],"
"\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1407,16 +1409,16 @@ TEST(GcsFileSystemTest, GetChildren_ThreeFiles_NoSlash) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/file3.txt\" }],"
"\"prefixes\": [\"path/subpath/\"]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children));
@@ -1432,16 +1434,16 @@ TEST(GcsFileSystemTest, GetChildren_Root) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket-a-b-c", &children));
@@ -1457,16 +1459,16 @@ TEST(GcsFileSystemTest, GetChildren_Empty) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path/", &children));
@@ -1498,16 +1500,16 @@ TEST(GcsFileSystemTest, GetChildren_Pagination) {
" { \"name\": \"path/file4.txt\" },"
" { \"name\": \"path/file5.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> children;
TF_EXPECT_OK(fs.GetChildren("gs://bucket/path", &children));
@@ -1525,16 +1527,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_NoWildcard) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/subpath/file2.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(
@@ -1553,16 +1555,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_BucketAndWildcard) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/*/*", &result));
@@ -1582,16 +1584,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_Matches) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file2.txt", &result));
@@ -1608,16 +1610,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_SelfDirectoryMarker) {
"{\"items\": [ "
" { \"name\": \"path/\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*", &result));
@@ -1634,16 +1636,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://bucket/path/*/file3.txt", &result));
@@ -1652,16 +1654,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_FolderAndWildcard_NoMatches) {
TEST(GcsFileSystemTest, GetMatchingPaths_OnlyWildcard) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::vector<string> result;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -1686,16 +1688,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache) {
" { \"name\": \"path/file1.txt\" },"
" { \"name\": \"path/subpath/file2.txt\" },"
" { \"name\": \"path/file3.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 3600 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 3600 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// Repeated calls to fs.GetMatchingPaths on these patterns should not lead to
// any additional HTTP requests to GCS.
@@ -1729,16 +1731,16 @@ TEST(GcsFileSystemTest, GetMatchingPaths_Cache_Flush) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/subpath/file2.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 3600 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 3600 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
// This loop should trigger the first HTTP request to GCS.
for (int i = 0; i < 10; i++) {
@@ -1800,7 +1802,7 @@ TEST(GcsFileSystemTest, DeleteFile) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */,
16 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1821,16 +1823,16 @@ TEST(GcsFileSystemTest, DeleteFile) {
TEST(GcsFileSystemTest, DeleteFile_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
fs.DeleteFile("gs://bucket/").code());
@@ -1871,7 +1873,7 @@ TEST(GcsFileSystemTest, DeleteFile_StatCacheRemoved) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */,
16 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -1894,16 +1896,16 @@ TEST(GcsFileSystemTest, DeleteDir_Empty) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/"));
}
@@ -1923,16 +1925,16 @@ TEST(GcsFileSystemTest, DeleteDir_OnlyDirMarkerLeft) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket/path/"));
}
@@ -1943,16 +1945,16 @@ TEST(GcsFileSystemTest, DeleteDir_BucketOnly) {
"name%2CnextPageToken&maxResults=2\nAuth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.DeleteDir("gs://bucket"));
}
@@ -1965,16 +1967,16 @@ TEST(GcsFileSystemTest, DeleteDir_NonEmpty) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"path/file1.txt\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::FAILED_PRECONDITION,
fs.DeleteDir("gs://bucket/path/").code());
@@ -1988,16 +1990,16 @@ TEST(GcsFileSystemTest, GetFileSize) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
uint64 size;
TF_EXPECT_OK(fs.GetFileSize("gs://bucket/file.txt", &size));
@@ -2006,16 +2008,16 @@ TEST(GcsFileSystemTest, GetFileSize) {
TEST(GcsFileSystemTest, GetFileSize_NoObjectName) {
std::vector<HttpRequest*> requests;
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
uint64 size;
EXPECT_EQ(errors::Code::INVALID_ARGUMENT,
@@ -2092,16 +2094,16 @@ TEST(GcsFileSystemTest, RenameFile_Folder) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.RenameFile("gs://bucket/path1", "gs://bucket/path2/"));
}
@@ -2191,7 +2193,7 @@ TEST(GcsFileSystemTest, RenameFile_Object) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 16 /* block size */,
64 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
// Do an initial read of the source and destination files to load their
@@ -2272,7 +2274,7 @@ TEST(GcsFileSystemTest, RenameFile_Object_FlushTargetStatCache) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
// Do an initial stat of the destination file to load their contents into the
@@ -2332,16 +2334,16 @@ TEST(GcsFileSystemTest, RenameFile_Object_DeletionRetried) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(
fs.RenameFile("gs://bucket/path/src.txt", "gs://bucket/path/dst.txt"));
@@ -2374,16 +2376,16 @@ TEST(GcsFileSystemTest, RenameFile_Object_Incomplete) {
"Post: yes\n"
"Timeouts: 5 1 10\n",
"{\"done\": false}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(
errors::Code::UNIMPLEMENTED,
@@ -2399,16 +2401,16 @@ TEST(GcsFileSystemTest, Stat_Object) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/file.txt", &stat));
@@ -2433,16 +2435,16 @@ TEST(GcsFileSystemTest, Stat_Folder) {
"Timeouts: 5 1 10\n",
"{\"items\": [ "
" { \"name\": \"subfolder/\" }]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/subfolder", &stat));
@@ -2466,16 +2468,16 @@ TEST(GcsFileSystemTest, Stat_ObjectOrFolderNotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/path", &stat).code());
@@ -2487,16 +2489,16 @@ TEST(GcsFileSystemTest, Stat_Bucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/", &stat));
@@ -2511,16 +2513,16 @@ TEST(GcsFileSystemTest, Stat_BucketNotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
EXPECT_EQ(error::Code::NOT_FOUND, fs.Stat("gs://bucket/", &stat).code());
@@ -2556,7 +2558,7 @@ TEST(GcsFileSystemTest, Stat_Cache) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
@@ -2598,7 +2600,7 @@ TEST(GcsFileSystemTest, Stat_Cache_Flush) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 3600 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
nullptr /* gcs additional header */);
// There should be a single HTTP request to GCS for fs.Stat in this loop.
@@ -2628,16 +2630,16 @@ TEST(GcsFileSystemTest, Stat_FilenameEndingWithSlash) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"5\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("gs://bucket/dir/", &stat));
@@ -2660,16 +2662,16 @@ TEST(GcsFileSystemTest, IsDirectory_NotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::NOT_FOUND,
fs.IsDirectory("gs://bucket/file.txt").code());
@@ -2691,16 +2693,16 @@ TEST(GcsFileSystemTest, IsDirectory_NotDirectoryButObject) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::FAILED_PRECONDITION,
fs.IsDirectory("gs://bucket/file.txt").code());
@@ -2722,16 +2724,16 @@ TEST(GcsFileSystemTest, IsDirectory_Yes) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{\"items\": [{\"name\": \"subfolder/\"}]}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder"));
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/subfolder/"));
@@ -2749,16 +2751,16 @@ TEST(GcsFileSystemTest, IsDirectory_Bucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.IsDirectory("gs://bucket"));
TF_EXPECT_OK(fs.IsDirectory("gs://bucket/"));
@@ -2770,16 +2772,16 @@ TEST(GcsFileSystemTest, IsDirectory_BucketNotFound) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
EXPECT_EQ(error::Code::NOT_FOUND, fs.IsDirectory("gs://bucket/").code());
}
@@ -2812,16 +2814,16 @@ TEST(GcsFileSystemTest, CreateDir_Folder) {
"Timeouts: 5 1 30\n"
"Put body: \n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath"));
TF_EXPECT_OK(fs.CreateDir("gs://bucket/subpath/"));
@@ -2839,16 +2841,16 @@ TEST(GcsFileSystemTest, CreateDir_Bucket) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TF_EXPECT_OK(fs.CreateDir("gs://bucket/"));
TF_EXPECT_OK(fs.CreateDir("gs://bucket"));
@@ -2911,16 +2913,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_Ok) {
"Timeouts: 5 1 10\n"
"Delete: yes\n",
"")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files,
@@ -3004,16 +3006,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_DeletionErrors) {
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
TF_EXPECT_OK(fs.DeleteRecursively("gs://bucket/path", &undeleted_files,
@@ -3039,16 +3041,16 @@ TEST(GcsFileSystemTest, DeleteRecursively_NotAFolder) {
"Auth Token: fake_token\n"
"Timeouts: 5 1 10\n",
"", errors::NotFound("404"), 404)});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
int64 undeleted_files, undeleted_dirs;
EXPECT_EQ(error::Code::NOT_FOUND,
@@ -3130,7 +3132,7 @@ TEST(GcsFileSystemTest, AdditionalRequestHeaderTest) {
std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
kTestTimeoutConfig, *kAllowedLocationsDefault,
add_header /* gcs additional header */);
@@ -3199,16 +3201,16 @@ TEST(GcsFileSystemTest, CreateHttpRequest) {
"Auth Token: fake_token\n"
"Header Hello: world\n",
"{}")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
std::unique_ptr<HttpRequest> request;
TF_EXPECT_OK(fs.CreateHttpRequest(&request));
@@ -3262,16 +3264,16 @@ TEST(GcsFileSystemTest, Stat_StatsRecording) {
"Timeouts: 5 1 10\n",
strings::StrCat("{\"size\": \"1010\",\"generation\": \"1\","
"\"updated\": \"2016-04-29T23:15:24.896Z\"}"))});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TestGcsStats stats;
fs.SetStats(&stats);
@@ -3289,16 +3291,16 @@ TEST(GcsFileSystemTest, NewRandomAccessFile_StatsRecording) {
"Range: 0-5\n"
"Timeouts: 5 1 20\n",
"012345")});
- GcsFileSystem fs(
- std::unique_ptr<AuthProvider>(new FakeAuthProvider),
- std::unique_ptr<HttpRequest::Factory>(
- new FakeHttpRequestFactory(&requests)),
- std::unique_ptr<ZoneProvider>(new FakeZoneProvider), 0 /* block size */,
- 0 /* max bytes */, 0 /* max staleness */, 0 /* stat cache max age */,
- 0 /* stat cache max entries */, 0 /* matching paths cache max age */,
- 0 /* matching paths cache max entries */, 0 /* initial retry delay */,
- kTestTimeoutConfig, *kAllowedLocationsDefault,
- nullptr /* gcs additional header */);
+ GcsFileSystem fs(std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ std::unique_ptr<ZoneProvider>(new FakeZoneProvider),
+ 0 /* block size */, 0 /* max bytes */, 0 /* max staleness */,
+ 0 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, kTestRetryConfig,
+ kTestTimeoutConfig, *kAllowedLocationsDefault,
+ nullptr /* gcs additional header */);
TestGcsStats stats;
fs.SetStats(&stats);
diff --git a/tensorflow/core/platform/cloud/google_auth_provider_test.cc b/tensorflow/core/platform/cloud/google_auth_provider_test.cc
index 07b88a880f..ec31c5ee8c 100644
--- a/tensorflow/core/platform/cloud/google_auth_provider_test.cc
+++ b/tensorflow/core/platform/cloud/google_auth_provider_test.cc
@@ -93,8 +93,8 @@ TEST_F(GoogleAuthProviderTest, EnvironmentVariable_Caching) {
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
oauth_client->return_token = "fake-token";
@@ -129,8 +129,8 @@ TEST_F(GoogleAuthProviderTest, GCloudRefreshToken) {
FakeEnv env;
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
@@ -178,8 +178,8 @@ TEST_F(GoogleAuthProviderTest, RunningOnGCE) {
FakeEnv env;
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
@@ -206,8 +206,8 @@ TEST_F(GoogleAuthProviderTest, OverrideForTesting) {
FakeEnv env;
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&empty_requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
@@ -228,8 +228,8 @@ TEST_F(GoogleAuthProviderTest, NothingAvailable) {
FakeEnv env;
std::shared_ptr<HttpRequest::Factory> fakeHttpRequestFactory =
std::make_shared<FakeHttpRequestFactory>(&requests);
- auto metadataClient =
- std::make_shared<ComputeEngineMetadataClient>(fakeHttpRequestFactory, 0);
+ auto metadataClient = std::make_shared<ComputeEngineMetadataClient>(
+ fakeHttpRequestFactory, RetryConfig(0 /* init_delay_time_us */));
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
metadataClient, &env);
diff --git a/tensorflow/core/platform/cloud/retrying_file_system.h b/tensorflow/core/platform/cloud/retrying_file_system.h
index 941ab7ad65..5ce6670dc7 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system.h
+++ b/tensorflow/core/platform/cloud/retrying_file_system.h
@@ -34,9 +34,9 @@ template <typename Underlying>
class RetryingFileSystem : public FileSystem {
public:
RetryingFileSystem(std::unique_ptr<Underlying> base_file_system,
- int64 delay_microseconds = 1000000)
+ const RetryConfig& retry_config)
: base_file_system_(std::move(base_file_system)),
- initial_delay_microseconds_(delay_microseconds) {}
+ retry_config_(retry_config) {}
Status NewRandomAccessFile(
const string& filename,
@@ -55,7 +55,7 @@ class RetryingFileSystem : public FileSystem {
Status FileExists(const string& fname) override {
return RetryingUtils::CallWithRetries(
[this, &fname]() { return base_file_system_->FileExists(fname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status GetChildren(const string& dir, std::vector<string>* result) override {
@@ -63,7 +63,7 @@ class RetryingFileSystem : public FileSystem {
[this, &dir, result]() {
return base_file_system_->GetChildren(dir, result);
},
- initial_delay_microseconds_);
+ retry_config_);
}
Status GetMatchingPaths(const string& pattern,
@@ -72,31 +72,31 @@ class RetryingFileSystem : public FileSystem {
[this, &pattern, result]() {
return base_file_system_->GetMatchingPaths(pattern, result);
},
- initial_delay_microseconds_);
+ retry_config_);
}
Status Stat(const string& fname, FileStatistics* stat) override {
return RetryingUtils::CallWithRetries(
[this, &fname, stat]() { return base_file_system_->Stat(fname, stat); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status DeleteFile(const string& fname) override {
return RetryingUtils::DeleteWithRetries(
[this, &fname]() { return base_file_system_->DeleteFile(fname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status CreateDir(const string& dirname) override {
return RetryingUtils::CallWithRetries(
[this, &dirname]() { return base_file_system_->CreateDir(dirname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status DeleteDir(const string& dirname) override {
return RetryingUtils::DeleteWithRetries(
[this, &dirname]() { return base_file_system_->DeleteDir(dirname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status GetFileSize(const string& fname, uint64* file_size) override {
@@ -104,7 +104,7 @@ class RetryingFileSystem : public FileSystem {
[this, &fname, file_size]() {
return base_file_system_->GetFileSize(fname, file_size);
},
- initial_delay_microseconds_);
+ retry_config_);
}
Status RenameFile(const string& src, const string& target) override {
@@ -112,13 +112,13 @@ class RetryingFileSystem : public FileSystem {
[this, &src, &target]() {
return base_file_system_->RenameFile(src, target);
},
- initial_delay_microseconds_);
+ retry_config_);
}
Status IsDirectory(const string& dirname) override {
return RetryingUtils::CallWithRetries(
[this, &dirname]() { return base_file_system_->IsDirectory(dirname); },
- initial_delay_microseconds_);
+ retry_config_);
}
Status DeleteRecursively(const string& dirname, int64* undeleted_files,
@@ -128,7 +128,7 @@ class RetryingFileSystem : public FileSystem {
return base_file_system_->DeleteRecursively(dirname, undeleted_files,
undeleted_dirs);
},
- initial_delay_microseconds_);
+ retry_config_);
}
void FlushCaches() override { base_file_system_->FlushCaches(); }
@@ -137,7 +137,7 @@ class RetryingFileSystem : public FileSystem {
private:
std::unique_ptr<Underlying> base_file_system_;
- const int64 initial_delay_microseconds_;
+ const RetryConfig retry_config_;
TF_DISALLOW_COPY_AND_ASSIGN(RetryingFileSystem);
};
@@ -147,9 +147,8 @@ namespace retrying_internals {
class RetryingRandomAccessFile : public RandomAccessFile {
public:
RetryingRandomAccessFile(std::unique_ptr<RandomAccessFile> base_file,
- int64 delay_microseconds)
- : base_file_(std::move(base_file)),
- initial_delay_microseconds_(delay_microseconds) {}
+ const RetryConfig& retry_config)
+ : base_file_(std::move(base_file)), retry_config_(retry_config) {}
Status Read(uint64 offset, size_t n, StringPiece* result,
char* scratch) const override {
@@ -157,20 +156,19 @@ class RetryingRandomAccessFile : public RandomAccessFile {
[this, offset, n, result, scratch]() {
return base_file_->Read(offset, n, result, scratch);
},
- initial_delay_microseconds_);
+ retry_config_);
}
private:
std::unique_ptr<RandomAccessFile> base_file_;
- const int64 initial_delay_microseconds_;
+ const RetryConfig retry_config_;
};
class RetryingWritableFile : public WritableFile {
public:
RetryingWritableFile(std::unique_ptr<WritableFile> base_file,
- int64 delay_microseconds)
- : base_file_(std::move(base_file)),
- initial_delay_microseconds_(delay_microseconds) {}
+ const RetryConfig& retry_config)
+ : base_file_(std::move(base_file)), retry_config_(retry_config) {}
~RetryingWritableFile() override {
// Makes sure the retrying version of Close() is called in the destructor.
@@ -179,25 +177,24 @@ class RetryingWritableFile : public WritableFile {
Status Append(StringPiece data) override {
return RetryingUtils::CallWithRetries(
- [this, &data]() { return base_file_->Append(data); },
- initial_delay_microseconds_);
+ [this, &data]() { return base_file_->Append(data); }, retry_config_);
}
Status Close() override {
return RetryingUtils::CallWithRetries(
- [this]() { return base_file_->Close(); }, initial_delay_microseconds_);
+ [this]() { return base_file_->Close(); }, retry_config_);
}
Status Flush() override {
return RetryingUtils::CallWithRetries(
- [this]() { return base_file_->Flush(); }, initial_delay_microseconds_);
+ [this]() { return base_file_->Flush(); }, retry_config_);
}
Status Sync() override {
return RetryingUtils::CallWithRetries(
- [this]() { return base_file_->Sync(); }, initial_delay_microseconds_);
+ [this]() { return base_file_->Sync(); }, retry_config_);
}
private:
std::unique_ptr<WritableFile> base_file_;
- const int64 initial_delay_microseconds_;
+ const RetryConfig retry_config_;
};
} // namespace retrying_internals
@@ -210,9 +207,9 @@ Status RetryingFileSystem<Underlying>::NewRandomAccessFile(
[this, &filename, &base_file]() {
return base_file_system_->NewRandomAccessFile(filename, &base_file);
},
- initial_delay_microseconds_));
+ retry_config_));
result->reset(new retrying_internals::RetryingRandomAccessFile(
- std::move(base_file), initial_delay_microseconds_));
+ std::move(base_file), retry_config_));
return Status::OK();
}
@@ -224,9 +221,9 @@ Status RetryingFileSystem<Underlying>::NewWritableFile(
[this, &filename, &base_file]() {
return base_file_system_->NewWritableFile(filename, &base_file);
},
- initial_delay_microseconds_));
+ retry_config_));
result->reset(new retrying_internals::RetryingWritableFile(
- std::move(base_file), initial_delay_microseconds_));
+ std::move(base_file), retry_config_));
return Status::OK();
}
@@ -238,9 +235,9 @@ Status RetryingFileSystem<Underlying>::NewAppendableFile(
[this, &filename, &base_file]() {
return base_file_system_->NewAppendableFile(filename, &base_file);
},
- initial_delay_microseconds_));
+ retry_config_));
result->reset(new retrying_internals::RetryingWritableFile(
- std::move(base_file), initial_delay_microseconds_));
+ std::move(base_file), retry_config_));
return Status::OK();
}
@@ -252,7 +249,7 @@ Status RetryingFileSystem<Underlying>::NewReadOnlyMemoryRegionFromFile(
return base_file_system_->NewReadOnlyMemoryRegionFromFile(filename,
result);
},
- initial_delay_microseconds_);
+ retry_config_);
}
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/retrying_file_system_test.cc b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
index 5910fef1d2..868eea096c 100644
--- a/tensorflow/core/platform/cloud/retrying_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_file_system_test.cc
@@ -184,7 +184,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_ImmediateSuccess) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->random_access_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped random access file.
std::unique_ptr<RandomAccessFile> random_access_file;
@@ -211,7 +212,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_SuccessWith3rdTry) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->random_access_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped random access file.
std::unique_ptr<RandomAccessFile> random_access_file;
@@ -235,7 +237,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_AllRetriesFailed) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->random_access_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped random access file.
std::unique_ptr<RandomAccessFile> random_access_file;
@@ -265,7 +268,8 @@ TEST(RetryingFileSystemTest, NewRandomAccessFile_NoRetriesForSomeErrors) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->random_access_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped random access file.
std::unique_ptr<RandomAccessFile> random_access_file;
@@ -291,7 +295,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_ImmediateSuccess) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped writable file.
std::unique_ptr<WritableFile> writable_file;
@@ -317,7 +322,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_SuccessWith3rdTry) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped writable file.
std::unique_ptr<WritableFile> writable_file;
@@ -343,7 +349,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_SuccessWith3rdTry_ViaDestructor) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped writable file.
std::unique_ptr<WritableFile> writable_file;
@@ -368,7 +375,8 @@ TEST(RetryingFileSystemTest, NewAppendableFile_SuccessWith3rdTry) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped appendable file.
std::unique_ptr<WritableFile> writable_file;
@@ -391,7 +399,8 @@ TEST(RetryingFileSystemTest, NewWritableFile_AllRetriesFailed) {
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
base_fs->writable_file_to_return = std::move(base_file);
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
// Retrieve the wrapped writable file.
std::unique_ptr<WritableFile> writable_file;
@@ -412,7 +421,8 @@ TEST(RetryingFileSystemTest,
std::make_tuple("NewReadOnlyMemoryRegionFromFile", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::unique_ptr<ReadOnlyMemoryRegion> result;
TF_EXPECT_OK(fs.NewReadOnlyMemoryRegionFromFile("filename.txt", &result));
@@ -423,7 +433,8 @@ TEST(RetryingFileSystemTest, NewReadOnlyMemoryRegionFromFile_AllRetriesFailed) {
CreateRetriableErrors("NewReadOnlyMemoryRegionFromFile", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::unique_ptr<ReadOnlyMemoryRegion> result;
const auto& status =
@@ -440,7 +451,8 @@ TEST(RetryingFileSystemTest, GetChildren_SuccessWith2ndTry) {
std::make_tuple("GetChildren", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.GetChildren("gs://path", &result));
@@ -450,7 +462,8 @@ TEST(RetryingFileSystemTest, GetChildren_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("GetChildren", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.GetChildren("gs://path", &result);
@@ -466,7 +479,8 @@ TEST(RetryingFileSystemTest, GetMatchingPaths_SuccessWith2ndTry) {
std::make_tuple("GetMatchingPaths", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.GetMatchingPaths("gs://path/dir", &result));
@@ -477,7 +491,8 @@ TEST(RetryingFileSystemTest, GetMatchingPaths_AllRetriesFailed) {
CreateRetriableErrors("GetMatchingPaths", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.GetMatchingPaths("gs://path/dir", &result);
@@ -492,7 +507,8 @@ TEST(RetryingFileSystemTest, DeleteFile_SuccessWith2ndTry) {
std::make_tuple("DeleteFile", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.DeleteFile("gs://path/file.txt"));
@@ -502,7 +518,8 @@ TEST(RetryingFileSystemTest, DeleteFile_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("DeleteFile", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.DeleteFile("gs://path/file.txt");
@@ -517,7 +534,8 @@ TEST(RetryingFileSystemTest, CreateDir_SuccessWith2ndTry) {
std::make_tuple("CreateDir", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.CreateDir("gs://path/newdir"));
@@ -527,7 +545,8 @@ TEST(RetryingFileSystemTest, CreateDir_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("CreateDir", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.CreateDir("gs://path/newdir");
@@ -542,7 +561,8 @@ TEST(RetryingFileSystemTest, DeleteDir_SuccessWith2ndTry) {
std::make_tuple("DeleteDir", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
TF_EXPECT_OK(fs.DeleteDir("gs://path/dir"));
@@ -552,7 +572,8 @@ TEST(RetryingFileSystemTest, DeleteDir_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("DeleteDir", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
std::vector<string> result;
const auto& status = fs.DeleteDir("gs://path/dir");
@@ -568,7 +589,8 @@ TEST(RetryingFileSystemTest, GetFileSize_SuccessWith2ndTry) {
std::make_tuple("GetFileSize", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
uint64 size;
TF_EXPECT_OK(fs.GetFileSize("gs://path/file.txt", &size));
@@ -578,7 +600,8 @@ TEST(RetryingFileSystemTest, GetFileSize_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("GetFileSize", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
uint64 size;
const auto& status = fs.GetFileSize("gs://path/file.txt", &size);
@@ -593,7 +616,8 @@ TEST(RetryingFileSystemTest, RenameFile_SuccessWith2ndTry) {
std::make_tuple("RenameFile", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
TF_EXPECT_OK(fs.RenameFile("old_name", "new_name"));
}
@@ -602,7 +626,8 @@ TEST(RetryingFileSystemTest, RenameFile_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("RenameFile", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
const auto& status = fs.RenameFile("old_name", "new_name");
EXPECT_TRUE(
@@ -616,7 +641,8 @@ TEST(RetryingFileSystemTest, Stat_SuccessWith2ndTry) {
std::make_tuple("Stat", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
FileStatistics stat;
TF_EXPECT_OK(fs.Stat("file_name", &stat));
@@ -626,7 +652,8 @@ TEST(RetryingFileSystemTest, Stat_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("Stat", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
FileStatistics stat;
const auto& status = fs.Stat("file_name", &stat);
@@ -639,7 +666,8 @@ TEST(RetryingFileSystemTest, FileExists_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("FileExists", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
const auto& status = fs.FileExists("file_name");
EXPECT_TRUE(
@@ -653,7 +681,8 @@ TEST(RetryingFileSystemTest, FileExists_SuccessWith2ndTry) {
std::make_tuple("FileExists", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
TF_EXPECT_OK(fs.FileExists("gs://path/dir"));
}
@@ -665,7 +694,8 @@ TEST(RetryingFileSystemTest, IsDirectory_SuccessWith2ndTry) {
std::make_tuple("IsDirectory", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
TF_EXPECT_OK(fs.IsDirectory("gs://path/dir"));
}
@@ -674,7 +704,8 @@ TEST(RetryingFileSystemTest, IsDirectory_AllRetriesFailed) {
ExpectedCalls expected_fs_calls = CreateRetriableErrors("IsDirectory", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
const auto& status = fs.IsDirectory("gs://path/dir");
EXPECT_TRUE(
@@ -689,7 +720,8 @@ TEST(RetryingFileSystemTest, DeleteRecursively_SuccessWith2ndTry) {
std::make_tuple("DeleteRecursively", Status::OK())});
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
int64 undeleted_files, undeleted_dirs;
TF_EXPECT_OK(
@@ -701,7 +733,8 @@ TEST(RetryingFileSystemTest, DeleteRecursively_AllRetriesFailed) {
CreateRetriableErrors("DeleteRecursively", 11);
std::unique_ptr<MockFileSystem> base_fs(
new MockFileSystem(expected_fs_calls));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
int64 undeleted_files, undeleted_dirs;
const auto& status =
@@ -715,7 +748,8 @@ TEST(RetryingFileSystemTest, FlushCaches) {
ExpectedCalls none;
bool flushed = false;
std::unique_ptr<MockFileSystem> base_fs(new MockFileSystem(none, &flushed));
- RetryingFileSystem<MockFileSystem> fs(std::move(base_fs), 0);
+ RetryingFileSystem<MockFileSystem> fs(
+ std::move(base_fs), RetryConfig(0 /* init_delay_time_us */));
fs.FlushCaches();
EXPECT_TRUE(flushed);
}
diff --git a/tensorflow/core/platform/cloud/retrying_utils.cc b/tensorflow/core/platform/cloud/retrying_utils.cc
index d2df422024..cb0aecdd35 100644
--- a/tensorflow/core/platform/cloud/retrying_utils.cc
+++ b/tensorflow/core/platform/cloud/retrying_utils.cc
@@ -23,11 +23,6 @@ namespace tensorflow {
namespace {
-// In case of failure, every call will be retried kMaxRetries times.
-constexpr int kMaxRetries = 10;
-// Maximum backoff time in microseconds.
-constexpr int64 kMaximumBackoffMicroseconds = 32000000; // 32 seconds.
-
bool IsRetriable(error::Code code) {
switch (code) {
case error::UNAVAILABLE:
@@ -43,40 +38,41 @@ bool IsRetriable(error::Code code) {
} // namespace
Status RetryingUtils::CallWithRetries(const std::function<Status()>& f,
- const int64 initial_delay_microseconds) {
- return CallWithRetries(f, initial_delay_microseconds, [](int64 micros) {
- return Env::Default()->SleepForMicroseconds(micros);
- });
+ const RetryConfig& config) {
+ return CallWithRetries(
+ f,
+ [](int64 micros) { return Env::Default()->SleepForMicroseconds(micros); },
+ config);
}
Status RetryingUtils::CallWithRetries(
- const std::function<Status()>& f, const int64 initial_delay_microseconds,
- const std::function<void(int64)>& sleep_usec) {
+ const std::function<Status()>& f,
+ const std::function<void(int64)>& sleep_usec, const RetryConfig& config) {
int retries = 0;
while (true) {
auto status = f();
if (!IsRetriable(status.code())) {
return status;
}
- if (retries >= kMaxRetries) {
+ if (retries >= config.max_retries) {
// Return AbortedError, so that it doesn't get retried again somewhere
// at a higher level.
return Status(
error::ABORTED,
strings::StrCat(
- "All ", kMaxRetries,
+ "All ", config.max_retries,
" retry attempts failed. The last failure: ", status.ToString()));
}
int64 delay_micros = 0;
- if (initial_delay_microseconds > 0) {
+ if (config.init_delay_time_us > 0) {
const int64 random_micros = random::New64() % 1000000;
- delay_micros = std::min(initial_delay_microseconds << retries,
- kMaximumBackoffMicroseconds) +
+ delay_micros = std::min(config.init_delay_time_us << retries,
+ config.max_delay_time_us) +
random_micros;
}
LOG(INFO) << "The operation failed and will be automatically retried in "
<< (delay_micros / 1000000.0) << " seconds (attempt "
- << (retries + 1) << " out of " << kMaxRetries
+ << (retries + 1) << " out of " << config.max_retries
<< "), caused by: " << status.ToString();
sleep_usec(delay_micros);
retries++;
@@ -84,8 +80,7 @@ Status RetryingUtils::CallWithRetries(
}
Status RetryingUtils::DeleteWithRetries(
- const std::function<Status()>& delete_func,
- const int64 initial_delay_microseconds) {
+ const std::function<Status()>& delete_func, const RetryConfig& config) {
bool is_retried = false;
return RetryingUtils::CallWithRetries(
[delete_func, &is_retried]() {
@@ -96,7 +91,7 @@ Status RetryingUtils::DeleteWithRetries(
is_retried = true;
return status;
},
- initial_delay_microseconds);
+ config);
}
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/retrying_utils.h b/tensorflow/core/platform/cloud/retrying_utils.h
index 546b8d1c4a..1a7ce1b122 100644
--- a/tensorflow/core/platform/cloud/retrying_utils.h
+++ b/tensorflow/core/platform/cloud/retrying_utils.h
@@ -21,6 +21,26 @@ limitations under the License.
namespace tensorflow {
+// Default time before reporting failure: ~100 seconds.
+struct RetryConfig {
+ RetryConfig(int64 init_delay_time_us = 100 * 1000,
+ int64 max_delay_time_us = 32 * 1000 * 1000,
+ int max_retries = 10) {
+ this->init_delay_time_us = init_delay_time_us;
+ this->max_delay_time_us = max_delay_time_us;
+ this->max_retries = max_retries;
+ }
+
+ // In case of failure, every call will be retried max_retries times.
+ int max_retries;
+
+ // Initial backoff time
+ int64 init_delay_time_us;
+
+ // Maximum backoff time in microseconds.
+ int64 max_delay_time_us;
+};
+
class RetryingUtils {
public:
/// \brief Retries the function in case of failure with exponential backoff.
@@ -31,18 +51,19 @@ class RetryingUtils {
/// retries.
/// If all retries failed, returns the last error status.
static Status CallWithRetries(const std::function<Status()>& f,
- const int64 initial_delay_microseconds);
+ const RetryConfig& config);
+
/// sleep_usec is a function that sleeps for the given number of microseconds.
static Status CallWithRetries(const std::function<Status()>& f,
- const int64 initial_delay_microseconds,
- const std::function<void(int64)>& sleep_usec);
+ const std::function<void(int64)>& sleep_usec,
+ const RetryConfig& config);
/// \brief A retrying wrapper for a function that deletes a resource.
///
/// The function takes care of the scenario when a delete operation
/// returns a failure but succeeds under the hood: if a retry returns
/// NOT_FOUND, the whole operation is considered a success.
static Status DeleteWithRetries(const std::function<Status()>& delete_func,
- const int64 initial_delay_microseconds);
+ const RetryConfig& config);
};
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/retrying_utils_test.cc b/tensorflow/core/platform/cloud/retrying_utils_test.cc
index 1b6527618a..75fe8a98f4 100644
--- a/tensorflow/core/platform/cloud/retrying_utils_test.cc
+++ b/tensorflow/core/platform/cloud/retrying_utils_test.cc
@@ -30,7 +30,8 @@ TEST(RetryingUtilsTest, CallWithRetries_RetryDelays) {
};
std::function<Status()> f = []() { return errors::Unavailable("Failed."); };
- const auto& status = RetryingUtils::CallWithRetries(f, 500000L, sleep);
+ const auto& status = RetryingUtils::CallWithRetries(
+ f, sleep, RetryConfig(500000 /* init_delay_time_us */));
EXPECT_EQ(errors::Code::ABORTED, status.code());
EXPECT_TRUE(str_util::StrContains(
status.error_message(),
@@ -60,8 +61,10 @@ TEST(RetryingUtilsTest, CallWithRetries_NotFoundIsNotRetried) {
results.erase(results.begin());
return result;
};
- EXPECT_EQ(errors::Code::NOT_FOUND,
- RetryingUtils::CallWithRetries(f, 0).code());
+ EXPECT_EQ(
+ errors::Code::NOT_FOUND,
+ RetryingUtils::CallWithRetries(f, RetryConfig(0 /* init_delay_time_us */))
+ .code());
}
TEST(RetryingUtilsTest, CallWithRetries_ImmediateSuccess) {
@@ -74,7 +77,8 @@ TEST(RetryingUtilsTest, CallWithRetries_ImmediateSuccess) {
results.erase(results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::CallWithRetries(f, 1.0, sleep));
+ TF_EXPECT_OK(RetryingUtils::CallWithRetries(
+ f, sleep, RetryConfig(1L /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, CallWithRetries_EventualSuccess) {
@@ -86,7 +90,8 @@ TEST(RetryingUtilsTest, CallWithRetries_EventualSuccess) {
results.erase(results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::CallWithRetries(f, 0));
+ TF_EXPECT_OK(RetryingUtils::CallWithRetries(
+ f, RetryConfig(0 /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, DeleteWithRetries_ImmediateSuccess) {
@@ -96,7 +101,8 @@ TEST(RetryingUtilsTest, DeleteWithRetries_ImmediateSuccess) {
delete_results.erase(delete_results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(delete_func, 0));
+ TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, DeleteWithRetries_EventualSuccess) {
@@ -106,7 +112,8 @@ TEST(RetryingUtilsTest, DeleteWithRetries_EventualSuccess) {
delete_results.erase(delete_results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(delete_func, 0));
+ TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, DeleteWithRetries_PermissionDeniedNotRetried) {
@@ -118,7 +125,9 @@ TEST(RetryingUtilsTest, DeleteWithRetries_PermissionDeniedNotRetried) {
return result;
};
EXPECT_EQ(errors::Code::PERMISSION_DENIED,
- RetryingUtils::DeleteWithRetries(delete_func, 0).code());
+ RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */))
+ .code());
}
TEST(RetryingUtilsTest, DeleteWithRetries_SuccessThroughFileNotFound) {
@@ -129,7 +138,8 @@ TEST(RetryingUtilsTest, DeleteWithRetries_SuccessThroughFileNotFound) {
delete_results.erase(delete_results.begin());
return result;
};
- TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(delete_func, 0));
+ TF_EXPECT_OK(RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */)));
}
TEST(RetryingUtilsTest, DeleteWithRetries_FirstNotFoundReturnedAsIs) {
@@ -140,7 +150,9 @@ TEST(RetryingUtilsTest, DeleteWithRetries_FirstNotFoundReturnedAsIs) {
return result;
};
EXPECT_EQ(error::NOT_FOUND,
- RetryingUtils::DeleteWithRetries(delete_func, 0).code());
+ RetryingUtils::DeleteWithRetries(
+ delete_func, RetryConfig(0 /* init_delay_time_us */))
+ .code());
}
} // namespace
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 3b14757945..d884c1aa7c 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -615,11 +615,7 @@ def tf_kernel_tests_linkstatic():
def tf_additional_lib_defines():
"""Additional defines needed to build TF libraries."""
- return select({
- "//tensorflow:with_jemalloc_linux_x86_64": ["TENSORFLOW_USE_JEMALLOC"],
- "//tensorflow:with_jemalloc_linux_ppc64le": ["TENSORFLOW_USE_JEMALLOC"],
- "//conditions:default": [],
- })
+ return []
def tf_additional_lib_deps():
"""Additional dependencies needed to build TF libraries."""
@@ -631,13 +627,7 @@ def tf_additional_lib_deps():
] + if_static(
["@nsync//:nsync_cpp"],
["@nsync//:nsync_headers"],
- ) + select({
- "//tensorflow:with_jemalloc_linux_x86_64_dynamic": ["@jemalloc//:jemalloc_headers"],
- "//tensorflow:with_jemalloc_linux_ppc64le_dynamic": ["@jemalloc//:jemalloc_headers"],
- "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
- "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
- "//conditions:default": [],
- })
+ )
def tf_additional_core_deps():
return select({
@@ -725,11 +715,7 @@ def tf_additional_binary_deps():
"//tensorflow/stream_executor:cuda_platform",
"//tensorflow/core/platform/default/build_config:cuda",
],
- ) + select({
- "//tensorflow:with_jemalloc_linux_x86_64": ["@jemalloc//:jemalloc_impl"],
- "//tensorflow:with_jemalloc_linux_ppc64le": ["@jemalloc//:jemalloc_impl"],
- "//conditions:default": [],
- }) + [
+ ) + [
# TODO(allenl): Split these out into their own shared objects (they are
# here because they are shared between contrib/ op shared objects and
# core).
diff --git a/tensorflow/core/platform/posix/port.cc b/tensorflow/core/platform/posix/port.cc
index b46b9927cd..acdd7798ea 100644
--- a/tensorflow/core/platform/posix/port.cc
+++ b/tensorflow/core/platform/posix/port.cc
@@ -13,10 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef TENSORFLOW_USE_JEMALLOC
-#include "jemalloc/jemalloc.h"
-#endif
-
#include "absl/base/internal/sysinfo.h"
#include "tensorflow/core/platform/cpu_info.h"
@@ -101,11 +97,7 @@ void* AlignedMalloc(size_t size, int minimum_alignment) {
// memory aligned to at least the size of a pointer.
const int required_alignment = sizeof(void*);
if (minimum_alignment < required_alignment) return Malloc(size);
-#ifdef TENSORFLOW_USE_JEMALLOC
- int err = jemalloc_posix_memalign(&ptr, minimum_alignment, size);
-#else
int err = posix_memalign(&ptr, minimum_alignment, size);
-#endif
if (err != 0) {
return nullptr;
} else {
@@ -116,29 +108,11 @@ void* AlignedMalloc(size_t size, int minimum_alignment) {
void AlignedFree(void* aligned_memory) { Free(aligned_memory); }
-void* Malloc(size_t size) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_malloc(size);
-#else
- return malloc(size);
-#endif
-}
+void* Malloc(size_t size) { return malloc(size); }
-void* Realloc(void* ptr, size_t size) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_realloc(ptr, size);
-#else
- return realloc(ptr, size);
-#endif
-}
+void* Realloc(void* ptr, size_t size) { return realloc(ptr, size); }
-void Free(void* ptr) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- jemalloc_free(ptr);
-#else
- free(ptr);
-#endif
-}
+void Free(void* ptr) { free(ptr); }
void* NUMAMalloc(int node, size_t size, int minimum_alignment) {
return AlignedMalloc(size, minimum_alignment);
@@ -146,9 +120,7 @@ void* NUMAMalloc(int node, size_t size, int minimum_alignment) {
void NUMAFree(void* ptr, size_t size) { Free(ptr); }
-int NUMAGetMemAffinity(const void* addr) {
- return kNUMANoAffinity;
-}
+int NUMAGetMemAffinity(const void* addr) { return kNUMANoAffinity; }
void MallocExtension_ReleaseToSystem(std::size_t num_bytes) {
// No-op.
diff --git a/tensorflow/core/platform/windows/port.cc b/tensorflow/core/platform/windows/port.cc
index 5375f56372..911ea1902f 100644
--- a/tensorflow/core/platform/windows/port.cc
+++ b/tensorflow/core/platform/windows/port.cc
@@ -13,10 +13,6 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifdef TENSORFLOW_USE_JEMALLOC
-#include "jemalloc/jemalloc.h"
-#endif
-
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
@@ -70,55 +66,16 @@ void NUMASetThreadNodeAffinity(int node) {}
int NUMAGetThreadNodeAffinity() { return kNUMANoAffinity; }
void* AlignedMalloc(size_t size, int minimum_alignment) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- void* ptr = NULL;
- // posix_memalign requires that the requested alignment be at least
- // sizeof(void*). In this case, fall back on malloc which should return
- // memory aligned to at least the size of a pointer.
- const int required_alignment = sizeof(void*);
- if (minimum_alignment < required_alignment) return Malloc(size);
- int err = jemalloc_posix_memalign(&ptr, minimum_alignment, size);
- if (err != 0) {
- return NULL;
- } else {
- return ptr;
- }
-#else
return _aligned_malloc(size, minimum_alignment);
-#endif
}
-void AlignedFree(void* aligned_memory) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- jemalloc_free(aligned_memory);
-#else
- _aligned_free(aligned_memory);
-#endif
-}
+void AlignedFree(void* aligned_memory) { _aligned_free(aligned_memory); }
-void* Malloc(size_t size) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_malloc(size);
-#else
- return malloc(size);
-#endif
-}
+void* Malloc(size_t size) { return malloc(size); }
-void* Realloc(void* ptr, size_t size) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_realloc(ptr, size);
-#else
- return realloc(ptr, size);
-#endif
-}
+void* Realloc(void* ptr, size_t size) { return realloc(ptr, size); }
-void Free(void* ptr) {
-#ifdef TENSORFLOW_USE_JEMALLOC
- return jemalloc_free(ptr);
-#else
- return free(ptr);
-#endif
-}
+void Free(void* ptr) { return free(ptr); }
void* NUMAMalloc(int node, size_t size, int minimum_alignment) {
return AlignedMalloc(size, minimum_alignment);
diff --git a/tensorflow/core/profiler/BUILD b/tensorflow/core/profiler/BUILD
index af034bdd7d..2bf371276e 100644
--- a/tensorflow/core/profiler/BUILD
+++ b/tensorflow/core/profiler/BUILD
@@ -40,7 +40,6 @@ tf_proto_library(
name = "protos_all",
srcs = glob(["**/*.proto"]),
cc_api_version = 2,
- java_api_version = 2,
protodeps = tf_additional_all_protos(),
visibility = ["//visibility:public"],
)
diff --git a/tensorflow/core/protobuf/config.proto b/tensorflow/core/protobuf/config.proto
index 85cd02350a..104ab039cb 100644
--- a/tensorflow/core/protobuf/config.proto
+++ b/tensorflow/core/protobuf/config.proto
@@ -453,6 +453,11 @@ message RunOptions {
// same group_key value (in a distributed computation where tasks
// run disjoint graphs).
int64 collective_graph_key = 1;
+ // If true, then operations (using the inter-op pool) across all
+ // session::run() calls will be centrally scheduled, optimizing for (median
+ // and tail) latency.
+ // Consider using this option for CPU-bound workloads like inference.
+ bool use_run_handler_pool = 2;
};
Experimental experimental = 8;
diff --git a/tensorflow/core/protobuf/rewriter_config.proto b/tensorflow/core/protobuf/rewriter_config.proto
index 482178a540..8c31468ff5 100644
--- a/tensorflow/core/protobuf/rewriter_config.proto
+++ b/tensorflow/core/protobuf/rewriter_config.proto
@@ -75,8 +75,10 @@ message RewriterConfig {
// Try to allocate some independent Op outputs contiguously in order to
// merge or eliminate downstream Ops (off by default).
Toggle scoped_allocator_optimization = 15;
- // Force small ops onto the CPU (default is ON).
+ // Force small ops onto the CPU (default is OFF).
Toggle pin_to_host_optimization = 18;
+ // Disable the entire meta optimizer (off by default).
+ bool disable_meta_optimizer = 19;
// Controls how many times we run the optimizers in meta optimizer (default
// is once).
diff --git a/tensorflow/core/util/mkl_util.h b/tensorflow/core/util/mkl_util.h
index cf7ffd8149..04aaea4f89 100644
--- a/tensorflow/core/util/mkl_util.h
+++ b/tensorflow/core/util/mkl_util.h
@@ -2039,8 +2039,8 @@ class MklPrimitiveFactory {
/// Fuction to check whether primitive memory optimization is enabled
static inline bool IsPrimitiveMemOptEnabled() {
bool is_primitive_mem_opt_enabled = true;
- TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITVE_MEMUSE", true,
- &is_primitive_mem_opt_enabled));
+ TF_CHECK_OK(ReadBoolFromEnvVar("TF_MKL_OPTIMIZE_PRIMITIVE_MEMUSE", true,
+ &is_primitive_mem_opt_enabled));
return is_primitive_mem_opt_enabled;
}
@@ -2095,9 +2095,8 @@ static inline memory::format get_desired_format(int channel,
fmt_desired = is_2d ? memory::format::nChw16c : memory::format::nCdhw16c;
} else if (port::TestCPUFeature(port::CPUFeature::AVX2) &&
(channel % 8) == 0) {
- fmt_desired = is_2d
- ? memory::format::nChw8c
- : memory::format::ncdhw; //not support avx2 for 3d yet.
+ fmt_desired = is_2d ? memory::format::nChw8c
+ : memory::format::ncdhw; // no avx2 support for 3d yet.
} else {
fmt_desired = is_2d ? memory::format::nchw : memory::format::ncdhw;
}
@@ -2209,7 +2208,8 @@ inline primitive FindOrCreateReorder(const memory* from, const memory* to) {
// utility function to determine if it is conv 1x1 and stride != 1
// for purpose of temporarily disabling primitive reuse
-inline bool IsConv1x1StrideNot1(memory::dims filter_dims, memory::dims strides) {
+inline bool IsConv1x1StrideNot1(memory::dims filter_dims,
+ memory::dims strides) {
if (filter_dims.size() != 4 || strides.size() != 2) return false;
return ((filter_dims[2] == 1) && (filter_dims[3] == 1) &&
diff --git a/tensorflow/core/util/tensor_bundle/BUILD b/tensorflow/core/util/tensor_bundle/BUILD
index 4d4db86df2..f40ec9b752 100644
--- a/tensorflow/core/util/tensor_bundle/BUILD
+++ b/tensorflow/core/util/tensor_bundle/BUILD
@@ -65,6 +65,10 @@ tf_cc_test(
name = "tensor_bundle_test",
srcs = ["tensor_bundle_test.cc"],
data = glob(["testdata/**"]),
+ tags = [
+ "nomsan",
+ "notsan",
+ ],
deps = [
":tensor_bundle",
"//tensorflow/core:framework",
diff --git a/tensorflow/docs_src/BUILD b/tensorflow/docs_src/BUILD
deleted file mode 100644
index 34bf7b6a11..0000000000
--- a/tensorflow/docs_src/BUILD
+++ /dev/null
@@ -1,14 +0,0 @@
-# Files used to generate TensorFlow docs.
-
-licenses(["notice"]) # Apache 2.0
-
-package(
- default_visibility = ["//tensorflow:internal"],
-)
-
-exports_files(["LICENSE"])
-
-filegroup(
- name = "docs_src",
- data = glob(["**/*.md"]),
-)
diff --git a/tensorflow/docs_src/__init__.py b/tensorflow/docs_src/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
--- a/tensorflow/docs_src/__init__.py
+++ /dev/null
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
new file mode 100644
index 0000000000..96d269bec4
--- /dev/null
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -0,0 +1,2426 @@
+# Operation Semantics
+
+The following describes the semantics of operations defined in the
+[`XlaBuilder`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
+interface. Typically, these operations map one-to-one to operations defined in
+the RPC interface in
+[`xla_data.proto`](https://www.tensorflow.org/code/tensorflow/compiler/xla/xla_data.proto).
+
+A note on nomenclature: the generalized data type XLA deals with is an
+N-dimensional array holding elements of some uniform type (such as 32-bit
+float). Throughout the documentation, *array* is used to denote an
+arbitrary-dimensional array. For convenience, special cases have more specific
+and familiar names; for example a *vector* is a 1-dimensional array and a
+*matrix* is a 2-dimensional array.
+
+## AllToAll
+
+See also
+[`XlaBuilder::AllToAll`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Alltoall is a collective operation that sends data from all cores to all cores.
+It has two phases:
+
+1. the scatter phase. On each core, the operand is split into `split_count`
+ number of blocks along the `split_dimensions`, and the blocks are scattered
+ to all cores, e.g., the ith block is send to the ith core.
+2. the gather phase. Each core concatenates the received blocks along the
+ `concat_dimension`.
+
+The participating cores can be configured by:
+
+- `replica_groups`: each ReplicaGroup contains a list of replica id. If empty,
+ all replicas belong to one group in the order of 0 - (n-1). Alltoall will be
+ applied within subgroups in the specified order. For example, replica
+ groups = {{1,2,3},{4,5,0}} means, an Alltoall will be applied within replica
+ 1, 2, 3, and in the gather phase, the received blocks will be concatenated
+ in the order of 1, 2, 3; another Alltoall will be applied within replica 4,
+ 5, 0, and the concatenation order is 4, 5, 0.
+
+Prerequisites:
+
+- The dimension size of the operand on the split_dimension is divisible by
+ split_count.
+- The operand's shape is not tuple.
+
+<b> `AllToAll(operand, split_dimension, concat_dimension, split_count,
+replica_groups)` </b>
+
+
+| Arguments | Type | Semantics |
+| ------------------ | --------------------- | ------------------------------- |
+| `operand` | `XlaOp` | n dimensional input array |
+| `split_dimension` | `int64` | A value in the interval `[0, |
+: : : n)` that names the dimension :
+: : : along which the operand is :
+: : : split :
+| `concat_dimension` | `int64` | a value in the interval `[0, |
+: : : n)` that names the dimension :
+: : : along which the split blocks :
+: : : are concatenated :
+| `split_count` | `int64` | the number of cores that |
+: : : participate this operation. If :
+: : : `replica_groups` is empty, this :
+: : : should be the number of :
+: : : replicas; otherwise, this :
+: : : should be equal to the number :
+: : : of replicas in each group. :
+| `replica_groups` | `ReplicaGroup` vector | each group contains a list of |
+: : : replica id. :
+
+Below shows an example of Alltoall.
+
+```
+XlaBuilder b("alltoall");
+auto x = Parameter(&b, 0, ShapeUtil::MakeShape(F32, {4, 16}), "x");
+AllToAll(x, /*split_dimension=*/1, /*concat_dimension=*/0, /*split_count=*/4);
+```
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src="../../images/xla/ops_alltoall.png">
+</div>
+
+In this example, there are 4 cores participating the Alltoall. On each core, the
+operand is split into 4 parts along dimension 0, so each part has shape
+f32[4,4]. The 4 parts are scattered to all cores. Then each core concatenates
+the received parts along dimension 1, in the order or core 0-4. So the output on
+each core has shape f32[16,4].
+
+## BatchNormGrad
+
+See also
+[`XlaBuilder::BatchNormGrad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
+and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
+for a detailed description of the algorithm.
+
+Calculates gradients of batch norm.
+
+<b> `BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)` </b>
+
+| Arguments | Type | Semantics |
+| --------------- | ----------------------- | -------------------------------- |
+| `operand` | `XlaOp` | n dimensional array to be |
+: : : normalized (x) :
+| `scale` | `XlaOp` | 1 dimensional array |
+: : : (\\(\gamma\\)) :
+| `mean` | `XlaOp` | 1 dimensional array (\\(\mu\\)) |
+| `variance` | `XlaOp` | 1 dimensional array |
+: : : (\\(\sigma^2\\)) :
+| `grad_output` | `XlaOp` | Gradients passed to |
+: : : `BatchNormTraining` :
+: : : (\\( \nabla y\\)) :
+| `epsilon` | `float` | Epsilon value (\\(\epsilon\\)) |
+| `feature_index` | `int64` | Index to feature dimension in |
+: : : `operand` :
+
+For each feature in the feature dimension (`feature_index` is the index for the
+feature dimension in `operand`), the operation calculates the gradients with
+respect to `operand`, `offset` and `scale` across all the other dimensions. The
+`feature_index` must be a valid index for the feature dimension in `operand`.
+
+The three gradients are defined by the following formulas (assuming a
+4-dimensional tensor as `operand` and with feature dimension index \\(l\\),
+batch size `m` and spatial sizes `w` and `h`):
+
+\\[ \begin{split} c_l&=
+\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h
+\left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right)
+\\\\
+\nabla x_{ijkl} &= \frac{\gamma_{l}}{\sqrt{\sigma^2_{l}+\epsilon}}
+\left( \nabla y_{ijkl} - \mathrm{mean}(\nabla y) - c_l (x_{ijkl} - \mu_{l})
+\right)
+\\\\
+\nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl}
+\frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon}} \right)
+\\\\\
+\nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl}
+\end{split} \\]
+
+The inputs `mean` and `variance` represent moments value
+across batch and spatial dimensions.
+
+The output type is a tuple of three handles:
+
+| Outputs | Type | Semantics |
+| ------------- | ----------------------- | --------------------------------- |
+| `grad_operand` | `XlaOp` | gradient with respect to input |
+: : : `operand` (\\( \nabla x\\)) :
+| `grad_scale` | `XlaOp` | gradient with respect to input |
+: : : `scale` (\\( \nabla \gamma\\)) :
+| `grad_offset` | `XlaOp` | gradient with respect to input |
+: : : `offset`(\\( \nabla \beta\\)) :
+
+## BatchNormInference
+
+See also
+[`XlaBuilder::BatchNormInference`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
+and [the original batch normalization paper](https://arxiv.org/abs/1502.03167)
+for a detailed description of the algorithm.
+
+Normalizes an array across batch and spatial dimensions.
+
+<b> `BatchNormInference(operand, scale, offset, mean, variance, epsilon, feature_index)` </b>
+
+Arguments | Type | Semantics
+--------------- | ------- | ---------------------------------------
+`operand` | `XlaOp` | n dimensional array to be normalized
+`scale` | `XlaOp` | 1 dimensional array
+`offset` | `XlaOp` | 1 dimensional array
+`mean` | `XlaOp` | 1 dimensional array
+`variance` | `XlaOp` | 1 dimensional array
+`epsilon` | `float` | Epsilon value
+`feature_index` | `int64` | Index to feature dimension in `operand`
+
+For each feature in the feature dimension (`feature_index` is the index for the
+feature dimension in `operand`), the operation calculates the mean and variance
+across all the other dimensions and uses the mean and variance to normalize each
+element in `operand`. The `feature_index` must be a valid index for the feature
+dimension in `operand`.
+
+`BatchNormInference` is equivalent to calling `BatchNormTraining` without
+computing `mean` and `variance` for each batch. It uses the input `mean` and
+`variance` instead as estimated values. The purpose of this op is to reduce
+latency in inference, hence the name `BatchNormInference`.
+
+The output is an n-dimensional, normalized array with the same shape as input
+`operand`.
+
+## BatchNormTraining
+
+See also
+[`XlaBuilder::BatchNormTraining`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
+and [`the original batch normalization paper`](https://arxiv.org/abs/1502.03167)
+for a detailed description of the algorithm.
+
+Normalizes an array across batch and spatial dimensions.
+
+<b> `BatchNormTraining(operand, scale, offset, epsilon, feature_index)` </b>
+
+Arguments | Type | Semantics
+--------------- | ------- | ----------------------------------------
+`operand` | `XlaOp` | n dimensional array to be normalized (x)
+`scale` | `XlaOp` | 1 dimensional array (\\(\gamma\\))
+`offset` | `XlaOp` | 1 dimensional array (\\(\beta\\))
+`epsilon` | `float` | Epsilon value (\\(\epsilon\\))
+`feature_index` | `int64` | Index to feature dimension in `operand`
+
+For each feature in the feature dimension (`feature_index` is the index for the
+feature dimension in `operand`), the operation calculates the mean and variance
+across all the other dimensions and uses the mean and variance to normalize each
+element in `operand`. The `feature_index` must be a valid index for the feature
+dimension in `operand`.
+
+The algorithm goes as follows for each batch in `operand` \\(x\\) that
+contains `m` elements with `w` and `h` as the size of spatial dimensions
+(assuming `operand` is an 4 dimensional array):
+
+- Calculates batch mean \\(\mu_l\\) for each feature `l` in feature dimension:
+\\(\mu_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h x_{ijkl}\\)
+
+- Calculates batch variance \\(\sigma^2_l\\):
+\\(\sigma^2_l=\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h (x_{ijkl} - \mu_l)^2\\)
+
+- Normalizes, scales and shifts:
+\\(y_{ijkl}=\frac{\gamma_l(x_{ijkl}-\mu_l)}{\sqrt[2]{\sigma^2_l+\epsilon}}+\beta_l\\)
+
+The epsilon value, usually a small number, is added to avoid divide-by-zero errors.
+
+The output type is a tuple of three `XlaOp`s:
+
+| Outputs | Type | Semantics |
+| ------------ | ----------------------- | -------------------------------------|
+| `output` | `XlaOp` | n dimensional array with the same |
+: : : shape as input `operand` (y) :
+| `batch_mean` | `XlaOp` | 1 dimensional array (\\(\mu\\)) |
+| `batch_var` | `XlaOp` | 1 dimensional array (\\(\sigma^2\\)) |
+
+The `batch_mean` and `batch_var` are moments calculated across the batch and
+spatial dimensions using the formulas above.
+
+## BitcastConvertType
+
+See also
+[`XlaBuilder::BitcastConvertType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Similar to a `tf.bitcast` in TensorFlow, performs an element-wise bitcast
+operation from a data shape to a target shape. The dimensions must match, and
+the conversion is an element-wise one; e.g. `s32` elements become `f32` elements
+via bitcast routine. Bitcast is implemented as a low-level cast, so machines
+with different floating-point representations will give different results.
+
+<b> `BitcastConvertType(operand, new_element_type)` </b>
+
+Arguments | Type | Semantics
+------------------ | --------------- | ---------------------------
+`operand` | `XlaOp` | array of type T with dims D
+`new_element_type` | `PrimitiveType` | type U
+
+The dimensions of the operand and the target shape must match. The bit-width of
+the source and destination element types must be equal. The source
+and destination element types must not be tuples.
+
+## Broadcast
+
+See also
+[`XlaBuilder::Broadcast`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Adds dimensions to an array by duplicating the data in the array.
+
+<b> `Broadcast(operand, broadcast_sizes)` </b>
+
+Arguments | Type | Semantics
+----------------- | ------------------- | -------------------------------
+`operand` | `XlaOp` | The array to duplicate
+`broadcast_sizes` | `ArraySlice<int64>` | The sizes of the new dimensions
+
+The new dimensions are inserted on the left, i.e. if `broadcast_sizes` has
+values `{a0, ..., aN}` and the operand shape has dimensions `{b0, ..., bM}` then
+the shape of the output has dimensions `{a0, ..., aN, b0, ..., bM}`.
+
+The new dimensions index into copies of the operand, i.e.
+
+```
+output[i0, ..., iN, j0, ..., jM] = operand[j0, ..., jM]
+```
+
+For example, if `operand` is a scalar `f32` with value `2.0f`, and
+`broadcast_sizes` is `{2, 3}`, then the result will be an array with shape
+`f32[2, 3]` and all the values in the result will be `2.0f`.
+
+## Call
+
+See also
+[`XlaBuilder::Call`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Invokes a computation with the given arguments.
+
+<b> `Call(computation, args...)` </b>
+
+| Arguments | Type | Semantics |
+| ------------- | ---------------------- | ----------------------------------- |
+| `computation` | `XlaComputation` | computation of type `T_0, T_1, ..., |
+: : : T_N -> S` with N parameters of :
+: : : arbitrary type :
+| `args` | sequence of N `XlaOp`s | N arguments of arbitrary type |
+
+The arity and types of the `args` must match the parameters of the
+`computation`. It is allowed to have no `args`.
+
+## Clamp
+
+See also
+[`XlaBuilder::Clamp`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Clamps an operand to within the range between a minimum and maximum value.
+
+<b> `Clamp(min, operand, max)` </b>
+
+Arguments | Type | Semantics
+--------- | ------- | ---------------
+`min` | `XlaOp` | array of type T
+`operand` | `XlaOp` | array of type T
+`max` | `XlaOp` | array of type T
+
+Given an operand and minimum and maximum values, returns the operand if it is in
+the range between the minimum and maximum, else returns the minimum value if the
+operand is below this range or the maximum value if the operand is above this
+range. That is, `clamp(a, x, b) = min(max(a, x), b)`.
+
+All three arrays must be the same shape. Alternatively, as a restricted form of
+[broadcasting](broadcasting.md), `min` and/or `max` can be a scalar of type `T`.
+
+Example with scalar `min` and `max`:
+
+```
+let operand: s32[3] = {-1, 5, 9};
+let min: s32 = 0;
+let max: s32 = 6;
+==>
+Clamp(min, operand, max) = s32[3]{0, 5, 6};
+```
+
+## Collapse
+
+See also
+[`XlaBuilder::Collapse`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
+and the `tf.reshape` operation.
+
+Collapses dimensions of an array into one dimension.
+
+<b> `Collapse(operand, dimensions)` </b>
+
+Arguments | Type | Semantics
+------------ | -------------- | -----------------------------------------------
+`operand` | `XlaOp` | array of type T
+`dimensions` | `int64` vector | in-order, consecutive subset of T's dimensions.
+
+Collapse replaces the given subset of the operand's dimensions by a single
+dimension. The input arguments are an arbitrary array of type T and a
+compile-time-constant vector of dimension indices. The dimension indices must be
+an in-order (low to high dimension numbers), consecutive subset of T's
+dimensions. Thus, {0, 1, 2}, {0, 1}, or {1, 2} are all valid dimension sets, but
+{1, 0} or {0, 2} are not. They are replaced by a single new dimension, in the
+same position in the dimension sequence as those they replace, with the new
+dimension size equal to the product of original dimension sizes. The lowest
+dimension number in `dimensions` is the slowest varying dimension (most major)
+in the loop nest which collapses these dimension, and the highest dimension
+number is fastest varying (most minor). See the `tf.reshape` operator
+if more general collapse ordering is needed.
+
+For example, let v be an array of 24 elements:
+
+```
+let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}},
+ {{20, 21, 22}, {25, 26, 27}},
+ {{30, 31, 32}, {35, 36, 37}},
+ {{40, 41, 42}, {45, 46, 47}}};
+
+// Collapse to a single dimension, leaving one dimension.
+let v012 = Collapse(v, {0,1,2});
+then v012 == f32[24] {10, 11, 12, 15, 16, 17,
+ 20, 21, 22, 25, 26, 27,
+ 30, 31, 32, 35, 36, 37,
+ 40, 41, 42, 45, 46, 47};
+
+// Collapse the two lower dimensions, leaving two dimensions.
+let v01 = Collapse(v, {0,1});
+then v01 == f32[4x6] {{10, 11, 12, 15, 16, 17},
+ {20, 21, 22, 25, 26, 27},
+ {30, 31, 32, 35, 36, 37},
+ {40, 41, 42, 45, 46, 47}};
+
+// Collapse the two higher dimensions, leaving two dimensions.
+let v12 = Collapse(v, {1,2});
+then v12 == f32[8x3] {{10, 11, 12},
+ {15, 16, 17},
+ {20, 21, 22},
+ {25, 26, 27},
+ {30, 31, 32},
+ {35, 36, 37},
+ {40, 41, 42},
+ {45, 46, 47}};
+
+```
+
+## Concatenate
+
+See also
+[`XlaBuilder::ConcatInDim`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Concatenate composes an array from multiple array operands. The array is of the
+same rank as each of the input array operands (which must be of the same rank as
+each other) and contains the arguments in the order that they were specified.
+
+<b> `Concatenate(operands..., dimension)` </b>
+
+| Arguments | Type | Semantics |
+| ----------- | --------------------- | -------------------------------------- |
+| `operands` | sequence of N `XlaOp` | N arrays of type T with dimensions |
+: : : [L0, L1, ...]. Requires N >= 1. :
+| `dimension` | `int64` | A value in the interval `[0, N)` that |
+: : : names the dimension to be concatenated :
+: : : between the `operands`. :
+
+With the exception of `dimension` all dimensions must be the same. This is
+because XLA does not support "ragged" arrays. Also note that rank-0 values
+cannot be concatenated (as it's impossible to name the dimension along which the
+concatenation occurs).
+
+1-dimensional example:
+
+```
+Concat({{2, 3}, {4, 5}, {6, 7}}, 0)
+>>> {2, 3, 4, 5, 6, 7}
+```
+
+2-dimensional example:
+
+```
+let a = {
+ {1, 2},
+ {3, 4},
+ {5, 6},
+};
+let b = {
+ {7, 8},
+};
+Concat({a, b}, 0)
+>>> {
+ {1, 2},
+ {3, 4},
+ {5, 6},
+ {7, 8},
+}
+```
+
+Diagram:
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src="https://www.tensorflow.org/images/ops_concatenate.png">
+</div>
+
+## Conditional
+
+See also
+[`XlaBuilder::Conditional`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+<b> `Conditional(pred, true_operand, true_computation, false_operand,
+false_computation)` </b>
+
+Arguments | Type | Semantics
+------------------- | ---------------- | ---------------------------------
+`pred` | `XlaOp` | Scalar of type `PRED`
+`true_operand` | `XlaOp` | Argument of type `T_0`
+`true_computation` | `XlaComputation` | XlaComputation of type `T_0 -> S`
+`false_operand` | `XlaOp` | Argument of type `T_1`
+`false_computation` | `XlaComputation` | XlaComputation of type `T_1 -> S`
+
+Executes `true_computation` if `pred` is `true`, `false_computation` if `pred`
+is `false`, and returns the result.
+
+The `true_computation` must take in a single argument of type `T_0` and will be
+invoked with `true_operand` which must be of the same type. The
+`false_computation` must take in a single argument of type `T_1` and will be
+invoked with `false_operand` which must be of the same type. The type of the
+returned value of `true_computation` and `false_computation` must be the same.
+
+Note that only one of `true_computation` and `false_computation` will be
+executed depending on the value of `pred`.
+
+## Conv (convolution)
+
+See also
+[`XlaBuilder::Conv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+As ConvWithGeneralPadding, but the padding is specified in a short-hand way as
+either SAME or VALID. SAME padding pads the input (`lhs`) with zeroes so that
+the output has the same shape as the input when not taking striding into
+account. VALID padding simply means no padding.
+
+## ConvWithGeneralPadding (convolution)
+
+See also
+[`XlaBuilder::ConvWithGeneralPadding`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Computes a convolution of the kind used in neural networks. Here, a convolution
+can be thought of as a n-dimensional window moving across a n-dimensional base
+area and a computation is performed for each possible position of the window.
+
+| Arguments | Type | Semantics |
+| --------------------- | -------------------- | ----------------------------- |
+| `lhs` | `XlaOp` | rank n+2 array of inputs |
+| `rhs` | `XlaOp` | rank n+2 array of kernel |
+: : : weights :
+| `window_strides` | `ArraySlice<int64>` | n-d array of kernel strides |
+| `padding` | `ArraySlice< | n-d array of (low, high) |
+: : pair<int64, int64>>` : padding :
+| `lhs_dilation` | `ArraySlice<int64>` | n-d lhs dilation factor array |
+| `rhs_dilation` | `ArraySlice<int64>` | n-d rhs dilation factor array |
+| `feature_group_count` | int64 | the number of feature groups |
+
+Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2
+array describing the base area. This is called the input, even though of course
+the rhs is also an input. In a neural network, these are the input activations.
+The n+2 dimensions are, in this order:
+
+* `batch`: Each coordinate in this dimension represents an independent input
+ for which convolution is carried out.
+* `z/depth/features`: Each (y,x) position in the base area has a vector
+ associated to it, which goes into this dimension.
+* `spatial_dims`: Describes the `n` spatial dimensions that define the base
+ area that the window moves across.
+
+The `rhs` argument is a rank n+2 array describing the convolutional
+filter/kernel/window. The dimensions are, in this order:
+
+* `output-z`: The `z` dimension of the output.
+* `input-z`: The size of this dimension times `feature_group_count` should
+ equal the size of the `z` dimension in lhs.
+* `spatial_dims`: Describes the `n` spatial dimensions that define the n-d
+ window that moves across the base area.
+
+The `window_strides` argument specifies the stride of the convolutional window
+in the spatial dimensions. For example, if the stride in the first spatial
+dimension is 3, then the window can only be placed at coordinates where the
+first spatial index is divisible by 3.
+
+The `padding` argument specifies the amount of zero padding to be applied to the
+base area. The amount of padding can be negative -- the absolute value of
+negative padding indicates the number of elements to remove from the specified
+dimension before doing the convolution. `padding[0]` specifies the padding for
+dimension `y` and `padding[1]` specifies the padding for dimension `x`. Each
+pair has the low padding as the first element and the high padding as the second
+element. The low padding is applied in the direction of lower indices while the
+high padding is applied in the direction of higher indices. For example, if
+`padding[1]` is `(2,3)` then there will be a padding by 2 zeroes on the left and
+by 3 zeroes on the right in the second spatial dimension. Using padding is
+equivalent to inserting those same zero values into the input (`lhs`) before
+doing the convolution.
+
+The `lhs_dilation` and `rhs_dilation` arguments specify the dilation factor to
+be applied to the lhs and rhs, respectively, in each spatial dimension. If the
+dilation factor in a spatial dimension is d, then d-1 holes are implicitly
+placed between each of the entries in that dimension, increasing the size of the
+array. The holes are filled with a no-op value, which for convolution means
+zeroes.
+
+Dilation of the rhs is also called atrous convolution. For more details, see
+`tf.nn.atrous_conv2d`. Dilation of the lhs is also called transposed
+convolution. For more details, see `tf.nn.conv2d_transpose`.
+
+The `feature_group_count` argument (default value 1) can be used for grouped
+convolutions. `feature_group_count` needs to be a divisor of both the input and
+the output feature dimension. If `feature_group_count` is greater than 1, it
+means that conceptually the input and output feature dimension and the `rhs`
+output feature dimension are split evenly into `feature_group_count` many
+groups, each group consisting of a consecutive subsequence of features. The
+input feature dimension of `rhs` needs to be equal to the `lhs` input feature
+dimension divided by `feature_group_count` (so it already has the size of a
+group of input features). The i-th groups are used together to compute
+`feature_group_count` many separate convolutions. The results of these
+convolutions are concatenated together in the output feature dimension.
+
+For depthwise convolution the `feature_group_count` argument would be set to the
+input feature dimension, and the filter would be reshaped from
+`[filter_height, filter_width, in_channels, channel_multiplier]` to
+`[filter_height, filter_width, 1, in_channels * channel_multiplier]`. For more
+details, see `tf.nn.depthwise_conv2d`.
+
+The output shape has these dimensions, in this order:
+
+* `batch`: Same size as `batch` on the input (`lhs`).
+* `z`: Same size as `output-z` on the kernel (`rhs`).
+* `spatial_dims`: One value for each valid placement of the convolutional
+ window.
+
+The valid placements of the convolutional window are determined by the strides
+and the size of the base area after padding.
+
+To describe what a convolution does, consider a 2d convolution, and pick some
+fixed `batch`, `z`, `y`, `x` coordinates in the output. Then `(y,x)` is a
+position of a corner of the window within the base area (e.g. the upper left
+corner, depending on how you interpret the spatial dimensions). We now have a 2d
+window, taken from the base area, where each 2d point is associated to a 1d
+vector, so we get a 3d box. From the convolutional kernel, since we fixed the
+output coordinate `z`, we also have a 3d box. The two boxes have the same
+dimensions, so we can take the sum of the element-wise products between the two
+boxes (similar to a dot product). That is the output value.
+
+Note that if `output-z` is e.g., 5, then each position of the window produces 5
+values in the output into the `z` dimension of the output. These values differ
+in what part of the convolutional kernel is used - there is a separate 3d box of
+values used for each `output-z` coordinate. So you could think of it as 5
+separate convolutions with a different filter for each of them.
+
+Here is pseudo-code for a 2d convolution with padding and striding:
+
+```
+for (b, oz, oy, ox) { // output coordinates
+ value = 0;
+ for (iz, ky, kx) { // kernel coordinates and input z
+ iy = oy*stride_y + ky - pad_low_y;
+ ix = ox*stride_x + kx - pad_low_x;
+ if ((iy, ix) inside the base area considered without padding) {
+ value += input(b, iz, iy, ix) * kernel(oz, iz, ky, kx);
+ }
+ }
+ output(b, oz, oy, ox) = value;
+}
+```
+
+## ConvertElementType
+
+See also
+[`XlaBuilder::ConvertElementType`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Similar to an element-wise `static_cast` in C++, performs an element-wise
+conversion operation from a data shape to a target shape. The dimensions must
+match, and the conversion is an element-wise one; e.g. `s32` elements become
+`f32` elements via an `s32`-to-`f32` conversion routine.
+
+<b> `ConvertElementType(operand, new_element_type)` </b>
+
+Arguments | Type | Semantics
+------------------ | --------------- | ---------------------------
+`operand` | `XlaOp` | array of type T with dims D
+`new_element_type` | `PrimitiveType` | type U
+
+The dimensions of the operand and the target shape must match. The source and
+destination element types must not be tuples.
+
+A conversion such as `T=s32` to `U=f32` will perform a normalizing int-to-float
+conversion routine such as round-to-nearest-even.
+
+> Note: The precise float-to-int and visa-versa conversions are currently
+> unspecified, but may become additional arguments to the convert operation in
+> the future. Not all possible conversions have been implemented for all
+>targets.
+
+```
+let a: s32[3] = {0, 1, 2};
+let b: f32[3] = convert(a, f32);
+then b == f32[3]{0.0, 1.0, 2.0}
+```
+
+## CrossReplicaSum
+
+See also
+[`XlaBuilder::CrossReplicaSum`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Computes a sum across replicas.
+
+<b> `CrossReplicaSum(operand)` </b>
+
+Arguments | Type | Semantics
+--------- | ------- | -----------------------------
+`operand` | `XlaOp` | Array to sum across replicas.
+| `replica_group_ids` | `int64` vector | Group ID for each replica. |
+
+The output shape is the same as the input shape. For example, if there are two
+replicas and the operand has the value `(1.0, 2.5)` and `(3.0, 5.25)`
+respectively on the two replicas, then the output value from this op will be
+`(4.0, 7.75)` on both replicas.
+
+`replica_group_ids` identifies the group ID of each replica. The group ID must
+either be empty (all replicas belong to a single group), or contain the same
+number of elements as the number of replicas. For example, if
+`replica_group_ids` = {0, 1, 2, 3, 0, 1, 2, 3} has eight replicas, there are
+four subgroups of replica IDs: {0, 4}, {1, 5}, {2, 6}, and {3, 7}. The size of
+each subgroup *must* be identical, so, for example, using:
+`replica_group_ids` = {0, 1, 2, 0} for four replicas is invalid.
+
+Computing the result of CrossReplicaSum requires having one input from each
+replica, so if one replica executes a CrossReplicaSum node more times than
+another, then the former replica will wait forever. Since the replicas are all
+running the same program, there are not a lot of ways for that to happen, but it
+is possible when a while loop's condition depends on data from infeed and the
+data that is infed causes the while loop to iterate more times on one replica
+than another.
+
+## CustomCall
+
+See also
+[`XlaBuilder::CustomCall`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Call a user-provided function within a computation.
+
+<b> `CustomCall(target_name, args..., shape)` </b>
+
+| Arguments | Type | Semantics |
+| ------------- | ---------------------- | --------------------------------- |
+| `target_name` | `string` | Name of the function. A call |
+: : : instruction will be emitted which :
+: : : targets this symbol name. :
+| `args` | sequence of N `XlaOp`s | N arguments of arbitrary type, |
+: : : which will be passed to the :
+: : : function. :
+| `shape` | `Shape` | Output shape of the function |
+
+The function signature is the same, regardless of the arity or type of args:
+
+```
+extern "C" void target_name(void* out, void** in);
+```
+
+For example, if CustomCall is used as follows:
+
+```
+let x = f32[2] {1,2};
+let y = f32[2x3] {{10, 20, 30}, {40, 50, 60}};
+
+CustomCall("myfunc", {x, y}, f32[3x3])
+```
+
+Here is an example of an implementation of `myfunc`:
+
+```
+extern "C" void myfunc(void* out, void** in) {
+ float (&x)[2] = *static_cast<float(*)[2]>(in[0]);
+ float (&y)[2][3] = *static_cast<float(*)[2][3]>(in[1]);
+ EXPECT_EQ(1, x[0]);
+ EXPECT_EQ(2, x[1]);
+ EXPECT_EQ(10, y[0][0]);
+ EXPECT_EQ(20, y[0][1]);
+ EXPECT_EQ(30, y[0][2]);
+ EXPECT_EQ(40, y[1][0]);
+ EXPECT_EQ(50, y[1][1]);
+ EXPECT_EQ(60, y[1][2]);
+ float (&z)[3][3] = *static_cast<float(*)[3][3]>(out);
+ z[0][0] = x[1] + y[1][0];
+ // ...
+}
+```
+
+The user-provided function must not have side-effects and its execution must be
+idempotent.
+
+> Note: The opaque nature of the user-provided function restricts optimization
+> opportunities for the compiler. Try to express your computation in terms of
+> native XLA ops whenever possible; only use CustomCall as a last resort.
+
+## Dot
+
+See also
+[`XlaBuilder::Dot`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+<b> `Dot(lhs, rhs)` </b>
+
+Arguments | Type | Semantics
+--------- | ------- | ---------------
+`lhs` | `XlaOp` | array of type T
+`rhs` | `XlaOp` | array of type T
+
+The exact semantics of this operation depend on the ranks of the operands:
+
+| Input | Output | Semantics |
+| ----------------------- | --------------------- | ----------------------- |
+| vector [n] `dot` vector | scalar | vector dot product |
+: [n] : : :
+| matrix [m x k] `dot` | vector [m] | matrix-vector |
+: vector [k] : : multiplication :
+| matrix [m x k] `dot` | matrix [m x n] | matrix-matrix |
+: matrix [k x n] : : multiplication :
+
+The operation performs sum of products over the last dimension of `lhs` and the
+one-before-last dimension of `rhs`. These are the "contracted" dimensions. The
+contracted dimensions of `lhs` and `rhs` must be of the same size. In practice,
+it can be used to perform dot products between vectors, vector/matrix
+multiplications or matrix/matrix multiplications.
+
+## DotGeneral
+
+See also
+[`XlaBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+<b> `DotGeneral(lhs, rhs, dimension_numbers)` </b>
+
+Arguments | Type | Semantics
+------------------- | --------------------- | ---------------
+`lhs` | `XlaOp` | array of type T
+`rhs` | `XlaOp` | array of type T
+`dimension_numbers` | `DotDimensionNumbers` | array of type T
+
+As Dot, but allows contracting and batch dimension numbers to be specified for
+both the 'lhs' and 'rhs'.
+
+| DotDimensionNumbers Fields | Type | Semantics
+| --------- | ----------------------- | ---------------
+| 'lhs_contracting_dimensions' | repeated int64 | 'lhs' contracting dimension numbers |
+| 'rhs_contracting_dimensions' | repeated int64 | 'rhs' contracting dimension numbers |
+| 'lhs_batch_dimensions' | repeated int64 | 'lhs' batch dimension numbers |
+| 'rhs_batch_dimensions' | repeated int64 | 'rhs' batch dimension numbers |
+
+DotGeneral performs the sum of products over contracting dimensions specified
+in 'dimension_numbers'.
+
+Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need
+to be the same, but must be listed in the same order in both
+'lhs/rhs_contracting_dimensions' arrays and have the same dimension sizes.
+There must be exactly one contracting dimension on both 'lhs' and 'rhs'.
+
+Example with contracting dimension numbers:
+
+```
+lhs = { {1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0} }
+
+rhs = { {1.0, 1.0, 1.0},
+ {2.0, 2.0, 2.0} }
+
+DotDimensionNumbers dnums;
+dnums.add_lhs_contracting_dimensions(1);
+dnums.add_rhs_contracting_dimensions(1);
+
+DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
+ {15.0, 30.0} }
+```
+
+Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same
+dimension number, must be listed in the same order in both arrays, must
+have the same dimension sizes, and must be ordered before contracting and
+non-contracting/non-batch dimension numbers.
+
+Example with batch dimension numbers (batch size 2, 2x2 matrices):
+
+```
+lhs = { { {1.0, 2.0},
+ {3.0, 4.0} },
+ { {5.0, 6.0},
+ {7.0, 8.0} } }
+
+rhs = { { {1.0, 0.0},
+ {0.0, 1.0} },
+ { {1.0, 0.0},
+ {0.0, 1.0} } }
+
+DotDimensionNumbers dnums;
+dnums.add_lhs_contracting_dimensions(2);
+dnums.add_rhs_contracting_dimensions(1);
+dnums.add_lhs_batch_dimensions(0);
+dnums.add_rhs_batch_dimensions(0);
+
+DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
+ {3.0, 4.0} },
+ { {5.0, 6.0},
+ {7.0, 8.0} } }
+```
+
+| Input | Output | Semantics |
+| ----------------------------------- | ----------------- | ---------------- |
+| [b0, m, k] `dot` [b0, k, n] | [b0, m, n] | batch matmul |
+| [b0, b1, m, k] `dot` [b0, b1, k, n] | [b0, b1, m, n] | batch matmul |
+
+It follows that the resulting dimension number starts with the batch dimension,
+then the 'lhs' non-contracting/non-batch dimension, and finally the 'rhs'
+non-contracting/non-batch dimension.
+
+## DynamicSlice
+
+See also
+[`XlaBuilder::DynamicSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+DynamicSlice extracts a sub-array from the input array at dynamic
+`start_indices`. The size of the slice in each dimension is passed in
+`size_indices`, which specify the end point of exclusive slice intervals in each
+dimension: [start, start + size). The shape of `start_indices` must be rank ==
+1, with dimension size equal to the rank of `operand`.
+
+<b> `DynamicSlice(operand, start_indices, size_indices)` </b>
+
+| Arguments | Type | Semantics |
+| --------------- | ------------------- | ----------------------------------- |
+| `operand` | `XlaOp` | N dimensional array of type T |
+| `start_indices` | `XlaOp` | Rank 1 array of N integers |
+: : : containing the starting indices of :
+: : : the slice for each dimension. Value :
+: : : must be greater than or equal to :
+: : : zero. :
+| `size_indices` | `ArraySlice<int64>` | List of N integers containing the |
+: : : slice size for each dimension. Each :
+: : : value must be strictly greater than :
+: : : zero, and start + size must be less :
+: : : than or equal to the size of the :
+: : : dimension to avoid wrapping modulo :
+: : : dimension size. :
+
+The effective slice indices are computed by applying the following
+transformation for each index `i` in `[1, N)` before performing the slice:
+
+```
+start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - size_indices[i])
+```
+
+This ensures that the extracted slice is always in-bounds with respect to the
+operand array. If the slice is in-bounds before the transformation is applied,
+the transformation has no effect.
+
+1-dimensional example:
+
+```
+let a = {0.0, 1.0, 2.0, 3.0, 4.0}
+let s = {2}
+
+DynamicSlice(a, s, {2}) produces:
+ {2.0, 3.0}
+```
+
+2-dimensional example:
+
+```
+let b =
+ { {0.0, 1.0, 2.0},
+ {3.0, 4.0, 5.0},
+ {6.0, 7.0, 8.0},
+ {9.0, 10.0, 11.0} }
+let s = {2, 1}
+
+DynamicSlice(b, s, {2, 2}) produces:
+ { { 7.0, 8.0},
+ {10.0, 11.0} }
+```
+## DynamicUpdateSlice
+
+See also
+[`XlaBuilder::DynamicUpdateSlice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+DynamicUpdateSlice generates a result which is the value of the input array
+`operand`, with a slice `update` overwritten at `start_indices`.
+The shape of `update` determines the shape of the sub-array of the result which
+is updated.
+The shape of `start_indices` must be rank == 1, with dimension size equal to
+the rank of `operand`.
+
+<b> `DynamicUpdateSlice(operand, update, start_indices)` </b>
+
+| Arguments | Type | Semantics |
+| --------------- | ------- | ------------------------------------------------ |
+| `operand` | `XlaOp` | N dimensional array of type T |
+| `update` | `XlaOp` | N dimensional array of type T containing the |
+: : : slice update. Each dimension of update shape :
+: : : must be strictly greater than zero, and start + :
+: : : update must be less than or equal to the operand :
+: : : size for each dimension to avoid generating :
+: : : out-of-bounds update indices. :
+| `start_indices` | `XlaOp` | Rank 1 array of N integers containing the |
+: : : starting indices of the slice for each :
+: : : dimension. Value must be greater than or equal :
+: : : to zero. :
+
+The effective slice indices are computed by applying the following
+transformation for each index `i` in `[1, N)` before performing the slice:
+
+```
+start_indices[i] = clamp(start_indices[i], 0, operand.dimension_size[i] - update.dimension_size[i])
+```
+
+This ensures that the updated slice is always in-bounds with respect to the
+operand array. If the slice is in-bounds before the transformation is applied,
+the transformation has no effect.
+
+1-dimensional example:
+
+```
+let a = {0.0, 1.0, 2.0, 3.0, 4.0}
+let u = {5.0, 6.0}
+let s = {2}
+
+DynamicUpdateSlice(a, u, s) produces:
+ {0.0, 1.0, 5.0, 6.0, 4.0}
+```
+
+2-dimensional example:
+
+```
+let b =
+ { {0.0, 1.0, 2.0},
+ {3.0, 4.0, 5.0},
+ {6.0, 7.0, 8.0},
+ {9.0, 10.0, 11.0} }
+let u =
+ { {12.0, 13.0},
+ {14.0, 15.0},
+ {16.0, 17.0} }
+
+let s = {1, 1}
+
+DynamicUpdateSlice(b, u, s) produces:
+ { {0.0, 1.0, 2.0},
+ {3.0, 12.0, 13.0},
+ {6.0, 14.0, 15.0},
+ {9.0, 16.0, 17.0} }
+```
+
+## Element-wise binary arithmetic operations
+
+See also
+[`XlaBuilder::Add`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+A set of element-wise binary arithmetic operations is supported.
+
+<b> `Op(lhs, rhs)` </b>
+
+Where `Op` is one of `Add` (addition), `Sub` (subtraction), `Mul`
+(multiplication), `Div` (division), `Rem` (remainder), `Max` (maximum), `Min`
+(minimum), `LogicalAnd` (logical AND), or `LogicalOr` (logical OR).
+
+Arguments | Type | Semantics
+--------- | ------- | ----------------------------------------
+`lhs` | `XlaOp` | left-hand-side operand: array of type T
+`rhs` | `XlaOp` | right-hand-side operand: array of type T
+
+The arguments' shapes have to be either similar or compatible. See the
+[broadcasting](../../performance/xla/broadcasting.md) documentation about what it means for shapes to
+be compatible. The result of an operation has a shape which is the result of
+broadcasting the two input arrays. In this variant, operations between arrays of
+different ranks are *not* supported, unless one of the operands is a scalar.
+
+When `Op` is `Rem`, the sign of the result is taken from the dividend, and the
+absolute value of the result is always less than the divisor's absolute value.
+
+Integer division overflow (signed/unsigned division/remainder by zero or signed
+divison/remainder of `INT_SMIN` with `-1`) produces an implementation defined
+value.
+
+An alternative variant with different-rank broadcasting support exists for these
+operations:
+
+<b> `Op(lhs, rhs, broadcast_dimensions)` </b>
+
+Where `Op` is the same as above. This variant of the operation should be used
+for arithmetic operations between arrays of different ranks (such as adding a
+matrix to a vector).
+
+The additional `broadcast_dimensions` operand is a slice of integers used to
+expand the rank of the lower-rank operand up to the rank of the higher-rank
+operand. `broadcast_dimensions` maps the dimensions of the lower-rank shape to
+the dimensions of the higher-rank shape. The unmapped dimensions of the expanded
+shape are filled with dimensions of size one. Degenerate-dimension broadcasting
+then broadcasts the shapes along these degenerate dimensions to equalize the
+shapes of both operands. The semantics are described in detail on the
+[broadcasting page](../../performance/xla/broadcasting.md).
+
+## Element-wise comparison operations
+
+See also
+[`XlaBuilder::Eq`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+A set of standard element-wise binary comparison operations is supported. Note
+that standard IEEE 754 floating-point comparison semantics apply when comparing
+floating-point types.
+
+<b> `Op(lhs, rhs)` </b>
+
+Where `Op` is one of `Eq` (equal-to), `Ne` (not equal-to), `Ge`
+(greater-or-equal-than), `Gt` (greater-than), `Le` (less-or-equal-than), `Lt`
+(less-than).
+
+Arguments | Type | Semantics
+--------- | ------- | ----------------------------------------
+`lhs` | `XlaOp` | left-hand-side operand: array of type T
+`rhs` | `XlaOp` | right-hand-side operand: array of type T
+
+The arguments' shapes have to be either similar or compatible. See the
+[broadcasting](../../performance/xla/broadcasting.md) documentation about what it means for shapes to
+be compatible. The result of an operation has a shape which is the result of
+broadcasting the two input arrays with the element type `PRED`. In this variant,
+operations between arrays of different ranks are *not* supported, unless one of
+the operands is a scalar.
+
+An alternative variant with different-rank broadcasting support exists for these
+operations:
+
+<b> `Op(lhs, rhs, broadcast_dimensions)` </b>
+
+Where `Op` is the same as above. This variant of the operation should be used
+for comparison operations between arrays of different ranks (such as adding a
+matrix to a vector).
+
+The additional `broadcast_dimensions` operand is a slice of integers specifying
+the dimensions to use for broadcasting the operands. The semantics are described
+in detail on the [broadcasting page](../../performance/xla/broadcasting.md).
+
+## Element-wise unary functions
+
+XlaBuilder supports these element-wise unary functions:
+
+<b>`Abs(operand)`</b> Element-wise abs `x -> |x|`.
+
+<b>`Ceil(operand)`</b> Element-wise ceil `x -> ⌈x⌉`.
+
+<b>`Cos(operand)`</b> Element-wise cosine `x -> cos(x)`.
+
+<b>`Exp(operand)`</b> Element-wise natural exponential `x -> e^x`.
+
+<b>`Floor(operand)`</b> Element-wise floor `x -> ⌊x⌋`.
+
+<b>`IsFinite(operand)`</b> Tests whether each element of `operand` is finite,
+i.e., is not positive or negative infinity, and is not `NaN`. Returns an array
+of `PRED` values with the same shape as the input, where each element is `true`
+if and only if the corresponding input element is finite.
+
+<b>`Log(operand)`</b> Element-wise natural logarithm `x -> ln(x)`.
+
+<b>`LogicalNot(operand)`</b> Element-wise logical not `x -> !(x)`.
+
+<b>`Neg(operand)`</b> Element-wise negation `x -> -x`.
+
+<b>`Sign(operand)`</b> Element-wise sign operation `x -> sgn(x)` where
+
+$$\text{sgn}(x) = \begin{cases} -1 & x < 0\\ 0 & x = 0\\ 1 & x > 0 \end{cases}$$
+
+using the comparison operator of the element type of `operand`.
+
+<b>`Tanh(operand)`</b> Element-wise hyperbolic tangent `x -> tanh(x)`.
+
+
+Arguments | Type | Semantics
+--------- | ------- | ---------------------------
+`operand` | `XlaOp` | The operand to the function
+
+The function is applied to each element in the `operand` array, resulting in an
+array with the same shape. It is allowed for `operand` to be a scalar (rank 0).
+
+## Gather
+
+The XLA gather operation stitches together several slices (each slice at a
+potentially different runtime offset) of an input array.
+
+### General Semantics
+
+See also
+[`XlaBuilder::Gather`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+For a more intuitive description, see the "Informal Description" section below.
+
+<b> `gather(operand, start_indices, offset_dims, collapsed_slice_dims, slice_sizes, start_index_map)` </b>
+
+|Arguments | Type | Semantics |
+|----------------- | ----------------------- | --------------------------------|
+|`operand` | `XlaOp` | The array we’re gathering |
+: : : from. :
+|`start_indices` | `XlaOp` | Array containing the starting |
+: : : indices of the slices we gather.:
+|`index_vector_dim` | `int64` | The dimension in |
+: : : `start_indices` that "contains" :
+: : : the starting indices. See :
+: : : below for a detailed :
+: : : description. :
+|`offset_dims` | `ArraySlice<int64>` | The set of dimensions in the :
+: : : output shape that offset into a :
+: : : array sliced from operand. :
+|`slice_sizes` | `ArraySlice<int64>` | `slice_sizes[i]` is the bounds |
+: : : for the slice on dimension `i`.:
+|`collapsed_slice_dims` | `ArraySlice<int64>` | The set of dimensions in each :
+| : | slice that are collapsed away. :
+| : | These dimensions must have size:
+| : | 1. |
+|`start_index_map` | `ArraySlice<int64>` | A map that describes how to map|
+: : : indices in `start_indices` to :
+: : : to legal indices into operand. :
+
+For convenience, we label dimensions in the output array not in `offset_dims`
+as `batch_dims`.
+
+The output is an array of rank `batch_dims.size` + `operand.rank` -
+`collapsed_slice_dims`.size.
+
+If `index_vector_dim` is equal to `start_indices.rank` we implicitly consider
+`start_indices` to have a trailing `1` dimension (i.e. if `start_indices` was of
+shape `[6,7]` and `index_vector_dim` is `2` then we implicitly consider the
+shape of `start_indices` to be `[6,7,1]`).
+
+The bounds for the output array along dimension `i` is computed as follows:
+
+ 1. If `i` is present in `batch_dims` (i.e. is equal to `batch_dims[k]` for
+ some `k`) then we pick the corresponding dimension bounds out of
+ `start_indices.shape`, skipping `index_vector_dim` (i.e. pick
+ `start_indices.shape.dims`[`k`] if `k` < `index_vector_dim` and
+ `start_indices.shape.dims`[`k`+`1`] otherwise).
+
+ 2. If `i` is present in `offset_dims` (i.e. equal to `offset_dims`[`k`] for
+ some `k`) then we pick the corresponding bound out of `slice_sizes` after
+ accounting for `collapsed_slice_dims` (i.e. we pick
+ `adjusted_slice_sizes`[`k`] where `adjusted_slice_sizes` is `slice_sizes`
+ with the bounds at indices `collapsed_slice_dims` removed).
+
+Formally, the operand index `In` corresponding to an output index `Out` is
+computed as follows:
+
+ 1. Let `G` = { `Out`[`k`] for `k` in `batch_dims` }. Use `G` to slice out
+ vector `S` such that `S`[`i`] = `start_indices`[Combine(`G`, `i`)] where
+ Combine(A, b) inserts b at position `index_vector_dim` into A. Note that
+ this is well defined even if `G` is empty -- if `G` is empty then `S` =
+ `start_indices`.
+
+ 2. Create a starting index, `S`<sub>`in`</sub>, into `operand` using `S` by
+ scattering `S` using `start_index_map`. More precisely:
+ 1. `S`<sub>`in`</sub>[`start_index_map`[`k`]] = `S`[`k`] if `k` <
+ `start_index_map.size`.
+ 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
+
+ 3. Create an index `O`<sub>`in`</sub> into `operand` by scattering the indices
+ at the offset dimensions in `Out` according to the `collapsed_slice_dims`
+ set. More precisely:
+ 1. `O`<sub>`in`</sub>[`expand_offset_dims`(`k`)] =
+ `Out`[`offset_dims`[`k`]] if `k` < `offset_dims.size`
+ (`expand_offset_dims` is defined below).
+ 2. `O`<sub>`in`</sub>[`_`] = `0` otherwise.
+ 4. `In` is `O`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
+ addition.
+
+`expand_offset_dims` is the monotonic function with domain [`0`, `offset.size`)
+and range [`0`, `operand.rank`) \ `collapsed_slice_dims`. So if, e.g.,
+`offset.size` is `4`, `operand.rank` is `6` and `collapsed_slice_dims` is {`0`,
+`2`} then `expand_offset_dims` is {`0`→`1`, `1`→`3`, `2`→`4`, `3`→`5`}.
+
+### Informal Description and Examples
+
+Informally, every index `Out` in the output array corresponds to an element `E`
+in the operand array, computed as follows:
+
+ - We use the batch dimensions in `Out` to look up a starting index from
+ `start_indices`.
+
+ - We use `start_index_map` to map the starting index (which may have size less
+ than operand.rank) to a "full" starting index into operand.
+
+ - We dynamic-slice out a slice with size `slice_sizes` using the full starting
+ index.
+
+ - We reshape the slice by collapsing the `collapsed_slice_dims` dimensions.
+ Since all collapsed slice dimensions have to have bound 1 this reshape is
+ always legal.
+
+ - We use the offset dimensions in `Out` to index into this slice to get the
+ input element, `E`, corresponding to output index `Out`.
+
+`index_vector_dim` is set to `start_indices.rank` - `1` in all of the
+examples that follow. More interesting values for `index_vector_dim` does not
+change the operation fundamentally, but makes the visual representation more
+cumbersome.
+
+To get an intuition on how all of the above fits together, let's look at an
+example that gathers 5 slices of shape `[8,6]` from a `[16,11]` array. The
+position of a slice into the `[16,11]` array can be represented as an index
+vector of shape `S64[2]`, so the set of 5 positions can be represented as a
+`S64[5,2]` array.
+
+The behavior of the gather operation can then be depicted as an index
+transformation that takes [`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>], an index in
+the output shape, and maps it to an element in the input array in the following
+way:
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src="../../images/ops_xla_gather_0.svg">
+</div>
+
+We first select an (`X`,`Y`) vector from the gather indices array using `G`.
+The element in the output array at index
+[`G`,`O`<sub>`0`</sub>,`O`<sub>`1`</sub>] is then the element in the input
+array at index [`X`+`O`<sub>`0`</sub>,`Y`+`O`<sub>`1`</sub>].
+
+`slice_sizes` is `[8,6]`, which decides the range of W<sub>`0`</sub> and
+W<sub>`1`</sub>, and this in turn decides the bounds of the slice.
+
+This gather operation acts as a batch dynamic slice with `G` as the batch
+dimension.
+
+The gather indices may be multidimensional. For instance, a more general
+version of the example above using a "gather indices" array of shape `[4,5,2]`
+would translate indices like this:
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src="../../images/ops_xla_gather_1.svg">
+</div>
+
+Again, this acts as a batch dynamic slice `G`<sub>`0`</sub> and
+`G`<sub>`1`</sub> as the batch dimensions. The slice size is still `[8,6]`.
+
+The gather operation in XLA generalizes the informal semantics outlined above in
+the following ways:
+
+ 1. We can configure which dimensions in the output shape are the offset
+ dimensions (dimensions containing `O`<sub>`0`</sub>, `O`<sub>`1`</sub> in
+ the last example). The output batch dimensions (dimensions containing
+ `G`<sub>`0`</sub>, `G`<sub>`1`</sub> in the last example) are defined to be
+ the output dimensions that are not offset dimensions.
+
+ 2. The number of output offset dimensions explicitly present in the output
+ shape may be smaller than the input rank. These "missing" dimensions, which
+ are listed explicitly as `collapsed_slice_dims`, must have a slice size of
+ `1`. Since they have a slice size of `1` the only valid index for them is
+ `0` and eliding them does not introduce ambiguity.
+
+ 3. The slice extracted from the "Gather Indices" array ((`X`, `Y`) in the last
+ example) may have fewer elements than the input array rank, and an explicit
+ mapping dictates how the index should be expanded to have the same rank as
+ the input.
+
+As a final example, we use (2) and (3) to implement `tf.gather_nd`:
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src="../../images/ops_xla_gather_2.svg">
+</div>
+
+`G`<sub>`0`</sub> and `G`<sub>`1`</sub> are used to slice out a starting index
+from the gather indices array as usual, except the starting index has only one
+element, `X`. Similarly, there is only one output offset index with the value
+`O`<sub>`0`</sub>. However, before being used as indices into the input array,
+these are expanded in accordance to "Gather Index Mapping" (`start_index_map` in
+the formal description) and "Offset Mapping" (`expand_offset_dims` in the formal
+description) into [`0`,`O`<sub>`0`</sub>] and [`X`,`0`] respectively, adding up
+to [`X`,`O`<sub>`0`</sub>]. In other words, the output index
+[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`O`<sub>`0`</sub>] maps to the input index
+[`GatherIndices`[`G`<sub>`0`</sub>,`G`<sub>`1`</sub>,`0`],`X`] which gives us
+the semantics for `tf.gather_nd`.
+
+`slice_sizes` for this case is `[1,11]`. Intuitively this means that every
+index `X` in the gather indices array picks an entire row and the result is the
+concatenation of all these rows.
+
+## GetTupleElement
+
+See also
+[`XlaBuilder::GetTupleElement`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Indexes into a tuple with a compile-time-constant value.
+
+The value must be a compile-time-constant so that shape inference can determine
+the type of the resulting value.
+
+This is analogous to `std::get<int N>(t)` in C++. Conceptually:
+
+```
+let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+let s: s32 = 5;
+let t: (f32[10], s32) = tuple(v, s);
+let element_1: s32 = gettupleelement(t, 1); // Inferred shape matches s32.
+```
+
+See also `tf.tuple`.
+
+## Infeed
+
+See also
+[`XlaBuilder::Infeed`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+<b> `Infeed(shape)` </b>
+
+| Argument | Type | Semantics |
+| -------- | ------- | ----------------------------------------------------- |
+| `shape` | `Shape` | Shape of the data read from the Infeed interface. The |
+: : : layout field of the shape must be set to match the :
+: : : layout of the data sent to the device; otherwise its :
+: : : behavior is undefined. :
+
+Reads a single data item from the implicit Infeed streaming interface of the
+device, interpreting the data as the given shape and its layout, and returns a
+`XlaOp` of the data. Multiple Infeed operations are allowed in a
+computation, but there must be a total order among the Infeed operations. For
+example, two Infeeds in the code below have a total order since there is a
+dependency between the while loops.
+
+```
+result1 = while (condition, init = init_value) {
+ Infeed(shape)
+}
+
+result2 = while (condition, init = result1) {
+ Infeed(shape)
+}
+```
+
+Nested tuple shapes are not supported. For an empty tuple shape, the Infeed
+operation is effectively a no-op and proceeds without reading any data from the
+Infeed of the device.
+
+> Note: We plan to allow multiple Infeed operations without a total order, in
+> which case the compiler will provide information about how the Infeed
+> operations are serialized in the compiled program.
+
+## Iota
+
+<b> `Iota()` </b>
+
+Builds a constant literal on device rather than a potentially large host
+transfer. Creates a rank 1 tensor of values starting at zero and incrementing
+by one.
+
+Arguments | Type | Semantics
+------------------ | --------------- | ---------------------------
+`type` | `PrimitiveType` | type U
+`size` | `int64` | The number of elements in the tensor.
+
+## Map
+
+See also
+[`XlaBuilder::Map`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+<b> `Map(operands..., computation)` </b>
+
+| Arguments | Type | Semantics |
+| ----------------- | ---------------------- | ------------------------------ |
+| `operands` | sequence of N `XlaOp`s | N arrays of types T_0..T_{N-1} |
+| `computation` | `XlaComputation` | computation of type `T_0, T_1, |
+: : : ..., T_{N + M -1} -> S` with N :
+: : : parameters of type T and M of :
+: : : arbitrary type :
+| `dimensions` | `int64` array | array of map dimensions |
+
+Applies a scalar function over the given `operands` arrays, producing an array
+of the same dimensions where each element is the result of the mapped function
+applied to the corresponding elements in the input arrays.
+
+The mapped function is an arbitrary computation with the restriction that it has
+N inputs of scalar type `T` and a single output with type `S`. The output has
+the same dimensions as the operands except that the element type T is replaced
+with S.
+
+For example: `Map(op1, op2, op3, computation, par1)` maps `elem_out <-
+computation(elem1, elem2, elem3, par1)` at each (multi-dimensional) index in the
+input arrays to produce the output array.
+
+## Pad
+
+See also
+[`XlaBuilder::Pad`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+<b> `Pad(operand, padding_value, padding_config)` </b>
+
+| Arguments | Type | Semantics |
+| ---------------- | --------------- | --------------------------------------- |
+| `operand` | `XlaOp` | array of type `T` |
+| `padding_value` | `XlaOp` | scalar of type `T` to fill in the added |
+: : : padding :
+| `padding_config` | `PaddingConfig` | padding amount on both edges (low, |
+: : : high) and between the elements of each :
+: : : dimension :
+
+Expands the given `operand` array by padding around the array as well as between
+the elements of the array with the given `padding_value`. `padding_config`
+specifies the amount of edge padding and the interior padding for each
+dimension.
+
+`PaddingConfig` is a repeated field of `PaddingConfigDimension`, which contains
+three fields for each dimension: `edge_padding_low`, `edge_padding_high`, and
+`interior_padding`. `edge_padding_low` and `edge_padding_high` specify the
+amount of padding added at the low-end (next to index 0) and the high-end (next
+to the highest index) of each dimension respectively. The amount of edge padding
+can be negative -- the absolute value of negative padding indicates the number
+of elements to remove from the specified dimension. `interior_padding` specifies
+the amount of padding added between any two elements in each dimension. Interior
+padding occurs logically before edge padding, so in the case of negative edge
+padding elements are removed from the interior-padded operand. This operation is
+a no-op if the edge padding pairs are all (0, 0) and the interior padding values
+are all 0. The figure below shows examples of different `edge_padding` and
+`interior_padding` values for a two-dimensional array.
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src="https://www.tensorflow.org/images/ops_pad.png">
+</div>
+
+## Recv
+
+See also
+[`XlaBuilder::Recv`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+<b> `Recv(shape, channel_handle)` </b>
+
+| Arguments | Type | Semantics |
+| ---------------- | --------------- | ------------------------------------ |
+| `shape` | `Shape` | shape of the data to receive |
+| `channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair |
+
+Receives data of the given shape from a `Send` instruction in another
+computation that shares the same channel handle. Returns a
+XlaOp for the received data.
+
+The client API of `Recv` operation represents synchronous communication.
+However, the instruction is internally decomposed into 2 HLO instructions
+(`Recv` and `RecvDone`) to enable asynchronous data transfers. See also
+[`HloInstruction::CreateRecv` and `HloInstruction::CreateRecvDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
+
+<b>`Recv(const Shape& shape, int64 channel_id)`</b>
+
+Allocates resources required to receive data from a `Send` instruction with the
+same channel_id. Returns a context for the allocated resources, which is used
+by a following `RecvDone` instruction to wait for the completion of the data
+transfer. The context is a tuple of {receive buffer (shape), request identifier
+(U32)} and it can only be used by a `RecvDone` instruction.
+
+<b> `RecvDone(HloInstruction context)` </b>
+
+Given a context created by a `Recv` instruction, waits for the data transfer to
+complete and returns the received data.
+
+## Reduce
+
+See also
+[`XlaBuilder::Reduce`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Applies a reduction function to one or more arrays in parallel.
+
+<b> `Reduce(operands..., init_values..., computation, dimensions)` </b>
+
+Arguments | Type | Semantics
+------------- | --------------------- | ---------------------------------------
+`operands` | Sequence of N `XlaOp` | N arrays of types `T_0, ..., T_N`.
+`init_values` | Sequence of N `XlaOp` | N scalars of types `T_0, ..., T_N`.
+`computation` | `XlaComputation` | computation of type
+ : : `T_0, ..., T_N, T_0, ..., T_N -> Collate(T_0, ..., T_N)`
+`dimensions` | `int64` array | unordered array of dimensions to reduce
+
+Where:
+* N is required to be greater or equal to 1.
+* All input arrays must have the same dimensions.
+* If `N = 1`, `Collate(T)` is `T`.
+* If `N > 1`, `Collate(T_0, ..., T_N)` is a tuple of `N` elements of type `T`.
+
+The output of the op is `Collate(Q_0, ..., Q_N)` where `Q_i` is an array of type
+`T_i`, the dimensions of which are described below.
+
+This operation reduces one or more dimensions of each input array into scalars.
+The rank of each returned array is `rank(operand) - len(dimensions)`.
+`init_value` is the initial value used for every reduction and may be inserted
+anywhere during computation by the back-end. In most cases, `init_value` is an
+identity of the reduction function (for example, 0 for addition). The applied
+`computation` is always passed the `init_value` on the left-hand side.
+
+The evaluation order of the reduction function is arbitrary and may be
+non-deterministic. Therefore, the reduction function should not be overly
+sensitive to reassociation.
+
+Some reduction functions like addition are not strictly associative for floats.
+However, if the range of the data is limited, floating-point addition is close
+enough to being associative for most practical uses. It is possible to conceive
+of some completely non-associative reductions, however, and these will produce
+incorrect or unpredictable results in XLA reductions.
+
+As an example, when reducing across one dimension in a single 1D array with
+values [10, 11, 12, 13], with reduction function `f` (this is `computation`)
+then that could be computed as
+
+`f(10, f(11, f(12, f(init_value, 13)))`
+
+but there are also many other possibilities, e.g.
+
+`f(init_value, f(f(10, f(init_value, 11)), f(f(init_value, 12), f(init_value, 13))))`
+
+The following is a rough pseudo-code example of how reduction could be
+implemented, using summation as the reduction computation with an initial value
+of 0.
+
+```python
+result_shape <- remove all dims in dimensions from operand_shape
+
+# Iterate over all elements in result_shape. The number of r's here is equal
+# to the rank of the result
+for r0 in range(result_shape[0]), r1 in range(result_shape[1]), ...:
+ # Initialize this result element
+ result[r0, r1...] <- 0
+
+ # Iterate over all the reduction dimensions
+ for d0 in range(dimensions[0]), d1 in range(dimensions[1]), ...:
+ # Increment the result element with the value of the operand's element.
+ # The index of the operand's element is constructed from all ri's and di's
+ # in the right order (by construction ri's and di's together index over the
+ # whole operand shape).
+ result[r0, r1...] += operand[ri... di]
+```
+
+Here's an example of reducing a 2D array (matrix). The shape has rank 2,
+dimension 0 of size 2 and dimension 1 of size 3:
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:35%" src="https://www.tensorflow.org/images/ops_2d_matrix.png">
+</div>
+
+Results of reducing dimensions 0 or 1 with an "add" function:
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:35%" src="https://www.tensorflow.org/images/ops_reduce_from_2d_matrix.png">
+</div>
+
+Note that both reduction results are 1D arrays. The diagram shows one as column
+and another as row just for visual convenience.
+
+For a more complex example, here is a 3D array. Its rank is 3, dimension 0 of
+size 4, dimension 1 of size 2 and dimension 2 of size 3. For simplicity, the
+values 1 to 6 are replicated across dimension 0.
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:35%" src="https://www.tensorflow.org/images/ops_reduce_from_3d_matrix.png">
+</div>
+
+Similarly to the 2D example, we can reduce just one dimension. If we reduce
+dimension 0, for example, we get a rank-2 array where all values across
+dimension 0 were folded into a scalar:
+
+```text
+| 4 8 12 |
+| 16 20 24 |
+```
+
+If we reduce dimension 2, we also get a rank-2 array where all values across
+dimension 2 were folded into a scalar:
+
+```text
+| 6 15 |
+| 6 15 |
+| 6 15 |
+| 6 15 |
+```
+
+Note that the relative order between the remaining dimensions in the input is
+preserved in the output, but some dimensions may get assigned new numbers (since
+the rank changes).
+
+We can also reduce multiple dimensions. Add-reducing dimensions 0 and 1 produces
+the 1D array `| 20 28 36 |`.
+
+Reducing the 3D array over all its dimensions produces the scalar `84`.
+
+When `N > 1`, reduce function application is slightly more complex, as it is
+applied simultaneously to all inputs. For example, consider the following
+reduction function, which can be used to compute the max and the argmax of a
+a 1-D tensor in parallel:
+
+```
+f: (Float, Int, Float, Int) -> Float, Int
+f(max, argmax, value, index):
+ if value >= argmax:
+ return (value, index)
+ else:
+ return (max, argmax)
+```
+
+For 1-D Input arrays `V = Float[N], K = Int[N]`, and init values
+`I_V = Float, I_K = Int`, the result `f_(N-1)` of reducing across the only
+input dimension is equivalent to the following recursive application:
+```
+f_0 = f(I_V, I_K, V_0, K_0)
+f_1 = f(f_0.first, f_0.second, V_1, K_1)
+...
+f_(N-1) = f(f_(N-2).first, f_(N-2).second, V_(N-1), K_(N-1))
+```
+
+Applying this reduction to an array of values, and an array of sequential
+indices (i.e. iota), will co-iterate over the arrays, and return a tuple
+containing the maximal value and the matching index.
+
+## ReducePrecision
+
+See also
+[`XlaBuilder::ReducePrecision`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Models the effect of converting floating-point values to a lower-precision
+format (such as IEEE-FP16) and back to the original format. The number of
+exponent and mantissa bits in the lower-precision format can be specified
+arbitrarily, although all bit sizes may not be supported on all hardware
+implementations.
+
+<b> `ReducePrecision(operand, mantissa_bits, exponent_bits)` </b>
+
+Arguments | Type | Semantics
+--------------- | ------- | -------------------------------------------------
+`operand` | `XlaOp` | array of floating-point type `T`.
+`exponent_bits` | `int32` | number of exponent bits in lower-precision format
+`mantissa_bits` | `int32` | number of mantissa bits in lower-precision format
+
+The result is an array of type `T`. The input values are rounded to the nearest
+value representable with the given number of mantissa bits (using "ties to even"
+semantics), and any values that exceed the range specified by the number of
+exponent bits are clamped to positive or negative infinity. `NaN` values are
+retained, although they may be converted to canonical `NaN` values.
+
+The lower-precision format must have at least one exponent bit (in order to
+distinguish a zero value from an infinity, since both have a zero mantissa), and
+must have a non-negative number of mantissa bits. The number of exponent or
+mantissa bits may exceed the corresponding value for type `T`; the corresponding
+portion of the conversion is then simply a no-op.
+
+## ReduceWindow
+
+See also
+[`XlaBuilder::ReduceWindow`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Applies a reduction function to all elements in each window of the input
+multi-dimensional array, producing an output multi-dimensional array with the
+same number of elements as the number of valid positions of the window. A
+pooling layer can be expressed as a `ReduceWindow`. Similar to
+[`Reduce`](#reduce), the applied `computation` is always passed the `init_value`
+on the left-hand side.
+
+<b> `ReduceWindow(operand, init_value, computation, window_dimensions,
+window_strides, padding)` </b>
+
+| Arguments | Type | Semantics |
+| ------------------- | ------------------- | -------------------------------- |
+| `operand` | `XlaOp` | N dimensional array containing |
+: : : elements of type T. This is the :
+: : : base area on which the window is :
+: : : placed. :
+| `init_value` | `XlaOp` | Starting value for the |
+: : : reduction. See [Reduce](#reduce) :
+: : : for details. :
+| `computation` | `XlaComputation` | Reduction function of type `T, T |
+: : : -> T`, to apply to all elements :
+: : : in each window :
+| `window_dimensions` | `ArraySlice<int64>` | array of integers for window |
+: : : dimension values :
+| `window_strides` | `ArraySlice<int64>` | array of integers for window |
+: : : stride values :
+| `padding` | `Padding` | padding type for window |
+: : : (Padding\:\:kSame or :
+: : : Padding\:\:kValid) :
+
+Below code and figure shows an example of using `ReduceWindow`. Input is a
+matrix of size [4x6] and both window_dimensions and window_stride_dimensions are
+[2x3].
+
+```
+// Create a computation for the reduction (maximum).
+XlaComputation max;
+{
+ XlaBuilder builder(client_, "max");
+ auto y = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y");
+ auto x = builder.Parameter(1, ShapeUtil::MakeShape(F32, {}), "x");
+ builder.Max(y, x);
+ max = builder.Build().ConsumeValueOrDie();
+}
+
+// Create a ReduceWindow computation with the max reduction computation.
+XlaBuilder builder(client_, "reduce_window_2x3");
+auto shape = ShapeUtil::MakeShape(F32, {4, 6});
+auto input = builder.Parameter(0, shape, "input");
+builder.ReduceWindow(
+ input, *max,
+ /*init_val=*/builder.ConstantLiteral(LiteralUtil::MinValue(F32)),
+ /*window_dimensions=*/{2, 3},
+ /*window_stride_dimensions=*/{2, 3},
+ Padding::kValid);
+```
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:35%" src="https://www.tensorflow.org/images/ops_reduce_window.png">
+</div>
+
+Stride of 1 in a dimension specifies that the position of a window in the
+dimension is 1 element away from its adjacent window. In order to specify that
+no windows overlap with each other, window_stride_dimensions should be equal to
+window_dimensions. The figure below illustrates the use of two different stride
+values. Padding is applied to each dimension of the input and the calculations
+are the same as though the input came in with the dimensions it has after
+padding.
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:75%" src="https://www.tensorflow.org/images/ops_reduce_window_stride.png">
+</div>
+
+The evaluation order of the reduction function is arbitrary and may be
+non-deterministic. Therefore, the reduction function should not be overly
+sensitive to reassociation. See the discussion about associativity in the
+context of [`Reduce`](#reduce) for more details.
+
+## Reshape
+
+See also
+[`XlaBuilder::Reshape`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h)
+and the [`Collapse`](#collapse) operation.
+
+Reshapes the dimensions of an array into a new configuration.
+
+<b> `Reshape(operand, new_sizes)` </b>
+<b> `Reshape(operand, dimensions, new_sizes)` </b>
+
+Arguments | Type | Semantics
+------------ | -------------- | ---------------------------------------
+`operand` | `XlaOp` | array of type T
+`dimensions` | `int64` vector | order in which dimensions are collapsed
+`new_sizes` | `int64` vector | vector of sizes of new dimensions
+
+Conceptually, reshape first flattens an array into a one-dimensional vector of
+data values, and then refines this vector into a new shape. The input arguments
+are an arbitrary array of type T, a compile-time-constant vector of dimension
+indices, and a compile-time-constant vector of dimension sizes for the result.
+The values in the `dimension` vector, if given, must be a permutation of all of
+T's dimensions; the default if not given is `{0, ..., rank - 1}`. The order of
+the dimensions in `dimensions` is from slowest-varying dimension (most major) to
+fastest-varying dimension (most minor) in the loop nest which collapses the
+input array into a single dimension. The `new_sizes` vector determines the size
+of the output array. The value at index 0 in `new_sizes` is the size of
+dimension 0, the value at index 1 is the size of dimension 1, and so on. The
+product of the `new_size` dimensions must equal the product of the operand's
+dimension sizes. When refining the collapsed array into the multidimensional
+array defined by `new_sizes`, the dimensions in `new_sizes` are ordered from
+slowest varying (most major) and to fastest varying (most minor).
+
+For example, let v be an array of 24 elements:
+
+```
+let v = f32[4x2x3] {{{10, 11, 12}, {15, 16, 17}},
+ {{20, 21, 22}, {25, 26, 27}},
+ {{30, 31, 32}, {35, 36, 37}},
+ {{40, 41, 42}, {45, 46, 47}}};
+
+In-order collapse:
+let v012_24 = Reshape(v, {0,1,2}, {24});
+then v012_24 == f32[24] {10, 11, 12, 15, 16, 17, 20, 21, 22, 25, 26, 27,
+ 30, 31, 32, 35, 36, 37, 40, 41, 42, 45, 46, 47};
+
+let v012_83 = Reshape(v, {0,1,2}, {8,3});
+then v012_83 == f32[8x3] {{10, 11, 12}, {15, 16, 17},
+ {20, 21, 22}, {25, 26, 27},
+ {30, 31, 32}, {35, 36, 37},
+ {40, 41, 42}, {45, 46, 47}};
+
+Out-of-order collapse:
+let v021_24 = Reshape(v, {1,2,0}, {24});
+then v012_24 == f32[24] {10, 20, 30, 40, 11, 21, 31, 41, 12, 22, 32, 42,
+ 15, 25, 35, 45, 16, 26, 36, 46, 17, 27, 37, 47};
+
+let v021_83 = Reshape(v, {1,2,0}, {8,3});
+then v021_83 == f32[8x3] {{10, 20, 30}, {40, 11, 21},
+ {31, 41, 12}, {22, 32, 42},
+ {15, 25, 35}, {45, 16, 26},
+ {36, 46, 17}, {27, 37, 47}};
+
+
+let v021_262 = Reshape(v, {1,2,0}, {2,6,2});
+then v021_262 == f32[2x6x2] {{{10, 20}, {30, 40},
+ {11, 21}, {31, 41},
+ {12, 22}, {32, 42}},
+ {{15, 25}, {35, 45},
+ {16, 26}, {36, 46},
+ {17, 27}, {37, 47}}};
+```
+
+As a special case, reshape can transform a single-element array to a scalar and
+vice versa. For example,
+
+```
+Reshape(f32[1x1] {{5}}, {0,1}, {}) == 5;
+Reshape(5, {}, {1,1}) == f32[1x1] {{5}};
+```
+
+## Rev (reverse)
+
+See also
+[`XlaBuilder::Rev`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+<b>`Rev(operand, dimensions)`</b>
+
+Arguments | Type | Semantics
+------------ | ------------------- | ---------------------
+`operand` | `XlaOp` | array of type T
+`dimensions` | `ArraySlice<int64>` | dimensions to reverse
+
+Reverses the order of elements in the `operand` array along the specified
+`dimensions`, generating an output array of the same shape. Each element of the
+operand array at a multidimensional index is stored into the output array at a
+transformed index. The multidimensional index is transformed by reversing the
+index in each dimension to be reversed (i.e., if a dimension of size N is one of
+the reversing dimensions, its index i is transformed into N - 1 - i).
+
+One use for the `Rev` operation is to reverse the convolution weight array along
+the two window dimensions during the gradient computation in neural networks.
+
+## RngNormal
+
+See also
+[`XlaBuilder::RngNormal`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Constructs an output of a given shape with random numbers generated following
+the $$N(\mu, \sigma)$$ normal distribution. The parameters $$\mu$$ and
+$$\sigma$$, and output shape have to have a floating point elemental type. The
+parameters furthermore have to be scalar valued.
+
+<b>`RngNormal(mu, sigma, shape)`</b>
+
+| Arguments | Type | Semantics |
+| --------- | ------- | --------------------------------------------------- |
+| `mu` | `XlaOp` | Scalar of type T specifying mean of generated |
+: : : numbers :
+| `sigma` | `XlaOp` | Scalar of type T specifying standard deviation of |
+: : : generated numbers :
+| `shape` | `Shape` | Output shape of type T |
+
+## RngUniform
+
+See also
+[`XlaBuilder::RngUniform`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Constructs an output of a given shape with random numbers generated following
+the uniform distribution over the interval $$[a,b)$$. The parameters and output
+element type have to be a boolean type, an integral type or a floating point
+types, and the types have to be consistent. The CPU and GPU backends currently
+only support F64, F32, F16, BF16, S64, U64, S32 and U32. Furthermore, the
+parameters need to be scalar valued. If $$b <= a$$ the result is
+implementation-defined.
+
+<b>`RngUniform(a, b, shape)`</b>
+
+| Arguments | Type | Semantics |
+| --------- | ----------------------- | --------------------------------- |
+| `a` | `XlaOp` | Scalar of type T specifying lower |
+: : : limit of interval :
+| `b` | `XlaOp` | Scalar of type T specifying upper |
+: : : limit of interval :
+| `shape` | `Shape` | Output shape of type T |
+
+## Scatter
+
+The XLA scatter operation generates a result which is the value of the input
+tensor `operand`, with several slices (at indices specified by
+`scatter_indices`) updated with the values in `updates` using
+`update_computation`.
+
+See also
+[`XlaBuilder::Scatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+<b> `scatter(operand, scatter_indices, updates, update_computation, index_vector_dim, update_window_dims, inserted_window_dims, scatter_dims_to_operand_dims)` </b>
+
+|Arguments | Type | Semantics |
+|------------------|------------------------|----------------------------------|
+|`operand` | `XlaOp` | Tensor to be scattered into. |
+|`scatter_indices` | `XlaOp` | Tensor containing the starting |
+: : : indices of the slices that must :
+: : : be scattered to. :
+|`updates` | `XlaOp` | Tensor containing the values that|
+: : : must be used for scattering. :
+|`update_computation`| `XlaComputation` | Computation to be used for |
+: : : combining the existing values in :
+: : : the input tensor and the updates :
+: : : during scatter. This computation :
+: : : should be of type `T, T -> T`. :
+|`index_vector_dim`| `int64` | The dimension in |
+: : : `scatter_indices` that contains :
+: : : the starting indices. :
+|`update_window_dims`| `ArraySlice<int64>` | The set of dimensions in |
+: : : `updates` shape that are _window :
+: : : dimensions_. :
+|`inserted_window_dims`| `ArraySlice<int64>`| The set of _window dimensions_ |
+: : : that must be inserted into :
+: : : `updates` shape. :
+|`scatter_dims_to_operand_dims`| `ArraySlice<int64>` | A dimensions map from |
+: : : the scatter indices to the :
+: : : operand index space. This array :
+: : : is interpreted as mapping `i` to :
+: : : `scatter_dims_to_operand_dims[i]`:
+: : : . It has to be one-to-one and :
+: : : total. :
+
+If `index_vector_dim` is equal to `scatter_indices.rank` we implicitly consider
+`scatter_indices` to have a trailing `1` dimension.
+
+We define `update_scatter_dims` of type `ArraySlice<int64>` as the set of
+dimensions in `updates` shape that are not in `update_window_dims`, in ascending
+order.
+
+The arguments of scatter should follow these constraints:
+
+ - `updates` tensor must be of rank `update_window_dims.size +
+ scatter_indices.rank - 1`.
+
+ - Bounds of dimension `i` in `updates` must conform to the following:
+ - If `i` is present in `update_window_dims` (i.e. equal to
+ `update_window_dims`[`k`] for some `k`), then the bound of dimension
+ `i` in `updates` must not exceed the corresponding bound of `operand`
+ after accounting for the `inserted_window_dims` (i.e.
+ `adjusted_window_bounds`[`k`], where `adjusted_window_bounds` contains
+ the bounds of `operand` with the bounds at indices
+ `inserted_window_dims` removed).
+ - If `i` is present in `update_scatter_dims` (i.e. equal to
+ `update_scatter_dims`[`k`] for some `k`), then the bound of dimension
+ `i` in `updates` must be equal to the corresponding bound of
+ `scatter_indices`, skipping `index_vector_dim` (i.e.
+ `scatter_indices.shape.dims`[`k`], if `k` < `index_vector_dim` and
+ `scatter_indices.shape.dims`[`k+1`] otherwise).
+
+ - `update_window_dims` must be in ascending order, not have any repeating
+ dimension numbers, and be in the range `[0, updates.rank)`.
+
+ - `inserted_window_dims` must be in ascending order, not have any
+ repeating dimension numbers, and be in the range `[0, operand.rank)`.
+
+ - `scatter_dims_to_operand_dims.size` must be equal to
+ `scatter_indices`[`index_vector_dim`], and its values must be in the range
+ `[0, operand.rank)`.
+
+For a given index `U` in the `updates` tensor, the corresponding index `I` in
+the `operand` tensor into which this update has to be applied is computed as
+follows:
+
+ 1. Let `G` = { `U`[`k`] for `k` in `update_scatter_dims` }. Use `G` to look up
+ an index vector `S` in the `scatter_indices` tensor such that `S`[`i`] =
+ `scatter_indices`[Combine(`G`, `i`)] where Combine(A, b) inserts b at
+ positions `index_vector_dim` into A.
+ 2. Create an index `S`<sub>`in`</sub> into `operand` using `S` by scattering
+ `S` using the `scatter_dims_to_operand_dims` map. More formally:
+ 1. `S`<sub>`in`</sub>[`scatter_dims_to_operand_dims`[`k`]] = `S`[`k`] if
+ `k` < `scatter_dims_to_operand_dims.size`.
+ 2. `S`<sub>`in`</sub>[`_`] = `0` otherwise.
+ 3. Create an index `W`<sub>`in`</sub> into `operand` by scattering the indices
+ at `update_window_dims` in `U` according to `inserted_window_dims`.
+ More formally:
+ 1. `W`<sub>`in`</sub>[`window_dims_to_operand_dims`(`k`)] = `U`[`k`] if
+ `k` < `update_window_dims.size`, where `window_dims_to_operand_dims`
+ is the monotonic function with domain [`0`, `update_window_dims.size`)
+ and range [`0`, `operand.rank`) \\ `inserted_window_dims`. (For
+ example, if `update_window_dims.size` is `4`, `operand.rank` is `6`,
+ and `inserted_window_dims` is {`0`, `2`} then
+ `window_dims_to_operand_dims` is {`0`→`1`, `1`→`3`, `2`→`4`,
+ `3`→`5`}).
+ 2. `W`<sub>`in`</sub>[`_`] = `0` otherwise.
+ 4. `I` is `W`<sub>`in`</sub> + `S`<sub>`in`</sub> where + is element-wise
+ addition.
+
+In summary, the scatter operation can be defined as follows.
+
+ - Initialize `output` with `operand`, i.e. for all indices `O` in the
+ `operand` tensor:\
+ `output`[`O`] = `operand`[`O`]
+ - For every index `U` in the `updates` tensor and the corresponding index `O`
+ in the `operand` tensor:\
+ `output`[`O`] = `update_computation`(`output`[`O`], `updates`[`U`])
+
+The order in which updates are applied is non-deterministic. So, when multiple
+indices in `updates` refer to the same index in `operand`, the corresponding
+value in `output` will be non-deterministic.
+
+Note that the first parameter that is passed into the `update_computation` will
+always be the current value from the `output` tensor and the second parameter
+will always be the value from the `updates` tensor. This is important
+specifically for cases when the `update_computation` is _not commutative_.
+
+Informally, the scatter op can be viewed as an _inverse_ of the gather op, i.e.
+the scatter op updates the elements in the input that are extracted by the
+corresponding gather op.
+
+For a detailed informal description and examples, refer to the
+"Informal Description" section under `Gather`.
+
+## Select
+
+See also
+[`XlaBuilder::Select`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Constructs an output array from elements of two input arrays, based on the
+values of a predicate array.
+
+<b> `Select(pred, on_true, on_false)` </b>
+
+Arguments | Type | Semantics
+---------- | ------- | ------------------
+`pred` | `XlaOp` | array of type PRED
+`on_true` | `XlaOp` | array of type T
+`on_false` | `XlaOp` | array of type T
+
+The arrays `on_true` and `on_false` must have the same shape. This is also the
+shape of the output array. The array `pred` must have the same dimensionality as
+`on_true` and `on_false`, with the `PRED` element type.
+
+For each element `P` of `pred`, the corresponding element of the output array is
+taken from `on_true` if the value of `P` is `true`, and from `on_false` if the
+value of `P` is `false`. As a restricted form of [broadcasting]
+(broadcasting.md), `pred` can be a scalar of type `PRED`. In this case, the
+output array is taken wholly from `on_true` if `pred` is `true`, and from
+`on_false` if `pred` is `false`.
+
+Example with non-scalar `pred`:
+
+```
+let pred: PRED[4] = {true, false, false, true};
+let v1: s32[4] = {1, 2, 3, 4};
+let v2: s32[4] = {100, 200, 300, 400};
+==>
+Select(pred, v1, v2) = s32[4]{1, 200, 300, 4};
+```
+
+Example with scalar `pred`:
+
+```
+let pred: PRED = true;
+let v1: s32[4] = {1, 2, 3, 4};
+let v2: s32[4] = {100, 200, 300, 400};
+==>
+Select(pred, v1, v2) = s32[4]{1, 2, 3, 4};
+```
+
+Selections between tuples are supported. Tuples are considered to be scalar
+types for this purpose. If `on_true` and `on_false` are tuples (which must have
+the same shape!) then `pred` has to be a scalar of type `PRED`.
+
+## SelectAndScatter
+
+See also
+[`XlaBuilder::SelectAndScatter`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+This operation can be considered as a composite operation that first computes
+`ReduceWindow` on the `operand` array to select an element from each window, and
+then scatters the `source` array to the indices of the selected elements to
+construct an output array with the same shape as the operand array. The binary
+`select` function is used to select an element from each window by applying it
+across each window, and it is called with the property that the first
+parameter's index vector is lexicographically less than the second parameter's
+index vector. The `select` function returns `true` if the first parameter is
+selected and returns `false` if the second parameter is selected, and the
+function must hold transitivity (i.e., if `select(a, b)` and `select(b, c)` are
+`true`, then `select(a, c)` is also `true`) so that the selected element does
+not depend on the order of the elements traversed for a given window.
+
+The function `scatter` is applied at each selected index in the output array. It
+takes two scalar parameters:
+
+1. Current value at the selected index in the output array
+2. The scatter value from `source` that applies to the selected index
+
+It combines the two parameters and returns a scalar value that's used to update
+the value at the selected index in the output array. Initially, all indices of
+the output array are set to `init_value`.
+
+The output array has the same shape as the `operand` array and the `source`
+array must have the same shape as the result of applying a `ReduceWindow`
+operation on the `operand` array. `SelectAndScatter` can be used to
+backpropagate the gradient values for a pooling layer in a neural network.
+
+<b>`SelectAndScatter(operand, select, window_dimensions, window_strides,
+padding, source, init_value, scatter)`</b>
+
+| Arguments | Type | Semantics |
+| ------------------- | ------------------- | -------------------------------- |
+| `operand` | `XlaOp` | array of type T over which the |
+: : : windows slide :
+| `select` | `XlaComputation` | binary computation of type `T, T |
+: : : -> PRED`, to apply to all :
+: : : elements in each window; returns :
+: : : `true` if the first parameter is :
+: : : selected and returns `false` if :
+: : : the second parameter is selected :
+| `window_dimensions` | `ArraySlice<int64>` | array of integers for window |
+: : : dimension values :
+| `window_strides` | `ArraySlice<int64>` | array of integers for window |
+: : : stride values :
+| `padding` | `Padding` | padding type for window |
+: : : (Padding\:\:kSame or :
+: : : Padding\:\:kValid) :
+| `source` | `XlaOp` | array of type T with the values |
+: : : to scatter :
+| `init_value` | `XlaOp` | scalar value of type T for the |
+: : : initial value of the output :
+: : : array :
+| `scatter` | `XlaComputation` | binary computation of type `T, T |
+: : : -> T`, to apply each scatter :
+: : : source element with its :
+: : : destination element :
+
+The figure below shows examples of using `SelectAndScatter`, with the `select`
+function computing the maximal value among its parameters. Note that when the
+windows overlap, as in the figure (2) below, an index of the `operand` array may
+be selected multiple times by different windows. In the figure, the element of
+value 9 is selected by both of the top windows (blue and red) and the binary
+addition `scatter` function produces the output element of value 8 (2 + 6).
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%"
+ src="https://www.tensorflow.org/images/ops_scatter_to_selected_window_element.png">
+</div>
+
+The evaluation order of the `scatter` function is arbitrary and may be
+non-deterministic. Therefore, the `scatter` function should not be overly
+sensitive to reassociation. See the discussion about associativity in the
+context of [`Reduce`](#reduce) for more details.
+
+## Send
+
+See also
+[`XlaBuilder::Send`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+<b> `Send(operand, channel_handle)` </b>
+
+Arguments | Type | Semantics
+---------------- | --------------- | -----------------------------------------
+`operand` | `XlaOp` | data to send (array of type T)
+`channel_handle` | `ChannelHandle` | unique identifier for each send/recv pair
+
+Sends the given operand data to a `Recv` instruction in another computation
+that shares the same channel handle. Does not return any data.
+
+Similar to the `Recv` operation, the client API of `Send` operation represents
+synchronous communication, and is internally decomposed into 2 HLO instructions
+(`Send` and `SendDone`) to enable asynchronous data transfers. See also
+[`HloInstruction::CreateSend` and `HloInstruction::CreateSendDone`](https://www.tensorflow.org/code/tensorflow/compiler/xla/service/hlo_instruction.h).
+
+<b>`Send(HloInstruction operand, int64 channel_id)`</b>
+
+Initiates an asynchronous transfer of the operand to the resources allocated by
+the `Recv` instruction with the same channel id. Returns a context, which is
+used by a following `SendDone` instruction to wait for the completion of the
+data transfer. The context is a tuple of {operand (shape), request identifier
+(U32)} and it can only be used by a `SendDone` instruction.
+
+<b> `SendDone(HloInstruction context)` </b>
+
+Given a context created by a `Send` instruction, waits for the data transfer to
+complete. The instruction does not return any data.
+
+<b> Scheduling of channel instructions </b>
+
+The execution order of the 4 instructions for each channel (`Recv`, `RecvDone`,
+`Send`, `SendDone`) is as below.
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:70%" src="../../images/send_recv_order.png">
+</div>
+
+* `Recv` happens before `Send`
+* `Send` happens before `RecvDone`
+* `Recv` happens before `RecvDone`
+* `Send` happens before `SendDone`
+
+When the backend compilers generate a linear schedule for each computation that
+communicates via channel instructions, there must not be cycles across the
+computations. For example, below schedules lead to deadlocks.
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src="../../images/send_recv_schedule.png">
+</div>
+
+## Slice
+
+See also
+[`XlaBuilder::Slice`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+Slicing extracts a sub-array from the input array. The sub-array is of the same
+rank as the input and contains the values inside a bounding box within the input
+array where the dimensions and indices of the bounding box are given as
+arguments to the slice operation.
+
+<b> `Slice(operand, start_indices, limit_indices)` </b>
+
+| Arguments | Type | Semantics |
+| --------------- | ------------------- | ------------------------------------ |
+| `operand` | `XlaOp` | N dimensional array of type T |
+| `start_indices` | `ArraySlice<int64>` | List of N integers containing the |
+: : : starting indices of the slice for :
+: : : each dimension. Values must be :
+: : : greater than or equal to zero. :
+| `limit_indices` | `ArraySlice<int64>` | List of N integers containing the |
+: : : ending indices (exclusive) for the :
+: : : slice for each dimension. Each value :
+: : : must be greater than or equal to the :
+: : : respective `start_indices` value for :
+: : : the dimension and less than or equal :
+: : : to the size of the dimension. :
+
+1-dimensional example:
+
+```
+let a = {0.0, 1.0, 2.0, 3.0, 4.0}
+Slice(a, {2}, {4}) produces:
+ {2.0, 3.0}
+```
+
+2-dimensional example:
+
+```
+let b =
+ { {0.0, 1.0, 2.0},
+ {3.0, 4.0, 5.0},
+ {6.0, 7.0, 8.0},
+ {9.0, 10.0, 11.0} }
+
+Slice(b, {2, 1}, {4, 3}) produces:
+ { { 7.0, 8.0},
+ {10.0, 11.0} }
+```
+
+## Sort
+
+See also
+[`XlaBuilder::Sort`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+There are two versions of the Sort instruction: a single-operand and a
+two-operand version.
+
+<b>`Sort(operand)`</b>
+
+Arguments | Type | Semantics
+----------- | ------- | --------------------
+`operand` | `XlaOp` | The operand to sort.
+`dimension` | `int64` | The dimension along which to sort.
+
+Sorts the elements in the operand in ascending order along the provided
+dimension. For example, for a rank-2 (matrix) operand, a `dimension` value of 0
+will sort each column independently, and a `dimension` value of 1 will sort each
+row independently. If the operand's elements have floating point type, and the
+operand contains NaN elements, the order of elements in the output is
+implementation-defined.
+
+<b>`Sort(key, value)`</b>
+
+Sorts both the key and the value operands. The keys are sorted as in the
+single-operand version. The values are sorted according to the order of their
+corresponding keys. For example, if the inputs are `keys = [3, 1]` and
+`values = [42, 50]`, then the output of the sort is the tuple
+`{[1, 3], [50, 42]}`.
+
+The sort is not guaranteed to be stable, that is, if the keys array contains
+duplicates, the order of their corresponding values may not be preserved.
+
+Arguments | Type | Semantics
+----------- | ------- | -------------------
+`keys` | `XlaOp` | The sort keys.
+`values` | `XlaOp` | The values to sort.
+`dimension` | `int64` | The dimension along which to sort.
+
+The `keys` and `values` must have the same dimensions, but may have different
+element types.
+
+## Transpose
+
+See also the `tf.reshape` operation.
+
+<b>`Transpose(operand)`</b>
+
+Arguments | Type | Semantics
+------------- | ------------------- | ------------------------------
+`operand` | `XlaOp` | The operand to transpose.
+`permutation` | `ArraySlice<int64>` | How to permute the dimensions.
+
+
+Permutes the operand dimensions with the given permutation, so
+`∀ i . 0 ≤ i < rank ⇒ input_dimensions[permutation[i]] = output_dimensions[i]`.
+
+This is the same as Reshape(operand, permutation,
+ Permute(permutation, operand.shape.dimensions)).
+
+## Tuple
+
+See also
+[`XlaBuilder::Tuple`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+A tuple containing a variable number of data handles, each of which has its own
+shape.
+
+This is analogous to `std::tuple` in C++. Conceptually:
+
+```
+let v: f32[10] = f32[10]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
+let s: s32 = 5;
+let t: (f32[10], s32) = tuple(v, s);
+```
+
+Tuples can be deconstructed (accessed) via the [`GetTupleElement`]
+(#gettupleelement) operation.
+
+## While
+
+See also
+[`XlaBuilder::While`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
+
+<b> `While(condition, body, init)` </b>
+
+| Arguments | Type | Semantics |
+| ----------- | ---------------- | ---------------------------------------- |
+| `condition` | `XlaComputation` | XlaComputation of type `T -> PRED` which |
+: : : defines the termination condition of the :
+: : : loop. :
+| `body` | `XlaComputation` | XlaComputation of type `T -> T` which |
+: : : defines the body of the loop. :
+| `init` | `T` | Initial value for the parameter of |
+: : : `condition` and `body`. :
+
+Sequentially executes the `body` until the `condition` fails. This is similar to
+a typical while loop in many other languages except for the differences and
+restrictions listed below.
+
+* A `While` node returns a value of type `T`, which is the result from the
+ last execution of the `body`.
+* The shape of the type `T` is statically determined and must be the same
+ across all iterations.
+
+The T parameters of the computations are initialized with the `init` value in
+the first iteration and are automatically updated to the new result from `body`
+in each subsequent iteration.
+
+One main use case of the `While` node is to implement the repeated execution of
+training in neural networks. Simplified pseudocode is shown below with a graph
+that represents the computation. The code can be found in
+[`while_test.cc`](https://www.tensorflow.org/code/tensorflow/compiler/xla/tests/while_test.cc).
+The type `T` in this example is a `Tuple` consisting of an `int32` for the
+iteration count and a `vector[10]` for the accumulator. For 1000 iterations, the
+loop keeps adding a constant vector to the accumulator.
+
+```
+// Pseudocode for the computation.
+init = {0, zero_vector[10]} // Tuple of int32 and float[10].
+result = init;
+while (result(0) < 1000) {
+ iteration = result(0) + 1;
+ new_vector = result(1) + constant_vector[10];
+ result = {iteration, new_vector};
+}
+```
+
+<div style="width:95%; margin:auto; margin-bottom:10px; margin-top:20px;">
+ <img style="width:100%" src="https://www.tensorflow.org/images/ops_while.png">
+</div>
diff --git a/tensorflow/examples/get_started/regression/test.py b/tensorflow/examples/get_started/regression/test.py
index 0b1477ad96..bb4db6700b 100644
--- a/tensorflow/examples/get_started/regression/test.py
+++ b/tensorflow/examples/get_started/regression/test.py
@@ -29,7 +29,7 @@ import tensorflow.examples.get_started.regression.imports85 as imports85
sys.modules["imports85"] = imports85
# pylint: disable=g-bad-import-order,g-import-not-at-top
-import tensorflow.contrib.data as data
+import tensorflow.data as data
import tensorflow.examples.get_started.regression.dnn_regression as dnn_regression
import tensorflow.examples.get_started.regression.linear_regression as linear_regression
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index 2f297d5161..b4d4db3e4d 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -3742,27 +3742,6 @@ func BoostedTreesMakeStatsSummary(scope *Scope, node_ids tf.Output, gradients tf
return op.Output(0)
}
-// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
-//
-// Arguments:
-// tree_ensemble_handle: Handle to the tree ensemble.
-//
-// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
-// layer.
-func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "BoostedTreesGetEnsembleStates",
- Input: []tf.Input{
- tree_ensemble_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
-}
-
// Creates a tree ensemble model and returns a handle to it.
//
// Arguments:
@@ -4059,6 +4038,364 @@ func FixedUnigramCandidateSampler(scope *Scope, true_classes tf.Output, num_true
return op.Output(0), op.Output(1), op.Output(2)
}
+// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler.
+type LogUniformCandidateSamplerAttr func(optionalAttr)
+
+// LogUniformCandidateSamplerSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func LogUniformCandidateSamplerSeed(value int64) LogUniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// LogUniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func LogUniformCandidateSamplerSeed2(value int64) LogUniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a log-uniform distribution.
+//
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
+//
+// For each batch, this op picks a single set of sampled candidate labels.
+//
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
+//
+// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to randomly sample.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+// range_max: The sampler will sample integers from the interval [0, range_max).
+//
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LogUniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "LogUniformCandidateSampler",
+ Input: []tf.Input{
+ true_classes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler.
+type UniformCandidateSamplerAttr func(optionalAttr)
+
+// UniformCandidateSamplerSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a uniform distribution.
+//
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
+//
+// For each batch, this op picks a single set of sampled candidate labels.
+//
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
+//
+// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to randomly sample.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+// range_max: The sampler will sample integers from the interval [0, range_max).
+//
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "UniformCandidateSampler",
+ Input: []tf.Input{
+ true_classes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping.
+type GenerateVocabRemappingAttr func(optionalAttr)
+
+// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value.
+//
+// value: Number of entries in the old vocab file to consider. If -1,
+// use the entire old vocabulary.
+// If not specified, defaults to -1
+//
+// REQUIRES: value >= -1
+func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr {
+ return func(m optionalAttr) {
+ m["old_vocab_size"] = value
+ }
+}
+
+// Given a path to new and old vocabulary files, returns a remapping Tensor of
+//
+// length `num_new_vocab`, where `remapping[i]` contains the row number in the old
+// vocabulary that corresponds to row `i` in the new vocabulary (starting at line
+// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i`
+// in the new vocabulary is not in the old vocabulary. The old vocabulary is
+// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the
+// default value of -1.
+//
+// `num_vocab_offset` enables
+// use in the partitioned variable case, and should generally be set through
+// examining partitioning info. The format of the files should be a text file,
+// with each line containing a single entity within the vocabulary.
+//
+// For example, with `new_vocab_file` a text file containing each of the following
+// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3],
+// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be
+// `[0, -1, 2]`.
+//
+// The op also returns a count of how many entries in the new vocabulary
+// were present in the old vocabulary, which is used to calculate the number of
+// values to initialize in a weight matrix remapping
+//
+// This functionality can be used to remap both row vocabularies (typically,
+// features) and column vocabularies (typically, classes) from TensorFlow
+// checkpoints. Note that the partitioning logic relies on contiguous vocabularies
+// corresponding to div-partitioned variables. Moreover, the underlying remapping
+// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should
+// use the corresponding index_table_from_file() as the FeatureColumn framework
+// does (as opposed to tf.feature_to_id(), which uses a CuckooTable).
+//
+// Arguments:
+// new_vocab_file: Path to the new vocab file.
+// old_vocab_file: Path to the old vocab file.
+// new_vocab_offset: How many entries into the new vocab file to start reading.
+// num_new_vocab: Number of entries in the new vocab file to remap.
+//
+// Returns A Tensor of length num_new_vocab where the element at index i
+// is equal to the old ID that maps to the new ID i. This element is -1 for any
+// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab.
+func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "GenerateVocabRemapping",
+ Input: []tf.Input{
+ new_vocab_file, old_vocab_file,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// Broadcasts a tensor value to one or more other devices.
+func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape}
+ opspec := tf.OpSpec{
+ Type: "CollectiveBcastSend",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Mutually reduces multiple tensors of identical type and shape.
+func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets}
+ opspec := tf.OpSpec{
+ Type: "CollectiveReduce",
+ Input: []tf.Input{
+ input,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// AbortAttr is an optional argument to Abort.
+type AbortAttr func(optionalAttr)
+
+// AbortErrorMsg sets the optional error_msg attribute to value.
+//
+// value: A string which is the message associated with the exception.
+// If not specified, defaults to ""
+func AbortErrorMsg(value string) AbortAttr {
+ return func(m optionalAttr) {
+ m["error_msg"] = value
+ }
+}
+
+// AbortExitWithoutError sets the optional exit_without_error attribute to value.
+// If not specified, defaults to false
+func AbortExitWithoutError(value bool) AbortAttr {
+ return func(m optionalAttr) {
+ m["exit_without_error"] = value
+ }
+}
+
+// Raise a exception to abort the process when called.
+//
+// If exit_without_error is true, the process will exit normally,
+// otherwise it will exit with a SIGABORT signal.
+//
+// Returns nothing but an exception.
+//
+// Returns the created operation.
+func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Abort",
+
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Forwards the input to the output.
+//
+// This operator represents the loop termination condition used by the
+// "pivot" switches of a loop.
+//
+// Arguments:
+// input: A boolean scalar, representing the branch predicate of the Switch op.
+//
+// Returns The same tensor as `input`.
+func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "LoopCond",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns a tensor of zeros with the same shape and type as x.
+//
+// Arguments:
+// x: a tensor of type T.
+//
+// Returns a tensor of the same shape and type as x but filled with zeros.
+func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ZerosLike",
+ Input: []tf.Input{
+ x,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns a copy of the input tensor.
+func Snapshot(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Snapshot",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResourceStridedSliceAssignAttr is an optional argument to ResourceStridedSliceAssign.
type ResourceStridedSliceAssignAttr func(optionalAttr)
@@ -10182,23 +10519,6 @@ func Assert(scope *Scope, condition tf.Output, data []tf.Output, optional ...Ass
return scope.AddOperation(opspec)
}
-// Broadcasts a tensor value to one or more other devices.
-func CollectiveBcastSend(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, shape tf.Shape) (data tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "shape": shape}
- opspec := tf.OpSpec{
- Type: "CollectiveBcastSend",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Split a `SparseTensor` into `num_split` tensors along one dimension.
//
// If the `shape[split_dim]` is not an integer multiple of `num_split`. Slices
@@ -10776,23 +11096,6 @@ func ResourceScatterNdAdd(scope *Scope, ref tf.Output, indices tf.Output, update
return scope.AddOperation(opspec)
}
-// Mutually reduces multiple tensors of identical type and shape.
-func CollectiveReduce(scope *Scope, input tf.Output, group_size int64, group_key int64, instance_key int64, merge_op string, final_op string, subdiv_offsets []int64) (data tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"group_size": group_size, "group_key": group_key, "instance_key": instance_key, "merge_op": merge_op, "final_op": final_op, "subdiv_offsets": subdiv_offsets}
- opspec := tf.OpSpec{
- Type: "CollectiveReduce",
- Input: []tf.Input{
- input,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Updates the tree ensemble by either adding a layer to the last tree being grown
//
// or by starting a new tree.
@@ -11671,6 +11974,49 @@ func ResourceApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.
return scope.AddOperation(opspec)
}
+// Exits the current frame to its parent frame.
+//
+// Exit makes its input `data` available to the parent frame.
+//
+// Arguments:
+// data: The tensor to be made available to the parent frame.
+//
+// Returns The same tensor as `data`.
+func Exit(scope *Scope, data tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Exit",
+ Input: []tf.Input{
+ data,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Produce a string tensor that encodes the state of a Reader.
+//
+// Not all Readers support being serialized, so this can produce an
+// Unimplemented error.
+//
+// Arguments:
+// reader_handle: Handle to a Reader.
+func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "ReaderSerializeStateV2",
+ Input: []tf.Input{
+ reader_handle,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// DepthwiseConv2dNativeBackpropFilterAttr is an optional argument to DepthwiseConv2dNativeBackpropFilter.
type DepthwiseConv2dNativeBackpropFilterAttr func(optionalAttr)
@@ -11804,68 +12150,6 @@ func StringJoin(scope *Scope, inputs []tf.Output, optional ...StringJoinAttr) (o
return op.Output(0)
}
-// StringSplitV2Attr is an optional argument to StringSplitV2.
-type StringSplitV2Attr func(optionalAttr)
-
-// StringSplitV2Maxsplit sets the optional maxsplit attribute to value.
-//
-// value: An `int`. If `maxsplit > 0`, limit of the split of the result.
-// If not specified, defaults to -1
-func StringSplitV2Maxsplit(value int64) StringSplitV2Attr {
- return func(m optionalAttr) {
- m["maxsplit"] = value
- }
-}
-
-// Split elements of `source` based on `sep` into a `SparseTensor`.
-//
-// Let N be the size of source (typically N will be the batch size). Split each
-// element of `source` based on `sep` and return a `SparseTensor`
-// containing the split tokens. Empty tokens are ignored.
-//
-// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
-// then the output will be
-// ```
-// st.indices = [0, 0;
-// 0, 1;
-// 1, 0;
-// 1, 1;
-// 1, 2]
-// st.shape = [2, 3]
-// st.values = ['hello', 'world', 'a', 'b', 'c']
-// ```
-//
-// If `sep` is given, consecutive delimiters are not grouped together and are
-// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
-// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
-// string, consecutive whitespace are regarded as a single separator, and the
-// result will contain no empty strings at the startor end if the string has
-// leading or trailing whitespace.
-//
-// Note that the above mentioned behavior matches python's str.split.
-//
-// Arguments:
-// input: `1-D` string `Tensor`, the strings to split.
-// sep: `0-D` string `Tensor`, the delimiter character.
-func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "StringSplitV2",
- Input: []tf.Input{
- input, sep,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// MaxPoolAttr is an optional argument to MaxPool.
type MaxPoolAttr func(optionalAttr)
@@ -12435,21 +12719,6 @@ func ResizeBilinear(scope *Scope, images tf.Output, size tf.Output, optional ...
return op.Output(0)
}
-// Computes softsign: `features / (abs(features) + 1)`.
-func Softsign(scope *Scope, features tf.Output) (activations tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Softsign",
- Input: []tf.Input{
- features,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Creates a TensorList which, when stacked, has the value of `tensor`.
//
// Each tensor in the result list corresponds to one row of the input tensor.
@@ -12470,81 +12739,6 @@ func TensorListFromTensor(scope *Scope, tensor tf.Output, element_shape tf.Outpu
return op.Output(0)
}
-// GenerateVocabRemappingAttr is an optional argument to GenerateVocabRemapping.
-type GenerateVocabRemappingAttr func(optionalAttr)
-
-// GenerateVocabRemappingOldVocabSize sets the optional old_vocab_size attribute to value.
-//
-// value: Number of entries in the old vocab file to consider. If -1,
-// use the entire old vocabulary.
-// If not specified, defaults to -1
-//
-// REQUIRES: value >= -1
-func GenerateVocabRemappingOldVocabSize(value int64) GenerateVocabRemappingAttr {
- return func(m optionalAttr) {
- m["old_vocab_size"] = value
- }
-}
-
-// Given a path to new and old vocabulary files, returns a remapping Tensor of
-//
-// length `num_new_vocab`, where `remapping[i]` contains the row number in the old
-// vocabulary that corresponds to row `i` in the new vocabulary (starting at line
-// `new_vocab_offset` and up to `num_new_vocab` entities), or `-1` if entry `i`
-// in the new vocabulary is not in the old vocabulary. The old vocabulary is
-// constrained to the first `old_vocab_size` entries if `old_vocab_size` is not the
-// default value of -1.
-//
-// `num_vocab_offset` enables
-// use in the partitioned variable case, and should generally be set through
-// examining partitioning info. The format of the files should be a text file,
-// with each line containing a single entity within the vocabulary.
-//
-// For example, with `new_vocab_file` a text file containing each of the following
-// elements on a single line: `[f0, f1, f2, f3]`, old_vocab_file = [f1, f0, f3],
-// `num_new_vocab = 3, new_vocab_offset = 1`, the returned remapping would be
-// `[0, -1, 2]`.
-//
-// The op also returns a count of how many entries in the new vocabulary
-// were present in the old vocabulary, which is used to calculate the number of
-// values to initialize in a weight matrix remapping
-//
-// This functionality can be used to remap both row vocabularies (typically,
-// features) and column vocabularies (typically, classes) from TensorFlow
-// checkpoints. Note that the partitioning logic relies on contiguous vocabularies
-// corresponding to div-partitioned variables. Moreover, the underlying remapping
-// uses an IndexTable (as opposed to an inexact CuckooTable), so client code should
-// use the corresponding index_table_from_file() as the FeatureColumn framework
-// does (as opposed to tf.feature_to_id(), which uses a CuckooTable).
-//
-// Arguments:
-// new_vocab_file: Path to the new vocab file.
-// old_vocab_file: Path to the old vocab file.
-// new_vocab_offset: How many entries into the new vocab file to start reading.
-// num_new_vocab: Number of entries in the new vocab file to remap.
-//
-// Returns A Tensor of length num_new_vocab where the element at index i
-// is equal to the old ID that maps to the new ID i. This element is -1 for any
-// new ID that is not found in the old vocabulary.Number of new vocab entries found in old vocab.
-func GenerateVocabRemapping(scope *Scope, new_vocab_file tf.Output, old_vocab_file tf.Output, new_vocab_offset int64, num_new_vocab int64, optional ...GenerateVocabRemappingAttr) (remapping tf.Output, num_present tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"new_vocab_offset": new_vocab_offset, "num_new_vocab": num_new_vocab}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "GenerateVocabRemapping",
- Input: []tf.Input{
- new_vocab_file, old_vocab_file,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
// Assigns sparse updates to the variable referenced by `resource`.
//
// This operation computes
@@ -13547,6 +13741,27 @@ func DataFormatDimMap(scope *Scope, x tf.Output, optional ...DataFormatDimMapAtt
return op.Output(0)
}
+// Retrieves the tree ensemble resource stamp token, number of trees and growing statistics.
+//
+// Arguments:
+// tree_ensemble_handle: Handle to the tree ensemble.
+//
+// Returns Stamp token of the tree ensemble resource.The number of trees in the tree ensemble resource.The number of trees that were finished successfully.The number of layers we attempted to build (but not necessarily succeeded).Rank size 2 tensor that contains start and end ids of the nodes in the latest
+// layer.
+func BoostedTreesGetEnsembleStates(scope *Scope, tree_ensemble_handle tf.Output) (stamp_token tf.Output, num_trees tf.Output, num_finalized_trees tf.Output, num_attempted_layers tf.Output, last_layer_nodes_range tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "BoostedTreesGetEnsembleStates",
+ Input: []tf.Input{
+ tree_ensemble_handle,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4)
+}
+
// ResourceApplyPowerSignAttr is an optional argument to ResourceApplyPowerSign.
type ResourceApplyPowerSignAttr func(optionalAttr)
@@ -16327,79 +16542,6 @@ func RandomPoisson(scope *Scope, shape tf.Output, rate tf.Output, optional ...Ra
return op.Output(0)
}
-// LogUniformCandidateSamplerAttr is an optional argument to LogUniformCandidateSampler.
-type LogUniformCandidateSamplerAttr func(optionalAttr)
-
-// LogUniformCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func LogUniformCandidateSamplerSeed(value int64) LogUniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// LogUniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func LogUniformCandidateSamplerSeed2(value int64) LogUniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a log-uniform distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to randomly sample.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-// range_max: The sampler will sample integers from the interval [0, range_max).
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func LogUniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...LogUniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "LogUniformCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// Returns the max of x and y (i.e. x > y ? x : y) element-wise.
//
// *NOTE*: `Maximum` supports broadcasting. More about broadcasting
@@ -19444,31 +19586,6 @@ func SparseDenseCwiseAdd(scope *Scope, sp_indices tf.Output, sp_values tf.Output
return op.Output(0)
}
-// Read an element from the TensorArray into output `value`.
-//
-// Arguments:
-// handle: The handle to a TensorArray.
-//
-// flow_in: A float scalar that enforces proper chaining of operations.
-// dtype: The type of the elem that is returned.
-//
-// Returns The tensor that is read from the TensorArray.
-func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype}
- opspec := tf.OpSpec{
- Type: "TensorArrayReadV3",
- Input: []tf.Input{
- handle, index, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// QuantizeV2Attr is an optional argument to QuantizeV2.
type QuantizeV2Attr func(optionalAttr)
@@ -20866,6 +20983,201 @@ func Sum(scope *Scope, input tf.Output, axis tf.Output, optional ...SumAttr) (ou
return op.Output(0)
}
+// EnterAttr is an optional argument to Enter.
+type EnterAttr func(optionalAttr)
+
+// EnterIsConstant sets the optional is_constant attribute to value.
+//
+// value: If true, the output is constant within the child frame.
+// If not specified, defaults to false
+func EnterIsConstant(value bool) EnterAttr {
+ return func(m optionalAttr) {
+ m["is_constant"] = value
+ }
+}
+
+// EnterParallelIterations sets the optional parallel_iterations attribute to value.
+//
+// value: The number of iterations allowed to run in parallel.
+// If not specified, defaults to 10
+func EnterParallelIterations(value int64) EnterAttr {
+ return func(m optionalAttr) {
+ m["parallel_iterations"] = value
+ }
+}
+
+// Creates or finds a child frame, and makes `data` available to the child frame.
+//
+// This op is used together with `Exit` to create loops in the graph.
+// The unique `frame_name` is used by the `Executor` to identify frames. If
+// `is_constant` is true, `output` is a constant in the child frame; otherwise
+// it may be changed in the child frame. At most `parallel_iterations` iterations
+// are run in parallel in the child frame.
+//
+// Arguments:
+// data: The tensor to be made available to the child frame.
+// frame_name: The name of the child frame.
+//
+// Returns The same tensor as `data`.
+func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"frame_name": frame_name}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Enter",
+ Input: []tf.Input{
+ data,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Add all input tensors element wise.
+//
+// Arguments:
+// inputs: Must all be the same size and shape.
+func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "AddN",
+ Input: []tf.Input{
+ tf.OutputList(inputs),
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// TryRpcAttr is an optional argument to TryRpc.
+type TryRpcAttr func(optionalAttr)
+
+// TryRpcProtocol sets the optional protocol attribute to value.
+//
+// value: RPC protocol to use. Empty string means use the default protocol.
+// Options include 'grpc'.
+// If not specified, defaults to ""
+func TryRpcProtocol(value string) TryRpcAttr {
+ return func(m optionalAttr) {
+ m["protocol"] = value
+ }
+}
+
+// TryRpcFailFast sets the optional fail_fast attribute to value.
+//
+// value: `boolean`. If `true` (default), then failures to connect
+// (i.e., the server does not immediately respond) cause an RPC failure.
+// If not specified, defaults to true
+func TryRpcFailFast(value bool) TryRpcAttr {
+ return func(m optionalAttr) {
+ m["fail_fast"] = value
+ }
+}
+
+// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value.
+//
+// value: `int`. If `0` (default), then the kernel will run the RPC
+// request and only time out if the RPC deadline passes or the session times out.
+// If this value is greater than `0`, then the op will raise an exception if
+// the RPC takes longer than `timeout_in_ms`.
+// If not specified, defaults to 0
+func TryRpcTimeoutInMs(value int64) TryRpcAttr {
+ return func(m optionalAttr) {
+ m["timeout_in_ms"] = value
+ }
+}
+
+// Perform batches of RPC requests.
+//
+// This op asynchronously performs either a single RPC request, or a batch
+// of requests. RPC requests are defined by three main parameters:
+//
+// - `address` (the host+port or BNS address of the request)
+// - `method` (the method name for the request)
+// - `request` (the serialized proto string, or vector of strings,
+// of the RPC request argument).
+//
+// For example, if you have an RPC service running on port localhost:2345,
+// and its interface is configured with the following proto declaration:
+//
+// ```
+// service MyService {
+// rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
+// }
+// };
+// ```
+//
+// then call this op with arguments:
+//
+// ```
+// address = "localhost:2345"
+// method = "MyService/MyMethod"
+// ```
+//
+// The `request` tensor is a string tensor representing serialized `MyRequestProto`
+// strings; and the output string tensor `response` will have the same shape
+// and contain (upon successful completion) corresponding serialized
+// `MyResponseProto` strings.
+//
+// For example, to send a single, empty, `MyRequestProto`, call
+// this op with `request = ""`. To send 5 **parallel** empty requests,
+// call this op with `request = ["", "", "", "", ""]`.
+//
+// More generally, one can create a batch of `MyRequestProto` serialized protos
+// from regular batched tensors using the `encode_proto` op, and convert
+// the response `MyResponseProto` serialized protos to batched tensors
+// using the `decode_proto` op.
+//
+// **NOTE** Working with serialized proto strings is faster than instantiating
+// actual proto objects in memory, so no performance degradation is expected
+// compared to writing custom kernels for this workflow.
+//
+// Unlike the standard `Rpc` op, if the connection fails or the remote worker
+// returns an error status, this op does **not** reraise the exception.
+// Instead, the `status_code` and `status_message` entry for the corresponding RPC
+// call is set with the error returned from the RPC call. The `response` tensor
+// will contain valid response values for those minibatch entries whose RPCs did
+// not fail; the rest of the entries will have empty strings.
+//
+// Arguments:
+// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
+// If this tensor has more than 1 element, then multiple parallel rpc requests
+// are sent. This argument broadcasts with `method` and `request`.
+// method: `0-D` or `1-D`. The method address on the RPC server.
+// If this tensor has more than 1 element, then multiple parallel rpc requests
+// are sent. This argument broadcasts with `address` and `request`.
+// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument.
+// If this tensor has more than 1 element, then multiple parallel rpc requests
+// are sent. This argument broadcasts with `address` and `method`.
+//
+// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages
+// returned from the RPC calls.
+func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "TryRpc",
+ Input: []tf.Input{
+ address, method, request,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
// Delete the tensor specified by its handle in the session.
//
// Arguments:
@@ -21612,29 +21924,6 @@ func QuantizeDownAndShrinkRange(scope *Scope, input tf.Output, input_min tf.Outp
return op.Output(0), op.Output(1), op.Output(2)
}
-// Forwards the input to the output.
-//
-// This operator represents the loop termination condition used by the
-// "pivot" switches of a loop.
-//
-// Arguments:
-// input: A boolean scalar, representing the branch predicate of the Switch op.
-//
-// Returns The same tensor as `input`.
-func LoopCond(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "LoopCond",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Computes the sum along segments of a tensor.
//
// Read
@@ -24163,6 +24452,31 @@ func TensorSummary(scope *Scope, tensor tf.Output, optional ...TensorSummaryAttr
return op.Output(0)
}
+// Read an element from the TensorArray into output `value`.
+//
+// Arguments:
+// handle: The handle to a TensorArray.
+//
+// flow_in: A float scalar that enforces proper chaining of operations.
+// dtype: The type of the elem that is returned.
+//
+// Returns The tensor that is read from the TensorArray.
+func TensorArrayReadV3(scope *Scope, handle tf.Output, index tf.Output, flow_in tf.Output, dtype tf.DataType) (value tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype}
+ opspec := tf.OpSpec{
+ Type: "TensorArrayReadV3",
+ Input: []tf.Input{
+ handle, index, flow_in,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Computes the gradient for the tanh of `x` wrt its input.
//
// Specifically, `grad = dy * (1 - y*y)`, where `y = tanh(x)`, and `dy`
@@ -27849,178 +28163,6 @@ func FakeParam(scope *Scope, dtype tf.DataType, shape tf.Shape) (output tf.Outpu
return op.Output(0)
}
-// EncodeProtoAttr is an optional argument to EncodeProto.
-type EncodeProtoAttr func(optionalAttr)
-
-// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value.
-// If not specified, defaults to "local://"
-func EncodeProtoDescriptorSource(value string) EncodeProtoAttr {
- return func(m optionalAttr) {
- m["descriptor_source"] = value
- }
-}
-
-// The op serializes protobuf messages provided in the input tensors.
-//
-// The types of the tensors in `values` must match the schema for the
-// fields specified in `field_names`. All the tensors in `values` must
-// have a common shape prefix, *batch_shape*.
-//
-// The `sizes` tensor specifies repeat counts for each field. The repeat
-// count (last dimension) of a each tensor in `values` must be greater
-// than or equal to corresponding repeat count in `sizes`.
-//
-// A `message_type` name must be provided to give context for the field
-// names. The actual message descriptor can be looked up either in the
-// linked-in descriptor pool or a filename provided by the caller using
-// the `descriptor_source` attribute.
-//
-// The `descriptor_source` attribute selects a source of protocol
-// descriptors to consult when looking up `message_type`. This may be a
-// filename containing a serialized `FileDescriptorSet` message,
-// or the special value `local://`, in which case only descriptors linked
-// into the code will be searched; the filename can be on any filesystem
-// accessible to TensorFlow.
-//
-// You can build a `descriptor_source` file using the `--descriptor_set_out`
-// and `--include_imports` options to the protocol compiler `protoc`.
-//
-// The `local://` database only covers descriptors linked into the
-// code via C++ libraries, not Python imports. You can link in a proto descriptor
-// by creating a cc_library target with alwayslink=1.
-//
-// There are a few special cases in the value mapping:
-//
-// Submessage and group fields must be pre-serialized as TensorFlow strings.
-//
-// TensorFlow lacks support for unsigned int64s, so they must be
-// represented as `tf.int64` with the same twos-complement bit pattern
-// (the obvious way).
-//
-// Unsigned int32 values can be represented exactly with `tf.int64`, or
-// with sign wrapping if the input is of type `tf.int32`.
-//
-// Arguments:
-// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`.
-// values: List of tensors containing values for the corresponding field.
-// field_names: List of strings containing proto field names.
-// message_type: Name of the proto message type to decode.
-//
-// Returns Tensor of serialized protos with shape `batch_shape`.
-func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "EncodeProto",
- Input: []tf.Input{
- sizes, tf.OutputList(values),
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Creates a TensorArray for storing the gradients of values in the given handle.
-//
-// If the given TensorArray gradient already exists, returns a reference to it.
-//
-// Locks the size of the original TensorArray by disabling its dynamic size flag.
-//
-// **A note about the input flow_in:**
-//
-// The handle flow_in forces the execution of the gradient lookup to occur
-// only after certain other operations have occurred. For example, when
-// the forward TensorArray is dynamically sized, writes to this TensorArray
-// may resize the object. The gradient TensorArray is statically sized based
-// on the size of the forward TensorArray when this operation executes.
-// Furthermore, the size of the forward TensorArray is frozen by this call.
-// As a result, the flow is used to ensure that the call to generate the gradient
-// TensorArray only happens after all writes are executed.
-//
-// In the case of dynamically sized TensorArrays, gradient computation should
-// only be performed on read operations that have themselves been chained via
-// flow to occur only after all writes have executed. That way the final size
-// of the forward TensorArray is known when this operation is called.
-//
-// **A note about the source attribute:**
-//
-// TensorArray gradient calls use an accumulator TensorArray object. If
-// multiple gradients are calculated and run in the same session, the multiple
-// gradient nodes may accidentally flow through the same accumulator TensorArray.
-// This double counts and generally breaks the TensorArray gradient flow.
-//
-// The solution is to identify which gradient call this particular
-// TensorArray gradient is being called in. This is performed by identifying
-// a unique string (e.g. "gradients", "gradients_1", ...) from the input
-// gradient Tensor's name. This string is used as a suffix when creating
-// the TensorArray gradient object here (the attribute `source`).
-//
-// The attribute `source` is added as a suffix to the forward TensorArray's
-// name when performing the creation / lookup, so that each separate gradient
-// calculation gets its own TensorArray accumulator.
-//
-// Arguments:
-// handle: The handle to the forward TensorArray.
-// flow_in: A float scalar that enforces proper chaining of operations.
-// source: The gradient source string, used to decide which gradient TensorArray
-// to return.
-func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"source": source}
- opspec := tf.OpSpec{
- Type: "TensorArrayGradV3",
- Input: []tf.Input{
- handle, flow_in,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1)
-}
-
-// Creates a dataset that splits a SparseTensor into elements row-wise.
-func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseTensorSliceDataset",
- Input: []tf.Input{
- indices, values, dense_shape,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns x / y element-wise for real types.
-//
-// If `x` and `y` are reals, this will return the floating-point division.
-//
-// *NOTE*: `Div` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "RealDiv",
- Input: []tf.Input{
- x, y,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Adds v into specified rows of x.
//
// Computes y = x; y[i, :] += v; return y.
@@ -28316,6 +28458,255 @@ func StackPushV2(scope *Scope, handle tf.Output, elem tf.Output, optional ...Sta
return op.Output(0)
}
+// StringSplitV2Attr is an optional argument to StringSplitV2.
+type StringSplitV2Attr func(optionalAttr)
+
+// StringSplitV2Maxsplit sets the optional maxsplit attribute to value.
+//
+// value: An `int`. If `maxsplit > 0`, limit of the split of the result.
+// If not specified, defaults to -1
+func StringSplitV2Maxsplit(value int64) StringSplitV2Attr {
+ return func(m optionalAttr) {
+ m["maxsplit"] = value
+ }
+}
+
+// Split elements of `source` based on `sep` into a `SparseTensor`.
+//
+// Let N be the size of source (typically N will be the batch size). Split each
+// element of `source` based on `sep` and return a `SparseTensor`
+// containing the split tokens. Empty tokens are ignored.
+//
+// For example, N = 2, source[0] is 'hello world' and source[1] is 'a b c',
+// then the output will be
+// ```
+// st.indices = [0, 0;
+// 0, 1;
+// 1, 0;
+// 1, 1;
+// 1, 2]
+// st.shape = [2, 3]
+// st.values = ['hello', 'world', 'a', 'b', 'c']
+// ```
+//
+// If `sep` is given, consecutive delimiters are not grouped together and are
+// deemed to delimit empty strings. For example, source of `"1<>2<><>3"` and
+// sep of `"<>"` returns `["1", "2", "", "3"]`. If `sep` is None or an empty
+// string, consecutive whitespace are regarded as a single separator, and the
+// result will contain no empty strings at the startor end if the string has
+// leading or trailing whitespace.
+//
+// Note that the above mentioned behavior matches python's str.split.
+//
+// Arguments:
+// input: `1-D` string `Tensor`, the strings to split.
+// sep: `0-D` string `Tensor`, the delimiter character.
+func StringSplitV2(scope *Scope, input tf.Output, sep tf.Output, optional ...StringSplitV2Attr) (indices tf.Output, values tf.Output, shape tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StringSplitV2",
+ Input: []tf.Input{
+ input, sep,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Computes softsign: `features / (abs(features) + 1)`.
+func Softsign(scope *Scope, features tf.Output) (activations tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Softsign",
+ Input: []tf.Input{
+ features,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// EncodeProtoAttr is an optional argument to EncodeProto.
+type EncodeProtoAttr func(optionalAttr)
+
+// EncodeProtoDescriptorSource sets the optional descriptor_source attribute to value.
+// If not specified, defaults to "local://"
+func EncodeProtoDescriptorSource(value string) EncodeProtoAttr {
+ return func(m optionalAttr) {
+ m["descriptor_source"] = value
+ }
+}
+
+// The op serializes protobuf messages provided in the input tensors.
+//
+// The types of the tensors in `values` must match the schema for the
+// fields specified in `field_names`. All the tensors in `values` must
+// have a common shape prefix, *batch_shape*.
+//
+// The `sizes` tensor specifies repeat counts for each field. The repeat
+// count (last dimension) of a each tensor in `values` must be greater
+// than or equal to corresponding repeat count in `sizes`.
+//
+// A `message_type` name must be provided to give context for the field
+// names. The actual message descriptor can be looked up either in the
+// linked-in descriptor pool or a filename provided by the caller using
+// the `descriptor_source` attribute.
+//
+// The `descriptor_source` attribute selects a source of protocol
+// descriptors to consult when looking up `message_type`. This may be a
+// filename containing a serialized `FileDescriptorSet` message,
+// or the special value `local://`, in which case only descriptors linked
+// into the code will be searched; the filename can be on any filesystem
+// accessible to TensorFlow.
+//
+// You can build a `descriptor_source` file using the `--descriptor_set_out`
+// and `--include_imports` options to the protocol compiler `protoc`.
+//
+// The `local://` database only covers descriptors linked into the
+// code via C++ libraries, not Python imports. You can link in a proto descriptor
+// by creating a cc_library target with alwayslink=1.
+//
+// There are a few special cases in the value mapping:
+//
+// Submessage and group fields must be pre-serialized as TensorFlow strings.
+//
+// TensorFlow lacks support for unsigned int64s, so they must be
+// represented as `tf.int64` with the same twos-complement bit pattern
+// (the obvious way).
+//
+// Unsigned int32 values can be represented exactly with `tf.int64`, or
+// with sign wrapping if the input is of type `tf.int32`.
+//
+// Arguments:
+// sizes: Tensor of int32 with shape `[batch_shape, len(field_names)]`.
+// values: List of tensors containing values for the corresponding field.
+// field_names: List of strings containing proto field names.
+// message_type: Name of the proto message type to decode.
+//
+// Returns Tensor of serialized protos with shape `batch_shape`.
+func EncodeProto(scope *Scope, sizes tf.Output, values []tf.Output, field_names []string, message_type string, optional ...EncodeProtoAttr) (bytes tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"field_names": field_names, "message_type": message_type}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "EncodeProto",
+ Input: []tf.Input{
+ sizes, tf.OutputList(values),
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Creates a TensorArray for storing the gradients of values in the given handle.
+//
+// If the given TensorArray gradient already exists, returns a reference to it.
+//
+// Locks the size of the original TensorArray by disabling its dynamic size flag.
+//
+// **A note about the input flow_in:**
+//
+// The handle flow_in forces the execution of the gradient lookup to occur
+// only after certain other operations have occurred. For example, when
+// the forward TensorArray is dynamically sized, writes to this TensorArray
+// may resize the object. The gradient TensorArray is statically sized based
+// on the size of the forward TensorArray when this operation executes.
+// Furthermore, the size of the forward TensorArray is frozen by this call.
+// As a result, the flow is used to ensure that the call to generate the gradient
+// TensorArray only happens after all writes are executed.
+//
+// In the case of dynamically sized TensorArrays, gradient computation should
+// only be performed on read operations that have themselves been chained via
+// flow to occur only after all writes have executed. That way the final size
+// of the forward TensorArray is known when this operation is called.
+//
+// **A note about the source attribute:**
+//
+// TensorArray gradient calls use an accumulator TensorArray object. If
+// multiple gradients are calculated and run in the same session, the multiple
+// gradient nodes may accidentally flow through the same accumulator TensorArray.
+// This double counts and generally breaks the TensorArray gradient flow.
+//
+// The solution is to identify which gradient call this particular
+// TensorArray gradient is being called in. This is performed by identifying
+// a unique string (e.g. "gradients", "gradients_1", ...) from the input
+// gradient Tensor's name. This string is used as a suffix when creating
+// the TensorArray gradient object here (the attribute `source`).
+//
+// The attribute `source` is added as a suffix to the forward TensorArray's
+// name when performing the creation / lookup, so that each separate gradient
+// calculation gets its own TensorArray accumulator.
+//
+// Arguments:
+// handle: The handle to the forward TensorArray.
+// flow_in: A float scalar that enforces proper chaining of operations.
+// source: The gradient source string, used to decide which gradient TensorArray
+// to return.
+func TensorArrayGradV3(scope *Scope, handle tf.Output, flow_in tf.Output, source string) (grad_handle tf.Output, flow_out tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"source": source}
+ opspec := tf.OpSpec{
+ Type: "TensorArrayGradV3",
+ Input: []tf.Input{
+ handle, flow_in,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1)
+}
+
+// Creates a dataset that splits a SparseTensor into elements row-wise.
+func SparseTensorSliceDataset(scope *Scope, indices tf.Output, values tf.Output, dense_shape tf.Output) (handle tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseTensorSliceDataset",
+ Input: []tf.Input{
+ indices, values, dense_shape,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Returns x / y element-wise for real types.
+//
+// If `x` and `y` are reals, this will return the floating-point division.
+//
+// *NOTE*: `Div` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func RealDiv(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "RealDiv",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Creates a dataset that concatenates `input_dataset` with `another_dataset`.
func ConcatenateDataset(scope *Scope, input_dataset tf.Output, another_dataset tf.Output, output_types []tf.DataType, output_shapes []tf.Shape) (handle tf.Output) {
if scope.Err() != nil {
@@ -32600,79 +32991,6 @@ func CudnnRNNParamsToCanonical(scope *Scope, num_layers tf.Output, num_units tf.
return weights, biases
}
-// UniformCandidateSamplerAttr is an optional argument to UniformCandidateSampler.
-type UniformCandidateSamplerAttr func(optionalAttr)
-
-// UniformCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func UniformCandidateSamplerSeed(value int64) UniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// UniformCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func UniformCandidateSamplerSeed2(value int64) UniformCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a uniform distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to randomly sample.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-// range_max: The sampler will sample integers from the interval [0, range_max).
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func UniformCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, range_max int64, optional ...UniformCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique, "range_max": range_max}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "UniformCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// CTCLossAttr is an optional argument to CTCLoss.
type CTCLossAttr func(optionalAttr)
@@ -32823,321 +33141,3 @@ func Switch(scope *Scope, data tf.Output, pred tf.Output) (output_false tf.Outpu
op := scope.AddOperation(opspec)
return op.Output(0), op.Output(1)
}
-
-// Add all input tensors element wise.
-//
-// Arguments:
-// inputs: Must all be the same size and shape.
-func AddN(scope *Scope, inputs []tf.Output) (sum tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "AddN",
- Input: []tf.Input{
- tf.OutputList(inputs),
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// TryRpcAttr is an optional argument to TryRpc.
-type TryRpcAttr func(optionalAttr)
-
-// TryRpcProtocol sets the optional protocol attribute to value.
-//
-// value: RPC protocol to use. Empty string means use the default protocol.
-// Options include 'grpc'.
-// If not specified, defaults to ""
-func TryRpcProtocol(value string) TryRpcAttr {
- return func(m optionalAttr) {
- m["protocol"] = value
- }
-}
-
-// TryRpcFailFast sets the optional fail_fast attribute to value.
-//
-// value: `boolean`. If `true` (default), then failures to connect
-// (i.e., the server does not immediately respond) cause an RPC failure.
-// If not specified, defaults to true
-func TryRpcFailFast(value bool) TryRpcAttr {
- return func(m optionalAttr) {
- m["fail_fast"] = value
- }
-}
-
-// TryRpcTimeoutInMs sets the optional timeout_in_ms attribute to value.
-//
-// value: `int`. If `0` (default), then the kernel will run the RPC
-// request and only time out if the RPC deadline passes or the session times out.
-// If this value is greater than `0`, then the op will raise an exception if
-// the RPC takes longer than `timeout_in_ms`.
-// If not specified, defaults to 0
-func TryRpcTimeoutInMs(value int64) TryRpcAttr {
- return func(m optionalAttr) {
- m["timeout_in_ms"] = value
- }
-}
-
-// Perform batches of RPC requests.
-//
-// This op asynchronously performs either a single RPC request, or a batch
-// of requests. RPC requests are defined by three main parameters:
-//
-// - `address` (the host+port or BNS address of the request)
-// - `method` (the method name for the request)
-// - `request` (the serialized proto string, or vector of strings,
-// of the RPC request argument).
-//
-// For example, if you have an RPC service running on port localhost:2345,
-// and its interface is configured with the following proto declaration:
-//
-// ```
-// service MyService {
-// rpc MyMethod(MyRequestProto) returns (MyResponseProto) {
-// }
-// };
-// ```
-//
-// then call this op with arguments:
-//
-// ```
-// address = "localhost:2345"
-// method = "MyService/MyMethod"
-// ```
-//
-// The `request` tensor is a string tensor representing serialized `MyRequestProto`
-// strings; and the output string tensor `response` will have the same shape
-// and contain (upon successful completion) corresponding serialized
-// `MyResponseProto` strings.
-//
-// For example, to send a single, empty, `MyRequestProto`, call
-// this op with `request = ""`. To send 5 **parallel** empty requests,
-// call this op with `request = ["", "", "", "", ""]`.
-//
-// More generally, one can create a batch of `MyRequestProto` serialized protos
-// from regular batched tensors using the `encode_proto` op, and convert
-// the response `MyResponseProto` serialized protos to batched tensors
-// using the `decode_proto` op.
-//
-// **NOTE** Working with serialized proto strings is faster than instantiating
-// actual proto objects in memory, so no performance degradation is expected
-// compared to writing custom kernels for this workflow.
-//
-// Unlike the standard `Rpc` op, if the connection fails or the remote worker
-// returns an error status, this op does **not** reraise the exception.
-// Instead, the `status_code` and `status_message` entry for the corresponding RPC
-// call is set with the error returned from the RPC call. The `response` tensor
-// will contain valid response values for those minibatch entries whose RPCs did
-// not fail; the rest of the entries will have empty strings.
-//
-// Arguments:
-// address: `0-D` or `1-D`. The address (i.e. host_name:port) of the RPC server.
-// If this tensor has more than 1 element, then multiple parallel rpc requests
-// are sent. This argument broadcasts with `method` and `request`.
-// method: `0-D` or `1-D`. The method address on the RPC server.
-// If this tensor has more than 1 element, then multiple parallel rpc requests
-// are sent. This argument broadcasts with `address` and `request`.
-// request: `0-D` or `1-D`. Serialized proto strings: the rpc request argument.
-// If this tensor has more than 1 element, then multiple parallel rpc requests
-// are sent. This argument broadcasts with `address` and `method`.
-//
-// Returns Same shape as `request`. Serialized proto strings: the rpc responses.Same shape as `request`. Values correspond to tensorflow Status enum codes.Same shape as `request`. Values correspond to Status messages
-// returned from the RPC calls.
-func TryRpc(scope *Scope, address tf.Output, method tf.Output, request tf.Output, optional ...TryRpcAttr) (response tf.Output, status_code tf.Output, status_message tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "TryRpc",
- Input: []tf.Input{
- address, method, request,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
-// EnterAttr is an optional argument to Enter.
-type EnterAttr func(optionalAttr)
-
-// EnterIsConstant sets the optional is_constant attribute to value.
-//
-// value: If true, the output is constant within the child frame.
-// If not specified, defaults to false
-func EnterIsConstant(value bool) EnterAttr {
- return func(m optionalAttr) {
- m["is_constant"] = value
- }
-}
-
-// EnterParallelIterations sets the optional parallel_iterations attribute to value.
-//
-// value: The number of iterations allowed to run in parallel.
-// If not specified, defaults to 10
-func EnterParallelIterations(value int64) EnterAttr {
- return func(m optionalAttr) {
- m["parallel_iterations"] = value
- }
-}
-
-// Creates or finds a child frame, and makes `data` available to the child frame.
-//
-// This op is used together with `Exit` to create loops in the graph.
-// The unique `frame_name` is used by the `Executor` to identify frames. If
-// `is_constant` is true, `output` is a constant in the child frame; otherwise
-// it may be changed in the child frame. At most `parallel_iterations` iterations
-// are run in parallel in the child frame.
-//
-// Arguments:
-// data: The tensor to be made available to the child frame.
-// frame_name: The name of the child frame.
-//
-// Returns The same tensor as `data`.
-func Enter(scope *Scope, data tf.Output, frame_name string, optional ...EnterAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"frame_name": frame_name}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Enter",
- Input: []tf.Input{
- data,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Produce a string tensor that encodes the state of a Reader.
-//
-// Not all Readers support being serialized, so this can produce an
-// Unimplemented error.
-//
-// Arguments:
-// reader_handle: Handle to a Reader.
-func ReaderSerializeStateV2(scope *Scope, reader_handle tf.Output) (state tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ReaderSerializeStateV2",
- Input: []tf.Input{
- reader_handle,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Exits the current frame to its parent frame.
-//
-// Exit makes its input `data` available to the parent frame.
-//
-// Arguments:
-// data: The tensor to be made available to the parent frame.
-//
-// Returns The same tensor as `data`.
-func Exit(scope *Scope, data tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Exit",
- Input: []tf.Input{
- data,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns a copy of the input tensor.
-func Snapshot(scope *Scope, input tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Snapshot",
- Input: []tf.Input{
- input,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Returns a tensor of zeros with the same shape and type as x.
-//
-// Arguments:
-// x: a tensor of type T.
-//
-// Returns a tensor of the same shape and type as x but filled with zeros.
-func ZerosLike(scope *Scope, x tf.Output) (y tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "ZerosLike",
- Input: []tf.Input{
- x,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// AbortAttr is an optional argument to Abort.
-type AbortAttr func(optionalAttr)
-
-// AbortErrorMsg sets the optional error_msg attribute to value.
-//
-// value: A string which is the message associated with the exception.
-// If not specified, defaults to ""
-func AbortErrorMsg(value string) AbortAttr {
- return func(m optionalAttr) {
- m["error_msg"] = value
- }
-}
-
-// AbortExitWithoutError sets the optional exit_without_error attribute to value.
-// If not specified, defaults to false
-func AbortExitWithoutError(value bool) AbortAttr {
- return func(m optionalAttr) {
- m["exit_without_error"] = value
- }
-}
-
-// Raise a exception to abort the process when called.
-//
-// If exit_without_error is true, the process will exit normally,
-// otherwise it will exit with a SIGABORT signal.
-//
-// Returns nothing but an exception.
-//
-// Returns the created operation.
-func Abort(scope *Scope, optional ...AbortAttr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Abort",
-
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
diff --git a/tensorflow/go/test.sh b/tensorflow/go/test.sh
index 6083608f22..47c3a68379 100755
--- a/tensorflow/go/test.sh
+++ b/tensorflow/go/test.sh
@@ -63,6 +63,9 @@ then
else
export DYLD_LIBRARY_PATH="${PWD}/tensorflow:${DYLD_LIBRARY_PATH}"
fi
+else
+ echo "Only support Linux/Darwin, System $OS is not supported"
+ exit 1
fi
# Document the Go version and run tests
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 410b3a553a..fe81254ef7 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -1639,6 +1639,15 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "experimental_dataset_ops_gen",
+ visibility = [
+ "//learning/brain/python/ops:__pkg__",
+ "//tensorflow:__subpackages__",
+ "//tensorflow/python/kernel_tests:__pkg__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "image_ops_gen",
visibility = ["//learning/brain/python/ops:__pkg__"],
)
@@ -1731,6 +1740,14 @@ tf_gen_op_wrapper_private_py(
)
tf_gen_op_wrapper_private_py(
+ name = "stateless_random_ops_gen",
+ visibility = [
+ "//tensorflow/contrib/stateless:__pkg__",
+ "//tensorflow/python/data/experimental/ops:__pkg__",
+ ],
+)
+
+tf_gen_op_wrapper_private_py(
name = "list_ops_gen",
)
@@ -2008,6 +2025,7 @@ py_library(
":array_ops",
":cond_v2_impl",
":constant_op",
+ ":control_flow_ops",
":control_flow_util",
":framework_ops",
":function_def_to_graph",
@@ -3292,9 +3310,11 @@ py_library(
"training/checkpointable/**/*.py",
# The following targets have their own build rules (same name as the
# file):
+ "training/basic_session_run_hooks.py",
"training/checkpoint_management.py",
"training/saveable_object.py",
"training/saver.py",
+ "training/session_run_hook.py",
"training/training_util.py",
],
),
@@ -3302,6 +3322,7 @@ py_library(
deps = [
":array_ops",
":array_ops_gen",
+ ":basic_session_run_hooks",
":checkpoint_management",
":checkpoint_ops_gen",
":client",
@@ -3326,6 +3347,7 @@ py_library(
":saver",
":sdca_ops",
":session",
+ ":session_run_hook",
":sparse_ops",
":sparse_tensor",
":state_ops",
@@ -3370,6 +3392,28 @@ py_library(
)
py_library(
+ name = "session_run_hook",
+ srcs = ["training/session_run_hook.py"],
+ srcs_version = "PY2AND3",
+ deps = [":util"],
+)
+
+py_library(
+ name = "basic_session_run_hooks",
+ srcs = ["training/basic_session_run_hooks.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":client",
+ ":framework",
+ ":platform",
+ ":protos_all_py",
+ ":session_run_hook",
+ ":training_util",
+ ":util",
+ ],
+)
+
+py_library(
name = "saver",
srcs = ["training/saver.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/python/autograph/CONTRIBUTING.md b/tensorflow/python/autograph/CONTRIBUTING.md
index 1ded5ba5f6..f3587a4384 100644
--- a/tensorflow/python/autograph/CONTRIBUTING.md
+++ b/tensorflow/python/autograph/CONTRIBUTING.md
@@ -9,8 +9,6 @@ In preparation for TF 2.0, we moved the code base of AutoGraph from
does not impact functionality, and AutoGraph will remain accessible under
`tensorflow.contrib.autograph` until `tensorflow.contrib` is retired.
-When
-
## TensorFlow Code of Conduct
Please review and follow the [TensorFlow Code of Conduct](../../CODE_OF_CONDUCT.md).
diff --git a/tensorflow/python/autograph/converters/BUILD b/tensorflow/python/autograph/converters/BUILD
index 7b029de8ed..f06dc78f0e 100644
--- a/tensorflow/python/autograph/converters/BUILD
+++ b/tensorflow/python/autograph/converters/BUILD
@@ -27,10 +27,10 @@ py_library(
"decorators.py",
"directives.py",
"error_handlers.py",
+ "function_scopes.py",
"list_comprehensions.py",
"lists.py",
"logical_expressions.py",
- "name_scopes.py",
"return_statements.py",
"side_effect_guards.py",
"slices.py",
@@ -157,8 +157,8 @@ py_test(
)
py_test(
- name = "name_scopes_test",
- srcs = ["name_scopes_test.py"],
+ name = "function_scopes_test",
+ srcs = ["function_scopes_test.py"],
deps = [
":converters",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/autograph/converters/name_scopes.py b/tensorflow/python/autograph/converters/function_scopes.py
index a9c55ccff0..284b5b3519 100644
--- a/tensorflow/python/autograph/converters/name_scopes.py
+++ b/tensorflow/python/autograph/converters/function_scopes.py
@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Wraps a function body with a `name_scope` of the function name."""
+"""Wraps the body of a converted function with auxiliary constructs."""
from __future__ import absolute_import
from __future__ import division
@@ -24,8 +24,8 @@ from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.pyct import templates
-class FunctionNameScopeTransformer(converter.Base):
- """Wrap a function body with a `name_scope` of the function name."""
+class FunctionBodyTransformer(converter.Base):
+ """Wraps function bodies around autograph-specific boilerplate."""
def _name_for_current_scope(self):
innermost = self.enclosing_entities[-1]
@@ -49,26 +49,28 @@ class FunctionNameScopeTransformer(converter.Base):
def visit_FunctionDef(self, node):
node = self.generic_visit(node)
- unscoped_body = []
- scoped_body = node.body
- if scoped_body:
- first = scoped_body[0]
- if isinstance(first, gast.Expr) and isinstance(first.value, gast.Str):
- # Skip any docstring.
- unscoped_body = scoped_body[:1]
- scoped_body = scoped_body[1:]
+ final_body = []
+ indented_body = node.body
+ if node.body:
+ first_statement = node.body[0]
+ # Skip the docstring, if any.
+ if (isinstance(first_statement, gast.Expr) and
+ isinstance(first_statement.value, gast.Str)):
+ indented_body = indented_body[1:]
+ final_body.append(first_statement)
template = """
- with tf.name_scope(scope_name):
+ with ag__.function_scope(scope_name):
body
"""
scoped_body = templates.replace(
template,
scope_name=gast.Str(self._name_for_current_scope()),
- body=scoped_body)
- node.body = unscoped_body + scoped_body
+ body=indented_body)
+ final_body.extend(scoped_body)
+ node.body = final_body
return node
def transform(node, ctx):
- return FunctionNameScopeTransformer(ctx).visit(node)
+ return FunctionBodyTransformer(ctx).visit(node)
diff --git a/tensorflow/python/autograph/converters/name_scopes_test.py b/tensorflow/python/autograph/converters/function_scopes_test.py
index 73933c1c4f..e5ce03a109 100644
--- a/tensorflow/python/autograph/converters/name_scopes_test.py
+++ b/tensorflow/python/autograph/converters/function_scopes_test.py
@@ -12,51 +12,51 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Tests for for_canonicalization module."""
+"""Tests for function_scopes module."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.autograph.converters import name_scopes
+from tensorflow.python.autograph.converters import function_scopes
from tensorflow.python.autograph.core import converter_testing
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
-class FunctionNameScopeTransformer(converter_testing.TestCase):
+class FunctionBodyTransformerTest(converter_testing.TestCase):
def test_basic(self):
def test_fn(l):
- """This should stay here."""
+ """Docstring."""
a = 1
l += a
return l
- with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
+ with self.converted(test_fn, function_scopes, {}) as result:
result_op = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', result_op.op.name)
- self.assertEqual('This should stay here.', result.test_fn.__doc__)
+ self.assertEqual('Docstring.', result.test_fn.__doc__)
- def test_long_docstring(self):
+ def test_multiline_docstring(self):
- def test_fn(l):
- """Multi-line docstring.
+ tf = None
+
+ def test_fn():
+ """First sentence.
- Args:
- l: A thing.
- Returns:
- l
+ Second sentence.
"""
- return l + 1
+ return tf.constant(1)
- with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
- result_op = result.test_fn(constant_op.constant(1))
+ with self.converted(test_fn, function_scopes, {},
+ constant_op.constant) as result:
+ result_op = result.test_fn()
self.assertIn('test_fn/', result_op.op.name)
- self.assertIn('Multi-line docstring.', result.test_fn.__doc__)
- self.assertIn('Returns:', result.test_fn.__doc__)
+ self.assertIn('First sentence.', result.test_fn.__doc__)
+ self.assertIn('Second sentence.', result.test_fn.__doc__)
def test_nested_functions(self):
@@ -68,7 +68,7 @@ class FunctionNameScopeTransformer(converter_testing.TestCase):
l += 1
return l, inner_fn(l)
- with self.converted(test_fn, name_scopes, {}, ops.name_scope) as result:
+ with self.converted(test_fn, function_scopes, {}, ops.name_scope) as result:
first, second = result.test_fn(constant_op.constant(1))
self.assertIn('test_fn/', first.op.name)
self.assertNotIn('inner_fn', first.op.name)
@@ -88,7 +88,7 @@ class FunctionNameScopeTransformer(converter_testing.TestCase):
ns = {'TestClass': TestClass}
node, ctx = self.prepare(TestClass, ns, owner_type=TestClass)
- node = name_scopes.transform(node, ctx)
+ node = function_scopes.transform(node, ctx)
with self.compiled(node, {}, ops.name_scope) as result:
first, second = result.TestClass().test_fn(constant_op.constant(1))
diff --git a/tensorflow/python/autograph/core/BUILD b/tensorflow/python/autograph/core/BUILD
index 85fecf084d..3ab2e7b1bc 100644
--- a/tensorflow/python/autograph/core/BUILD
+++ b/tensorflow/python/autograph/core/BUILD
@@ -20,17 +20,48 @@ py_library(
"config.py",
"converter.py",
"errors.py",
+ "function_wrapping.py",
"naming.py",
],
srcs_version = "PY2AND3",
visibility = ["//tensorflow:__subpackages__"],
deps = [
+ "//tensorflow/python:framework_ops",
"//tensorflow/python/autograph/pyct",
"//tensorflow/python/autograph/pyct/static_analysis",
"//tensorflow/python/autograph/utils",
],
)
+py_library(
+ name = "test_lib",
+ srcs = [
+ "converter_testing.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = ["//tensorflow:__subpackages__"],
+ deps = [
+ ":core",
+ "//tensorflow/python/autograph/operators",
+ "//tensorflow/python/autograph/pyct",
+ "//tensorflow/python/autograph/pyct/static_analysis",
+ "//tensorflow/python/autograph/utils",
+ "@gast_archive//:gast",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "converter_test",
+ srcs = ["converter_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":core",
+ ":test_lib",
+ "//tensorflow/python:client_testlib",
+ ],
+)
+
py_test(
name = "errors_test",
srcs = ["errors_test.py"],
@@ -47,8 +78,8 @@ py_test(
)
py_test(
- name = "naming_test",
- srcs = ["naming_test.py"],
+ name = "function_wrapping_test",
+ srcs = ["function_wrapping_test.py"],
srcs_version = "PY2AND3",
deps = [
":core",
@@ -56,20 +87,12 @@ py_test(
],
)
-py_library(
- name = "test_lib",
- srcs = [
- "converter_testing.py",
- ],
+py_test(
+ name = "naming_test",
+ srcs = ["naming_test.py"],
srcs_version = "PY2AND3",
- visibility = ["//tensorflow:__subpackages__"],
deps = [
":core",
- "//tensorflow/python/autograph/operators",
- "//tensorflow/python/autograph/pyct",
- "//tensorflow/python/autograph/pyct/static_analysis",
- "//tensorflow/python/autograph/utils",
- "@gast_archive//:gast",
- "@six_archive//:six",
+ "//tensorflow/python:client_testlib",
],
)
diff --git a/tensorflow/python/autograph/core/converter.py b/tensorflow/python/autograph/core/converter.py
index 80928ae7f4..408a573ad0 100644
--- a/tensorflow/python/autograph/core/converter.py
+++ b/tensorflow/python/autograph/core/converter.py
@@ -210,14 +210,22 @@ class Base(transformer.Base):
self._ast_depth = 0
def get_definition_directive(self, node, directive, arg, default):
- """Returns the unique directive for a symbol, or a default if none exist.
+ """Returns the unique directive argument for a symbol.
See lang/directives.py for details on directives.
+ Example:
+ # Given a directive in the code:
+ ag.foo_directive(bar, baz=1)
+
+ # One can write for an AST node Name(id='bar'):
+ get_definition_directive(node, ag.foo_directive, 'baz')
+
Args:
- node: ast.AST
- directive: Callable[..., Any]
- arg: str
+ node: ast.AST, the node representing the symbol for which the directive
+ argument is needed.
+ directive: Callable[..., Any], the directive to search.
+ arg: str, the directive argument to return.
default: Any
Raises:
@@ -227,27 +235,28 @@ class Base(transformer.Base):
if not defs:
return default
- # TODO(mdan): Simplify this.
- arg_values = []
+ arg_values_found = []
for def_ in defs:
- if (directive not in def_.directives or
- arg not in def_.directives[directive]):
- continue
- arg_value = def_.directives[directive][arg]
- for prev_value in arg_values:
- if not ast_util.matches(arg_value, prev_value):
- qn = anno.getanno(node, anno.Basic.QN)
- raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
- (qn, directive.__name__, arg,
- compiler.ast_to_source(arg_value).strip(),
- compiler.ast_to_source(prev_value).strip()))
- arg_values.append(arg_value)
-
- if not arg_values:
+ if (directive in def_.directives and arg in def_.directives[directive]):
+ arg_values_found.append(def_.directives[directive][arg])
+
+ if not arg_values_found:
return default
- arg_value, = arg_values
- return arg_value
+ if len(arg_values_found) == 1:
+ return arg_values_found[0]
+
+ # If multiple annotations reach the symbol, they must all match. If they do,
+ # return any of them.
+ first_value = arg_values_found[0]
+ for other_value in arg_values_found[1:]:
+ if not ast_util.matches(first_value, other_value):
+ qn = anno.getanno(node, anno.Basic.QN)
+ raise ValueError('%s has ambiguous annotations for %s(%s): %s, %s' %
+ (qn, directive.__name__, arg,
+ compiler.ast_to_source(other_value).strip(),
+ compiler.ast_to_source(first_value).strip()))
+ return first_value
def visit(self, node):
if not self._ast_depth:
diff --git a/tensorflow/python/autograph/core/converter_test.py b/tensorflow/python/autograph/core/converter_test.py
new file mode 100644
index 0000000000..b73c67e337
--- /dev/null
+++ b/tensorflow/python/autograph/core/converter_test.py
@@ -0,0 +1,124 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for lists module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import converter
+from tensorflow.python.autograph.core import converter_testing
+from tensorflow.python.autograph.pyct import anno
+from tensorflow.python.autograph.pyct import parser
+from tensorflow.python.platform import test
+
+
+class TestConverter(converter.Base):
+ pass
+
+
+class ConverterBaseTest(converter_testing.TestCase):
+
+ def test_get_definition_directive_basic(self):
+
+ directive_key = object
+
+ def test_fn():
+ a = 1
+ return a
+
+ ns = {}
+ node, ctx = self.prepare(test_fn, ns)
+ symbol_a = node.body[1].value
+ defs, = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
+ defs.directives[directive_key] = {
+ 'test_arg': parser.parse_expression('foo'),
+ 'other_arg': parser.parse_expression('bar'),
+ }
+ c = TestConverter(ctx)
+ value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
+ None)
+ self.assertEqual(value.id, 'foo')
+
+ def test_get_definition_directive_default(self):
+
+ directive_key = object
+
+ def test_fn():
+ a = 1
+ return a
+
+ ns = {}
+ node, ctx = self.prepare(test_fn, ns)
+ symbol_a = node.body[1].value
+ c = TestConverter(ctx)
+ value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
+ parser.parse_expression('default'))
+ self.assertEqual(value.id, 'default')
+
+ def test_get_definition_directive_multiple_consistent(self):
+
+ directive_key = object
+
+ def test_fn():
+ a = 1
+ if a:
+ a = 2
+ return a
+
+ ns = {}
+ node, ctx = self.prepare(test_fn, ns)
+ symbol_a = node.body[2].value
+ defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
+ defs[0].directives[directive_key] = {
+ 'test_arg': parser.parse_expression('foo'),
+ 'other_arg': parser.parse_expression('bar'),
+ }
+ defs[1].directives[directive_key] = {
+ 'test_arg': parser.parse_expression('foo'),
+ 'other_arg': parser.parse_expression('baz'),
+ }
+ c = TestConverter(ctx)
+ value = c.get_definition_directive(symbol_a, directive_key, 'test_arg',
+ None)
+ self.assertEqual(value.id, 'foo')
+
+ def test_get_definition_directive_multiple_inconsistent(self):
+
+ directive_key = object
+
+ def test_fn():
+ a = 1
+ if a:
+ a = 2
+ return a
+
+ ns = {}
+ node, ctx = self.prepare(test_fn, ns)
+ symbol_a = node.body[2].value
+ defs = anno.getanno(symbol_a, anno.Static.ORIG_DEFINITIONS)
+ defs[0].directives[directive_key] = {
+ 'test_arg': parser.parse_expression('foo'),
+ }
+ defs[1].directives[directive_key] = {
+ 'test_arg': parser.parse_expression('bar'),
+ }
+ c = TestConverter(ctx)
+ with self.assertRaises(ValueError):
+ c.get_definition_directive(symbol_a, directive_key, 'test_arg', None)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py
index 7ce1b7c4c5..dc2d419d34 100644
--- a/tensorflow/python/autograph/core/converter_testing.py
+++ b/tensorflow/python/autograph/core/converter_testing.py
@@ -29,6 +29,7 @@ from tensorflow.python.autograph import utils
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.core import function_wrapping
from tensorflow.python.autograph.pyct import compiler
from tensorflow.python.autograph.pyct import parser
from tensorflow.python.autograph.pyct import pretty_printer
@@ -112,6 +113,7 @@ class TestCase(test.TestCase):
fake_ag.__dict__['utils'] = utils
fake_ag.__dict__['rewrite_graph_construction_error'] = (
errors.rewrite_graph_construction_error)
+ fake_ag.__dict__['function_scope'] = function_wrapping.function_scope
result.__dict__['ag__'] = fake_ag
for k, v in namespace.items():
result.__dict__[k] = v
diff --git a/tensorflow/python/autograph/core/function_wrapping.py b/tensorflow/python/autograph/core/function_wrapping.py
new file mode 100644
index 0000000000..21b66eff02
--- /dev/null
+++ b/tensorflow/python/autograph/core/function_wrapping.py
@@ -0,0 +1,30 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Support for wrapping converted functions bodies with auxiliary logic."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+
+from tensorflow.python.framework import ops
+
+
+@contextlib.contextmanager
+def function_scope(function_name):
+ """Returns a context manager for the converted body of a function."""
+ with ops.name_scope(function_name):
+ yield
diff --git a/tensorflow/python/autograph/core/function_wrapping_test.py b/tensorflow/python/autograph/core/function_wrapping_test.py
new file mode 100644
index 0000000000..5e217055c7
--- /dev/null
+++ b/tensorflow/python/autograph/core/function_wrapping_test.py
@@ -0,0 +1,34 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for function_wrapping module."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.autograph.core import function_wrapping
+from tensorflow.python.framework import constant_op
+from tensorflow.python.platform import test
+
+
+class FunctionWrappingTest(test.TestCase):
+
+ def test_function_scope_name(self):
+ with function_wrapping.function_scope('test_name'):
+ t = constant_op.constant(1)
+ self.assertIn('test_name', t.name)
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py
index a0d13c82a8..52abd40626 100644
--- a/tensorflow/python/autograph/impl/conversion.py
+++ b/tensorflow/python/autograph/impl/conversion.py
@@ -34,15 +34,16 @@ from tensorflow.python.autograph.converters import control_flow
from tensorflow.python.autograph.converters import decorators
from tensorflow.python.autograph.converters import directives
from tensorflow.python.autograph.converters import error_handlers
+from tensorflow.python.autograph.converters import function_scopes
from tensorflow.python.autograph.converters import lists
from tensorflow.python.autograph.converters import logical_expressions
-from tensorflow.python.autograph.converters import name_scopes
from tensorflow.python.autograph.converters import return_statements
from tensorflow.python.autograph.converters import side_effect_guards
from tensorflow.python.autograph.converters import slices
from tensorflow.python.autograph.core import config
from tensorflow.python.autograph.core import converter
from tensorflow.python.autograph.core import errors
+from tensorflow.python.autograph.core import function_wrapping
from tensorflow.python.autograph.pyct import ast_util
from tensorflow.python.autograph.pyct import inspect_utils
from tensorflow.python.autograph.pyct import origin_info
@@ -257,6 +258,7 @@ def _add_self_references(namespace, autograph_module):
ag_internal.converted_call = autograph_module.converted_call
ag_internal.ConversionOptions = autograph_module.ConversionOptions
ag_internal.utils = utils
+ ag_internal.function_scope = function_wrapping.function_scope
ag_internal.rewrite_graph_construction_error = (
errors.rewrite_graph_construction_error)
# TODO(mdan): Add safeguards against name clashes.
@@ -346,7 +348,7 @@ def node_to_graph(node, context, rewrite_errors=True):
node = converter.apply_(node, context, conditional_expressions)
node = converter.apply_(node, context, logical_expressions)
node = converter.apply_(node, context, side_effect_guards)
- node = converter.apply_(node, context, name_scopes)
+ node = converter.apply_(node, context, function_scopes)
if rewrite_errors:
node = converter.apply_(node, context, error_handlers)
return node
diff --git a/tensorflow/python/autograph/lang/special_functions.py b/tensorflow/python/autograph/lang/special_functions.py
index e4838d1b6d..62ac018ac4 100644
--- a/tensorflow/python/autograph/lang/special_functions.py
+++ b/tensorflow/python/autograph/lang/special_functions.py
@@ -24,6 +24,26 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.autograph.operators import data_structures
+from tensorflow.python.framework import tensor_util
+
+
+def _validate_list_constructor(elements, element_dtype, element_shape):
+ """Validates the inputs of tensor_list."""
+ if element_dtype is not None and element_shape is not None:
+ return
+ if tensor_util.is_tensor(elements):
+ return
+ if isinstance(elements, (list, tuple)):
+ if elements:
+ return
+ else:
+ raise ValueError(
+ 'element_dtype and element_shape are required when elements are'
+ ' empty')
+
+ raise ValueError(
+ 'unknown type for elements: {}; only Tensor, list and tuple are'
+ ' allowed'.format(type(elements)))
def tensor_list(elements,
@@ -52,9 +72,7 @@ def tensor_list(elements,
Raises:
ValueError: for invalid arguments
"""
- if not (elements or (element_dtype and element_shape)):
- raise ValueError(
- 'element_dtype and element_shape are required for empty lists')
+ _validate_list_constructor(elements, element_dtype, element_shape)
if use_tensor_array:
return data_structures.tf_tensor_array_new(elements, element_dtype,
element_shape)
diff --git a/tensorflow/python/autograph/lang/special_functions_test.py b/tensorflow/python/autograph/lang/special_functions_test.py
index 545dd11729..206a32d07c 100644
--- a/tensorflow/python/autograph/lang/special_functions_test.py
+++ b/tensorflow/python/autograph/lang/special_functions_test.py
@@ -18,6 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.python.autograph.lang import special_functions
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -28,12 +30,43 @@ from tensorflow.python.platform import test
class SpecialFunctionsTest(test.TestCase):
+ def test_tensor_list_empty_list(self):
+ l = special_functions.tensor_list([],
+ element_dtype=dtypes.int32,
+ element_shape=())
+ sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(sl), [])
+
+ l = special_functions.tensor_list((),
+ element_dtype=dtypes.int32,
+ element_shape=())
+ sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(sl), [])
+
+ def test_tensor_list_tensor(self):
+ l = special_functions.tensor_list(
+ constant_op.constant([], dtype=dtypes.int32))
+ sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.test_session() as sess:
+ self.assertAllEqual(sess.run(sl), [])
+
+ def test_tensor_list_unsupported_initializer(self):
+ with self.assertRaisesRegexp(ValueError, 'unknown type'):
+ special_functions.tensor_list(np.array([1, 2, 3]))
+
+ def test_tensor_list_empty_list_no_type(self):
+ with self.assertRaisesRegexp(
+ ValueError, 'element_dtype and element_shape are required'):
+ special_functions.tensor_list([])
+
def test_tensor_list_from_elements(self):
elements = [constant_op.constant([1, 2]), constant_op.constant([3, 4])]
l = special_functions.tensor_list(elements)
sl = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
- with self.cached_session() as sess:
+ with self.test_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_tensor_list_array_from_elements(self):
@@ -41,7 +74,7 @@ class SpecialFunctionsTest(test.TestCase):
l = special_functions.tensor_list(elements, use_tensor_array=True)
sl = l.stack()
- with self.cached_session() as sess:
+ with self.test_session() as sess:
self.assertAllEqual(sess.run(sl), [[1, 2], [3, 4]])
def test_stack(self):
diff --git a/tensorflow/python/autograph/operators/data_structures.py b/tensorflow/python/autograph/operators/data_structures.py
index cc0a3c3544..b3a3851333 100644
--- a/tensorflow/python/autograph/operators/data_structures.py
+++ b/tensorflow/python/autograph/operators/data_structures.py
@@ -106,6 +106,14 @@ def tf_tensor_array_new(elements, element_dtype=None, element_shape=None):
def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
"""Overload of new_list that stages a Tensor list creation."""
+ if tensor_util.is_tensor(elements):
+ if element_shape is not None:
+ raise ValueError(
+ 'element shape may not be specified when creating list from tensor')
+ element_shape = array_ops.shape(elements)[1:]
+ l = list_ops.tensor_list_from_tensor(elements, element_shape=element_shape)
+ return l
+
elements = tuple(ops.convert_to_tensor(el) for el in elements)
all_dtypes = set(el.dtype for el in elements)
@@ -115,13 +123,15 @@ def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
raise ValueError(
'incompatible dtype; specified: {}, inferred from {}: {}'.format(
element_dtype, elements, inferred_dtype))
- else:
+ elif all_dtypes:
# Heterogeneous lists are ok.
if element_dtype is not None:
raise ValueError(
'specified dtype {} is inconsistent with that of elements {}'.format(
element_dtype, elements))
inferred_dtype = dtypes.variant
+ else:
+ inferred_dtype = dtypes.variant
all_shapes = set(tuple(el.shape.as_list()) for el in elements)
if len(all_shapes) == 1:
@@ -130,19 +140,22 @@ def tf_tensor_list_new(elements, element_dtype=None, element_shape=None):
raise ValueError(
'incompatible shape; specified: {}, inferred from {}: {}'.format(
element_shape, elements, inferred_shape))
- else:
+ elif all_shapes:
# Heterogeneous lists are ok.
if element_shape is not None:
raise ValueError(
'specified shape {} is inconsistent with that of elements {}'.format(
element_shape, elements))
inferred_shape = constant_op.constant(-1) # unknown shape, by convention
+ else:
+ inferred_shape = constant_op.constant(-1) # unknown shape, by convention
if element_dtype is None:
element_dtype = inferred_dtype
if element_shape is None:
element_shape = inferred_shape
+ element_shape = ops.convert_to_tensor(element_shape, dtype=dtypes.int32)
l = list_ops.empty_tensor_list(
element_shape=element_shape, element_dtype=element_dtype)
for el in elements:
diff --git a/tensorflow/python/autograph/operators/data_structures_test.py b/tensorflow/python/autograph/operators/data_structures_test.py
index 8532dbe466..6039b07982 100644
--- a/tensorflow/python/autograph/operators/data_structures_test.py
+++ b/tensorflow/python/autograph/operators/data_structures_test.py
@@ -45,6 +45,20 @@ class ListTest(test.TestCase):
with self.cached_session() as sess:
self.assertAllEqual(sess.run(t), [3, 4, 5])
+ def test_tf_tensor_list_new_empty(self):
+ l = data_structures.tf_tensor_list_new([],
+ element_dtype=dtypes.int32,
+ element_shape=())
+ t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(t), [])
+
+ def test_tf_tensor_list_new_from_tensor(self):
+ l = data_structures.tf_tensor_list_new(constant_op.constant([3, 4, 5]))
+ t = list_ops.tensor_list_stack(l, element_dtype=dtypes.int32)
+ with self.cached_session() as sess:
+ self.assertAllEqual(sess.run(t), [3, 4, 5])
+
def test_tf_tensor_list_new_illegal_input(self):
with self.assertRaises(ValueError):
data_structures.tf_tensor_list_new([3, 4.0])
@@ -56,9 +70,8 @@ class ListTest(test.TestCase):
with self.assertRaises(ValueError):
data_structures.tf_tensor_list_new([3, 4], element_shape=(2,))
with self.assertRaises(ValueError):
- data_structures.tf_tensor_list_new([], element_shape=(2,))
- with self.assertRaises(ValueError):
- data_structures.tf_tensor_list_new([], element_dtype=dtypes.float32)
+ data_structures.tf_tensor_list_new(
+ constant_op.constant([1, 2, 3]), element_shape=[1])
def test_tf_tensor_array_new(self):
l = data_structures.tf_tensor_array_new([3, 4, 5])
@@ -141,6 +154,18 @@ class ListTest(test.TestCase):
t = data_structures.list_stack(l, opts)
self.assertAllEqual(sess.run(t), sess.run(initial_list))
+ def test_stack_tensor_list_empty(self):
+ l = list_ops.empty_tensor_list(
+ element_shape=-1,
+ element_dtype=dtypes.variant)
+
+ opts = data_structures.ListStackOpts(
+ element_dtype=dtypes.int32, original_call=None)
+
+ # TODO(mdan): Allow stacking empty lists if the dtype and shape are known.
+ with self.assertRaises(ValueError):
+ data_structures.list_stack(l, opts)
+
def test_stack_fallback(self):
def dummy_function(l):
diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py
index b74fce3a4c..3bb95b56c2 100644
--- a/tensorflow/python/compat/compat.py
+++ b/tensorflow/python/compat/compat.py
@@ -26,7 +26,7 @@ import datetime
from tensorflow.python.util import tf_contextlib
from tensorflow.python.util.tf_export import tf_export
-_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 9, 28)
+_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2018, 10, 2)
@tf_export("compat.forward_compatible")
diff --git a/tensorflow/python/data/BUILD b/tensorflow/python/data/BUILD
index 138141f4fc..e32eeecbb8 100644
--- a/tensorflow/python/data/BUILD
+++ b/tensorflow/python/data/BUILD
@@ -10,6 +10,7 @@ py_library(
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:util",
+ "//tensorflow/python/data/experimental",
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//tensorflow/python/data/ops:multi_device_iterator_ops",
diff --git a/tensorflow/python/data/__init__.py b/tensorflow/python/data/__init__.py
index f8b561205e..7536ba668a 100644
--- a/tensorflow/python/data/__init__.py
+++ b/tensorflow/python/data/__init__.py
@@ -22,6 +22,7 @@ from __future__ import division
from __future__ import print_function
# pylint: disable=unused-import
+from tensorflow.python.data import experimental
from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.data.ops.iterator_ops import Iterator
from tensorflow.python.data.ops.readers import FixedLengthRecordDataset
diff --git a/tensorflow/python/data/experimental/BUILD b/tensorflow/python/data/experimental/BUILD
new file mode 100644
index 0000000000..84e761d376
--- /dev/null
+++ b/tensorflow/python/data/experimental/BUILD
@@ -0,0 +1,16 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+py_library(
+ name = "experimental",
+ srcs = ["__init__.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:dataset_ops",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/__init__.py b/tensorflow/python/data/experimental/__init__.py
new file mode 100644
index 0000000000..2ac159d38a
--- /dev/null
+++ b/tensorflow/python/data/experimental/__init__.py
@@ -0,0 +1,109 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for building input pipelines.
+
+This module contains experimental `Dataset` sources and transformations that can
+be used in conjunction with the `tf.data.Dataset` API. Note that the
+`tf.data.experimental` API is not subject to the same backwards compatibility
+guarantees as `tf.data`, but we will provide deprecation advice in advance of
+removing existing functionality.
+
+See [Importing Data](https://tensorflow.org/guide/datasets) for an overview.
+
+@@Counter
+@@CheckpointInputPipelineHook
+@@CsvDataset
+@@Optional
+@@RandomDataset
+@@Reducer
+@@SqlDataset
+@@TFRecordWriter
+
+@@bucket_by_sequence_length
+@@choose_from_datasets
+@@copy_to_device
+@@dense_to_sparse_batch
+@@enumerate_dataset
+@@get_next_as_optional
+@@get_single_element
+@@group_by_reducer
+@@group_by_window
+@@ignore_errors
+@@latency_stats
+@@make_batched_features_dataset
+@@make_csv_dataset
+@@make_saveable_from_iterator
+@@map_and_batch
+@@parallel_interleave
+@@parse_example_dataset
+@@prefetch_to_device
+@@rejection_resample
+@@sample_from_datasets
+@@scan
+@@set_stats_aggregator
+@@shuffle_and_repeat
+@@StatsAggregator
+@@unbatch
+@@unique
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=unused-import
+
+from tensorflow.python.data.experimental.ops.batching import dense_to_sparse_batch
+from tensorflow.python.data.experimental.ops.batching import map_and_batch
+from tensorflow.python.data.experimental.ops.batching import unbatch
+from tensorflow.python.data.experimental.ops.counter import Counter
+from tensorflow.python.data.experimental.ops.enumerate_ops import enumerate_dataset
+from tensorflow.python.data.experimental.ops.error_ops import ignore_errors
+from tensorflow.python.data.experimental.ops.get_single_element import get_single_element
+from tensorflow.python.data.experimental.ops.grouping import bucket_by_sequence_length
+from tensorflow.python.data.experimental.ops.grouping import group_by_reducer
+from tensorflow.python.data.experimental.ops.grouping import group_by_window
+from tensorflow.python.data.experimental.ops.grouping import Reducer
+from tensorflow.python.data.experimental.ops.interleave_ops import choose_from_datasets
+from tensorflow.python.data.experimental.ops.interleave_ops import parallel_interleave
+from tensorflow.python.data.experimental.ops.interleave_ops import sample_from_datasets
+from tensorflow.python.data.experimental.ops.iterator_ops import CheckpointInputPipelineHook
+from tensorflow.python.data.experimental.ops.iterator_ops import make_saveable_from_iterator
+
+# Optimization constant that can be used to enable auto-tuning.
+from tensorflow.python.data.experimental.ops.optimization import AUTOTUNE
+
+from tensorflow.python.data.experimental.ops.parsing_ops import parse_example_dataset
+from tensorflow.python.data.experimental.ops.prefetching_ops import copy_to_device
+from tensorflow.python.data.experimental.ops.prefetching_ops import prefetch_to_device
+from tensorflow.python.data.experimental.ops.random_ops import RandomDataset
+from tensorflow.python.data.experimental.ops.readers import CsvDataset
+from tensorflow.python.data.experimental.ops.readers import make_batched_features_dataset
+from tensorflow.python.data.experimental.ops.readers import make_csv_dataset
+from tensorflow.python.data.experimental.ops.readers import SqlDataset
+from tensorflow.python.data.experimental.ops.resampling import rejection_resample
+from tensorflow.python.data.experimental.ops.scan_ops import scan
+from tensorflow.python.data.experimental.ops.shuffle_ops import shuffle_and_repeat
+from tensorflow.python.data.experimental.ops.stats_ops import latency_stats
+from tensorflow.python.data.experimental.ops.stats_ops import set_stats_aggregator
+from tensorflow.python.data.experimental.ops.stats_ops import StatsAggregator
+from tensorflow.python.data.experimental.ops.unique import unique
+from tensorflow.python.data.experimental.ops.writers import TFRecordWriter
+from tensorflow.python.data.ops.iterator_ops import get_next_as_optional
+from tensorflow.python.data.ops.optional_ops import Optional
+# pylint: enable=unused-import
+
+from tensorflow.python.util.all_util import remove_undocumented
+remove_undocumented(__name__)
diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD
new file mode 100644
index 0000000000..f56127f3ef
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/BUILD
@@ -0,0 +1,662 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "cuda_py_test")
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+py_test(
+ name = "batch_dataset_op_test",
+ size = "medium",
+ srcs = ["batch_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss", # (b/79552534)
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "bucketing_test",
+ size = "medium",
+ srcs = ["bucketing_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "csv_dataset_op_test",
+ size = "medium",
+ srcs = ["csv_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:platform_test",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:error_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/eager:context",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "dataset_constructor_op_test",
+ size = "medium",
+ srcs = ["dataset_constructor_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "manual",
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ "nomac", # b/62040583
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_test(
+ name = "directed_interleave_dataset_test",
+ size = "medium",
+ srcs = ["directed_interleave_dataset_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "get_single_element_test",
+ size = "small",
+ srcs = ["get_single_element_test.py"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:get_single_element",
+ "//tensorflow/python/data/experimental/ops:grouping",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "indexed_dataset_ops_test",
+ srcs = ["indexed_dataset_ops_test.py"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/experimental/ops:indexed_dataset_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "interleave_dataset_op_test",
+ size = "medium",
+ srcs = ["interleave_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ "notap",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python:sparse_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "iterator_ops_test",
+ size = "small",
+ srcs = ["iterator_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:training",
+ "//tensorflow/python:variables",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/estimator:estimator_py",
+ ],
+)
+
+py_test(
+ name = "map_dataset_op_test",
+ size = "medium",
+ srcs = ["map_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ "noasan", # times out
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:error_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "filter_dataset_op_test",
+ size = "medium",
+ srcs = ["filter_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:io_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "map_defun_op_test",
+ size = "small",
+ srcs = ["map_defun_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:data_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:map_defun",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ ],
+)
+
+py_test(
+ name = "parsing_ops_test",
+ size = "small",
+ srcs = ["parsing_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:parsing_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+cuda_py_test(
+ name = "prefetching_ops_test",
+ size = "small",
+ srcs = ["prefetching_ops_test.py"],
+ additional_deps = [
+ "//tensorflow/python/data/experimental/ops:prefetching_ops",
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:function",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/compat:compat",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ ],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ "no_windows_gpu",
+ ],
+)
+
+py_test(
+ name = "range_dataset_op_test",
+ size = "small",
+ srcs = ["range_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/experimental/ops:counter",
+ "//tensorflow/python/data/experimental/ops:enumerate_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "reader_dataset_ops_test_base",
+ testonly = 1,
+ srcs = [
+ "reader_dataset_ops_test_base.py",
+ ],
+ srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow/python/data/experimental/kernel_tests:__pkg__",
+ "//tensorflow/python/data/experimental/kernel_tests/serialization:__pkg__",
+ ],
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
+
+py_test(
+ name = "reader_dataset_ops_test",
+ size = "medium",
+ srcs = ["reader_dataset_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":reader_dataset_ops_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "resample_test",
+ size = "medium",
+ srcs = ["resample_test.py"],
+ shard_count = 2,
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ "noasan",
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:string_ops",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:resampling",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "scan_dataset_op_test",
+ size = "small",
+ srcs = ["scan_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:scan_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/eager:context",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "shuffle_dataset_op_test",
+ size = "medium",
+ srcs = ["shuffle_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ "optonly",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:shuffle_ops",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "sql_dataset_op_test_base",
+ srcs = ["sql_dataset_op_test_base.py"],
+ srcs_version = "PY2AND3",
+ visibility = [
+ "//tensorflow/python/data/experimental/kernel_tests:__pkg__",
+ "//tensorflow/python/data/experimental/kernel_tests/serialization:__pkg__",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/ops:readers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "@org_sqlite//:python",
+ ],
+)
+
+py_test(
+ name = "sql_dataset_op_test",
+ size = "small",
+ srcs = ["sql_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":sql_dataset_op_test_base",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ ],
+)
+
+py_test(
+ name = "stats_dataset_ops_test",
+ size = "medium",
+ srcs = ["stats_dataset_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ ":reader_dataset_ops_test_base",
+ ":stats_dataset_test_base",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:stats_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "stats_dataset_test_base",
+ srcs = ["stats_dataset_test_base.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/core:protos_all_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ ],
+)
+
+py_test(
+ name = "threadpool_dataset_ops_test",
+ size = "small",
+ srcs = ["threadpool_dataset_ops_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:script_ops",
+ "//tensorflow/python/data/experimental/ops:threadpool",
+ "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ "@absl_py//absl/testing:parameterized",
+ ],
+)
+
+py_test(
+ name = "unique_dataset_op_test",
+ size = "small",
+ srcs = ["unique_dataset_op_test.py"],
+ srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:unique",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_test(
+ name = "writer_ops_test",
+ size = "small",
+ srcs = ["writer_ops_test.py"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/experimental/ops:writers",
+ "//tensorflow/python/data/kernel_tests:test_base",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:readers",
+ ],
+)
diff --git a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
index e2508de9e9..8703b2810e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/batch_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/batch_dataset_op_test.py
@@ -23,15 +23,15 @@ import time
from absl.testing import parameterized
import numpy as np
-from tensorflow.contrib.data.python.ops import batching
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
-from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
@@ -40,12 +40,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class BatchDatasetTest(test.TestCase, parameterized.TestCase):
-
- def assertSparseValuesEqual(self, a, b):
- self.assertAllEqual(a.indices, b.indices)
- self.assertAllEqual(a.values, b.values)
- self.assertAllEqual(a.dense_shape, b.dense_shape)
+class BatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
def testDenseToSparseBatchDataset(self):
components = np.random.randint(12, size=(100,)).astype(np.int32)
@@ -305,128 +300,6 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
with self.assertRaises(errors.InvalidArgumentError):
sess.run(next_element)
- def testBatchAndDropRemainder(self):
- components = (np.arange(7),
- np.array([[1, 2, 3]]) * np.arange(7)[:, np.newaxis],
- np.array(37.0) * np.arange(7))
-
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
-
- iterator = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(batch_size))
- .make_initializable_iterator())
-
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for test_batch_size in [1, 3, 7, 10]:
- sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
- num_batches = 7 // test_batch_size
- for i in range(num_batches):
- result = sess.run(next_element)
- for component, result_component in zip(components, result):
- for j in range(test_batch_size):
- self.assertAllEqual(component[(i * test_batch_size + j)],
- result_component[j])
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testBatchAndDropRemainderSparse(self):
-
- def _sparse(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0]], values=(i * [1]), dense_shape=[1])
-
- iterator = dataset_ops.Dataset.range(12).map(_sparse).apply(
- batching.batch_and_drop_remainder(5)).make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- sess.run(init_op)
- for i in range(2):
- actual = sess.run(get_next)
- expected = sparse_tensor.SparseTensorValue(
- indices=[[0, 0], [1, 0], [2, 0], [3, 0], [4, 0]],
- values=[i * 5, i * 5 + 1, i * 5 + 2, i * 5 + 3, i * 5 + 4],
- dense_shape=[5, 1])
- self.assertTrue(sparse_tensor.is_sparse(actual))
- self.assertSparseValuesEqual(actual, expected)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testPaddedBatchAndDropRemainder(self):
- els = []
- for length in [3, 6, 9, 4, 12, 10, 2]:
- els.append((np.array(length), np.arange(length) + 1,
- np.array(length * 2)))
-
- dataset = dataset_ops.Dataset.from_tensors(els[0])
- for el in els[1:]:
- dataset = dataset.concatenate(dataset_ops.Dataset.from_tensors(el))
-
- batch_size = array_ops.placeholder(dtypes.int64, shape=[])
- iterator = (
- dataset.apply(
- batching.padded_batch_and_drop_remainder(
- batch_size, ([], [None], []))).make_initializable_iterator())
-
- next_element = iterator.get_next()
-
- with self.cached_session() as sess:
- for test_batch_size in [1, 3, 7, 10]:
- sess.run(iterator.initializer, feed_dict={batch_size: test_batch_size})
- num_batches = 7 // test_batch_size
- for i in range(num_batches):
- result = sess.run(next_element)
- for component_idx, result_component in enumerate(result):
- for j in range(test_batch_size):
- data_idx = i * test_batch_size + j
- comp = result_component[j]
- unpadded = comp[comp > 0]
- if np.isscalar(comp):
- # The boolean mask indexing above adds a dim back. Rm it.
- unpadded = unpadded[0]
- self.assertAllEqual(els[data_idx][component_idx], unpadded)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(next_element)
-
- def testPaddedBatchAndDropRemainderSparseError(self):
-
- def _map_fn(i):
- return sparse_tensor.SparseTensorValue(
- indices=[[0, 0]], values=(i * [1]), dense_shape=[1, 1]), i
-
- with self.assertRaises(TypeError):
- _ = dataset_ops.Dataset.range(10).map(_map_fn).apply(
- batching.padded_batch_and_drop_remainder(5))
-
- def testBatchAndDropRemainderShapeInference(self):
- components = (array_ops.placeholder(dtypes.int32),
- (array_ops.placeholder(dtypes.int32, shape=[None]),
- array_ops.placeholder(dtypes.int32, shape=[20, 30])))
-
- # Test with a statically known batch size.
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(128)))
-
- self.assertIs(None, dataset.output_shapes[0].ndims)
- self.assertEqual([128], dataset.output_shapes[1][0].as_list())
- self.assertEqual([128, 30], dataset.output_shapes[1][1].as_list())
-
- # Test with a dynamic batch size: the static shape will be unknown, because
- # `batch_size` is a placeholder.
- batch_size = array_ops.placeholder(dtypes.int64)
- dataset = (
- dataset_ops.Dataset.from_tensor_slices(components).apply(
- batching.batch_and_drop_remainder(batch_size)))
-
- self.assertIs(None, dataset.output_shapes[0].ndims)
- self.assertEqual([None], dataset.output_shapes[1][0].as_list())
- self.assertEqual([None, 30], dataset.output_shapes[1][1].as_list())
-
@parameterized.named_parameters(
("Default", None, None),
("SequentialCalls", 1, None),
@@ -723,197 +596,6 @@ class BatchDatasetTest(test.TestCase, parameterized.TestCase):
self.assertAllEqual([element for _ in range(10)], sess.run(get_next))
-class RestructuredDatasetTest(test.TestCase):
-
- def test_assert_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(5).map(create_dataset)
- expected_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 4)))
- self.assertEqual(expected_shapes, dataset.output_shapes)
-
- result = dataset.apply(batching.assert_element_shape(expected_shapes))
- self.assertEqual(expected_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(3).map(create_dataset)
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
- with self.assertRaises(ValueError):
- dataset.apply(batching.assert_element_shape(wrong_shapes))
-
- def test_assert_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- expected_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 4)))
- result = dataset.apply(batching.assert_element_shape(expected_shapes))
- self.assertEqual(expected_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 10)))
- iterator = (
- dataset.apply(batching.assert_element_shape(wrong_shapes))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
- def test_assert_partial_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(5).map(create_dataset)
- partial_expected_shape = (tensor_shape.TensorShape(None), # Unknown shape
- tensor_shape.TensorShape((None, 4))) # Partial shape
- result = dataset.apply(
- batching.assert_element_shape(partial_expected_shape))
- # Partial shapes are merged with actual shapes:
- actual_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((3, 4)))
- self.assertEqual(actual_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_partial_element_shape(self):
-
- def create_dataset(_):
- return (array_ops.ones(2, dtype=dtypes.float32),
- array_ops.zeros((3, 4), dtype=dtypes.int32))
-
- dataset = dataset_ops.Dataset.range(3).map(create_dataset)
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((None, 10)))
- with self.assertRaises(ValueError):
- dataset.apply(batching.assert_element_shape(wrong_shapes))
-
- def test_assert_partial_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(5).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- expected_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((None, 4)))
- result = dataset.apply(batching.assert_element_shape(expected_shapes))
- self.assertEqual(expected_shapes, result.output_shapes)
-
- iterator = result.make_initializable_iterator()
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- for _ in range(5):
- sess.run(get_next)
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def test_assert_wrong_partial_element_shape_on_unknown_shape_dataset(self):
-
- def create_unknown_shape_dataset(x):
- return script_ops.py_func(
- lambda _: ( # pylint: disable=g-long-lambda
- np.ones(2, dtype=np.float32),
- np.zeros((3, 4), dtype=np.int32)),
- [x],
- [dtypes.float32, dtypes.int32])
-
- dataset = dataset_ops.Dataset.range(3).map(create_unknown_shape_dataset)
- unknown_shapes = (tensor_shape.TensorShape(None),
- tensor_shape.TensorShape(None))
- self.assertEqual(unknown_shapes, dataset.output_shapes)
-
- wrong_shapes = (tensor_shape.TensorShape(2),
- tensor_shape.TensorShape((None, 10)))
- iterator = (
- dataset.apply(batching.assert_element_shape(wrong_shapes))
- .make_initializable_iterator())
- init_op = iterator.initializer
- get_next = iterator.get_next()
- with self.cached_session() as sess:
- sess.run(init_op)
- with self.assertRaises(errors.InvalidArgumentError):
- sess.run(get_next)
-
-
class UnbatchDatasetBenchmark(test.Benchmark):
def benchmarkNativeUnbatch(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py b/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py
index 48971f2ccc..153a03989b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/bucketing_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/bucketing_test.py
@@ -21,7 +21,8 @@ import random
import numpy as np
-from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -35,7 +36,7 @@ from tensorflow.python.ops import string_ops
from tensorflow.python.platform import test
-class GroupByReducerTest(test.TestCase):
+class GroupByReducerTest(test_base.DatasetTestBase):
def checkResults(self, dataset, shapes, values):
self.assertEqual(shapes, dataset.output_shapes)
@@ -198,7 +199,7 @@ class GroupByReducerTest(test.TestCase):
self.assertEqual(y, 45)
-class GroupByWindowTest(test.TestCase):
+class GroupByWindowTest(test_base.DatasetTestBase):
def testSimple(self):
components = np.random.randint(100, size=(200,)).astype(np.int64)
@@ -345,7 +346,7 @@ class GroupByWindowTest(test.TestCase):
# NOTE(mrry): These tests are based on the tests in bucket_ops_test.py.
# Currently, they use a constant batch size, though should be made to use a
# different batch size per key.
-class BucketTest(test.TestCase):
+class BucketTest(test_base.DatasetTestBase):
def _dynamicPad(self, bucket, window, window_size):
# TODO(mrry): To match `tf.contrib.training.bucket()`, implement a
@@ -570,7 +571,7 @@ def _get_record_shape(sparse):
return tensor_shape.TensorShape([None])
-class BucketBySequenceLength(test.TestCase):
+class BucketBySequenceLength(test_base.DatasetTestBase):
def testBucket(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py
index f8e74e4583..4ee1779710 100644
--- a/tensorflow/contrib/data/python/kernel_tests/csv_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/csv_dataset_op_test.py
@@ -27,9 +27,10 @@ import zlib
import numpy as np
-from tensorflow.contrib.data.python.ops import error_ops
-from tensorflow.contrib.data.python.ops import readers
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -43,37 +44,7 @@ from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
-class CsvDatasetOpTest(test.TestCase):
-
- def _get_next(self, dataset):
- # Returns a no argument function whose result is fed to self.evaluate to
- # yield the next element
- it = dataset.make_one_shot_iterator()
- if context.executing_eagerly():
- return it.get_next
- else:
- get_next = it.get_next()
- return lambda: get_next
-
- def _assert_datasets_equal(self, ds1, ds2):
- assert ds1.output_shapes == ds2.output_shapes, ('output_shapes differ: %s, '
- '%s') % (ds1.output_shapes,
- ds2.output_shapes)
- assert ds1.output_types == ds2.output_types
- assert ds1.output_classes == ds2.output_classes
- next1 = self._get_next(ds1)
- next2 = self._get_next(ds2)
- # Run through datasets and check that outputs match, or errors match.
- while True:
- try:
- op1 = self.evaluate(next1())
- except (errors.OutOfRangeError, ValueError) as e:
- # If op1 throws an exception, check that op2 throws same exception.
- with self.assertRaises(type(e)):
- self.evaluate(next2())
- break
- op2 = self.evaluate(next2())
- self.assertAllEqual(op1, op2)
+class CsvDatasetOpTest(test_base.DatasetTestBase):
def _setup_files(self, inputs, linebreak='\n', compression_type=None):
filenames = []
@@ -108,7 +79,7 @@ class CsvDatasetOpTest(test.TestCase):
"""Checks that CsvDataset is equiv to TextLineDataset->map(decode_csv)."""
dataset_actual, dataset_expected = self._make_test_datasets(
inputs, **kwargs)
- self._assert_datasets_equal(dataset_actual, dataset_expected)
+ self.assertDatasetsEqual(dataset_actual, dataset_expected)
def _verify_output_or_err(self,
dataset,
@@ -116,7 +87,7 @@ class CsvDatasetOpTest(test.TestCase):
expected_err_re=None):
if expected_err_re is None:
# Verify that output is expected, without errors
- nxt = self._get_next(dataset)
+ nxt = self.getNext(dataset)
expected_output = [[
v.encode('utf-8') if isinstance(v, str) else v for v in op
] for op in expected_output]
@@ -128,7 +99,7 @@ class CsvDatasetOpTest(test.TestCase):
else:
# Verify that OpError is produced as expected
with self.assertRaisesOpError(expected_err_re):
- nxt = self._get_next(dataset)
+ nxt = self.getNext(dataset)
while True:
try:
self.evaluate(nxt())
@@ -354,7 +325,7 @@ class CsvDatasetOpTest(test.TestCase):
inputs = [['1,,3,4', '5,6,,8']]
ds_actual, ds_expected = self._make_test_datasets(
inputs, record_defaults=record_defaults)
- self._assert_datasets_equal(
+ self.assertDatasetsEqual(
ds_actual.repeat(5).prefetch(1),
ds_expected.repeat(5).prefetch(1))
@@ -377,7 +348,7 @@ class CsvDatasetOpTest(test.TestCase):
ds = readers.make_csv_dataset(
file_path, batch_size=1, shuffle=False, num_epochs=1)
- nxt = self._get_next(ds)
+ nxt = self.getNext(ds)
result = list(self.evaluate(nxt()).values())
diff --git a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py b/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py
index a2ab3de52e..3fc7157bc5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/dataset_constructor_op_test.py
@@ -17,7 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -25,7 +26,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class DatasetConstructorTest(test.TestCase):
+class DatasetConstructorTest(test_base.DatasetTestBase):
def testRestructureDataset(self):
components = (array_ops.placeholder(dtypes.int32),
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py
index 595cecef4d..7f435b8239 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_serialization_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/dataset_serialization_test_base.py
@@ -22,7 +22,7 @@ import os
import numpy as np
-from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
diff --git a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py
index eb110324d1..796a692c56 100644
--- a/tensorflow/contrib/data/python/kernel_tests/directed_interleave_dataset_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/directed_interleave_dataset_test.py
@@ -19,14 +19,15 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import random_seed
from tensorflow.python.platform import test
-class DirectedInterleaveDatasetTest(test.TestCase):
+class DirectedInterleaveDatasetTest(test_base.DatasetTestBase):
def testBasic(self):
selector_dataset = dataset_ops.Dataset.range(10).repeat(100)
@@ -83,7 +84,7 @@ class DirectedInterleaveDatasetTest(test.TestCase):
# Use chi-squared test to assert that the observed distribution matches the
# expected distribution. Based on the implementation in
- # "tensorflow/python/kernel_tests/multinomial_op_test.py".
+ # "third_party/tensorflow/python/kernel_tests/multinomial_op_test.py".
for probs in [[.85, .05, .1], rand_probs, [1.]]:
probs = np.asarray(probs)
classes = len(probs)
diff --git a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py
index 6d01bf585c..c6ee88c676 100644
--- a/tensorflow/contrib/data/python/kernel_tests/filter_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/filter_dataset_op_test.py
@@ -21,8 +21,8 @@ import time
import numpy as np
-from tensorflow.contrib.data.python.ops import optimization
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py
index f3968cdc15..8c07afbac5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/get_single_element_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/get_single_element_test.py
@@ -18,10 +18,9 @@ from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
-import numpy as np
-from tensorflow.contrib.data.python.ops import get_single_element
-from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.experimental.ops import get_single_element
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +29,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class GetSingleElementTest(test.TestCase, parameterized.TestCase):
+class GetSingleElementTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("Zero", 0, 1),
@@ -68,32 +67,6 @@ class GetSingleElementTest(test.TestCase, parameterized.TestCase):
with self.assertRaisesRegexp(error, error_msg):
sess.run(element, feed_dict={skip_t: skip, take_t: take})
- @parameterized.named_parameters(
- ("SumZero", 0),
- ("SumOne", 1),
- ("SumFive", 5),
- ("SumTen", 10),
- )
- def testReduceDataset(self, stop):
- def init_fn(_):
- return np.int64(0)
-
- def reduce_fn(state, value):
- return state + value
-
- def finalize_fn(state):
- return state
-
- sum_reducer = grouping.Reducer(init_fn, reduce_fn, finalize_fn)
-
- stop_t = array_ops.placeholder(dtypes.int64, shape=[])
- dataset = dataset_ops.Dataset.range(stop_t)
- element = get_single_element.reduce_dataset(dataset, sum_reducer)
-
- with self.cached_session() as sess:
- value = sess.run(element, feed_dict={stop_t: stop})
- self.assertEqual(stop * (stop - 1) / 2, value)
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py
index 9c508d686d..c93a8353ce 100644
--- a/tensorflow/contrib/data/python/kernel_tests/indexed_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/indexed_dataset_ops_test.py
@@ -19,29 +19,30 @@ from __future__ import print_function
import unittest
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
-from tensorflow.contrib.data.python.ops import indexed_dataset_ops
+from tensorflow.python.data.experimental.ops import indexed_dataset_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
from tensorflow.python.platform import test
-class IndexedDatasetOpsTest(test.TestCase):
+class IndexedDatasetOpsTest(test_base.DatasetTestBase):
def testLowLevelIndexedDatasetOps(self):
- identity = gen_dataset_ops.identity_indexed_dataset(
+ identity = ged_ops.experimental_identity_indexed_dataset(
ops.convert_to_tensor(16, dtype=dtypes.uint64))
- handle = gen_dataset_ops.materialized_index_dataset_handle(
+ handle = ged_ops.experimental_materialized_index_dataset_handle(
container="",
shared_name="",
output_types=[dtypes.uint64],
output_shapes=[[]])
- materialize = gen_dataset_ops.indexed_dataset_materialize(identity, handle)
+ materialize = ged_ops.experimental_indexed_dataset_materialize(
+ identity, handle)
index = array_ops.placeholder(dtypes.uint64)
- get_op = gen_dataset_ops.indexed_dataset_get(
+ get_op = ged_ops.experimental_indexed_dataset_get(
handle, index, output_types=[dtypes.uint64], output_shapes=[[]])
with self.cached_session() as sess:
diff --git a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py
index b9e74dfddb..560902caad 100644
--- a/tensorflow/contrib/data/python/kernel_tests/interleave_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/interleave_dataset_op_test.py
@@ -24,7 +24,8 @@ import time
from six.moves import zip_longest
-from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -36,7 +37,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.platform import test
-class ParallelInterleaveDatasetTest(test.TestCase):
+class ParallelInterleaveDatasetTest(test_base.DatasetTestBase):
def setUp(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py
index 7e2326bd17..94393d6d4b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/iterator_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/iterator_ops_test.py
@@ -18,7 +18,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import iterator_ops
+from tensorflow.python.data.experimental.ops import iterator_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn
@@ -33,7 +34,7 @@ from tensorflow.python.training import saver as saver_lib
from tensorflow.python.training import training_util
-class CheckpointInputPipelineHookTest(test.TestCase):
+class CheckpointInputPipelineHookTest(test_base.DatasetTestBase):
@staticmethod
def _model_fn(features, labels, mode, config):
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py
index e8519381d6..2f0bd1456b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_dataset_op_test.py
@@ -24,11 +24,12 @@ import time
import numpy as np
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import error_ops
-from tensorflow.contrib.data.python.ops import optimization
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import error_ops
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -41,7 +42,7 @@ from tensorflow.python.util import compat
_NUMPY_RANDOM_SEED = 42
-class MapDatasetTest(test.TestCase):
+class MapDatasetTest(test_base.DatasetTestBase):
def testMapIgnoreError(self):
components = np.array([1., 2., 3., np.nan, 5.]).astype(np.float32)
diff --git a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
index 25aea0393f..612ee332c4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/map_defun_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/map_defun_op_test.py
@@ -19,8 +19,9 @@ from __future__ import print_function
import time
-from tensorflow.contrib.data.python.ops import map_defun
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import map_defun
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -33,7 +34,8 @@ from tensorflow.python.ops import functional_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapDefunTest(test.TestCase):
+
+class MapDefunTest(test_base.DatasetTestBase):
def testMapDefunSimple(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
index 1ae92bdeff..c92bb8b9bc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/BUILD
@@ -11,10 +11,16 @@ py_test(
size = "medium",
srcs = ["assert_next_dataset_op_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -24,13 +30,19 @@ py_test(
size = "small",
srcs = ["hoist_random_uniform_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -41,12 +53,17 @@ py_test(
size = "small",
srcs = ["latency_all_edges_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
- "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
- "//tensorflow/contrib/data/python/ops:optimization",
- "//tensorflow/contrib/data/python/ops:stats_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/experimental/ops:stats_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -56,9 +73,12 @@ py_test(
size = "small",
srcs = ["map_vectorization_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
- "//tensorflow/contrib/data/python/kernel_tests:test_utils",
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:check_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
@@ -67,6 +87,8 @@ py_test(
"//tensorflow/python:framework_ops",
"//tensorflow/python:math_ops",
"//tensorflow/python:session",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
"@absl_py//absl/testing:parameterized",
@@ -78,13 +100,19 @@ py_test(
size = "medium",
srcs = ["map_and_filter_fusion_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -95,13 +123,19 @@ py_test(
size = "small",
srcs = ["map_parallelization_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"@absl_py//absl/testing:parameterized",
],
@@ -113,14 +147,18 @@ py_test(
srcs = ["model_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
"optonly",
],
deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -131,12 +169,18 @@ py_test(
size = "small",
srcs = ["noop_elimination_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
- "//tensorflow/contrib/data/python/ops:batching",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:batching",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -147,10 +191,16 @@ py_test(
size = "small",
srcs = ["optimize_dataset_op_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
"//tensorflow/python:errors",
+ "//tensorflow/python/data/experimental/ops:optimization",
+ "//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py
index d10da80442..45b77b5c20 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/assert_next_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/assert_next_dataset_op_test.py
@@ -17,13 +17,14 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
-class AssertNextDatasetTest(test.TestCase):
+class AssertNextDatasetTest(test_base.DatasetTestBase):
def testAssertNext(self):
dataset = dataset_ops.Dataset.from_tensors(0).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
index 9518c2e1ad..81437c0aec 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/hoist_random_uniform_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/hoist_random_uniform_test.py
@@ -19,7 +19,8 @@ from __future__ import print_function
from absl.testing import parameterized
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -31,7 +32,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class HoistRandomUniformTest(test.TestCase, parameterized.TestCase):
+class HoistRandomUniformTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
@@ -63,7 +64,9 @@ class HoistRandomUniformTest(test.TestCase, parameterized.TestCase):
optimization.assert_next(
["Zip[0]", "Map"] if will_optimize else ["Map"])).map(function)
- dataset = dataset.apply(optimization.optimize(["hoist_random_uniform"]))
+ options = dataset_ops.Options()
+ options.experimental_hoist_random_uniform = True
+ dataset = dataset.with_options(options)
self._testDataset(dataset)
def testAdditionalInputs(self):
@@ -76,9 +79,10 @@ class HoistRandomUniformTest(test.TestCase, parameterized.TestCase):
[], minval=1, maxval=10, dtype=dtypes.float32, seed=42)
dataset = dataset_ops.Dataset.range(5).apply(
- optimization.assert_next(
- ["Zip[0]", "Map"])).map(random_with_capture).apply(
- optimization.optimize(["hoist_random_uniform"]))
+ optimization.assert_next(["Zip[0]", "Map"])).map(random_with_capture)
+ options = dataset_ops.Options()
+ options.experimental_hoist_random_uniform = True
+ dataset = dataset.with_options(options)
self._testDataset(dataset)
def _testDataset(self, dataset):
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
index e4f18222fd..26fec0414e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/latency_all_edges_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/latency_all_edges_test.py
@@ -17,9 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
-from tensorflow.contrib.data.python.ops import optimization
-from tensorflow.contrib.data.python.ops import stats_ops
+from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
@@ -28,14 +28,15 @@ from tensorflow.python.platform import test
class OptimizeStatsDatasetTest(stats_dataset_test_base.StatsDatasetTestBase):
def testLatencyStatsOptimization(self):
-
stats_aggregator = stats_ops.StatsAggregator()
dataset = dataset_ops.Dataset.from_tensors(1).apply(
optimization.assert_next(
["LatencyStats", "Map", "LatencyStats", "Prefetch",
"LatencyStats"])).map(lambda x: x * x).prefetch(1).apply(
- stats_ops.set_stats_aggregator(stats_aggregator)).apply(
- optimization.optimize(["latency_all_edges"]))
+ stats_ops.set_stats_aggregator(stats_aggregator))
+ options = dataset_ops.Options()
+ options.experimental_latency_all_edges = True
+ dataset = dataset.with_options(options)
iterator = dataset.make_initializable_iterator()
get_next = iterator.get_next()
summary_t = stats_aggregator.get_summary()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
index e75edf6086..7f8a4e6406 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_and_filter_fusion_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_and_filter_fusion_test.py
@@ -19,7 +19,8 @@ from __future__ import print_function
from absl.testing import parameterized
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -28,7 +29,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
+class MapAndFilterFusionTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
@@ -71,7 +72,10 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
for function in functions:
dataset = dataset.map(function)
- dataset = dataset.prefetch(0).apply(optimization.optimize(["map_fusion"]))
+ dataset = dataset.prefetch(0)
+ options = dataset_ops.Options()
+ options.experimental_map_fusion = True
+ dataset = dataset.with_options(options)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
@@ -123,9 +127,10 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
def testMapFilterFusion(self, function, predicate):
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(
- ["Map",
- "FilterByLastComponent"])).map(function).filter(predicate).apply(
- optimization.optimize(["map_and_filter_fusion"]))
+ ["Map", "FilterByLastComponent"])).map(function).filter(predicate)
+ options = dataset_ops.Options()
+ options.experimental_map_and_filter_fusion = True
+ dataset = dataset.with_options(options)
self._testMapAndFilter(dataset, function, predicate)
def _testMapAndFilter(self, dataset, function, predicate):
@@ -155,10 +160,11 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
# We are currently not supporting functions with additional inputs.
dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Filter"])).map(function).filter(predicate).apply(
- optimization.optimize(["map_and_filter_fusion"]))
-
+ optimization.assert_next(["Map",
+ "Filter"])).map(function).filter(predicate)
+ options = dataset_ops.Options()
+ options.experimental_map_and_filter_fusion = True
+ dataset = dataset.with_options(options)
self._testMapAndFilter(dataset, function, predicate)
@staticmethod
@@ -196,8 +202,10 @@ class MapAndFilterFusionTest(test.TestCase, parameterized.TestCase):
for predicate in predicates:
dataset = dataset.filter(predicate)
- dataset = dataset.prefetch(0).apply(
- optimization.optimize(["filter_fusion"]))
+ dataset = dataset.prefetch(0)
+ options = dataset_ops.Options()
+ options.experimental_filter_fusion = True
+ dataset = dataset.with_options(options)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
index dd547db086..ce9c9bc47b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_parallelization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_parallelization_test.py
@@ -19,7 +19,8 @@ from __future__ import print_function
from absl.testing import parameterized
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -30,7 +31,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class MapParallelizationTest(test.TestCase, parameterized.TestCase):
+class MapParallelizationTest(test_base.DatasetTestBase, parameterized.TestCase):
@staticmethod
def map_functions():
@@ -61,8 +62,10 @@ class MapParallelizationTest(test.TestCase, parameterized.TestCase):
def testMapParallelization(self, function, should_optimize):
next_nodes = ["ParallelMap"] if should_optimize else ["Map"]
dataset = dataset_ops.Dataset.range(5).apply(
- optimization.assert_next(next_nodes)).map(function).apply(
- optimization.optimize(["map_parallelization"]))
+ optimization.assert_next(next_nodes)).map(function)
+ options = dataset_ops.Options()
+ options.experimental_map_parallelization = True
+ dataset = dataset.with_options(options)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
index 5b493f44c9..32ebc49c40 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/map_vectorization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/map_vectorization_test.py
@@ -22,9 +22,9 @@ import time
from absl.testing import parameterized
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import test_utils
-from tensorflow.contrib.data.python.ops import optimization
from tensorflow.python.client import session
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -36,7 +36,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
+class MapVectorizationTest(test_base.DatasetTestBase, parameterized.TestCase):
def _get_test_datasets(self,
base_dataset,
@@ -69,10 +69,11 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
map_fn, num_parallel_calls=num_parallel_calls).batch(batch_size)
unoptimized = _make_dataset([map_node_name, "Batch"])
- optimized = _make_dataset(["Batch", map_node_name] if expect_optimized else
- [map_node_name, "Batch"]).apply(
- optimization.optimize(["map_vectorization"]))
-
+ optimized = _make_dataset(["Batch", map_node_name]
+ if expect_optimized else [map_node_name, "Batch"])
+ options = dataset_ops.Options()
+ options.experimental_map_vectorization = True
+ optimized = optimized.with_options(options)
return unoptimized, optimized
@parameterized.named_parameters(
@@ -85,7 +86,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
[3, 4]]).repeat(5)
unoptimized, optimized = self._get_test_datasets(base_dataset, map_fn,
num_parallel_calls)
- self._assert_datasets_equal(unoptimized, optimized)
+ self.assertDatasetsEqual(unoptimized, optimized)
def testOptimizationBadMapFn(self):
# Test map functions that give an error
@@ -112,7 +113,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
# TODO(rachelim): when this optimization works, turn on expect_optimized
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_equal(optimized, unoptimized)
+ self.assertDatasetsEqual(optimized, unoptimized)
def testOptimizationIgnoreStateful(self):
@@ -124,7 +125,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
[3, 4]]).repeat(5)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_raise_same_error(
+ self.assertDatasetsRaiseSameError(
unoptimized, optimized, errors.InvalidArgumentError,
[("OneShotIterator", "OneShotIterator_1", 1),
("IteratorGetNext", "IteratorGetNext_1", 1)])
@@ -138,7 +139,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.range(20).batch(3, drop_remainder=False)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_equal(unoptimized, optimized)
+ self.assertDatasetsEqual(unoptimized, optimized)
def testOptimizationIgnoreRaggedMap(self):
# Don't optimize when the output of the map fn shapes are unknown.
@@ -148,7 +149,7 @@ class MapVectorizationTest(test_utils.DatasetTestBase, parameterized.TestCase):
base_dataset = dataset_ops.Dataset.range(20).batch(1, drop_remainder=True)
unoptimized, optimized = self._get_test_datasets(
base_dataset, map_fn, expect_optimized=False)
- self._assert_datasets_raise_same_error(
+ self.assertDatasetsRaiseSameError(
unoptimized, optimized, errors.InvalidArgumentError,
[("OneShotIterator", "OneShotIterator_1", 1),
("IteratorGetNext", "IteratorGetNext_1", 1)])
@@ -179,7 +180,10 @@ class MapVectorizationBenchmark(test.Benchmark):
unoptimized = input_dataset.map(map_fn).batch(batch_size)
unoptimized_op = unoptimized.make_one_shot_iterator().get_next()
- optimized = unoptimized.apply(optimization.optimize(["map_vectorization"]))
+ optimized = input_dataset.map(map_fn).batch(batch_size)
+ options = dataset_ops.Options()
+ options.experimental_map_vectorization = True
+ optimized = optimized.with_options(options)
optimized_op = optimized.make_one_shot_iterator().get_next()
unoptimized_time = self._run(
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
index 3b62a7e468..82516356df 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/model_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/model_dataset_op_test.py
@@ -21,14 +21,15 @@ import time
import numpy as np
-from tensorflow.contrib.data.python.ops import batching
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class ModelDatasetTest(test.TestCase):
+class ModelDatasetTest(test_base.DatasetTestBase):
def testModelMap(self):
k = 1024 * 1024
@@ -36,7 +37,9 @@ class ModelDatasetTest(test.TestCase):
np.random.rand(4 * k,
1))).repeat()
dataset = dataset.map(math_ops.matmul)
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ options = dataset_ops.Options()
+ options.experimental_autotune = True
+ iterator = dataset.with_options(options).make_one_shot_iterator()
get_next = iterator.get_next()
deltas = []
@@ -60,7 +63,9 @@ class ModelDatasetTest(test.TestCase):
1))).repeat()
dataset = dataset.map(
math_ops.matmul, num_parallel_calls=optimization.AUTOTUNE)
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ options = dataset_ops.Options()
+ options.experimental_autotune = True
+ iterator = dataset.with_options(options).make_one_shot_iterator()
get_next = iterator.get_next()
deltas = []
@@ -88,7 +93,9 @@ class ModelDatasetTest(test.TestCase):
math_ops.matmul,
num_parallel_calls=optimization.AUTOTUNE,
batch_size=batch_size))
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ options = dataset_ops.Options()
+ options.experimental_autotune = True
+ iterator = dataset.with_options(options).make_one_shot_iterator()
get_next = iterator.get_next()
deltas = []
@@ -115,7 +122,9 @@ class ModelDatasetTest(test.TestCase):
lambda _: dataset,
cycle_length=10,
num_parallel_calls=optimization.AUTOTUNE)
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ options = dataset_ops.Options()
+ options.experimental_autotune = True
+ iterator = dataset.with_options(options).make_one_shot_iterator()
get_next = iterator.get_next()
deltas = []
@@ -160,7 +169,9 @@ class ModelDatasetTest(test.TestCase):
lambda _: dataset, cycle_length=2)
dataset = dataset.map(f3, num_parallel_calls=optimization.AUTOTUNE)
- iterator = dataset.apply(optimization.model()).make_one_shot_iterator()
+ options = dataset_ops.Options()
+ options.experimental_autotune = True
+ iterator = dataset.with_options(options).make_one_shot_iterator()
get_next = iterator.get_next()
deltas = []
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py
index 507feda3ad..fb0640fe9f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/noop_elimination_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/noop_elimination_test.py
@@ -17,7 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -26,7 +27,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class NoopEliminationTest(test.TestCase):
+class NoopEliminationTest(test_base.DatasetTestBase):
def testNoopElimination(self):
a = constant_op.constant(1, dtype=dtypes.int64)
@@ -39,7 +40,9 @@ class NoopEliminationTest(test.TestCase):
["FiniteRepeat", "FiniteSkip", "Prefetch", "Prefetch"]))
dataset = dataset.repeat(some_tensor).skip(5).prefetch(0).take(-1).skip(
0).repeat(1).prefetch(0)
- dataset = dataset.apply(optimization.optimize(["noop_elimination"]))
+ options = dataset_ops.Options()
+ options.experimental_noop_elimination = True
+ dataset = dataset.with_options(options)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
diff --git a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
index a3fb824ce9..760cd8cc4e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/optimization/optimize_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/optimization/optimize_dataset_op_test.py
@@ -19,7 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -28,27 +29,14 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
-class OptimizeDatasetTest(test.TestCase):
+class OptimizeDatasetTest(test_base.DatasetTestBase):
def testOptimizationDefault(self):
dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize())
- iterator = dataset.make_one_shot_iterator()
- get_next = iterator.get_next()
-
- with self.cached_session() as sess:
- self.assertAllEqual([x * x for x in range(10)], sess.run(get_next))
- with self.assertRaises(errors.OutOfRangeError):
- sess.run(get_next)
-
- def testOptimizationEmpty(self):
- dataset = dataset_ops.Dataset.range(10).apply(
- optimization.assert_next(
- ["Map", "Batch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize([]))
- iterator = dataset.make_one_shot_iterator()
+ optimization.assert_next(["Map",
+ "Batch"])).map(lambda x: x * x).batch(10)
+ iterator = dataset.with_options(
+ dataset_ops.Options()).make_one_shot_iterator()
get_next = iterator.get_next()
with self.cached_session() as sess:
@@ -59,8 +47,10 @@ class OptimizeDatasetTest(test.TestCase):
def testOptimizationFusion(self):
dataset = dataset_ops.Dataset.range(10).apply(
optimization.assert_next(
- ["MapAndBatch"])).map(lambda x: x * x).batch(10).apply(
- optimization.optimize(["map_and_batch_fusion"]))
+ ["MapAndBatch"])).map(lambda x: x * x).batch(10)
+ options = dataset_ops.Options()
+ options.experimental_map_and_batch_fusion = True
+ dataset = dataset.with_options(options)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
@@ -71,8 +61,10 @@ class OptimizeDatasetTest(test.TestCase):
def testOptimizationStatefulFunction(self):
dataset = dataset_ops.Dataset.range(10).map(
- lambda _: random_ops.random_uniform([])).batch(10).apply(
- optimization.optimize(["map_and_batch_fusion"]))
+ lambda _: random_ops.random_uniform([])).batch(10)
+ options = dataset_ops.Options()
+ options.experimental_map_and_batch_fusion = True
+ dataset = dataset.with_options(options)
iterator = dataset.make_one_shot_iterator()
get_next = iterator.get_next()
@@ -81,8 +73,10 @@ class OptimizeDatasetTest(test.TestCase):
def testOptimizationLargeInputFromTensor(self):
input_t = array_ops.placeholder(dtypes.int32, (None, None, None))
- dataset = dataset_ops.Dataset.from_tensors(input_t).apply(
- optimization.optimize())
+ dataset = dataset_ops.Dataset.from_tensors(input_t)
+ options = dataset_ops.Options()
+ options.experimental_map_and_batch_fusion = True
+ dataset = dataset.with_options(options)
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -93,8 +87,10 @@ class OptimizeDatasetTest(test.TestCase):
def testOptimizationLargeInputFromTensorSlices(self):
input_t = array_ops.placeholder(dtypes.int32, (None, None, None, None))
- dataset = dataset_ops.Dataset.from_tensor_slices(input_t).apply(
- optimization.optimize())
+ dataset = dataset_ops.Dataset.from_tensor_slices(input_t)
+ options = dataset_ops.Options()
+ options.experimental_map_and_batch_fusion = True
+ dataset = dataset.with_options(options)
iterator = dataset.make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
diff --git a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py
index c4623bca73..13f924b656 100644
--- a/tensorflow/contrib/data/python/kernel_tests/parsing_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/parsing_ops_test.py
@@ -22,9 +22,10 @@ import copy
import numpy as np
-from tensorflow.contrib.data.python.ops import parsing_ops as contrib_parsing_ops
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.experimental.ops import parsing_ops as contrib_parsing_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.framework import dtypes
@@ -72,7 +73,7 @@ def _compare_output_to_expected(tester, dict_tensors, expected_tensors,
i += 1
-class ParseExampleTest(test.TestCase):
+class ParseExampleTest(test_base.DatasetTestBase):
def _test(self,
input_tensor,
@@ -845,6 +846,5 @@ class ParseExampleTest(test.TestCase):
"allow_missing to be True."))
-
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py
index 33a64ea767..7d7b842c17 100644
--- a/tensorflow/contrib/data/python/kernel_tests/prefetching_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/prefetching_ops_test.py
@@ -19,9 +19,10 @@ from __future__ import print_function
import threading
-from tensorflow.contrib.data.python.ops import prefetching_ops
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.compat import compat
+from tensorflow.python.data.experimental.ops import prefetching_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -35,7 +36,7 @@ from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.platform import test
-class PrefetchingKernelsOpsTest(test.TestCase):
+class PrefetchingKernelsOpsTest(test_base.DatasetTestBase):
def setUp(self):
self._event = threading.Event()
@@ -244,7 +245,7 @@ class PrefetchingKernelsOpsTest(test.TestCase):
sess.run(destroy_op)
-class PrefetchToDeviceTest(test.TestCase):
+class PrefetchToDeviceTest(test_base.DatasetTestBase):
def testPrefetchToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
@@ -445,7 +446,7 @@ class PrefetchToDeviceTest(test.TestCase):
sess.run(next_element)
-class CopyToDeviceTest(test.TestCase):
+class CopyToDeviceTest(test_base.DatasetTestBase):
def testCopyToDevice(self):
host_dataset = dataset_ops.Dataset.range(10)
diff --git a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py
index db8fe6aa1b..22412c3965 100644
--- a/tensorflow/contrib/data/python/kernel_tests/range_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/range_dataset_op_test.py
@@ -17,8 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import counter
-from tensorflow.contrib.data.python.ops import enumerate_ops
+from tensorflow.python.data.experimental.ops import counter
+from tensorflow.python.data.experimental.ops import enumerate_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -27,7 +28,7 @@ from tensorflow.python.framework import tensor_shape
from tensorflow.python.platform import test
-class RangeDatasetTest(test.TestCase):
+class RangeDatasetTest(test_base.DatasetTestBase):
def testEnumerateDataset(self):
components = (["a", "b"], [1, 2], [37.0, 38])
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py
index ed75b27a44..a02f4bd14f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test.py
@@ -23,8 +23,9 @@ import zlib
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.data.util import nest
from tensorflow.python.framework import constant_op
@@ -242,7 +243,7 @@ class ReadBatchFeaturesTest(
self.assertEqual(32, shape[0])
-class MakeCsvDatasetTest(test.TestCase):
+class MakeCsvDatasetTest(test_base.DatasetTestBase):
def _make_csv_dataset(self, filenames, batch_size, num_epochs=1, **kwargs):
return readers.make_csv_dataset(
diff --git a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
index 08b9f03816..b6ab80d132 100644
--- a/tensorflow/contrib/data/python/kernel_tests/reader_dataset_ops_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/reader_dataset_ops_test_base.py
@@ -22,9 +22,10 @@ import gzip
import os
import zlib
-from tensorflow.contrib.data.python.ops import readers
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.framework import constant_op
@@ -32,11 +33,10 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.lib.io import python_io
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import parsing_ops
-from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class FixedLengthRecordDatasetTestBase(test.TestCase):
+class FixedLengthRecordDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing FixedLengthRecordDataset."""
def setUp(self):
@@ -63,7 +63,7 @@ class FixedLengthRecordDatasetTestBase(test.TestCase):
return filenames
-class ReadBatchFeaturesTestBase(test.TestCase):
+class ReadBatchFeaturesTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing `make_batched_feature_dataset`."""
def setUp(self):
@@ -273,7 +273,7 @@ class ReadBatchFeaturesTestBase(test.TestCase):
self.assertAllEqual(expected_batch[i], actual_batch[i])
-class TextLineDatasetTestBase(test.TestCase):
+class TextLineDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing TextLineDataset."""
def _lineText(self, f, l):
@@ -313,7 +313,7 @@ class TextLineDatasetTestBase(test.TestCase):
return filenames
-class TFRecordDatasetTestBase(test.TestCase):
+class TFRecordDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing TFRecordDataset."""
def setUp(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/resample_test.py b/tensorflow/python/data/experimental/kernel_tests/resample_test.py
index 16b1441baa..775648c943 100644
--- a/tensorflow/contrib/data/python/kernel_tests/resample_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/resample_test.py
@@ -23,7 +23,8 @@ from absl.testing import parameterized
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
-from tensorflow.contrib.data.python.ops import resampling
+from tensorflow.python.data.experimental.ops import resampling
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -57,7 +58,7 @@ def _time_resampling(
return end_time - start_time
-class ResampleTest(test.TestCase, parameterized.TestCase):
+class ResampleTest(test_base.DatasetTestBase, parameterized.TestCase):
@parameterized.named_parameters(
("InitialDistributionKnown", True),
diff --git a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py
index dde678bd54..78ec80de23 100644
--- a/tensorflow/contrib/data/python/kernel_tests/scan_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/scan_dataset_op_test.py
@@ -21,7 +21,8 @@ import itertools
import numpy as np
-from tensorflow.contrib.data.python.ops import scan_ops
+from tensorflow.python.data.experimental.ops import scan_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
@@ -33,7 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class ScanDatasetTest(test.TestCase):
+class ScanDatasetTest(test_base.DatasetTestBase):
def _counting_dataset(self, start, scan_fn):
return dataset_ops.Dataset.from_tensors(0).repeat().apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
index aa89674c6e..58a335ae4f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/BUILD
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/BUILD
@@ -13,7 +13,6 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
"//tensorflow/python:errors",
@@ -24,6 +23,7 @@ py_library(
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variables",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
"//tensorflow/python/data/ops:iterator_ops",
"//third_party/py/numpy",
],
@@ -34,13 +34,17 @@ py_test(
size = "medium",
srcs = ["batch_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -51,6 +55,11 @@ py_test(
size = "small",
srcs = ["cache_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
@@ -65,6 +74,11 @@ py_test(
size = "small",
srcs = ["concatenate_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
@@ -78,12 +92,16 @@ py_test(
size = "small",
srcs = ["csv_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:readers",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:readers",
],
)
@@ -92,6 +110,11 @@ py_test(
size = "medium",
srcs = ["dataset_constructor_serialization_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
@@ -106,7 +129,11 @@ py_test(
size = "medium",
srcs = ["filter_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
@@ -123,11 +150,15 @@ py_test(
srcs = ["fixed_length_record_dataset_serialization_test.py"],
shard_count = 4,
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python/data/ops:readers",
],
)
@@ -136,7 +167,11 @@ py_test(
name = "flat_map_dataset_serialization_test",
size = "medium",
srcs = ["flat_map_dataset_serialization_test.py"],
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
@@ -158,10 +193,15 @@ py_test(
size = "medium",
srcs = ["group_by_reducer_serialization_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:grouping",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:grouping",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -172,10 +212,15 @@ py_test(
size = "medium",
srcs = ["group_by_window_serialization_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:grouping",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:grouping",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -186,12 +231,16 @@ py_test(
size = "small",
srcs = ["ignore_errors_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:error_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:error_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -202,7 +251,11 @@ py_test(
size = "medium",
srcs = ["interleave_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
@@ -219,12 +272,16 @@ py_test(
size = "medium",
srcs = ["map_and_batch_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:client_testlib",
"//tensorflow/python:math_ops",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -234,7 +291,11 @@ py_test(
size = "medium",
srcs = ["map_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
@@ -256,10 +317,15 @@ py_test(
size = "small",
srcs = ["optimize_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:optimization",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:optimization",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -269,7 +335,11 @@ py_test(
size = "medium",
srcs = ["padded_batch_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:array_ops",
@@ -285,13 +355,17 @@ py_test(
size = "medium",
srcs = ["parallel_interleave_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:sparse_ops",
"//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -302,7 +376,11 @@ py_test(
size = "medium",
srcs = ["parallel_map_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
@@ -323,11 +401,15 @@ py_test(
size = "medium",
srcs = ["parse_example_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
],
)
@@ -336,7 +418,11 @@ py_test(
size = "small",
srcs = ["prefetch_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
@@ -349,6 +435,11 @@ py_test(
size = "small",
srcs = ["range_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
@@ -368,10 +459,15 @@ py_test(
size = "medium",
srcs = ["sample_from_datasets_serialization_test.py"],
srcs_version = "PY2AND3",
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:interleave_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:interleave_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -381,11 +477,15 @@ py_test(
size = "small",
srcs = ["scan_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:scan_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:scan_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -395,7 +495,11 @@ py_test(
size = "medium",
srcs = ["sequence_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
@@ -409,12 +513,16 @@ py_test(
size = "small",
srcs = ["serialization_integration_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
- "//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -424,11 +532,15 @@ py_test(
size = "medium",
srcs = ["shuffle_and_repeat_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:shuffle_ops",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:shuffle_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -438,13 +550,17 @@ py_test(
size = "medium",
srcs = ["shuffle_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
"//tensorflow/python:training",
+ "//tensorflow/python/data/experimental/ops:iterator_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -454,14 +570,18 @@ py_test(
size = "small",
srcs = ["sql_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:sql_dataset_op_test_base",
- "//tensorflow/contrib/data/python/ops:readers",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:dtypes",
+ "//tensorflow/python/data/experimental/kernel_tests:sql_dataset_op_test_base",
+ "//tensorflow/python/data/experimental/ops:readers",
],
)
@@ -470,13 +590,17 @@ py_test(
size = "medium",
srcs = ["stats_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:stats_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/experimental/ops:stats_ops",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -487,11 +611,15 @@ py_test(
srcs = ["textline_dataset_serialization_test.py"],
shard_count = 4,
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python/data/ops:readers",
],
)
@@ -502,11 +630,15 @@ py_test(
srcs = ["tf_record_dataset_serialization_test.py"],
shard_count = 4,
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/kernel_tests:reader_dataset_ops_test_base",
"//tensorflow/python/data/ops:readers",
],
)
@@ -516,11 +648,15 @@ py_test(
size = "medium",
srcs = ["unbatch_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:batching",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:batching",
"//tensorflow/python/data/ops:dataset_ops",
"//third_party/py/numpy",
],
@@ -531,11 +667,15 @@ py_test(
size = "small",
srcs = ["unique_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/ops:unique",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python/data/experimental/ops:unique",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -545,7 +685,11 @@ py_test(
size = "small",
srcs = ["zip_dataset_serialization_test.py"],
srcs_version = "PY2AND3",
- tags = ["no_pip"],
+ tags = [
+ "no_oss",
+ "no_pip",
+ "no_windows",
+ ],
deps = [
":dataset_serialization_test_base",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py
index af87d8b608..d72a6df14c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/batch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/batch_dataset_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import array_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py
index 1b6059ccbc..2bcf77f5d8 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/cache_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/cache_dataset_serialization_test.py
@@ -21,7 +21,7 @@ import os
from absl.testing import parameterized
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py
index 96f13d75a3..c075dff8cb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/concatenate_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/concatenate_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py
index 247f2046ea..d4983492e7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/csv_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/csv_dataset_serialization_test.py
@@ -20,8 +20,8 @@ from __future__ import print_function
import gzip
import os
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py
index 2139b5c33d..41a095fb1a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/dataset_constructor_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_constructor_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
new file mode 100644
index 0000000000..7f435b8239
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/dataset_serialization_test_base.py
@@ -0,0 +1,692 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Base class for testing serializable datasets."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import lookup_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import gfile
+from tensorflow.python.platform import test
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.util import nest
+
+
+def remove_variants(get_next_op):
+ # TODO(b/72408568): Remove this once session.run can get
+ # variant tensors.
+ """Remove variants from a nest structure, so sess.run will execute."""
+
+ def _remove_variant(x):
+ if isinstance(x, ops.Tensor) and x.dtype == dtypes.variant:
+ return ()
+ else:
+ return x
+
+ return nest.map_structure(_remove_variant, get_next_op)
+
+
+class DatasetSerializationTestBase(test.TestCase):
+ """Base class for testing serializable datasets."""
+
+ def tearDown(self):
+ self._delete_ckpt()
+
+ # TODO(b/72657739): Remove sparse_tensor argument, which is to test the
+ # (deprecated) saveable `SparseTensorSliceDataset`, once the API
+ # `from_sparse_tensor_slices()`and related tests are deleted.
+ def run_core_tests(self, ds_fn1, ds_fn2, num_outputs, sparse_tensors=False):
+ """Runs the core tests.
+
+ Args:
+ ds_fn1: 0-argument function that returns a Dataset.
+ ds_fn2: 0-argument function that returns a Dataset different from
+ ds_fn1. If None, verify_restore_in_modified_graph test is not run.
+ num_outputs: Total number of outputs expected from this Dataset.
+ sparse_tensors: Whether dataset is built from SparseTensor(s).
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_unused_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_fully_used_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_exhausted_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_init_before_restore(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_multiple_breaks(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_reset_restored_iterator(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ self.verify_restore_in_empty_graph(
+ ds_fn1, num_outputs, sparse_tensors=sparse_tensors)
+ if ds_fn2:
+ self.verify_restore_in_modified_graph(
+ ds_fn1, ds_fn2, num_outputs, sparse_tensors=sparse_tensors)
+
+ def verify_unused_iterator(self,
+ ds_fn,
+ num_outputs,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that saving and restoring an unused iterator works.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn, [0],
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_fully_used_iterator(self, ds_fn, num_outputs,
+ sparse_tensors=False):
+ """Verifies that saving and restoring a fully used iterator works.
+
+ Note that this only checks saving and restoring an iterator from which
+ `num_outputs` items have been produced but does not check for an
+ exhausted iterator, i.e., one from which an OutOfRange error has been
+ returned.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn, [num_outputs], num_outputs, sparse_tensors=sparse_tensors)
+
+ def verify_exhausted_iterator(self, ds_fn, num_outputs, sparse_tensors=False):
+ """Verifies that saving and restoring an exhausted iterator works.
+
+ An exhausted iterator is one which has returned an OutOfRange error.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ verify_exhausted=True,
+ sparse_tensors=sparse_tensors)
+ actual = self.gen_outputs(
+ ds_fn, [],
+ 0,
+ ckpt_saved=True,
+ verify_exhausted=True,
+ sparse_tensors=sparse_tensors)
+ self.assertEqual(len(actual), 0)
+
+ def verify_init_before_restore(self,
+ ds_fn,
+ num_outputs,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that restoring into an already initialized iterator works.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn,
+ self.gen_break_points(num_outputs),
+ num_outputs,
+ init_before_restore=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_multiple_breaks(self,
+ ds_fn,
+ num_outputs,
+ num_breaks=10,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to save/restore at multiple break points.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ num_breaks: The number of break points. These are uniformly spread in
+ [0, num_outputs] both inclusive.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ self.verify_run_with_breaks(
+ ds_fn,
+ self.gen_break_points(num_outputs, num_breaks),
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ def verify_reset_restored_iterator(self,
+ ds_fn,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to re-initialize a restored iterator.
+
+ This is useful when restoring a training checkpoint during validation.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Collect ground truth containing all outputs.
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Skip some items and save checkpoint.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Restore from checkpoint and then run init_op.
+ with ops.Graph().as_default() as g:
+ saver = self._import_meta_graph()
+ init_op, get_next_op = self._get_iterator_ops_from_collection(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ self._initialize(init_op, sess)
+ for _ in range(num_outputs):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ self.match(expected, actual)
+
+ def verify_restore_in_modified_graph(self,
+ ds_fn1,
+ ds_fn2,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to restore an iterator in a modified graph.
+
+ Builds an input pipeline using ds_fn1, runs it for `break_point` steps
+ and saves a checkpoint. Then builds a new graph using ds_fn2, restores
+ the checkpoint from ds_fn1 and verifies that the restore is successful.
+
+ Args:
+ ds_fn1: See `run_core_tests`.
+ ds_fn2: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Skip `break_point` items and store the remaining produced from ds_fn1
+ # in `expected`.
+ self.gen_outputs(
+ ds_fn1, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+ expected = self.gen_outputs(
+ ds_fn1, [],
+ num_outputs - break_point,
+ ckpt_saved=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Generate `break_point` items from ds_fn1 and save checkpoint.
+ self.gen_outputs(
+ ds_fn1, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Build graph for ds_fn2 but load checkpoint for ds_fn1.
+ with ops.Graph().as_default() as g:
+ _, get_next_op, saver = self._build_graph(
+ ds_fn2, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ for _ in range(num_outputs - break_point):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ self.match(expected, actual)
+
+ def verify_restore_in_empty_graph(self,
+ ds_fn,
+ num_outputs,
+ break_point=None,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Attempts to restore an iterator in an empty graph.
+
+ Builds an input pipeline using ds_fn, runs it for `break_point` steps
+ and saves a checkpoint. Then builds a new empty graph, restores
+ the checkpoint from ds_fn and verifies that the restore is successful.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ break_point = num_outputs // 2 if not break_point else break_point
+
+ # Skip `break_point` items and store the remaining produced from ds_fn
+ # in `expected`.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs - break_point,
+ ckpt_saved=True,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ # Generate `break_point` items from ds_fn and save checkpoint.
+ self.gen_outputs(
+ ds_fn, [],
+ break_point,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=False)
+
+ actual = []
+ # Build an empty graph but load checkpoint for ds_fn.
+ with ops.Graph().as_default() as g:
+ get_next_op, saver = self._build_empty_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._restore(saver, sess)
+ for _ in range(num_outputs - break_point):
+ actual.append(sess.run(get_next_op))
+ if verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+
+ self.match(expected, actual)
+
+ def verify_error_on_save(self,
+ ds_fn,
+ num_outputs,
+ error,
+ break_point=None,
+ sparse_tensors=False):
+ """Attempts to save a non-saveable iterator.
+
+ Args:
+ ds_fn: See `run_core_tests`.
+ num_outputs: See `run_core_tests`.
+ error: Declared error when trying to save iterator.
+ break_point: Break point. Optional. Defaults to num_outputs/2.
+ sparse_tensors: See `run_core_tests`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+
+ break_point = num_outputs // 2 if not break_point else break_point
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, saver = self._build_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ self._initialize(init_op, sess)
+ for _ in range(break_point):
+ sess.run(get_next_op)
+ with self.assertRaises(error):
+ self._save(sess, saver)
+
+ def verify_run_with_breaks(self,
+ ds_fn,
+ break_points,
+ num_outputs,
+ init_before_restore=False,
+ sparse_tensors=False,
+ verify_exhausted=True):
+ """Verifies that ds_fn() produces the same outputs with and without breaks.
+
+ 1. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
+ *without* stopping at break points.
+ 2. Builds a Dataset using `ds_fn` and produces `num_outputs` items from it
+ with stopping at break points.
+
+ Deep matches outputs from 1 and 2.
+
+ Args:
+ ds_fn: See `gen_outputs`.
+ break_points: See `gen_outputs`.
+ num_outputs: See `gen_outputs`.
+ init_before_restore: See `gen_outputs`.
+ sparse_tensors: See `run_core_tests`.
+ verify_exhausted: See `gen_outputs`.
+
+ Raises:
+ AssertionError if any test fails.
+ """
+ expected = self.gen_outputs(
+ ds_fn, [],
+ num_outputs,
+ init_before_restore=init_before_restore,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ actual = self.gen_outputs(
+ ds_fn,
+ break_points,
+ num_outputs,
+ init_before_restore=init_before_restore,
+ sparse_tensors=sparse_tensors,
+ verify_exhausted=verify_exhausted)
+
+ self.match(expected, actual)
+
+ def gen_outputs(self,
+ ds_fn,
+ break_points,
+ num_outputs,
+ ckpt_saved=False,
+ init_before_restore=False,
+ sparse_tensors=False,
+ verify_exhausted=True,
+ save_checkpoint_at_end=True):
+ """Generates elements from input dataset while stopping at break points.
+
+ Produces `num_outputs` outputs and saves the state of the iterator in the
+ Saver checkpoint.
+
+ Args:
+ ds_fn: 0-argument function that returns the dataset.
+ break_points: A list of integers. For each `break_point` in
+ `break_points`, we produce outputs till `break_point` number of items
+ have been produced and then checkpoint the state. The current graph
+ and session are destroyed and a new graph and session are used to
+ produce outputs till next checkpoint or till `num_outputs` elements
+ have been produced. `break_point` must be <= `num_outputs`.
+ num_outputs: The total number of outputs to produce from the iterator.
+ ckpt_saved: Whether a checkpoint already exists. If False, we build the
+ graph from ds_fn.
+ init_before_restore: Whether init should be called before saver.restore.
+ This is just so that we can verify that restoring an already initialized
+ iterator works.
+ sparse_tensors: Whether dataset is built from SparseTensor(s).
+ verify_exhausted: Whether to verify that the iterator has been exhausted
+ after producing `num_outputs` elements.
+ save_checkpoint_at_end: Whether to save a checkpoint after producing all
+ outputs. If False, checkpoints are saved each break point but not at the
+ end. Note that checkpoints overwrite each other so there is always only
+ a single checkpoint available. Defaults to True.
+
+ Returns:
+ A list of `num_outputs` items.
+ """
+ outputs = []
+
+ def get_ops():
+ if ckpt_saved:
+ saver = self._import_meta_graph()
+ init_op, get_next_op = self._get_iterator_ops_from_collection(
+ ds_fn, sparse_tensors=sparse_tensors)
+ else:
+ init_op, get_next_op, saver = self._build_graph(
+ ds_fn, sparse_tensors=sparse_tensors)
+ return init_op, get_next_op, saver
+
+ for i in range(len(break_points) + 1):
+ with ops.Graph().as_default() as g:
+ init_op, get_next_op, saver = get_ops()
+ get_next_op = remove_variants(get_next_op)
+ with self.session(graph=g) as sess:
+ if ckpt_saved:
+ if init_before_restore:
+ self._initialize(init_op, sess)
+ self._restore(saver, sess)
+ else:
+ self._initialize(init_op, sess)
+ start = break_points[i - 1] if i > 0 else 0
+ end = break_points[i] if i < len(break_points) else num_outputs
+ num_iters = end - start
+ for _ in range(num_iters):
+ outputs.append(sess.run(get_next_op))
+ if i == len(break_points) and verify_exhausted:
+ with self.assertRaises(errors.OutOfRangeError):
+ sess.run(get_next_op)
+ if save_checkpoint_at_end or i < len(break_points):
+ self._save(sess, saver)
+ ckpt_saved = True
+
+ return outputs
+
+ def match(self, expected, actual):
+ """Matches nested structures.
+
+ Recursively matches shape and values of `expected` and `actual`.
+ Handles scalars, numpy arrays and other python sequence containers
+ e.g. list, dict.
+
+ Args:
+ expected: Nested structure 1.
+ actual: Nested structure 2.
+
+ Raises:
+ AssertionError if matching fails.
+ """
+ if isinstance(expected, np.ndarray):
+ expected = expected.tolist()
+ if isinstance(actual, np.ndarray):
+ actual = actual.tolist()
+ self.assertEqual(type(expected), type(actual))
+
+ if nest.is_sequence(expected):
+ self.assertEqual(len(expected), len(actual))
+ if isinstance(expected, dict):
+ for key1, key2 in zip(sorted(expected), sorted(actual)):
+ self.assertEqual(key1, key2)
+ self.match(expected[key1], actual[key2])
+ else:
+ for item1, item2 in zip(expected, actual):
+ self.match(item1, item2)
+ else:
+ self.assertEqual(expected, actual)
+
+ def does_not_match(self, expected, actual):
+ with self.assertRaises(AssertionError):
+ self.match(expected, actual)
+
+ def gen_break_points(self, num_outputs, num_samples=10):
+ """Generates `num_samples` breaks points in [0, num_outputs]."""
+ return np.linspace(0, num_outputs, num_samples, dtype=int)
+
+ def _build_graph(self, ds_fn, sparse_tensors=False):
+ iterator = ds_fn().make_initializable_iterator()
+
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ init_op = iterator.initializer
+ if sparse_tensors:
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ else:
+ get_next = iterator.get_next()
+ self._add_iterator_ops_to_collection(init_op, get_next, ds_fn,
+ sparse_tensors)
+ saver = saver_lib.Saver(allow_empty=True)
+ return init_op, get_next, saver
+
+ def _build_empty_graph(self, ds_fn, sparse_tensors=False):
+ iterator = iterator_ops.Iterator.from_structure(
+ self._get_output_types(ds_fn),
+ output_shapes=self._get_output_shapes(ds_fn),
+ output_classes=self._get_output_classes(ds_fn))
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ if sparse_tensors:
+ get_next = sparse_tensor.SparseTensor(*iterator.get_next())
+ else:
+ get_next = iterator.get_next()
+ saver = saver_lib.Saver(allow_empty=True)
+ return get_next, saver
+
+ def _add_iterator_ops_to_collection(self,
+ init_op,
+ get_next,
+ ds_fn,
+ sparse_tensors=False):
+ ops.add_to_collection("iterator_ops", init_op)
+ # `get_next` may be a tuple e.g. in TensorSliceDataset. Since Collections
+ # do not support tuples we flatten the tensors and restore the shape in
+ # `_get_iterator_ops_from_collection`.
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
+ ops.add_to_collection("iterator_ops", get_next.indices)
+ ops.add_to_collection("iterator_ops", get_next.values)
+ ops.add_to_collection("iterator_ops", get_next.dense_shape)
+ return
+
+ get_next_list = nest.flatten(get_next)
+ for i, output_class in enumerate(
+ nest.flatten(self._get_output_classes(ds_fn))):
+ if output_class is sparse_tensor.SparseTensor:
+ ops.add_to_collection("iterator_ops", get_next_list[i].indices)
+ ops.add_to_collection("iterator_ops", get_next_list[i].values)
+ ops.add_to_collection("iterator_ops", get_next_list[i].dense_shape)
+ else:
+ ops.add_to_collection("iterator_ops", get_next_list[i])
+
+ def _get_iterator_ops_from_collection(self, ds_fn, sparse_tensors=False):
+ all_ops = ops.get_collection("iterator_ops")
+ if sparse_tensors: # specific for deprecated `from_sparse_tensor_slices`.
+ init_op, indices, values, dense_shape = all_ops
+ return init_op, sparse_tensor.SparseTensor(indices, values, dense_shape)
+ get_next_list = []
+ i = 1
+ for output_class in nest.flatten(self._get_output_classes(ds_fn)):
+ if output_class is sparse_tensor.SparseTensor:
+ indices, values, dense_shape = all_ops[i:i + 3]
+ i += 3
+ get_next_list.append(
+ sparse_tensor.SparseTensor(indices, values, dense_shape))
+ else:
+ get_next_list.append(all_ops[i])
+ i += 1
+ return all_ops[0], nest.pack_sequence_as(
+ self._get_output_types(ds_fn), get_next_list)
+
+ def _get_output_types(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_types
+
+ def _get_output_shapes(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_shapes
+
+ def _get_output_classes(self, ds_fn):
+ with ops.Graph().as_default():
+ return ds_fn().output_classes
+
+ def _ckpt_path(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def _latest_ckpt(self):
+ return checkpoint_management.latest_checkpoint(self.get_temp_dir())
+
+ def _save(self, sess, saver):
+ saver.save(sess, self._ckpt_path())
+
+ def _restore(self, saver, sess):
+ sess.run(lookup_ops.tables_initializer())
+ saver.restore(sess, self._latest_ckpt())
+
+ def _initialize(self, init_op, sess):
+ sess.run(variables.global_variables_initializer())
+ sess.run(lookup_ops.tables_initializer())
+ sess.run(init_op)
+
+ def _import_meta_graph(self):
+ meta_file_path = self._ckpt_path() + ".meta"
+ return saver_lib.import_meta_graph(meta_file_path)
+
+ def _delete_ckpt(self):
+ # Remove all checkpoint files.
+ prefix = self._ckpt_path()
+ pattern = prefix + "*"
+ files = gfile.Glob(pattern)
+ map(gfile.Remove, files)
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
index 7c170078a1..225f6cbac0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/filter_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/filter_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import math_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
index 34392d88d4..70caf3e0d5 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/fixed_length_record_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py
index 16051ffd3f..c30534a9e9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/flat_map_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/flat_map_dataset_serialization_test.py
@@ -17,7 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py
index 571e0899bb..169c8845d0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_reducer_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_reducer_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py
index f86af4084e..e5bc76288e 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/group_by_window_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/group_by_window_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import grouping
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py
index 65ae9923b8..df1f43129a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/ignore_errors_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/ignore_errors_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import error_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import error_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py
index 243f6405a1..0c1d40ce39 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/interleave_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/interleave_dataset_serialization_test.py
@@ -20,7 +20,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import sparse_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
index c9cd211328..166ffa99ca 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_and_batch_dataset_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import math
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py
index ab783e5cce..b93156a96c 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/map_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/map_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py
index d5c03495e3..ed4a1da596 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/optimize_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/optimize_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import optimization
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import optimization
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
index 9ac42a461a..6f72b24673 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/padded_batch_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import string_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
index 1f8a584df9..b8f38e8a28 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_interleave_dataset_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.ops import sparse_ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
index 3fb7605be1..a0bdd4fa59 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parallel_map_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -65,7 +65,7 @@ class ParallelMapDatasetSerializationTest(
for ds_fn in [self._build_ds, self._build_ds_with_prefetch]:
self.run_core_tests(
ds_fn,
- lambda: ds_fn(multiplier=15.0),
+ lambda: ds_fn(multiplier=15.0), # pylint: disable=cell-var-from-loop
self._num_outputs)
def testSaveStatefulFunction(self):
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
index d3fa84e74c..a0dd6960b0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/parse_example_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/parse_example_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py
index c802402461..00d74c0025 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/prefetch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/prefetch_dataset_serialization_test.py
@@ -17,7 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
index 6341190847..ef99d01c73 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/range_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/range_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import os
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py
index fdb35ea624..c23c1ecdfb 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/sample_from_datasets_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sample_from_datasets_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import interleave_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py
index af9ef48c0f..5f50160619 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/scan_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/scan_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import scan_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py
index 2afebca0f5..fe99a3d3d9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/sequence_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sequence_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
index 6aac50ecd9..88d5c896c9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/serialization_integration_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/serialization_integration_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import os
-from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
index f199ec835e..f847ac19f9 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_and_repeat_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import shuffle_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
index a59fa94d66..a04f1ddafc 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/shuffle_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/shuffle_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
index 93b26ed58a..b179770ce3 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/sql_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/sql_dataset_serialization_test.py
@@ -19,9 +19,9 @@ from __future__ import print_function
import os
-from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import readers
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py
index 14cd3e9c4a..ef7061b190 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/stats_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py
@@ -17,9 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import stats_ops
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@@ -90,6 +91,16 @@ class StatsDatasetSerializationTest(
lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
None, num_outputs)
+ def _build_dataset_stats_aggregator(self):
+ stats_aggregator = stats_ops.StatsAggregator()
+ return dataset_ops.Dataset.range(10).apply(
+ stats_ops.set_stats_aggregator(stats_aggregator))
+
+ def test_set_stats_aggregator_not_support_checkpointing(self):
+ with self.assertRaisesRegexp(errors.UnimplementedError,
+ "does not support checkpointing"):
+ self.run_core_tests(self._build_dataset_stats_aggregator, None, 10)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py
index 2483787f44..c87a7443a7 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/textline_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/textline_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py
index 55a6257a27..f0dcc131d4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/tf_record_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/tf_record_dataset_serialization_test.py
@@ -21,8 +21,8 @@ import gzip
import os
import zlib
-from tensorflow.contrib.data.python.kernel_tests import reader_dataset_ops_test_base
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests import reader_dataset_ops_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import readers as core_readers
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py
index b2a5a8a20d..528598dfe4 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/unbatch_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unbatch_dataset_serialization_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import batching
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import batching
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py
index 22f15b8846..e2862af4d6 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/unique_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/unique_dataset_serialization_test.py
@@ -17,8 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
-from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.ops import unique
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py
index 340a6ff72e..4ea6131c22 100644
--- a/tensorflow/contrib/data/python/kernel_tests/serialization/zip_dataset_serialization_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization/zip_dataset_serialization_test.py
@@ -19,7 +19,7 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests.serialization import dataset_serialization_test_base
+from tensorflow.python.data.experimental.kernel_tests.serialization import dataset_serialization_test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.platform import test
diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py
new file mode 100644
index 0000000000..88d5c896c9
--- /dev/null
+++ b/tensorflow/python/data/experimental/kernel_tests/serialization_integration_test.py
@@ -0,0 +1,85 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Integration test for dataset serialization."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from tensorflow.python.data.experimental.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import saver as saver_lib
+
+
+class SerializationIntegrationTest(test.TestCase):
+
+ def _build_input_pipeline(self, name, num_outputs):
+ with ops.name_scope(name):
+ ds = dataset_ops.Dataset.range(num_outputs).shuffle(
+ 10, reshuffle_each_iteration=False).prefetch(10)
+ iterator = ds.make_initializable_iterator()
+ saveable = contrib_iterator_ops.make_saveable_from_iterator(iterator)
+ ops.add_to_collection(ops.GraphKeys.SAVEABLE_OBJECTS, saveable)
+ return iterator.initializer, iterator.get_next()
+
+ def _build_graph(self, num_pipelines, num_outputs):
+ init_ops = []
+ get_next_ops = []
+ for i in range(num_pipelines):
+ name = "input_pipeline_%d" % i
+ init_op, get_next_op = self._build_input_pipeline(name, num_outputs)
+ init_ops.append(init_op)
+ get_next_ops.append(get_next_op)
+ saver = saver_lib.Saver()
+ return init_ops, get_next_ops, saver
+
+ def _ckpt_path(self):
+ return os.path.join(self.get_temp_dir(), "iterator")
+
+ def testConcurrentSaves(self):
+ num_pipelines = 100
+ num_outputs = 100
+ break_point = 10
+ all_outputs = [[] for _ in range(num_pipelines)]
+ with ops.Graph().as_default() as g:
+ init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
+ num_outputs)
+ with self.session(graph=g) as sess:
+ sess.run(init_ops)
+ for _ in range(break_point):
+ output = sess.run(get_next_ops)
+ for i in range(num_pipelines):
+ all_outputs[i].append(output[i])
+ saver.save(sess, self._ckpt_path())
+
+ with ops.Graph().as_default() as g:
+ init_ops, get_next_ops, saver = self._build_graph(num_pipelines,
+ num_outputs)
+ with self.session(graph=g) as sess:
+ saver.restore(sess, self._ckpt_path())
+ for _ in range(num_outputs - break_point):
+ output = sess.run(get_next_ops)
+ for i in range(num_pipelines):
+ all_outputs[i].append(output[i])
+
+ for output in all_outputs:
+ self.assertSequenceEqual(sorted(output), range(num_outputs))
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py
index 440e48db30..50895b5945 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/shuffle_dataset_op_test.py
@@ -19,14 +19,15 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.ops import shuffle_ops
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
-class ShuffleAndRepeatTest(test.TestCase):
+class ShuffleAndRepeatTest(test_base.DatasetTestBase):
def _build_ds(self, seed, count=5, num_elements=20):
return dataset_ops.Dataset.range(num_elements).apply(
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py
index 52823d3fca..301f75488a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test.py
@@ -18,7 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.kernel_tests import sql_dataset_op_test_base
+from tensorflow.python.data.experimental.kernel_tests import sql_dataset_op_test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.platform import test
diff --git a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py
index 1f5c725a92..a135c357f0 100644
--- a/tensorflow/contrib/data/python/kernel_tests/sql_dataset_op_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/sql_dataset_op_test_base.py
@@ -23,13 +23,14 @@ import os
import sqlite3
-from tensorflow.contrib.data.python.ops import readers
+from tensorflow.python.data.experimental.ops import readers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
-class SqlDatasetTestBase(test.TestCase):
+class SqlDatasetTestBase(test_base.DatasetTestBase):
"""Base class for setting up and testing SqlDataset."""
def _createSqlDataset(self, output_types, num_repeats=1):
@@ -92,5 +93,3 @@ class SqlDatasetTestBase(test.TestCase):
9007199254740992.0)])
conn.commit()
conn.close()
-
-
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
index be8ae5e955..6761fbd16b 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_ops_test.py
@@ -19,8 +19,8 @@ from __future__ import print_function
import numpy as np
-from tensorflow.contrib.data.python.kernel_tests import stats_dataset_test_base
-from tensorflow.contrib.data.python.ops import stats_ops
+from tensorflow.python.data.experimental.kernel_tests import stats_dataset_test_base
+from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
diff --git a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py
index b1b4c23510..80f2625927 100644
--- a/tensorflow/contrib/data/python/kernel_tests/stats_dataset_test_base.py
+++ b/tensorflow/python/data/experimental/kernel_tests/stats_dataset_test_base.py
@@ -19,10 +19,10 @@ from __future__ import print_function
from tensorflow.core.framework import summary_pb2
-from tensorflow.python.platform import test
+from tensorflow.python.data.kernel_tests import test_base
-class StatsDatasetTestBase(test.TestCase):
+class StatsDatasetTestBase(test_base.DatasetTestBase):
"""Base class for testing statistics gathered in `StatsAggregator`."""
def _assertSummaryContains(self, summary_str, tag):
diff --git a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py
index 8d335e87d5..4432dcb05a 100644
--- a/tensorflow/contrib/data/python/kernel_tests/threadpool_dataset_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/threadpool_dataset_ops_test.py
@@ -22,8 +22,9 @@ import threading
from absl.testing import parameterized
import numpy as np
-from tensorflow.contrib.data.python.ops import threadpool
-from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.experimental.ops import threadpool
+from tensorflow.python.data.experimental.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -31,7 +32,8 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.platform import test
-class OverrideThreadpoolDatasetTest(test.TestCase, parameterized.TestCase):
+class OverrideThreadpoolDatasetTest(test_base.DatasetTestBase,
+ parameterized.TestCase):
@parameterized.named_parameters(
("1", 1, None),
diff --git a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py b/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py
index f994c8563f..b5a0b20f3f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/unique_dataset_op_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/unique_dataset_op_test.py
@@ -17,7 +17,8 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import unique
+from tensorflow.python.data.experimental.ops import unique
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -25,7 +26,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class UniqueDatasetTest(test.TestCase):
+class UniqueDatasetTest(test_base.DatasetTestBase):
def _testSimpleHelper(self, dtype, test_cases):
"""Test the `unique()` transformation on a list of test cases.
diff --git a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py
index 867ee2ba37..25a2e63ba1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/writer_ops_test.py
+++ b/tensorflow/python/data/experimental/kernel_tests/writer_ops_test.py
@@ -19,7 +19,8 @@ from __future__ import print_function
import os
-from tensorflow.contrib.data.python.ops import writers
+from tensorflow.python.data.experimental.ops import writers
+from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import dtypes
@@ -30,7 +31,7 @@ from tensorflow.python.platform import test
from tensorflow.python.util import compat
-class TFRecordWriterTest(test.TestCase):
+class TFRecordWriterTest(test_base.DatasetTestBase):
def setUp(self):
super(TFRecordWriterTest, self).setUp()
diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD
new file mode 100644
index 0000000000..915d399f1b
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/BUILD
@@ -0,0 +1,377 @@
+package(default_visibility = ["//tensorflow:internal"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load(
+ "//tensorflow:tensorflow.bzl",
+ "tf_gen_op_wrapper_py",
+ "tf_kernel_library",
+)
+load("//tensorflow:tensorflow.bzl", "tf_custom_op_py_library")
+
+py_library(
+ name = "counter",
+ srcs = ["counter.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":scan_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "get_single_element",
+ srcs = ["get_single_element.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "iterator_ops",
+ srcs = [
+ "iterator_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:basic_session_run_hooks",
+ "//tensorflow/python:checkpoint_management",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:saver",
+ "//tensorflow/python:session_run_hook",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/ops:optional_ops",
+ ],
+)
+
+py_library(
+ name = "random_ops",
+ srcs = [
+ "random_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "readers",
+ srcs = [
+ "readers.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":batching",
+ ":interleave_ops",
+ ":optimization",
+ ":parsing_ops",
+ ":shuffle_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:lib",
+ "//tensorflow/python:platform",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:convert",
+ "//tensorflow/python/data/util:nest",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "shuffle_ops",
+ srcs = [
+ "shuffle_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "batching",
+ srcs = ["batching.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":get_single_element",
+ ":grouping",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:tensor_util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:convert",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "enumerate_ops",
+ srcs = ["enumerate_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "error_ops",
+ srcs = ["error_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "grouping",
+ srcs = ["grouping.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:check_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "interleave_ops",
+ srcs = ["interleave_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":random_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:stateless_random_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:readers",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "optimization",
+ srcs = ["optimization.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "parsing_ops",
+ srcs = ["parsing_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:parsing_ops",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
+
+py_library(
+ name = "map_defun",
+ srcs = ["map_defun.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:tensor_shape",
+ ],
+)
+
+py_library(
+ name = "resampling",
+ srcs = ["resampling.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":batching",
+ ":interleave_ops",
+ ":scan_ops",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:logging_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_library(
+ name = "scan_ops",
+ srcs = ["scan_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:function",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "stats_ops",
+ srcs = ["stats_ops.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/ops:iterator_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "threadpool",
+ srcs = ["threadpool.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ "//tensorflow/python/eager:context",
+ ],
+)
+
+py_library(
+ name = "unique",
+ srcs = [
+ "unique.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "writers",
+ srcs = [
+ "writers.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
+ name = "indexed_dataset_ops",
+ srcs = ["indexed_dataset_ops.py"],
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "prefetching_ops",
+ srcs = ["prefetching_ops.py"],
+ deps = [
+ "//tensorflow/python:experimental_dataset_ops_gen",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
+ name = "dataset_ops",
+ deps = [
+ ":batching",
+ ":counter",
+ ":enumerate_ops",
+ ":error_ops",
+ ":get_single_element",
+ ":grouping",
+ ":indexed_dataset_ops",
+ ":interleave_ops",
+ ":map_defun",
+ ":optimization",
+ ":prefetching_ops",
+ ":readers",
+ ":resampling",
+ ":scan_ops",
+ ":shuffle_ops",
+ ":stats_ops",
+ ":threadpool",
+ ":unique",
+ ":writers",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:util",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ ],
+)
diff --git a/tensorflow/python/data/experimental/ops/batching.py b/tensorflow/python/data/experimental/ops/batching.py
new file mode 100644
index 0000000000..d42af9e7e9
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/batching.py
@@ -0,0 +1,669 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Batching dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import get_single_element
+from tensorflow.python.data.experimental.ops import grouping
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import gen_array_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import sparse_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+def batch_window(dataset):
+ """Batches a window of tensors.
+
+ Args:
+ dataset: the input dataset.
+
+ Returns:
+ A `Tensor` representing the batch of the entire input dataset.
+ """
+ if isinstance(dataset.output_classes, tuple):
+ raise TypeError("Input dataset expected to have a single component")
+ if dataset.output_classes is ops.Tensor:
+ return _batch_dense_window(dataset)
+ elif dataset.output_classes is sparse_tensor.SparseTensor:
+ return _batch_sparse_window(dataset)
+ else:
+ raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
+
+
+def _batch_dense_window(dataset):
+ """Batches a window of dense tensors."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def shape_init_fn(_):
+ return array_ops.shape(first_element)
+
+ def shape_reduce_fn(state, value):
+ check_ops.assert_equal(state, array_ops.shape(value))
+ return state
+
+ def finalize_fn(state):
+ return state
+
+ if dataset.output_shapes.is_fully_defined():
+ shape = dataset.output_shapes
+ else:
+ first_element = get_single_element.get_single_element(dataset.take(1))
+ shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
+ finalize_fn)
+ shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
+
+ def batch_init_fn(_):
+ batch_shape = array_ops.concat([[0], shape], 0)
+ return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
+
+ def batch_reduce_fn(state, value):
+ return array_ops.concat([state, [value]], 0)
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+def _batch_sparse_window(dataset):
+ """Batches a window of sparse tensors."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def shape_init_fn(_):
+ return first_element.dense_shape
+
+ def shape_reduce_fn(state, value):
+ check_ops.assert_equal(state, value.dense_shape)
+ return state
+
+ def finalize_fn(state):
+ return state
+
+ if dataset.output_shapes.is_fully_defined():
+ shape = dataset.output_shapes
+ else:
+ first_element = get_single_element.get_single_element(dataset.take(1))
+ shape_reducer = grouping.Reducer(shape_init_fn, shape_reduce_fn,
+ finalize_fn)
+ shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, shape_reducer)))
+
+ def batch_init_fn(_):
+ indices_shape = array_ops.concat([[0], [array_ops.size(shape) + 1]], 0)
+ return sparse_tensor.SparseTensor(
+ indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
+ values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
+ dense_shape=array_ops.concat(
+ [np.array([0], dtype=np.int64),
+ math_ops.cast(shape, dtypes.int64)], 0))
+
+ def batch_reduce_fn(state, value):
+ return sparse_ops.sparse_concat(0, [state, value])
+
+ def reshape_fn(value):
+ return sparse_ops.sparse_reshape(
+ value,
+ array_ops.concat([np.array([1], dtype=np.int64), value.dense_shape], 0))
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.map(reshape_fn).apply(
+ grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+@tf_export("data.experimental.dense_to_sparse_batch")
+def dense_to_sparse_batch(batch_size, row_shape):
+ """A transformation that batches ragged elements into `tf.SparseTensor`s.
+
+ Like `Dataset.padded_batch()`, this transformation combines multiple
+ consecutive elements of the dataset, which might have different
+ shapes, into a single element. The resulting element has three
+ components (`indices`, `values`, and `dense_shape`), which
+ comprise a `tf.SparseTensor` that represents the same data. The
+ `row_shape` represents the dense shape of each row in the
+ resulting `tf.SparseTensor`, to which the effective batch size is
+ prepended. For example:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset.
+ a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
+
+ a.apply(tf.data.experimental.dense_to_sparse_batch(
+ batch_size=2, row_shape=[6])) ==
+ {
+ ([[0, 0], [0, 1], [0, 2], [1, 0], [1, 1]], # indices
+ ['a', 'b', 'c', 'a', 'b'], # values
+ [2, 6]), # dense_shape
+ ([[0, 0], [0, 1], [0, 2], [0, 3]],
+ ['a', 'b', 'c', 'd'],
+ [1, 6])
+ }
+ ```
+
+ Args:
+ batch_size: A `tf.int64` scalar `tf.Tensor`, representing the
+ number of consecutive elements of this dataset to combine in a
+ single batch.
+ row_shape: A `tf.TensorShape` or `tf.int64` vector tensor-like
+ object representing the equivalent dense shape of a row in the
+ resulting `tf.SparseTensor`. Each element of this dataset must
+ have the same rank as `row_shape`, and must have size less
+ than or equal to `row_shape` in each dimension.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _DenseToSparseBatchDataset(dataset, batch_size, row_shape)
+
+ return _apply_fn
+
+
+def padded_batch_window(dataset, padded_shape, padding_value=None):
+ """Batches a window of tensors with padding.
+
+ Args:
+ dataset: the input dataset.
+ padded_shape: (Optional.) `tf.TensorShape` or `tf.int64` vector tensor-like
+ object representing the shape to which the input elements should be padded
+ prior to batching. Any unknown dimensions (e.g. `tf.Dimension(None)` in a
+ `tf.TensorShape` or `-1` in a tensor-like object) will be padded to the
+ maximum size of that dimension in each batch.
+ padding_value: (Optional.) A scalar-shaped `tf.Tensor`, representing the
+ padding value to use. Defaults are `0` for numeric types and the empty
+ string for string types. If `dataset` contains `tf.SparseTensor`, this
+ value is ignored.
+
+ Returns:
+ A `Tensor` representing the batch of the entire input dataset.
+
+ Raises:
+ ValueError: if invalid arguments are provided.
+ """
+ if not issubclass(dataset.output_classes,
+ (ops.Tensor, sparse_tensor.SparseTensor)):
+ raise TypeError("Input dataset expected to have a single tensor component")
+ if issubclass(dataset.output_classes, (ops.Tensor)):
+ return _padded_batch_dense_window(dataset, padded_shape, padding_value)
+ elif issubclass(dataset.output_classes, (sparse_tensor.SparseTensor)):
+ if padding_value is not None:
+ raise ValueError("Padding value not allowed for sparse tensors")
+ return _padded_batch_sparse_window(dataset, padded_shape)
+ else:
+ raise TypeError("Unsupported dataset type: %s" % dataset.output_classes)
+
+
+def _padded_batch_dense_window(dataset, padded_shape, padding_value=None):
+ """Batches a window of dense tensors with padding."""
+
+ padded_shape = math_ops.cast(
+ convert.partial_shape_to_tensor(padded_shape), dtypes.int32)
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def max_init_fn(_):
+ return padded_shape
+
+ def max_reduce_fn(state, value):
+ """Computes the maximum shape to pad to."""
+ condition = math_ops.reduce_all(
+ math_ops.logical_or(
+ math_ops.less_equal(array_ops.shape(value), padded_shape),
+ math_ops.equal(padded_shape, -1)))
+ assert_op = control_flow_ops.Assert(condition, [
+ "Actual shape greater than padded shape: ",
+ array_ops.shape(value), padded_shape
+ ])
+ with ops.control_dependencies([assert_op]):
+ return math_ops.maximum(state, array_ops.shape(value))
+
+ def finalize_fn(state):
+ return state
+
+ # Compute the padded shape.
+ max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
+ padded_shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
+
+ if padding_value is None:
+ if dataset.output_types == dtypes.string:
+ padding_value = ""
+ elif dataset.output_types == dtypes.bool:
+ padding_value = False
+ elif dataset.output_types == dtypes.variant:
+ raise TypeError("Unable to create padding for field of type 'variant'")
+ else:
+ padding_value = 0
+
+ def batch_init_fn(_):
+ batch_shape = array_ops.concat(
+ [np.array([0], dtype=np.int32), padded_shape], 0)
+ return gen_array_ops.empty(batch_shape, dtype=dataset.output_types)
+
+ def batch_reduce_fn(state, value):
+ return array_ops.concat([state, [value]], 0)
+
+ def pad_fn(value):
+ shape = array_ops.shape(value)
+ left = array_ops.zeros_like(shape)
+ right = padded_shape - shape
+ return array_ops.pad(
+ value, array_ops.stack([left, right], 1), constant_values=padding_value)
+
+ batch_reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.map(pad_fn).apply(
+ grouping.group_by_reducer(key_fn, batch_reducer)))
+
+
+def _padded_batch_sparse_window(dataset, padded_shape):
+ """Batches a window of sparse tensors with padding."""
+
+ def key_fn(_):
+ return np.int64(0)
+
+ def max_init_fn(_):
+ return convert.partial_shape_to_tensor(padded_shape)
+
+ def max_reduce_fn(state, value):
+ """Computes the maximum shape to pad to."""
+ condition = math_ops.reduce_all(
+ math_ops.logical_or(
+ math_ops.less_equal(value.dense_shape, padded_shape),
+ math_ops.equal(padded_shape, -1)))
+ assert_op = control_flow_ops.Assert(condition, [
+ "Actual shape greater than padded shape: ", value.dense_shape,
+ padded_shape
+ ])
+ with ops.control_dependencies([assert_op]):
+ return math_ops.maximum(state, value.dense_shape)
+
+ def finalize_fn(state):
+ return state
+
+ # Compute the padded shape.
+ max_reducer = grouping.Reducer(max_init_fn, max_reduce_fn, finalize_fn)
+ padded_shape = get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, max_reducer)))
+
+ def batch_init_fn(_):
+ indices_shape = array_ops.concat([[0], [array_ops.size(padded_shape) + 1]],
+ 0)
+ return sparse_tensor.SparseTensor(
+ indices=gen_array_ops.empty(indices_shape, dtype=dtypes.int64),
+ values=constant_op.constant([], shape=[0], dtype=dataset.output_types),
+ dense_shape=array_ops.concat(
+ [np.array([0], dtype=np.int64), padded_shape], 0))
+
+ def batch_reduce_fn(state, value):
+ padded_value = sparse_tensor.SparseTensor(
+ indices=value.indices, values=value.values, dense_shape=padded_shape)
+ reshaped_value = sparse_ops.sparse_reshape(
+ padded_value,
+ array_ops.concat(
+ [np.array([1], dtype=np.int64), padded_value.dense_shape], 0))
+ return sparse_ops.sparse_concat(0, [state, reshaped_value])
+
+ reducer = grouping.Reducer(batch_init_fn, batch_reduce_fn, finalize_fn)
+ return get_single_element.get_single_element(
+ dataset.apply(grouping.group_by_reducer(key_fn, reducer)))
+
+
+class _UnbatchDataset(dataset_ops.UnaryDataset):
+ """A dataset that splits the elements of its input into multiple elements."""
+
+ def __init__(self, input_dataset):
+ """See `unbatch()` for more details."""
+ super(_UnbatchDataset, self).__init__(input_dataset)
+ flat_shapes = nest.flatten(input_dataset.output_shapes)
+ if any(s.ndims == 0 for s in flat_shapes):
+ raise ValueError("Cannot unbatch an input with scalar components.")
+ known_batch_dim = tensor_shape.Dimension(None)
+ for s in flat_shapes:
+ try:
+ known_batch_dim = known_batch_dim.merge_with(s[0])
+ except ValueError:
+ raise ValueError("Cannot unbatch an input whose components have "
+ "different batch sizes.")
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.unbatch_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return nest.map_structure(lambda s: s[1:],
+ self._input_dataset.output_shapes)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+@tf_export("data.experimental.unbatch")
+def unbatch():
+ """Splits elements of a dataset into multiple elements on the batch dimension.
+
+ For example, if elements of the dataset are shaped `[B, a0, a1, ...]`,
+ where `B` may vary for each input element, then for each element in the
+ dataset, the unbatched dataset will contain `B` consecutive elements
+ of shape `[a0, a1, ...]`.
+
+ ```python
+ # NOTE: The following example uses `{ ... }` to represent the contents
+ # of a dataset.
+ a = { ['a', 'b', 'c'], ['a', 'b'], ['a', 'b', 'c', 'd'] }
+
+ a.apply(tf.data.experimental.unbatch()) == {
+ 'a', 'b', 'c', 'a', 'b', 'a', 'b', 'c', 'd'}
+ ```
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ if not sparse.any_sparse(dataset.output_classes):
+ return _UnbatchDataset(dataset)
+
+ # NOTE(mrry): We must ensure that any SparseTensors in `dataset`
+ # are normalized to the rank-1 dense representation, so that the
+ # sparse-oblivious unbatching logic will slice them
+ # appropriately. This leads to a somewhat inefficient re-encoding step
+ # for all SparseTensor components.
+ # TODO(mrry): Consider optimizing this in future
+ # if it turns out to be a bottleneck.
+ def normalize(arg, *rest):
+ if rest:
+ return sparse.serialize_many_sparse_tensors((arg,) + rest)
+ else:
+ return sparse.serialize_many_sparse_tensors(arg)
+
+ normalized_dataset = dataset.map(normalize)
+
+ # NOTE(mrry): Our `map()` has lost information about the sparseness
+ # of any SparseTensor components, so re-apply the structure of the
+ # original dataset.
+ restructured_dataset = _RestructuredDataset(
+ normalized_dataset,
+ dataset.output_types,
+ dataset.output_shapes,
+ dataset.output_classes,
+ allow_unsafe_cast=True)
+ return _UnbatchDataset(restructured_dataset)
+
+ return _apply_fn
+
+
+class _DenseToSparseBatchDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that batches ragged dense elements into `tf.SparseTensor`s."""
+
+ def __init__(self, input_dataset, batch_size, row_shape):
+ """See `Dataset.dense_to_sparse_batch()` for more details."""
+ super(_DenseToSparseBatchDataset, self).__init__(input_dataset)
+ if not isinstance(input_dataset.output_types, dtypes.DType):
+ raise TypeError("DenseToSparseDataset requires an input whose elements "
+ "have a single component, whereas the input has %r." %
+ input_dataset.output_types)
+ self._input_dataset = input_dataset
+ self._batch_size = batch_size
+ self._row_shape = row_shape
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.dense_to_sparse_batch_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._batch_size,
+ row_shape=convert.partial_shape_to_tensor(self._row_shape),
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return sparse_tensor.SparseTensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.vector(None).concatenate(self._row_shape)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _RestructuredDataset(dataset_ops.UnaryDataset):
+ """An internal helper for changing the structure and shape of a dataset."""
+
+ def __init__(self,
+ dataset,
+ output_types,
+ output_shapes=None,
+ output_classes=None,
+ allow_unsafe_cast=False):
+ """Creates a new dataset with the given output types and shapes.
+
+ The given `dataset` must have a structure that is convertible:
+ * `dataset.output_types` must be the same as `output_types` module nesting.
+ * Each shape in `dataset.output_shapes` must be compatible with each shape
+ in `output_shapes` (if given).
+
+ Note: This helper permits "unsafe casts" for shapes, equivalent to using
+ `tf.Tensor.set_shape()` where domain-specific knowledge is available.
+
+ Args:
+ dataset: A `Dataset` object.
+ output_types: A nested structure of `tf.DType` objects.
+ output_shapes: (Optional.) A nested structure of `tf.TensorShape` objects.
+ If omitted, the shapes will be inherited from `dataset`.
+ output_classes: (Optional.) A nested structure of class types.
+ If omitted, the class types will be inherited from `dataset`.
+ allow_unsafe_cast: (Optional.) If `True`, the caller may switch the
+ reported output types and shapes of the restructured dataset, e.g. to
+ switch a sparse tensor represented as `tf.variant` to its user-visible
+ type and shape.
+
+ Raises:
+ ValueError: If either `output_types` or `output_shapes` is not compatible
+ with the structure of `dataset`.
+ """
+ super(_RestructuredDataset, self).__init__(dataset)
+ self._input_dataset = dataset
+
+ if not allow_unsafe_cast:
+ # Validate that the types are compatible.
+ output_types = nest.map_structure(dtypes.as_dtype, output_types)
+ flat_original_types = nest.flatten(dataset.output_types)
+ flat_new_types = nest.flatten(output_types)
+ if flat_original_types != flat_new_types:
+ raise ValueError(
+ "Dataset with output types %r cannot be restructured to have "
+ "output types %r" % (dataset.output_types, output_types))
+
+ self._output_types = output_types
+
+ if output_shapes is None:
+ # Inherit shapes from the original `dataset`.
+ self._output_shapes = nest.pack_sequence_as(output_types,
+ nest.flatten(
+ dataset.output_shapes))
+ else:
+ if not allow_unsafe_cast:
+ # Validate that the shapes are compatible.
+ nest.assert_same_structure(output_types, output_shapes)
+ flat_original_shapes = nest.flatten(dataset.output_shapes)
+ flat_new_shapes = nest.flatten_up_to(output_types, output_shapes)
+
+ for original_shape, new_shape in zip(flat_original_shapes,
+ flat_new_shapes):
+ if not original_shape.is_compatible_with(new_shape):
+ raise ValueError(
+ "Dataset with output shapes %r cannot be restructured to have "
+ "incompatible output shapes %r" % (dataset.output_shapes,
+ output_shapes))
+ self._output_shapes = nest.map_structure_up_to(
+ output_types, tensor_shape.as_shape, output_shapes)
+ if output_classes is None:
+ # Inherit class types from the original `dataset`.
+ self._output_classes = nest.pack_sequence_as(output_types,
+ nest.flatten(
+ dataset.output_classes))
+ else:
+ self._output_classes = output_classes
+
+ def _as_variant_tensor(self):
+ return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+
+class _MapAndBatchDataset(dataset_ops.MapDataset):
+ """A `Dataset` that maps a function over a batch of elements."""
+
+ def __init__(self, input_dataset, map_func, batch_size, num_parallel_calls,
+ drop_remainder):
+ """See `Dataset.map()` for details."""
+ super(_MapAndBatchDataset, self).__init__(input_dataset, map_func)
+ self._batch_size_t = ops.convert_to_tensor(
+ batch_size, dtype=dtypes.int64, name="batch_size")
+ self._num_parallel_calls_t = ops.convert_to_tensor(
+ num_parallel_calls, dtype=dtypes.int64, name="num_parallel_calls")
+ self._drop_remainder_t = ops.convert_to_tensor(
+ drop_remainder, dtype=dtypes.bool, name="drop_remainder")
+
+ self._batch_size = batch_size
+ self._drop_remainder = drop_remainder
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ input_resource = self._input_dataset._as_variant_tensor()
+ return gen_dataset_ops.map_and_batch_dataset_v2(
+ input_resource,
+ self._map_func.captured_inputs,
+ f=self._map_func,
+ batch_size=self._batch_size_t,
+ num_parallel_calls=self._num_parallel_calls_t,
+ drop_remainder=self._drop_remainder_t,
+ **dataset_ops.flat_structure(self))
+ # pylint: enable=protected-access
+
+ @property
+ def output_shapes(self):
+ dim = self._batch_size if self._drop_remainder else None
+ return nest.pack_sequence_as(self._output_shapes, [
+ tensor_shape.vector(dim).concatenate(s)
+ for s in nest.flatten(self._output_shapes)
+ ])
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+
+@tf_export("data.experimental.map_and_batch")
+def map_and_batch(map_func,
+ batch_size,
+ num_parallel_batches=None,
+ drop_remainder=False,
+ num_parallel_calls=None):
+ """Fused implementation of `map` and `batch`.
+
+ Maps `map_func` across `batch_size` consecutive elements of this dataset
+ and then combines them into a batch. Functionally, it is equivalent to `map`
+ followed by `batch`. However, by fusing the two transformations together, the
+ implementation can be more efficient. Surfacing this transformation in the API
+ is temporary. Once automatic input pipeline optimization is implemented,
+ the fusing of `map` and `batch` will happen automatically and this API will be
+ deprecated.
+
+ Args:
+ map_func: A function mapping a nested structure of tensors to another
+ nested structure of tensors.
+ batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ consecutive elements of this dataset to combine in a single batch.
+ num_parallel_batches: (Optional.) A `tf.int64` scalar `tf.Tensor`,
+ representing the number of batches to create in parallel. On one hand,
+ higher values can help mitigate the effect of stragglers. On the other
+ hand, higher values can increase contention if CPU is scarce.
+ drop_remainder: (Optional.) A `tf.bool` scalar `tf.Tensor`, representing
+ whether the last batch should be dropped in case its size is smaller than
+ desired; the default behavior is not to drop the smaller batch.
+ num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
+ representing the number of elements to process in parallel. If not
+ specified, `batch_size * num_parallel_batches` elements will be
+ processed in parallel.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: If both `num_parallel_batches` and `num_parallel_calls` are
+ specified.
+ """
+
+ if num_parallel_batches is None and num_parallel_calls is None:
+ num_parallel_calls = batch_size
+ elif num_parallel_batches is not None and num_parallel_calls is None:
+ num_parallel_calls = batch_size * num_parallel_batches
+ elif num_parallel_batches is not None and num_parallel_calls is not None:
+ raise ValueError("The `num_parallel_batches` and `num_parallel_calls` "
+ "arguments are mutually exclusive.")
+
+ def _apply_fn(dataset):
+ return _MapAndBatchDataset(dataset, map_func, batch_size,
+ num_parallel_calls, drop_remainder)
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/counter.py b/tensorflow/python/data/experimental/ops/counter.py
new file mode 100644
index 0000000000..42200eaef9
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/counter.py
@@ -0,0 +1,55 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""The Counter Dataset."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import scan_ops
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.Counter")
+def Counter(start=0, step=1, dtype=dtypes.int64):
+ """Creates a `Dataset` that counts from `start` in steps of size `step`.
+
+ For example:
+
+ ```python
+ Dataset.count() == [0, 1, 2, ...)
+ Dataset.count(2) == [2, 3, ...)
+ Dataset.count(2, 5) == [2, 7, 12, ...)
+ Dataset.count(0, -1) == [0, -1, -2, ...)
+ Dataset.count(10, -1) == [10, 9, ...)
+ ```
+
+ Args:
+ start: (Optional.) The starting value for the counter. Defaults to 0.
+ step: (Optional.) The step size for the counter. Defaults to 1.
+ dtype: (Optional.) The data type for counter elements. Defaults to
+ `tf.int64`.
+
+ Returns:
+ A `Dataset` of scalar `dtype` elements.
+ """
+ with ops.name_scope("counter"):
+ start = ops.convert_to_tensor(start, dtype=dtype, name="start")
+ step = ops.convert_to_tensor(step, dtype=dtype, name="step")
+ return dataset_ops.Dataset.from_tensors(0).repeat(None).apply(
+ scan_ops.scan(start, lambda state, _: (state + step, state)))
diff --git a/tensorflow/python/data/experimental/ops/enumerate_ops.py b/tensorflow/python/data/experimental/ops/enumerate_ops.py
new file mode 100644
index 0000000000..a1af98f552
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/enumerate_ops.py
@@ -0,0 +1,60 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Enumerate dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.enumerate_dataset")
+def enumerate_dataset(start=0):
+ """A transformation that enumerate the elements of a dataset.
+
+ It is Similar to python's `enumerate`.
+ For example:
+
+ ```python
+ # NOTE: The following examples use `{ ... }` to represent the
+ # contents of a dataset.
+ a = { 1, 2, 3 }
+ b = { (7, 8), (9, 10) }
+
+ # The nested structure of the `datasets` argument determines the
+ # structure of elements in the resulting dataset.
+ a.apply(tf.data.experimental.enumerate(start=5)) == { (5, 1), (6, 2), (7, 3) }
+ b.apply(tf.data.experimental.enumerate()) == { (0, (7, 8)), (1, (9, 10)) }
+ ```
+
+ Args:
+ start: A `tf.int64` scalar `tf.Tensor`, representing the start
+ value for enumeration.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
+ return dataset_ops.Dataset.zip((dataset_ops.Dataset.range(start, max_value),
+ dataset))
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/error_ops.py b/tensorflow/python/data/experimental/ops/error_ops.py
new file mode 100644
index 0000000000..82e274b70c
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/error_ops.py
@@ -0,0 +1,78 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Ignore_errors dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.ignore_errors")
+def ignore_errors():
+ """Creates a `Dataset` from another `Dataset` and silently ignores any errors.
+
+ Use this transformation to produce a dataset that contains the same elements
+ as the input, but silently drops any elements that caused an error. For
+ example:
+
+ ```python
+ dataset = tf.data.Dataset.from_tensor_slices([1., 2., 0., 4.])
+
+ # Computing `tf.check_numerics(1. / 0.)` will raise an InvalidArgumentError.
+ dataset = dataset.map(lambda x: tf.check_numerics(1. / x, "error"))
+
+ # Using `ignore_errors()` will drop the element that causes an error.
+ dataset =
+ dataset.apply(tf.data.experimental.ignore_errors()) # ==> {1., 0.5, 0.2}
+ ```
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _IgnoreErrorsDataset(dataset)
+
+ return _apply_fn
+
+
+class _IgnoreErrorsDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that silently ignores errors when computing its input."""
+
+ def __init__(self, input_dataset):
+ """See `Dataset.ignore_errors()` for details."""
+ super(_IgnoreErrorsDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_experimental_dataset_ops.experimental_ignore_errors_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
diff --git a/tensorflow/python/data/experimental/ops/get_single_element.py b/tensorflow/python/data/experimental/ops/get_single_element.py
new file mode 100644
index 0000000000..132526166c
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/get_single_element.py
@@ -0,0 +1,72 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for Datasets and Iterators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.get_single_element")
+def get_single_element(dataset):
+ """Returns the single element in `dataset` as a nested structure of tensors.
+
+ This function enables you to use a `tf.data.Dataset` in a stateless
+ "tensor-in tensor-out" expression, without creating a `tf.data.Iterator`.
+ This can be useful when your preprocessing transformations are expressed
+ as a `Dataset`, and you want to use the transformation at serving time.
+ For example:
+
+ ```python
+ input_batch = tf.placeholder(tf.string, shape=[BATCH_SIZE])
+
+ def preprocessing_fn(input_str):
+ # ...
+ return image, label
+
+ dataset = (tf.data.Dataset.from_tensor_slices(input_batch)
+ .map(preprocessing_fn, num_parallel_calls=BATCH_SIZE)
+ .batch(BATCH_SIZE))
+
+ image_batch, label_batch = tf.data.experimental.get_single_element(dataset)
+ ```
+
+ Args:
+ dataset: A `tf.data.Dataset` object containing a single element.
+
+ Returns:
+ A nested structure of `tf.Tensor` objects, corresponding to the single
+ element of `dataset`.
+
+ Raises:
+ TypeError: if `dataset` is not a `tf.data.Dataset` object.
+ InvalidArgumentError (at runtime): if `dataset` does not contain exactly
+ one element.
+ """
+ if not isinstance(dataset, dataset_ops.Dataset):
+ raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
+
+ nested_ret = nest.pack_sequence_as(
+ dataset.output_types, gen_dataset_ops.dataset_to_single_element(
+ dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(dataset)))
+ return sparse.deserialize_sparse_tensors(
+ nested_ret, dataset.output_types, dataset.output_shapes,
+ dataset.output_classes)
diff --git a/tensorflow/python/data/experimental/ops/grouping.py b/tensorflow/python/data/experimental/ops/grouping.py
new file mode 100644
index 0000000000..18ba583220
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/grouping.py
@@ -0,0 +1,551 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Grouping dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.group_by_reducer")
+def group_by_reducer(key_func, reducer):
+ """A transformation that groups elements and performs a reduction.
+
+ This transformation maps element of a dataset to a key using `key_func` and
+ groups the elements by key. The `reducer` is used to process each group; its
+ `init_func` is used to initialize state for each group when it is created, the
+ `reduce_func` is used to update the state every time an element is mapped to
+ the matching group, and the `finalize_func` is used to map the final state to
+ an output value.
+
+ Args:
+ key_func: A function mapping a nested structure of tensors
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to a scalar `tf.int64` tensor.
+ reducer: An instance of `Reducer`, which captures the reduction logic using
+ the `init_func`, `reduce_func`, and `finalize_func` functions.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _GroupByReducerDataset(dataset, key_func, reducer)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.group_by_window")
+def group_by_window(key_func,
+ reduce_func,
+ window_size=None,
+ window_size_func=None):
+ """A transformation that groups windows of elements by key and reduces them.
+
+ This transformation maps each consecutive element in a dataset to a key
+ using `key_func` and groups the elements by key. It then applies
+ `reduce_func` to at most `window_size_func(key)` elements matching the same
+ key. All except the final window for each key will contain
+ `window_size_func(key)` elements; the final window may be smaller.
+
+ You may provide either a constant `window_size` or a window size determined by
+ the key through `window_size_func`.
+
+ Args:
+ key_func: A function mapping a nested structure of tensors
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to a scalar `tf.int64` tensor.
+ reduce_func: A function mapping a key and a dataset of up to `window_size`
+ consecutive elements matching that key to another dataset.
+ window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
+ consecutive elements matching the same key to combine in a single
+ batch, which will be passed to `reduce_func`. Mutually exclusive with
+ `window_size_func`.
+ window_size_func: A function mapping a key to a `tf.int64` scalar
+ `tf.Tensor`, representing the number of consecutive elements matching
+ the same key to combine in a single batch, which will be passed to
+ `reduce_func`. Mutually exclusive with `window_size`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if neither or both of {`window_size`, `window_size_func`} are
+ passed.
+ """
+ if (window_size is not None and window_size_func or
+ not (window_size is not None or window_size_func)):
+ raise ValueError("Must pass either window_size or window_size_func.")
+
+ if window_size is not None:
+
+ def constant_window_func(unused_key):
+ return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
+
+ window_size_func = constant_window_func
+
+ assert window_size_func is not None
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _GroupByWindowDataset(dataset, key_func, reduce_func,
+ window_size_func)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.bucket_by_sequence_length")
+def bucket_by_sequence_length(element_length_func,
+ bucket_boundaries,
+ bucket_batch_sizes,
+ padded_shapes=None,
+ padding_values=None,
+ pad_to_bucket_boundary=False,
+ no_padding=False):
+ """A transformation that buckets elements in a `Dataset` by length.
+
+ Elements of the `Dataset` are grouped together by length and then are padded
+ and batched.
+
+ This is useful for sequence tasks in which the elements have variable length.
+ Grouping together elements that have similar lengths reduces the total
+ fraction of padding in a batch which increases training step efficiency.
+
+ Args:
+ element_length_func: function from element in `Dataset` to `tf.int32`,
+ determines the length of the element, which will determine the bucket it
+ goes into.
+ bucket_boundaries: `list<int>`, upper length boundaries of the buckets.
+ bucket_batch_sizes: `list<int>`, batch size per bucket. Length should be
+ `len(bucket_boundaries) + 1`.
+ padded_shapes: Nested structure of `tf.TensorShape` to pass to
+ `tf.data.Dataset.padded_batch`. If not provided, will use
+ `dataset.output_shapes`, which will result in variable length dimensions
+ being padded out to the maximum length in each batch.
+ padding_values: Values to pad with, passed to
+ `tf.data.Dataset.padded_batch`. Defaults to padding with 0.
+ pad_to_bucket_boundary: bool, if `False`, will pad dimensions with unknown
+ size to maximum length in batch. If `True`, will pad dimensions with
+ unknown size to bucket boundary minus 1 (i.e., the maximum length in each
+ bucket), and caller must ensure that the source `Dataset` does not contain
+ any elements with length longer than `max(bucket_boundaries)`.
+ no_padding: `bool`, indicates whether to pad the batch features (features
+ need to be either of type `tf.SparseTensor` or of same shape).
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if `len(bucket_batch_sizes) != len(bucket_boundaries) + 1`.
+ """
+ with ops.name_scope("bucket_by_seq_length"):
+ if len(bucket_batch_sizes) != (len(bucket_boundaries) + 1):
+ raise ValueError(
+ "len(bucket_batch_sizes) must equal len(bucket_boundaries) + 1")
+
+ batch_sizes = constant_op.constant(bucket_batch_sizes, dtype=dtypes.int64)
+
+ def element_to_bucket_id(*args):
+ """Return int64 id of the length bucket for this element."""
+ seq_length = element_length_func(*args)
+
+ boundaries = list(bucket_boundaries)
+ buckets_min = [np.iinfo(np.int32).min] + boundaries
+ buckets_max = boundaries + [np.iinfo(np.int32).max]
+ conditions_c = math_ops.logical_and(
+ math_ops.less_equal(buckets_min, seq_length),
+ math_ops.less(seq_length, buckets_max))
+ bucket_id = math_ops.reduce_min(array_ops.where(conditions_c))
+
+ return bucket_id
+
+ def window_size_fn(bucket_id):
+ # The window size is set to the batch size for this bucket
+ window_size = batch_sizes[bucket_id]
+ return window_size
+
+ def make_padded_shapes(shapes, none_filler=None):
+ padded = []
+ for shape in nest.flatten(shapes):
+ shape = tensor_shape.TensorShape(shape)
+ shape = [
+ none_filler if d.value is None else d
+ for d in shape
+ ]
+ padded.append(shape)
+ return nest.pack_sequence_as(shapes, padded)
+
+ def batching_fn(bucket_id, grouped_dataset):
+ """Batch elements in dataset."""
+ batch_size = window_size_fn(bucket_id)
+ if no_padding:
+ return grouped_dataset.batch(batch_size)
+ none_filler = None
+ if pad_to_bucket_boundary:
+ err_msg = ("When pad_to_bucket_boundary=True, elements must have "
+ "length < max(bucket_boundaries).")
+ check = check_ops.assert_less(
+ bucket_id,
+ constant_op.constant(len(bucket_batch_sizes) - 1,
+ dtype=dtypes.int64),
+ message=err_msg)
+ with ops.control_dependencies([check]):
+ boundaries = constant_op.constant(bucket_boundaries,
+ dtype=dtypes.int64)
+ bucket_boundary = boundaries[bucket_id]
+ none_filler = bucket_boundary - 1
+ shapes = make_padded_shapes(
+ padded_shapes or grouped_dataset.output_shapes,
+ none_filler=none_filler)
+ return grouped_dataset.padded_batch(batch_size, shapes, padding_values)
+
+ def _apply_fn(dataset):
+ return dataset.apply(
+ group_by_window(element_to_bucket_id, batching_fn,
+ window_size_func=window_size_fn))
+
+ return _apply_fn
+
+
+def _map_x_dataset(map_func):
+ """A transformation that maps `map_func` across its input.
+
+ This transformation is similar to `tf.data.Dataset.map`, but in addition to
+ supporting dense and sparse tensor inputs, it also supports dataset inputs.
+
+ Args:
+ map_func: A function mapping a nested structure of tensors and/or datasets
+ (having shapes and types defined by `self.output_shapes` and
+ `self.output_types`) to another nested structure of tensors and/or
+ datasets.
+
+ Returns:
+ Dataset: A `Dataset`.
+ """
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ return _MapXDataset(dataset, map_func)
+
+ return _apply_fn
+
+
+class _GroupByReducerDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that groups its input and performs a reduction."""
+
+ def __init__(self, input_dataset, key_func, reducer):
+ """See `group_by_reducer()` for details."""
+ super(_GroupByReducerDataset, self).__init__(input_dataset)
+
+ self._input_dataset = input_dataset
+
+ self._make_key_func(key_func, input_dataset)
+ self._make_init_func(reducer.init_func)
+ self._make_reduce_func(reducer.reduce_func, input_dataset)
+ self._make_finalize_func(reducer.finalize_func)
+
+ def _make_key_func(self, key_func, input_dataset):
+ """Make wrapping Defun for key_func."""
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ key_func, "tf.data.experimental.group_by_reducer()", input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`key_func` must return a single tf.int64 tensor. "
+ "Got type=%s and shape=%s"
+ % (wrapped_func.output_types, wrapped_func.output_shapes))
+ self._key_func = wrapped_func.function
+
+ def _make_init_func(self, init_func):
+ """Make wrapping Defun for init_func."""
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ init_func,
+ "tf.data.experimental.group_by_reducer()",
+ input_classes=ops.Tensor,
+ input_shapes=tensor_shape.scalar(),
+ input_types=dtypes.int64)
+ self._init_func = wrapped_func.function
+ self._state_classes = wrapped_func.output_classes
+ self._state_shapes = wrapped_func.output_shapes
+ self._state_types = wrapped_func.output_types
+
+ def _make_reduce_func(self, reduce_func, input_dataset):
+ """Make wrapping Defun for reduce_func."""
+
+ # Iteratively rerun the reduce function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ reduce_func,
+ "tf.data.experimental.group_by_reducer()",
+ input_classes=(self._state_classes, input_dataset.output_classes),
+ input_shapes=(self._state_shapes, input_dataset.output_shapes),
+ input_types=(self._state_types, input_dataset.output_types),
+ add_to_graph=False)
+
+ # Extract and validate class information from the returned values.
+ for new_state_class, state_class in zip(
+ nest.flatten(wrapped_func.output_classes),
+ nest.flatten(self._state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes, wrapped_func.output_classes))
+
+ # Extract and validate type information from the returned values.
+ for new_state_type, state_type in zip(
+ nest.flatten(wrapped_func.output_types),
+ nest.flatten(self._state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types, wrapped_func.output_types))
+
+ # Extract shape information from the returned values.
+ flat_state_shapes = nest.flatten(self._state_shapes)
+ flat_new_state_shapes = nest.flatten(wrapped_func.output_shapes)
+ weakened_state_shapes = [
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ original_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ self._state_shapes = nest.pack_sequence_as(self._state_shapes,
+ weakened_state_shapes)
+
+ self._reduce_func = wrapped_func.function
+ self._reduce_func.add_to_graph(ops.get_default_graph())
+
+ def _make_finalize_func(self, finalize_func):
+ """Make wrapping Defun for finalize_func."""
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ finalize_func,
+ "tf.data.experimental.group_by_reducer()",
+ input_classes=self._state_classes,
+ input_shapes=self._state_shapes,
+ input_types=self._state_types)
+ self._finalize_func = wrapped_func.function
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.group_by_reducer_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._key_func.captured_inputs,
+ self._init_func.captured_inputs,
+ self._reduce_func.captured_inputs,
+ self._finalize_func.captured_inputs,
+ key_func=self._key_func,
+ init_func=self._init_func,
+ reduce_func=self._reduce_func,
+ finalize_func=self._finalize_func,
+ **dataset_ops.flat_structure(self))
+
+
+class _GroupByWindowDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that groups its input and performs a windowed reduction."""
+
+ def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
+ """See `group_by_window()` for details."""
+ super(_GroupByWindowDataset, self).__init__(input_dataset)
+
+ self._input_dataset = input_dataset
+
+ self._make_key_func(key_func, input_dataset)
+ self._make_reduce_func(reduce_func, input_dataset)
+ self._make_window_size_func(window_size_func)
+
+ def _make_window_size_func(self, window_size_func):
+ """Make wrapping Defun for window_size_func."""
+ def window_size_func_wrapper(key):
+ return ops.convert_to_tensor(window_size_func(key), dtype=dtypes.int64)
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ window_size_func_wrapper,
+ "tf.data.experimental.group_by_window()",
+ input_classes=ops.Tensor,
+ input_shapes=tensor_shape.scalar(),
+ input_types=dtypes.int64)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`window_size_func` must return a single tf.int64 scalar tensor.")
+ self._window_size_func = wrapped_func.function
+
+ def _make_key_func(self, key_func, input_dataset):
+ """Make wrapping Defun for key_func."""
+ def key_func_wrapper(*args):
+ return ops.convert_to_tensor(key_func(*args), dtype=dtypes.int64)
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ key_func_wrapper, "tf.data.experimental.group_by_window()",
+ input_dataset)
+ if not (
+ wrapped_func.output_types == dtypes.int64 and
+ wrapped_func.output_shapes.is_compatible_with(tensor_shape.scalar())):
+ raise ValueError(
+ "`key_func` must return a single tf.int64 scalar tensor.")
+ self._key_func = wrapped_func.function
+
+ def _make_reduce_func(self, reduce_func, input_dataset):
+ """Make wrapping Defun for reduce_func."""
+ nested_dataset = dataset_ops._NestedDatasetComponent(input_dataset) # pylint: disable=protected-access
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ reduce_func,
+ "tf.data.experimental.reduce_by_window()",
+ input_classes=(ops.Tensor, nested_dataset),
+ input_shapes=(tensor_shape.scalar(), nested_dataset),
+ input_types=(dtypes.int64, nested_dataset),
+ experimental_nested_dataset_support=True)
+ if not isinstance(
+ wrapped_func.output_classes, dataset_ops._NestedDatasetComponent): # pylint: disable=protected-access
+ raise TypeError("`reduce_func` must return a `Dataset` object.")
+ self._output_classes = wrapped_func.output_classes.output_classes
+ self._output_types = wrapped_func.output_types.output_types
+ self._output_shapes = wrapped_func.output_shapes.output_shapes
+ self._reduce_func = wrapped_func.function
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.group_by_window_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._key_func.captured_inputs,
+ self._reduce_func.captured_inputs,
+ self._window_size_func.captured_inputs,
+ key_func=self._key_func,
+ reduce_func=self._reduce_func,
+ window_size_func=self._window_size_func,
+ **dataset_ops.flat_structure(self))
+
+
+@tf_export("data.experimental.Reducer")
+class Reducer(object):
+ """A reducer is used for reducing a set of elements.
+
+ A reducer is represented as a tuple of the three functions:
+ 1) initialization function: key => initial state
+ 2) reduce function: (old state, input) => new state
+ 3) finalization function: state => result
+ """
+
+ def __init__(self, init_func, reduce_func, finalize_func):
+ self._init_func = init_func
+ self._reduce_func = reduce_func
+ self._finalize_func = finalize_func
+
+ @property
+ def init_func(self):
+ return self._init_func
+
+ @property
+ def reduce_func(self):
+ return self._reduce_func
+
+ @property
+ def finalize_func(self):
+ return self._finalize_func
+
+
+class _MapXDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that maps a function over elements in its input."""
+
+ def __init__(self, input_dataset, map_func):
+ """See `map_x_dataset()` for details."""
+ super(_MapXDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ map_func,
+ "tf.data.experimental.map_x_dataset()",
+ input_dataset,
+ experimental_nested_dataset_support=True)
+ self._output_classes = wrapped_func.output_classes
+ self._output_shapes = wrapped_func.output_shapes
+ self._output_types = wrapped_func.output_types
+ self._map_func = wrapped_func.function
+
+ def _as_variant_tensor(self):
+ input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+ return gen_dataset_ops.map_dataset(
+ input_t,
+ self._map_func.captured_inputs,
+ f=self._map_func,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
diff --git a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py b/tensorflow/python/data/experimental/ops/indexed_dataset_ops.py
index cc76ab0850..9c06474a2f 100644
--- a/tensorflow/contrib/data/python/ops/indexed_dataset_ops.py
+++ b/tensorflow/python/data/experimental/ops/indexed_dataset_ops.py
@@ -19,14 +19,13 @@ from __future__ import print_function
import abc
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import sparse
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
class MaterializedIndexedDataset(object):
@@ -57,7 +56,7 @@ class MaterializedIndexedDataset(object):
A tensor containing the values corresponding to `index`.
"""
# TODO(saeta): nest.pack_sequence_as(...)
- return gen_dataset_ops.indexed_dataset_get(
+ return ged_ops.experimental_indexed_dataset_get(
self._materialized_resource,
index,
output_types=nest.flatten(
@@ -90,16 +89,18 @@ class IndexedDataset(dataset_ops.Dataset):
container = ""
if shared_name is None:
shared_name = ""
- materialized_resource = gen_dataset_ops.materialized_index_dataset_handle(
- container=container,
- shared_name=shared_name,
- output_types=nest.flatten(
- sparse.as_dense_types(self.output_types, self.output_classes)),
- output_shapes=nest.flatten(
- sparse.as_dense_types(self.output_shapes, self.output_classes)))
+ materialized_resource = (
+ ged_ops.experimental_materialized_index_dataset_handle(
+ container=container,
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ output_shapes=nest.flatten(
+ sparse.as_dense_types(self.output_shapes,
+ self.output_classes))))
with ops.colocate_with(materialized_resource):
- materializer = gen_dataset_ops.indexed_dataset_materialize(
+ materializer = ged_ops.experimental_indexed_dataset_materialize(
self._as_variant_tensor(), materialized_resource)
return MaterializedIndexedDataset(materialized_resource, materializer,
self.output_classes, self.output_types,
@@ -170,7 +171,7 @@ class IdentityIndexedDataset(IndexedDataset):
return tensor_shape.scalar()
def _as_variant_tensor(self):
- return gen_dataset_ops.identity_indexed_dataset(self._size)
+ return ged_ops.experimental_identity_indexed_dataset(self._size)
def _inputs(self):
return []
diff --git a/tensorflow/python/data/experimental/ops/interleave_ops.py b/tensorflow/python/data/experimental/ops/interleave_ops.py
new file mode 100644
index 0000000000..a3c094859e
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/interleave_ops.py
@@ -0,0 +1,262 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Non-deterministic dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.experimental.ops import random_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.ops import gen_stateless_random_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.parallel_interleave")
+def parallel_interleave(map_func,
+ cycle_length,
+ block_length=1,
+ sloppy=False,
+ buffer_output_elements=None,
+ prefetch_input_elements=None):
+ """A parallel version of the `Dataset.interleave()` transformation.
+
+ `parallel_interleave()` maps `map_func` across its input to produce nested
+ datasets, and outputs their elements interleaved. Unlike
+ `tf.data.Dataset.interleave`, it gets elements from `cycle_length` nested
+ datasets in parallel, which increases the throughput, especially in the
+ presence of stragglers. Furthermore, the `sloppy` argument can be used to
+ improve performance, by relaxing the requirement that the outputs are produced
+ in a deterministic order, and allowing the implementation to skip over nested
+ datasets whose elements are not readily available when requested.
+
+ Example usage:
+
+ ```python
+ # Preprocess 4 files concurrently.
+ filenames = tf.data.Dataset.list_files("/path/to/data/train*.tfrecords")
+ dataset = filenames.apply(
+ tf.data.experimental.parallel_interleave(
+ lambda filename: tf.data.TFRecordDataset(filename),
+ cycle_length=4))
+ ```
+
+ WARNING: If `sloppy` is `True`, the order of produced elements is not
+ deterministic.
+
+ Args:
+ map_func: A function mapping a nested structure of tensors to a `Dataset`.
+ cycle_length: The number of input `Dataset`s to interleave from in parallel.
+ block_length: The number of consecutive elements to pull from an input
+ `Dataset` before advancing to the next input `Dataset`.
+ sloppy: If false, elements are produced in deterministic order. Otherwise,
+ the implementation is allowed, for the sake of expediency, to produce
+ elements in a non-deterministic order.
+ buffer_output_elements: The number of elements each iterator being
+ interleaved should buffer (similar to the `.prefetch()` transformation for
+ each interleaved iterator).
+ prefetch_input_elements: The number of input elements to transform to
+ iterators before they are needed for interleaving.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ return readers.ParallelInterleaveDataset(
+ dataset, map_func, cycle_length, block_length, sloppy,
+ buffer_output_elements, prefetch_input_elements)
+
+ return _apply_fn
+
+
+class _DirectedInterleaveDataset(dataset_ops.Dataset):
+ """A substitute for `Dataset.interleave()` on a fixed list of datasets."""
+
+ def __init__(self, selector_input, data_inputs):
+ self._selector_input = selector_input
+ self._data_inputs = list(data_inputs)
+
+ for data_input in data_inputs[1:]:
+ if (data_input.output_types != data_inputs[0].output_types or
+ data_input.output_classes != data_inputs[0].output_classes):
+ raise TypeError("All datasets must have the same type and class.")
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ return (
+ gen_experimental_dataset_ops.experimental_directed_interleave_dataset(
+ self._selector_input._as_variant_tensor(), [
+ data_input._as_variant_tensor()
+ for data_input in self._data_inputs
+ ], **dataset_ops.flat_structure(self)))
+ # pylint: enable=protected-access
+
+ def _inputs(self):
+ return [self._selector_input] + self._data_inputs
+
+ @property
+ def output_classes(self):
+ return self._data_inputs[0].output_classes
+
+ @property
+ def output_shapes(self):
+ ret = self._data_inputs[0].output_shapes
+ for data_input in self._data_inputs[1:]:
+ ret = nest.pack_sequence_as(ret, [
+ ts1.most_specific_compatible_shape(ts2) for (ts1, ts2) in zip(
+ nest.flatten(ret), nest.flatten(data_input.output_shapes))
+ ])
+ return ret
+
+ @property
+ def output_types(self):
+ return self._data_inputs[0].output_types
+
+
+@tf_export("data.experimental.sample_from_datasets")
+def sample_from_datasets(datasets, weights=None, seed=None):
+ """Samples elements at random from the datasets in `datasets`.
+
+ Args:
+ datasets: A list of `tf.data.Dataset` objects with compatible structure.
+ weights: (Optional.) A list of `len(datasets)` floating-point values where
+ `weights[i]` represents the probability with which an element should be
+ sampled from `datasets[i]`, or a `tf.data.Dataset` object where each
+ element is such a list. Defaults to a uniform distribution across
+ `datasets`.
+ seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ random seed that will be used to create the distribution. See
+ `tf.set_random_seed` for behavior.
+
+ Returns:
+ A dataset that interleaves elements from `datasets` at random, according to
+ `weights` if provided, otherwise with uniform probability.
+
+ Raises:
+ TypeError: If the `datasets` or `weights` arguments have the wrong type.
+ ValueError: If the `weights` argument is specified and does not match the
+ length of the `datasets` element.
+ """
+ num_datasets = len(datasets)
+ if not isinstance(weights, dataset_ops.Dataset):
+ if weights is None:
+ # Select inputs with uniform probability.
+ logits = [[1.0] * num_datasets]
+
+ else:
+ # Use the given `weights` as the probability of choosing the respective
+ # input.
+ weights = ops.convert_to_tensor(weights, name="weights")
+ if weights.dtype not in (dtypes.float32, dtypes.float64):
+ raise TypeError("`weights` must be convertible to a tensor of "
+ "`tf.float32` or `tf.float64` elements.")
+ if not weights.shape.is_compatible_with([num_datasets]):
+ raise ValueError(
+ "`weights` must be a vector of length `len(datasets)`.")
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed
+ # to weights.
+ logits = array_ops.expand_dims(math_ops.log(weights, name="logits"), 0)
+
+ # NOTE(mrry): We only specialize when `weights` is not a `Dataset`. When it
+ # is a `Dataset`, it is possible that evaluating it has a side effect the
+ # user depends on.
+ if len(datasets) == 1:
+ return datasets[0]
+
+ def select_dataset_constant_logits(seed):
+ return array_ops.squeeze(
+ gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
+ axis=[0, 1])
+
+ selector_input = dataset_ops.MapDataset(
+ random_ops.RandomDataset(seed).batch(2),
+ select_dataset_constant_logits,
+ use_inter_op_parallelism=False)
+
+ else:
+ # Use each element of the given `weights` dataset as the probability of
+ # choosing the respective input.
+
+ # The `stateless_multinomial()` op expects log-probabilities, as opposed to
+ # weights.
+ logits_ds = weights.map(lambda *p: math_ops.log(p, name="logits"))
+
+ def select_dataset_varying_logits(logits, seed):
+ return array_ops.squeeze(
+ gen_stateless_random_ops.stateless_multinomial(logits, 1, seed=seed),
+ axis=[0, 1])
+
+ logits_and_seeds = dataset_ops.Dataset.zip(
+ (logits_ds, random_ops.RandomDataset(seed).batch(2)))
+ selector_input = dataset_ops.MapDataset(
+ logits_and_seeds,
+ select_dataset_varying_logits,
+ use_inter_op_parallelism=False)
+
+ return _DirectedInterleaveDataset(selector_input, datasets)
+
+
+@tf_export("data.experimental.choose_from_datasets")
+def choose_from_datasets(datasets, choice_dataset):
+ """Creates a dataset that deterministically chooses elements from `datasets`.
+
+ For example, given the following datasets:
+
+ ```python
+ datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
+ tf.data.Dataset.from_tensors("bar").repeat(),
+ tf.data.Dataset.from_tensors("baz").repeat()]
+
+ # Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
+ choice_dataset = tf.data.Dataset.range(3).repeat(3)
+
+ result = tf.data.experimental.choose_from_datasets(datasets, choice_dataset)
+ ```
+
+ The elements of `result` will be:
+
+ ```
+ "foo", "bar", "baz", "foo", "bar", "baz", "foo", "bar", "baz"
+ ```
+
+ Args:
+ datasets: A list of `tf.data.Dataset` objects with compatible structure.
+ choice_dataset: A `tf.data.Dataset` of scalar `tf.int64` tensors between
+ `0` and `len(datasets) - 1`.
+
+ Returns:
+ A dataset that interleaves elements from `datasets` according to the values
+ of `choice_dataset`.
+
+ Raises:
+ TypeError: If the `datasets` or `choice_dataset` arguments have the wrong
+ type.
+ """
+ if not (choice_dataset.output_types == dtypes.int64
+ and choice_dataset.output_shapes.is_compatible_with(
+ tensor_shape.scalar())
+ and choice_dataset.output_classes == ops.Tensor):
+ raise TypeError("`choice_dataset` must be a dataset of scalar "
+ "`tf.int64` tensors.")
+ return _DirectedInterleaveDataset(choice_dataset, datasets)
diff --git a/tensorflow/python/data/experimental/ops/iterator_ops.py b/tensorflow/python/data/experimental/ops/iterator_ops.py
new file mode 100644
index 0000000000..72d7d58f06
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/iterator_ops.py
@@ -0,0 +1,268 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Iterator ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.ops import optional_ops
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.training import basic_session_run_hooks
+from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import saver as saver_lib
+from tensorflow.python.training import session_run_hook
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.make_saveable_from_iterator")
+def make_saveable_from_iterator(iterator):
+ """Returns a SaveableObject for saving/restore iterator state using Saver.
+
+ Args:
+ iterator: Iterator.
+
+ For example:
+
+ ```python
+ with tf.Graph().as_default():
+ ds = tf.data.Dataset.range(10)
+ iterator = ds.make_initializable_iterator()
+ # Build the iterator SaveableObject.
+ saveable_obj = tf.data.experimental.make_saveable_from_iterator(iterator)
+ # Add the SaveableObject to the SAVEABLE_OBJECTS collection so
+ # it can be automatically saved using Saver.
+ tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)
+ saver = tf.train.Saver()
+
+ while continue_training:
+ ... Perform training ...
+ if should_save_checkpoint:
+ saver.save()
+ ```
+
+ Note: When restoring the iterator, the existing iterator state is completely
+ discarded. This means that any changes you may have made to the Dataset
+ graph will be discarded as well! This includes the new Dataset graph
+ that you may have built during validation. So, while running validation,
+ make sure to run the initializer for the validation input pipeline after
+ restoring the checkpoint.
+
+ Note: Not all iterators support checkpointing yet. Attempting to save the
+ state of an unsupported iterator will throw an error.
+ """
+ return _Saveable(iterator._iterator_resource) # pylint: disable=protected-access
+
+
+class _Saveable(saver_lib.BaseSaverBuilder.SaveableObject):
+ """SaveableObject for saving/restoring iterator state."""
+
+ def __init__(self, iterator_resource):
+ serialized_iterator = gen_dataset_ops.serialize_iterator(iterator_resource)
+ specs = [
+ saver_lib.BaseSaverBuilder.SaveSpec(serialized_iterator, "",
+ iterator_resource.name + "-state")
+ ]
+ super(_Saveable, self).__init__(iterator_resource, specs,
+ iterator_resource.name)
+
+ def restore(self, restored_tensors, unused_restored_shapes):
+ with ops.colocate_with(self.op):
+ return gen_dataset_ops.deserialize_iterator(self.op, restored_tensors[0])
+
+
+@tf_export("data.experimental.CheckpointInputPipelineHook")
+class CheckpointInputPipelineHook(session_run_hook.SessionRunHook):
+ """Checkpoints input pipeline state every N steps or seconds.
+
+ This hook saves the state of the iterators in the `Graph` so that when
+ training is resumed the input pipeline continues from where it left off.
+ This could potentially avoid overfitting in certain pipelines where the
+ number of training steps per eval are small compared to the dataset
+ size or if the training pipeline is pre-empted.
+
+ Differences from `CheckpointSaverHook`:
+ 1. Saves only the input pipelines in the "iterators" collection and not the
+ global variables or other saveable objects.
+ 2. Does not write the `GraphDef` and `MetaGraphDef` to the summary.
+
+ Example of checkpointing the training pipeline:
+
+ ```python
+ est = tf.estimator.Estimator(model_fn)
+ while True:
+ est.train(
+ train_input_fn,
+ hooks=[tf.data.experimental.CheckpointInputPipelineHook(est)],
+ steps=train_steps_per_eval)
+ # Note: We do not pass the hook here.
+ metrics = est.evaluate(eval_input_fn)
+ if should_stop_the_training(metrics):
+ break
+ ```
+
+ This hook should be used if the input pipeline state needs to be saved
+ separate from the model checkpoint. Doing so may be useful for a few reasons:
+ 1. The input pipeline checkpoint may be large, if there are large shuffle
+ or prefetch buffers for instance, and may bloat the checkpoint size.
+ 2. If the input pipeline is shared between training and validation, restoring
+ the checkpoint during validation may override the validation input
+ pipeline.
+
+ For saving the input pipeline checkpoint alongside the model weights use
+ `tf.data.experimental.make_saveable_from_iterator` directly to create a
+ `SaveableObject` and add to the `SAVEABLE_OBJECTS` collection. Note, however,
+ that you will need to be careful not to restore the training iterator during
+ eval. You can do that by not adding the iterator to the SAVEABLE_OBJECTS
+ collector when building the eval graph.
+ """
+
+ def __init__(self, estimator):
+ """Initializes a `CheckpointInputPipelineHook`.
+
+ Args:
+ estimator: Estimator.
+
+ Raises:
+ ValueError: One of `save_steps` or `save_secs` should be set.
+ ValueError: At most one of saver or scaffold should be set.
+ """
+ # `checkpoint_basename` is "input.ckpt" for non-distributed pipelines or
+ # of the form "input_<task_type>_<task_id>.ckpt" for distributed pipelines.
+ # Note: The default `checkpoint_basename` used by `CheckpointSaverHook` is
+ # "model.ckpt". We intentionally choose the input pipeline checkpoint prefix
+ # to be different to avoid conflicts with the model checkpoint.
+
+ # pylint: disable=protected-access
+ checkpoint_prefix = "input"
+ if estimator._config.num_worker_replicas > 1:
+ # Distributed setting.
+ suffix = "_{}_{}".format(estimator._config.task_type,
+ estimator._config.task_id)
+ checkpoint_prefix += suffix
+ # pylint: enable=protected-access
+
+ # We use a composition paradigm instead of inheriting from
+ # `CheckpointSaverHook` because `Estimator` does an `isinstance` check
+ # to check whether a `CheckpointSaverHook` is already present in the list
+ # of hooks and if not, adds one. Inheriting from `CheckpointSaverHook`
+ # would thwart this behavior. This hook checkpoints *only the iterators*
+ # and not the graph variables.
+ self._checkpoint_saver_hook = basic_session_run_hooks.CheckpointSaverHook(
+ estimator.model_dir,
+ save_secs=estimator._config.save_checkpoints_secs, # pylint: disable=protected-access
+ save_steps=estimator._config.save_checkpoints_steps, # pylint: disable=protected-access
+ checkpoint_basename=checkpoint_prefix + ".ckpt")
+
+ # Name for the protocol buffer file that will contain the list of most
+ # recent checkpoints stored as a `CheckpointState` protocol buffer.
+ # This file, kept in the same directory as the checkpoint files, is
+ # automatically managed by the `Saver` to keep track of recent checkpoints.
+ # The default name used by the `Saver` for this file is "checkpoint". Here
+ # we use the name "checkpoint_<checkpoint_prefix>" so that in case the
+ # `checkpoint_dir` is the same as the model checkpoint directory, there are
+ # no conflicts during restore.
+ self._latest_filename = "checkpoint_" + checkpoint_prefix
+ self._first_run = True
+
+ def begin(self):
+ # Build a Saver that saves all iterators in the `GLOBAL_ITERATORS`
+ # collection if no `Saver` or `Scaffold` is provided.
+ # pylint: disable=protected-access
+ if (self._checkpoint_saver_hook._saver is None and
+ self._checkpoint_saver_hook._scaffold is None):
+ iterators = ops.get_collection(iterator_ops.GLOBAL_ITERATORS)
+ saveables = [_Saveable(i) for i in iterators]
+ self._checkpoint_saver_hook._saver = _CustomSaver(saveables,
+ self._latest_filename)
+ # pylint: enable=protected-access
+ self._checkpoint_saver_hook.begin()
+
+ def _restore_or_save_initial_ckpt(self, session):
+ # Ideally this should be run in after_create_session but is not for the
+ # following reason:
+ # Currently there is no way of enforcing an order of running the
+ # `SessionRunHooks`. Hence it is possible that the `_DatasetInitializerHook`
+ # is run *after* this hook. That is troublesome because
+ # 1. If a checkpoint exists and this hook restores it, the initializer hook
+ # will override it.
+ # 2. If no checkpoint exists, this hook will try to save an initialized
+ # iterator which will result in an exception.
+ #
+ # As a temporary fix we enter the following implicit contract between this
+ # hook and the _DatasetInitializerHook.
+ # 1. The _DatasetInitializerHook initializes the iterator in the call to
+ # after_create_session.
+ # 2. This hook saves the iterator on the first call to `before_run()`, which
+ # is guaranteed to happen after `after_create_session()` of all hooks
+ # have been run.
+
+ # Check if there is an existing checkpoint. If so, restore from it.
+ # pylint: disable=protected-access
+ latest_checkpoint_path = checkpoint_management.latest_checkpoint(
+ self._checkpoint_saver_hook._checkpoint_dir,
+ latest_filename=self._latest_filename)
+ if latest_checkpoint_path:
+ self._checkpoint_saver_hook._get_saver().restore(session,
+ latest_checkpoint_path)
+ else:
+ # The checkpoint saved here is the state at step "global_step".
+ # Note: We do not save the GraphDef or MetaGraphDef here.
+ global_step = session.run(self._checkpoint_saver_hook._global_step_tensor)
+ self._checkpoint_saver_hook._save(session, global_step)
+ self._checkpoint_saver_hook._timer.update_last_triggered_step(global_step)
+ # pylint: enable=protected-access
+
+ def before_run(self, run_context):
+ if self._first_run:
+ self._restore_or_save_initial_ckpt(run_context.session)
+ self._first_run = False
+ return self._checkpoint_saver_hook.before_run(run_context)
+
+ def after_run(self, run_context, run_values):
+ self._checkpoint_saver_hook.after_run(run_context, run_values)
+
+ def end(self, session):
+ self._checkpoint_saver_hook.end(session)
+
+
+class _CustomSaver(saver_lib.Saver):
+ """`Saver` with a different default `latest_filename`.
+
+ This is used in the `CheckpointInputPipelineHook` to avoid conflicts with
+ the model ckpt saved by the `CheckpointSaverHook`.
+ """
+
+ def __init__(self, var_list, latest_filename):
+ super(_CustomSaver, self).__init__(var_list)
+ self._latest_filename = latest_filename
+
+ def save(self,
+ sess,
+ save_path,
+ global_step=None,
+ latest_filename=None,
+ meta_graph_suffix="meta",
+ write_meta_graph=True,
+ write_state=True,
+ strip_default_attrs=False):
+ return super(_CustomSaver, self).save(
+ sess, save_path, global_step, latest_filename or self._latest_filename,
+ meta_graph_suffix, write_meta_graph, write_state, strip_default_attrs)
+
+
+tf_export("data.experimental.Optional")(optional_ops.Optional)
+tf_export("data.experimental.get_next_as_optional")(
+ iterator_ops.get_next_as_optional)
diff --git a/tensorflow/contrib/data/python/ops/map_defun.py b/tensorflow/python/data/experimental/ops/map_defun.py
index 3d0d0993c9..3d0d0993c9 100644
--- a/tensorflow/contrib/data/python/ops/map_defun.py
+++ b/tensorflow/python/data/experimental/ops/map_defun.py
diff --git a/tensorflow/contrib/data/python/ops/optimization.py b/tensorflow/python/data/experimental/ops/optimization.py
index 3eb172acd5..276dde8383 100644
--- a/tensorflow/contrib/data/python/ops/optimization.py
+++ b/tensorflow/python/data/experimental/ops/optimization.py
@@ -17,12 +17,10 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.data.python.ops import contrib_op_loader # pylint: disable=unused-import
-from tensorflow.contrib.data.python.ops import gen_dataset_ops as contrib_gen_dataset_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
-from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
# A constant that can be used to enable auto-tuning.
AUTOTUNE = -1
@@ -54,12 +52,12 @@ def model():
Returns:
A `Dataset` transformation function, which can be passed to
- @{tf.data.Dataset.apply}.
+ `tf.data.Dataset.apply`.
"""
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- return _ModelDataset(dataset)
+ return dataset_ops._ModelDataset(dataset) # pylint: disable=protected-access
return _apply_fn
@@ -79,7 +77,7 @@ def optimize(optimizations=None):
def _apply_fn(dataset):
"""Function from `Dataset` to `Dataset` that applies the transformation."""
- return _OptimizeDataset(dataset, optimizations)
+ return dataset_ops._OptimizeDataset(dataset, optimizations) # pylint: disable=protected-access
return _apply_fn
@@ -97,7 +95,7 @@ class _AssertNextDataset(dataset_ops.UnaryDataset):
transformations, dtype=dtypes.string, name="transformations")
def _as_variant_tensor(self):
- return contrib_gen_dataset_ops.assert_next_dataset(
+ return gen_experimental_dataset_ops.experimental_assert_next_dataset(
self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
self._transformations,
**dataset_ops.flat_structure(self))
@@ -114,59 +112,3 @@ class _AssertNextDataset(dataset_ops.UnaryDataset):
def output_types(self):
return self._input_dataset.output_types
-
-class _ModelDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and models performance."""
-
- def __init__(self, input_dataset):
- """See `optimize()` for details."""
- super(_ModelDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.model_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
-
-
-class _OptimizeDataset(dataset_ops.UnaryDataset):
- """A `Dataset` that acts as an identity, and applies optimizations."""
-
- def __init__(self, input_dataset, optimizations):
- """See `optimize()` for details."""
- super(_OptimizeDataset, self).__init__(input_dataset)
- self._input_dataset = input_dataset
- if optimizations is None:
- optimizations = []
- self._optimizations = ops.convert_to_tensor(
- optimizations, dtype=dtypes.string, name="optimizations")
-
- def _as_variant_tensor(self):
- return gen_dataset_ops.optimize_dataset(
- self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
- self._optimizations,
- **dataset_ops.flat_structure(self))
-
- @property
- def output_classes(self):
- return self._input_dataset.output_classes
-
- @property
- def output_shapes(self):
- return self._input_dataset.output_shapes
-
- @property
- def output_types(self):
- return self._input_dataset.output_types
diff --git a/tensorflow/python/data/experimental/ops/parsing_ops.py b/tensorflow/python/data/experimental/ops/parsing_ops.py
new file mode 100644
index 0000000000..6615b9022a
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/parsing_ops.py
@@ -0,0 +1,152 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental `dataset` API for parsing example."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import parsing_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+class _ParseExampleDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that parses `example` dataset into a `dict` dataset."""
+
+ def __init__(self, input_dataset, features, num_parallel_calls):
+ super(_ParseExampleDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ if not all(types == dtypes.string
+ for types in nest.flatten(input_dataset.output_types)):
+ raise TypeError("Input dataset should be a dataset of vectors of strings")
+ self._num_parallel_calls = num_parallel_calls
+ # pylint: disable=protected-access
+ self._features = parsing_ops._prepend_none_dimension(features)
+ # sparse_keys and dense_keys come back sorted here.
+ (sparse_keys, sparse_types, dense_keys, dense_types, dense_defaults,
+ dense_shapes) = parsing_ops._features_to_raw_params(
+ self._features, [
+ parsing_ops.VarLenFeature, parsing_ops.SparseFeature,
+ parsing_ops.FixedLenFeature, parsing_ops.FixedLenSequenceFeature
+ ])
+ # TODO(b/112859642): Pass sparse_index and sparse_values for SparseFeature.
+ (_, dense_defaults_vec, sparse_keys, sparse_types, dense_keys, dense_shapes,
+ dense_shape_as_shape) = parsing_ops._process_raw_parameters(
+ None, dense_defaults, sparse_keys, sparse_types, dense_keys,
+ dense_types, dense_shapes)
+ # pylint: enable=protected-access
+ self._sparse_keys = sparse_keys
+ self._sparse_types = sparse_types
+ self._dense_keys = dense_keys
+ self._dense_defaults = dense_defaults_vec
+ self._dense_shapes = dense_shapes
+ self._dense_types = dense_types
+ dense_output_shapes = [
+ self._input_dataset.output_shapes.concatenate(shape)
+ for shape in dense_shape_as_shape
+ ]
+ sparse_output_shapes = [
+ self._input_dataset.output_shapes.concatenate([None])
+ for _ in range(len(sparse_keys))
+ ]
+
+ self._output_shapes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ dense_output_shapes + sparse_output_shapes))
+ self._output_types = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ self._dense_types + self._sparse_types))
+ self._output_classes = dict(
+ zip(self._dense_keys + self._sparse_keys,
+ [ops.Tensor for _ in range(len(self._dense_defaults))] +
+ [sparse_tensor.SparseTensor for _ in range(len(self._sparse_keys))
+ ]))
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.parse_example_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._num_parallel_calls,
+ self._dense_defaults,
+ self._sparse_keys,
+ self._dense_keys,
+ self._sparse_types,
+ self._dense_shapes,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+# TODO(b/111553342): add arguments names and example names as well.
+@tf_export("data.experimental.parse_example_dataset")
+def parse_example_dataset(features, num_parallel_calls=1):
+ """A transformation that parses `Example` protos into a `dict` of tensors.
+
+ Parses a number of serialized `Example` protos given in `serialized`. We refer
+ to `serialized` as a batch with `batch_size` many entries of individual
+ `Example` protos.
+
+ This op parses serialized examples into a dictionary mapping keys to `Tensor`
+ and `SparseTensor` objects. `features` is a dict from keys to `VarLenFeature`,
+ `SparseFeature`, and `FixedLenFeature` objects. Each `VarLenFeature`
+ and `SparseFeature` is mapped to a `SparseTensor`, and each
+ `FixedLenFeature` is mapped to a `Tensor`. See `tf.parse_example` for more
+ details about feature dictionaries.
+
+ Args:
+ features: A `dict` mapping feature keys to `FixedLenFeature`,
+ `VarLenFeature`, and `SparseFeature` values.
+ num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
+ representing the number of parsing processes to call in parallel.
+
+ Returns:
+ A dataset transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+
+ Raises:
+ ValueError: if features argument is None.
+ """
+ if features is None:
+ raise ValueError("Missing: features was %s." % features)
+
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ out_dataset = _ParseExampleDataset(dataset, features, num_parallel_calls)
+ if any([
+ isinstance(feature, parsing_ops.SparseFeature)
+ for _, feature in features.items()
+ ]):
+ # pylint: disable=protected-access
+ # pylint: disable=g-long-lambda
+ out_dataset = out_dataset.map(
+ lambda x: parsing_ops._construct_sparse_tensors_for_sparse_features(
+ features, x), num_parallel_calls=num_parallel_calls)
+ return out_dataset
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/prefetching_ops.py b/tensorflow/python/data/experimental/ops/prefetching_ops.py
new file mode 100644
index 0000000000..48d7136f95
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/prefetching_ops.py
@@ -0,0 +1,531 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrapper for prefetching_ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import warnings
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.eager import context
+from tensorflow.python.framework import device as framework_device
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import function
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+def function_buffering_resource(string_arg,
+ target_device,
+ f,
+ buffer_size,
+ output_types,
+ container="",
+ shared_name=None,
+ name=None):
+ """Creates a FunctionBufferingResource.
+
+ A FunctionBufferingResource fills up a buffer by calling a function `f` on
+ `target_device`. `f` should take in only a single string argument as input.
+
+ Args:
+ string_arg: The single string argument to the function.
+ target_device: The device to run `f` on.
+ f: The function to be executed.
+ buffer_size: Size of the buffer to be populated.
+ output_types: The output types generated by the function.
+ container: (Optional) string. Defaults to "".
+ shared_name: (Optional) string.
+ name: (Optional) string to name the op.
+
+ Returns:
+ Handle to a FunctionBufferingResource.
+ """
+ if shared_name is None:
+ shared_name = ""
+ return ged_ops.experimental_function_buffering_resource(
+ string_arg=string_arg,
+ target_device=target_device,
+ shared_name=shared_name,
+ f=f,
+ buffer_size=buffer_size,
+ container=container,
+ name=name,
+ output_types=output_types)
+
+
+def function_buffering_resource_get_next(function_buffer_resource,
+ output_types,
+ name=None):
+ return ged_ops.experimental_function_buffering_resource_get_next(
+ function_buffer_resource=function_buffer_resource,
+ output_types=output_types,
+ name=name)
+
+
+def function_buffering_resource_reset(function_buffer_resource, name=None):
+ return ged_ops.experimental_function_buffering_resource_reset(
+ function_buffer_resource=function_buffer_resource, name=name)
+
+
+# pylint: disable=protected-access
+class _PrefetchToDeviceIterator(object):
+ """A replacement for `tf.data.Iterator` that prefetches to another device.
+
+ Args:
+ input_dataset: The input dataset
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ device: A fully specified device string where we want to prefetch to
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be
+ shared under the given name across multiple sessions that share the
+ same devices (e.g. when using a remote server).
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ one_shot,
+ device,
+ buffer_size,
+ shared_name=None):
+ self._input_dataset = input_dataset
+ self._get_next_call_count = 0
+ self._one_shot = one_shot
+ if shared_name is None:
+ shared_name = ""
+
+ if self._one_shot:
+ self._input_iterator = input_dataset.make_one_shot_iterator()
+ else:
+ self._input_iterator = iterator_ops.Iterator.from_structure(
+ self._input_dataset.output_types, self._input_dataset.output_shapes,
+ shared_name, self._input_dataset.output_classes)
+ input_iterator_handle = self._input_iterator.string_handle()
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, self._input_iterator.output_types,
+ self._input_iterator.output_shapes,
+ self._input_iterator.output_classes)
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ iterator_device = ged_ops.experimental_iterator_get_device(
+ self._input_iterator._iterator_resource)
+
+ with ops.device(device):
+ self._buffering_resource = function_buffering_resource(
+ f=_prefetch_fn,
+ target_device=iterator_device,
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ shared_name=shared_name,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes)))
+
+ if not self._one_shot:
+ reset_op = function_buffering_resource_reset(self._buffering_resource)
+ with ops.control_dependencies([reset_op]):
+ self._initializer = self._input_iterator.make_initializer(
+ self._input_dataset)
+
+ def get_next(self, name=None):
+ """See `tf.data.Iterator.get_next`."""
+ self._get_next_call_count += 1
+ if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD:
+ warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE)
+
+ flat_ret = ged_ops.experimental_function_buffering_resource_get_next(
+ self._buffering_resource,
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)),
+ name=name)
+
+ ret = sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self.output_types, flat_ret),
+ self.output_types, self.output_shapes, self.output_classes)
+
+ for tensor, shape in zip(
+ nest.flatten(ret), nest.flatten(self.output_shapes)):
+ if isinstance(tensor, ops.Tensor):
+ tensor.set_shape(shape)
+
+ return ret
+
+ @property
+ def initializer(self):
+ if self._one_shot:
+ raise NotImplementedError("Can't initialize a one_shot_iterator")
+ return self._initializer
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _PrefetchToDeviceEagerIterator(iterator_ops.EagerIterator):
+ """A replacement for `tf.data.Iterator` that prefetches to another device.
+
+ Args:
+ input_dataset: The input dataset
+ one_shot: If true, we make a one shot iterator that's already initialized.
+ device: A fully specified device string where we want to prefetch to
+ buffer_size: Size of the prefetching buffer.
+ shared_name: (Optional.) If non-empty, the returned iterator will be
+ shared under the given name across multiple sessions that share the
+ same devices (e.g. when using a remote server).
+
+ Returns:
+ An Iterator type object.
+ """
+
+ def __init__(self,
+ input_dataset,
+ device,
+ buffer_size):
+ with ops.device("/device:CPU:0"):
+ super(_PrefetchToDeviceEagerIterator, self).__init__(input_dataset)
+ input_iterator_handle = gen_dataset_ops.iterator_to_string_handle(
+ self._resource)
+
+ self._device = device
+
+ @function.Defun(dtypes.string)
+ def _prefetch_fn(handle):
+ """Prefetches one element from `input_iterator`."""
+ remote_iterator = iterator_ops.Iterator.from_string_handle(
+ handle, self.output_types, self.output_shapes, self.output_classes)
+ ret = remote_iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ _prefetch_fn.add_to_graph(None)
+
+ with ops.device(device):
+ self._buffering_resource = function_buffering_resource(
+ f=_prefetch_fn,
+ output_types=self._flat_output_types,
+ target_device=ged_ops.experimental_iterator_get_device(
+ self._resource),
+ string_arg=input_iterator_handle,
+ buffer_size=buffer_size,
+ shared_name=iterator_ops._generate_shared_name(
+ "function_buffer_resource"))
+
+ def _next_internal(self):
+ """Returns a nested structure of `tf.Tensor`s containing the next element.
+ """
+ # This runs in sync mode as iterators use an error status to communicate
+ # that there is no more data to iterate over.
+ # TODO(b/77291417): Fix
+ with context.execution_mode(context.SYNC):
+ with ops.device(self._device):
+ ret = ged_ops.experimental_function_buffering_resource_get_next(
+ function_buffer_resource=self._buffering_resource,
+ output_types=self._flat_output_types)
+ return sparse.deserialize_sparse_tensors(
+ nest.pack_sequence_as(self._output_types, ret), self._output_types,
+ self._output_shapes, self._output_classes)
+# pylint: enable=protected-access
+
+
+class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` whose iterator prefetches elements to another device."""
+
+ def __init__(self, input_dataset, device, buffer_size):
+ super(_PrefetchToDeviceDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._device = device
+ self._buffer_size = buffer_size if buffer_size is not None else 1
+
+ # The static analysis cannot tell that the eager iterator's superclass has
+ # a `next()` method.
+ # pylint: disable=non-iterator-returned
+ def __iter__(self):
+ """Creates an `Iterator` for enumerating the elements of this dataset.
+
+ The returned iterator implements the Python iterator protocol and therefore
+ can only be used in eager mode.
+
+ Returns:
+ An `Iterator` over the elements of this dataset.
+
+ Raises:
+ RuntimeError: If eager execution is enabled.
+ """
+ if context.executing_eagerly():
+ return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
+ self._buffer_size)
+ else:
+ raise RuntimeError("dataset.__iter__() is only supported when eager "
+ "execution is enabled.")
+ # pylint: enable=non-iterator-returned
+
+ def make_one_shot_iterator(self):
+ if context.executing_eagerly():
+ return _PrefetchToDeviceEagerIterator(self._input_dataset, self._device,
+ self._buffer_size)
+ else:
+ return _PrefetchToDeviceIterator(self._input_dataset, one_shot=True,
+ device=self._device,
+ buffer_size=self._buffer_size)
+
+ def make_initializable_iterator(self, shared_name=None):
+ return _PrefetchToDeviceIterator(
+ self._input_dataset,
+ one_shot=False,
+ device=self._device,
+ buffer_size=self._buffer_size,
+ shared_name=shared_name)
+
+ def _as_variant_tensor(self):
+ # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset
+ # transformation methods is called.
+ # TODO(mrry): Investigate support for chaining further transformations after
+ # the prefetch, including GPU support.
+ raise NotImplementedError("`prefetch_to_device()` must be the last "
+ "transformation in a dataset pipeline.")
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+@tf_export("data.experimental.prefetch_to_device")
+def prefetch_to_device(device, buffer_size=None):
+ """A transformation that prefetches dataset values to the given `device`.
+
+ NOTE: Although the transformation creates a `tf.data.Dataset`, the
+ transformation must be the final `Dataset` in the input pipeline.
+
+ Args:
+ device: A string. The name of a device to which elements will be prefetched.
+ buffer_size: (Optional.) The number of elements to buffer on `device`.
+ Defaults to an automatically chosen value.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ return _PrefetchToDeviceDataset(dataset, device, buffer_size)
+
+ return _apply_fn
+
+
+@tf_export("data.experimental.copy_to_device")
+def copy_to_device(target_device, source_device="/cpu:0"):
+ """A transformation that copies dataset elements to the given `target_device`.
+
+ Args:
+ target_device: The name of a device to which elements will be copied.
+ source_device: The original device on which `input_dataset` will be placed.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _CopyToDeviceDataset(
+ dataset, target_device=target_device, source_device=source_device)
+
+ return _apply_fn
+
+
+# TODO(rohanj): Use the _input_hostmem attr on the RemoteCall ops to indicate
+# all inputs to the Op are in host memory, thereby avoiding some unnecessary
+# Sends and Recvs.
+class _CopyToDeviceDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that copies elements to another device."""
+
+ def __init__(self, input_dataset, target_device, source_device="/cpu:0"):
+ """Constructs a _CopyToDeviceDataset.
+
+ Args:
+ input_dataset: `Dataset` to be copied
+ target_device: The name of the device to which elements would be copied.
+ source_device: Device where input_dataset would be placed.
+ """
+ super(_CopyToDeviceDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._target_device = target_device
+ spec = framework_device.DeviceSpec().from_string(self._target_device)
+ self._is_gpu_target = (spec.device_type == "GPU")
+ self._source_device_string = source_device
+ self._source_device = ops.convert_to_tensor(source_device)
+
+ self._flat_output_shapes = nest.flatten(
+ sparse.as_dense_shapes(self._input_dataset.output_shapes,
+ self._input_dataset.output_classes))
+ self._flat_output_types = nest.flatten(
+ sparse.as_dense_types(self._input_dataset.output_types,
+ self._input_dataset.output_classes))
+
+ @function.Defun()
+ def _init_func():
+ """Creates an iterator for the input dataset.
+
+ Returns:
+ A `string` tensor that encapsulates the iterator created.
+ """
+ # pylint: disable=protected-access
+ ds_variant = self._input_dataset._as_variant_tensor()
+ resource = gen_dataset_ops.anonymous_iterator(
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies(
+ [gen_dataset_ops.make_iterator(ds_variant, resource)]):
+ return gen_dataset_ops.iterator_to_string_handle(resource)
+
+ @function.Defun()
+ def _remote_init_func():
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=_init_func.captured_inputs,
+ Tout=[dtypes.string],
+ f=_init_func)
+
+ self._init_func = _remote_init_func
+ self._init_captured_args = _remote_init_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _next_func(string_handle):
+ """Calls get_next for created iterator.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ The elements generated from `input_dataset`
+ """
+ with ops.device(self._source_device_string):
+ iterator = iterator_ops.Iterator.from_string_handle(
+ string_handle, self.output_types, self.output_shapes,
+ self.output_classes)
+ ret = iterator.get_next()
+ return nest.flatten(sparse.serialize_sparse_tensors(ret))
+
+ @function.Defun(dtypes.string)
+ def _remote_next_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _next_func.captured_inputs,
+ Tout=self._flat_output_types,
+ f=_next_func)
+
+ self._next_func = _remote_next_func
+ self._next_captured_args = _remote_next_func.captured_inputs
+
+ @function.Defun(dtypes.string)
+ def _finalize_func(string_handle):
+ """Destroys the iterator resource created.
+
+ Args:
+ string_handle: An iterator string handle created by _init_func
+ Returns:
+ Tensor constant 0
+ """
+ iterator_resource = gen_dataset_ops.iterator_from_string_handle_v2(
+ string_handle,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+ with ops.control_dependencies([
+ resource_variable_ops.destroy_resource_op(
+ iterator_resource, ignore_lookup_error=True)]):
+ return array_ops.constant(0, dtypes.int64)
+
+ @function.Defun(dtypes.string)
+ def _remote_finalize_func(string_handle):
+ return functional_ops.remote_call(
+ target=self._source_device,
+ args=[string_handle] + _finalize_func.captured_inputs,
+ Tout=[dtypes.int64],
+ f=_finalize_func)
+
+ self._finalize_func = _remote_finalize_func
+ self._finalize_captured_args = _remote_finalize_func.captured_inputs
+
+ g = ops.get_default_graph()
+ _remote_init_func.add_to_graph(g)
+ _remote_next_func.add_to_graph(g)
+ _remote_finalize_func.add_to_graph(g)
+ # pylint: enable=protected-scope
+
+ # The one_shot_iterator implementation needs a 0 arg _make_dataset function
+ # that thereby captures all the inputs required to create the dataset. Since
+ # there are strings that are inputs to the GeneratorDataset which can't be
+ # placed on a GPU, this fails for the GPU case. Therefore, disabling it for
+ # GPU
+ def make_one_shot_iterator(self):
+ if self._is_gpu_target:
+ raise ValueError("Cannot create a one shot iterator when using "
+ "`tf.data.experimental.copy_to_device()` on GPU. Please "
+ "use `Dataset.make_initializable_iterator()` instead.")
+ else:
+ return super(_CopyToDeviceDataset, self).make_one_shot_iterator()
+
+ def _as_variant_tensor(self):
+ with ops.device(self._target_device):
+ return gen_dataset_ops.generator_dataset(
+ self._init_captured_args,
+ self._next_captured_args,
+ self._finalize_captured_args,
+ init_func=self._init_func,
+ next_func=self._next_func,
+ finalize_func=self._finalize_func,
+ output_types=self._flat_output_types,
+ output_shapes=self._flat_output_shapes)
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
diff --git a/tensorflow/python/data/experimental/ops/random_ops.py b/tensorflow/python/data/experimental/ops/random_ops.py
new file mode 100644
index 0000000000..e3a2aeab31
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/random_ops.py
@@ -0,0 +1,54 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Datasets for random number generators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import random_seed
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.RandomDataset")
+class RandomDataset(dataset_ops.DatasetSource):
+ """A `Dataset` of pseudorandom values."""
+
+ def __init__(self, seed=None):
+ """A `Dataset` of pseudorandom values."""
+ super(RandomDataset, self).__init__()
+ self._seed, self._seed2 = random_seed.get_seed(seed)
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.random_dataset(
+ seed=self._seed,
+ seed2=self._seed2,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ @property
+ def output_types(self):
+ return dtypes.int64
diff --git a/tensorflow/python/data/experimental/ops/readers.py b/tensorflow/python/data/experimental/ops/readers.py
new file mode 100644
index 0000000000..3b2d094514
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/readers.py
@@ -0,0 +1,904 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for reader Datasets."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import csv
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import optimization
+from tensorflow.python.data.experimental.ops import parsing_ops
+from tensorflow.python.data.experimental.ops import shuffle_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers as core_readers
+from tensorflow.python.data.util import convert
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.lib.io import file_io
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.platform import gfile
+from tensorflow.python.util.tf_export import tf_export
+
+_ACCEPTABLE_CSV_TYPES = (dtypes.float32, dtypes.float64, dtypes.int32,
+ dtypes.int64, dtypes.string)
+
+
+def _is_valid_int32(str_val):
+ try:
+ # Checks equality to prevent int32 overflow
+ return dtypes.int32.as_numpy_dtype(str_val) == dtypes.int64.as_numpy_dtype(
+ str_val)
+ except (ValueError, OverflowError):
+ return False
+
+
+def _is_valid_int64(str_val):
+ try:
+ dtypes.int64.as_numpy_dtype(str_val)
+ return True
+ except (ValueError, OverflowError):
+ return False
+
+
+def _is_valid_float(str_val, float_dtype):
+ try:
+ return float_dtype.as_numpy_dtype(str_val) < np.inf
+ except ValueError:
+ return False
+
+
+def _infer_type(str_val, na_value, prev_type):
+ """Given a string, infers its tensor type.
+
+ Infers the type of a value by picking the least 'permissive' type possible,
+ while still allowing the previous type inference for this column to be valid.
+
+ Args:
+ str_val: String value to infer the type of.
+ na_value: Additional string to recognize as a NA/NaN CSV value.
+ prev_type: Type previously inferred based on values of this column that
+ we've seen up till now.
+ Returns:
+ Inferred dtype.
+ """
+ if str_val in ("", na_value):
+ # If the field is null, it gives no extra information about its type
+ return prev_type
+
+ type_list = [
+ dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64, dtypes.string
+ ] # list of types to try, ordered from least permissive to most
+
+ type_functions = [
+ _is_valid_int32,
+ _is_valid_int64,
+ lambda str_val: _is_valid_float(str_val, dtypes.float32),
+ lambda str_val: _is_valid_float(str_val, dtypes.float64),
+ lambda str_val: True,
+ ] # Corresponding list of validation functions
+
+ for i in range(len(type_list)):
+ validation_fn = type_functions[i]
+ if validation_fn(str_val) and (prev_type is None or
+ prev_type in type_list[:i + 1]):
+ return type_list[i]
+
+
+def _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header):
+ """Generator that yields rows of CSV file(s) in order."""
+ for fn in filenames:
+ with file_io.FileIO(fn, "r") as f:
+ rdr = csv.reader(
+ f,
+ delimiter=field_delim,
+ quoting=csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE)
+ if header:
+ next(rdr) # Skip header lines
+
+ for csv_row in rdr:
+ if len(csv_row) != num_cols:
+ raise ValueError(
+ "Problem inferring types: CSV row has different number of fields "
+ "than expected.")
+ yield csv_row
+
+
+def _infer_column_defaults(filenames, num_cols, field_delim, use_quote_delim,
+ na_value, header, num_rows_for_inference,
+ select_columns):
+ """Infers column types from the first N valid CSV records of files."""
+ if select_columns is None:
+ select_columns = range(num_cols)
+ inferred_types = [None] * len(select_columns)
+
+ for i, csv_row in enumerate(
+ _next_csv_row(filenames, num_cols, field_delim, use_quote_delim, header)):
+ if num_rows_for_inference is not None and i >= num_rows_for_inference:
+ break
+
+ for j, col_index in enumerate(select_columns):
+ inferred_types[j] = _infer_type(csv_row[col_index], na_value,
+ inferred_types[j])
+
+ # Replace None's with a default type
+ inferred_types = [t or dtypes.string for t in inferred_types]
+ # Default to 0 or '' for null values
+ return [
+ constant_op.constant([0 if t is not dtypes.string else ""], dtype=t)
+ for t in inferred_types
+ ]
+
+
+def _infer_column_names(filenames, field_delim, use_quote_delim):
+ """Infers column names from first rows of files."""
+ csv_kwargs = {
+ "delimiter": field_delim,
+ "quoting": csv.QUOTE_MINIMAL if use_quote_delim else csv.QUOTE_NONE
+ }
+ with file_io.FileIO(filenames[0], "r") as f:
+ try:
+ column_names = next(csv.reader(f, **csv_kwargs))
+ except StopIteration:
+ raise ValueError(("Received StopIteration when reading the header line "
+ "of %s. Empty file?") % filenames[0])
+
+ for name in filenames[1:]:
+ with file_io.FileIO(name, "r") as f:
+ try:
+ if next(csv.reader(f, **csv_kwargs)) != column_names:
+ raise ValueError(
+ "Files have different column names in the header row.")
+ except StopIteration:
+ raise ValueError(("Received StopIteration when reading the header line "
+ "of %s. Empty file?") % filenames[0])
+ return column_names
+
+
+def _get_sorted_col_indices(select_columns, column_names):
+ """Transforms select_columns argument into sorted column indices."""
+ names_to_indices = {n: i for i, n in enumerate(column_names)}
+ num_cols = len(column_names)
+ for i, v in enumerate(select_columns):
+ if isinstance(v, int):
+ if v < 0 or v >= num_cols:
+ raise ValueError(
+ "Column index %d specified in select_columns out of valid range." %
+ v)
+ continue
+ if v not in names_to_indices:
+ raise ValueError(
+ "Value '%s' specified in select_columns not a valid column index or "
+ "name." % v)
+ select_columns[i] = names_to_indices[v]
+
+ # Sort and ensure there are no duplicates
+ result = sorted(set(select_columns))
+ if len(result) != len(select_columns):
+ raise ValueError("select_columns contains duplicate columns")
+ return result
+
+
+def _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed):
+ """Optionally shuffle and repeat dataset, as requested."""
+ if num_epochs != 1 and shuffle:
+ # Use shuffle_and_repeat for perf
+ return dataset.apply(
+ shuffle_ops.shuffle_and_repeat(shuffle_buffer_size, num_epochs,
+ shuffle_seed))
+ elif shuffle:
+ return dataset.shuffle(shuffle_buffer_size, shuffle_seed)
+ elif num_epochs != 1:
+ return dataset.repeat(num_epochs)
+ return dataset
+
+
+def make_tf_record_dataset(file_pattern,
+ batch_size,
+ parser_fn=None,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=None,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ num_parallel_reads=None,
+ num_parallel_parser_calls=None,
+ drop_final_batch=False):
+ """Reads and optionally parses TFRecord files into a dataset.
+
+ Provides common functionality such as batching, optional parsing, shuffling,
+ and performant defaults.
+
+ Args:
+ file_pattern: List of files or patterns of TFRecord file paths.
+ See `tf.gfile.Glob` for pattern rules.
+ batch_size: An int representing the number of records to combine
+ in a single batch.
+ parser_fn: (Optional.) A function accepting string input to parse
+ and process the record contents. This function must map records
+ to components of a fixed shape, so they may be batched. By
+ default, uses the record contents unmodified.
+ num_epochs: (Optional.) An int specifying the number of times this
+ dataset is repeated. If None (the default), cycles through the
+ dataset forever.
+ shuffle: (Optional.) A bool that indicates whether the input
+ should be shuffled. Defaults to `True`.
+ shuffle_buffer_size: (Optional.) Buffer size to use for
+ shuffling. A large buffer size ensures better shuffling, but
+ increases memory usage and startup time.
+ shuffle_seed: (Optional.) Randomization seed to use for shuffling.
+ prefetch_buffer_size: (Optional.) An int specifying the number of
+ feature batches to prefetch for performance improvement.
+ Defaults to auto-tune. Set to 0 to disable prefetching.
+ num_parallel_reads: (Optional.) Number of threads used to read
+ records from files. By default or if set to a value >1, the
+ results will be interleaved.
+ num_parallel_parser_calls: (Optional.) Number of parallel
+ records to parse in parallel. Defaults to an automatic selection.
+ drop_final_batch: (Optional.) Whether the last batch should be
+ dropped in case its size is smaller than `batch_size`; the
+ default behavior is not to drop the smaller batch.
+
+ Returns:
+ A dataset, where each element matches the output of `parser_fn`
+ except it will have an additional leading `batch-size` dimension,
+ or a `batch_size`-length 1-D tensor of strings if `parser_fn` is
+ unspecified.
+ """
+ files = dataset_ops.Dataset.list_files(
+ file_pattern, shuffle=shuffle, seed=shuffle_seed)
+
+ if num_parallel_reads is None:
+ # Note: We considered auto-tuning this value, but there is a concern
+ # that this affects the mixing of records from different files, which
+ # could affect training convergence/accuracy, so we are defaulting to
+ # a constant for now.
+ num_parallel_reads = 24
+ dataset = core_readers.TFRecordDataset(
+ files, num_parallel_reads=num_parallel_reads)
+
+ if shuffle_buffer_size is None:
+ # TODO(josh11b): Auto-tune this value when not specified
+ shuffle_buffer_size = 10000
+ dataset = _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+
+ # NOTE(mrry): We set `drop_final_batch=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ drop_final_batch = drop_final_batch or num_epochs is None
+
+ if parser_fn is None:
+ dataset = dataset.batch(batch_size, drop_remainder=drop_final_batch)
+ else:
+ # TODO(josh11b): if num_parallel_parser_calls is None, use some function
+ # of num cores instead of map_and_batch's default behavior of one batch.
+ dataset = dataset.apply(batching.map_and_batch(
+ parser_fn, batch_size, num_parallel_calls=num_parallel_parser_calls,
+ drop_remainder=drop_final_batch))
+
+ if prefetch_buffer_size == 0:
+ return dataset
+ else:
+ return dataset.prefetch(buffer_size=prefetch_buffer_size)
+
+
+@tf_export("data.experimental.make_csv_dataset")
+def make_csv_dataset(
+ file_pattern,
+ batch_size,
+ column_names=None,
+ column_defaults=None,
+ label_name=None,
+ select_columns=None,
+ field_delim=",",
+ use_quote_delim=True,
+ na_value="",
+ header=True,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=10000,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ num_parallel_reads=1,
+ sloppy=False,
+ num_rows_for_inference=100,
+ compression_type=None,
+):
+ """Reads CSV files into a dataset.
+
+ Reads CSV files into a dataset, where each element is a (features, labels)
+ tuple that corresponds to a batch of CSV rows. The features dictionary
+ maps feature column names to `Tensor`s containing the corresponding
+ feature data, and labels is a `Tensor` containing the batch's label data.
+
+ Args:
+ file_pattern: List of files or patterns of file paths containing CSV
+ records. See `tf.gfile.Glob` for pattern rules.
+ batch_size: An int representing the number of records to combine
+ in a single batch.
+ column_names: An optional list of strings that corresponds to the CSV
+ columns, in order. One per column of the input record. If this is not
+ provided, infers the column names from the first row of the records.
+ These names will be the keys of the features dict of each dataset element.
+ column_defaults: A optional list of default values for the CSV fields. One
+ item per selected column of the input record. Each item in the list is
+ either a valid CSV dtype (float32, float64, int32, int64, or string), or a
+ `Tensor` with one of the aforementioned types. The tensor can either be
+ a scalar default value (if the column is optional), or an empty tensor (if
+ the column is required). If a dtype is provided instead of a tensor, the
+ column is also treated as required. If this list is not provided, tries
+ to infer types based on reading the first num_rows_for_inference rows of
+ files specified, and assumes all columns are optional, defaulting to `0`
+ for numeric values and `""` for string values. If both this and
+ `select_columns` are specified, these must have the same lengths, and
+ `column_defaults` is assumed to be sorted in order of increasing column
+ index.
+ label_name: A optional string corresponding to the label column. If
+ provided, the data for this column is returned as a separate `Tensor` from
+ the features dictionary, so that the dataset complies with the format
+ expected by a `tf.Estimator.train` or `tf.Estimator.evaluate` input
+ function.
+ select_columns: An optional list of integer indices or string column
+ names, that specifies a subset of columns of CSV data to select. If
+ column names are provided, these must correspond to names provided in
+ `column_names` or inferred from the file header lines. When this argument
+ is specified, only a subset of CSV columns will be parsed and returned,
+ corresponding to the columns specified. Using this results in faster
+ parsing and lower memory usage. If both this and `column_defaults` are
+ specified, these must have the same lengths, and `column_defaults` is
+ assumed to be sorted in order of increasing column index.
+ field_delim: An optional `string`. Defaults to `","`. Char delimiter to
+ separate fields in a record.
+ use_quote_delim: An optional bool. Defaults to `True`. If false, treats
+ double quotation marks as regular characters inside of the string fields.
+ na_value: Additional string to recognize as NA/NaN.
+ header: A bool that indicates whether the first rows of provided CSV files
+ correspond to header lines with column names, and should not be included
+ in the data.
+ num_epochs: An int specifying the number of times this dataset is repeated.
+ If None, cycles through the dataset forever.
+ shuffle: A bool that indicates whether the input should be shuffled.
+ shuffle_buffer_size: Buffer size to use for shuffling. A large buffer size
+ ensures better shuffling, but increases memory usage and startup time.
+ shuffle_seed: Randomization seed to use for shuffling.
+ prefetch_buffer_size: An int specifying the number of feature
+ batches to prefetch for performance improvement. Recommended value is the
+ number of batches consumed per training step. Defaults to auto-tune.
+
+ num_parallel_reads: Number of threads used to read CSV records from files.
+ If >1, the results will be interleaved.
+ sloppy: If `True`, reading performance will be improved at
+ the cost of non-deterministic ordering. If `False`, the order of elements
+ produced is deterministic prior to shuffling (elements are still
+ randomized if `shuffle=True`. Note that if the seed is set, then order
+ of elements after shuffling is deterministic). Defaults to `False`.
+ num_rows_for_inference: Number of rows of a file to use for type inference
+ if record_defaults is not provided. If None, reads all the rows of all
+ the files. Defaults to 100.
+ compression_type: (Optional.) A `tf.string` scalar evaluating to one of
+ `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no compression.
+
+ Returns:
+ A dataset, where each element is a (features, labels) tuple that corresponds
+ to a batch of `batch_size` CSV rows. The features dictionary maps feature
+ column names to `Tensor`s containing the corresponding column data, and
+ labels is a `Tensor` containing the column data for the label column
+ specified by `label_name`.
+
+ Raises:
+ ValueError: If any of the arguments is malformed.
+ """
+ # Create dataset of all matching filenames
+ filenames = _get_file_names(file_pattern, False)
+ dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
+ if shuffle:
+ dataset = dataset.shuffle(len(filenames), shuffle_seed)
+
+ # Clean arguments; figure out column names and defaults
+
+ if column_names is None:
+ if not header:
+ raise ValueError("Cannot infer column names without a header line.")
+ # If column names are not provided, infer from the header lines
+ column_names = _infer_column_names(filenames, field_delim, use_quote_delim)
+ if len(column_names) != len(set(column_names)):
+ raise ValueError("Cannot have duplicate column names.")
+
+ if select_columns is not None:
+ select_columns = _get_sorted_col_indices(select_columns, column_names)
+
+ if column_defaults is not None:
+ column_defaults = [
+ constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
+ for x in column_defaults
+ ]
+ else:
+ # If column defaults are not provided, infer from records at graph
+ # construction time
+ column_defaults = _infer_column_defaults(
+ filenames, len(column_names), field_delim, use_quote_delim, na_value,
+ header, num_rows_for_inference, select_columns)
+
+ if select_columns is not None and len(column_defaults) != len(select_columns):
+ raise ValueError(
+ "If specified, column_defaults and select_columns must have same "
+ "length."
+ )
+ if select_columns is not None and len(column_names) > len(select_columns):
+ # Pick the relevant subset of column names
+ column_names = [column_names[i] for i in select_columns]
+
+ if label_name is not None and label_name not in column_names:
+ raise ValueError("`label_name` provided must be one of the columns.")
+
+ def filename_to_dataset(filename):
+ return CsvDataset(
+ filename,
+ record_defaults=column_defaults,
+ field_delim=field_delim,
+ use_quote_delim=use_quote_delim,
+ na_value=na_value,
+ select_cols=select_columns,
+ header=header,
+ compression_type=compression_type,
+ )
+
+ def map_fn(*columns):
+ """Organizes columns into a features dictionary.
+
+ Args:
+ *columns: list of `Tensor`s corresponding to one csv record.
+ Returns:
+ An OrderedDict of feature names to values for that particular record. If
+ label_name is provided, extracts the label feature to be returned as the
+ second element of the tuple.
+ """
+ features = collections.OrderedDict(zip(column_names, columns))
+ if label_name is not None:
+ label = features.pop(label_name)
+ return features, label
+ return features
+
+ # Read files sequentially (if num_parallel_reads=1) or in parallel
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ filename_to_dataset, cycle_length=num_parallel_reads, sloppy=sloppy))
+
+ dataset = _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+
+ # Apply batch before map for perf, because map has high overhead relative
+ # to the size of the computation in each map.
+ # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ dataset = dataset.batch(batch_size=batch_size,
+ drop_remainder=num_epochs is None)
+ dataset = dataset_ops.MapDataset(
+ dataset, map_fn, use_inter_op_parallelism=False)
+ dataset = dataset.prefetch(prefetch_buffer_size)
+
+ return dataset
+
+
+_DEFAULT_READER_BUFFER_SIZE_BYTES = 4 * 1024 * 1024 # 4 MB
+
+
+@tf_export("data.experimental.CsvDataset")
+class CsvDataset(dataset_ops.DatasetSource):
+ """A Dataset comprising lines from one or more CSV files."""
+
+ def __init__(self,
+ filenames,
+ record_defaults,
+ compression_type=None,
+ buffer_size=None,
+ header=False,
+ field_delim=",",
+ use_quote_delim=True,
+ na_value="",
+ select_cols=None):
+ """Creates a `CsvDataset` by reading and decoding CSV files.
+
+ The elements of this dataset correspond to records from the file(s).
+ RFC 4180 format is expected for CSV files
+ (https://tools.ietf.org/html/rfc4180)
+ Note that we allow leading and trailing spaces with int or float field.
+
+
+ For example, suppose we have a file 'my_file0.csv' with four CSV columns of
+ different data types:
+ ```
+ abcdefg,4.28E10,5.55E6,12
+ hijklmn,-5.3E14,,2
+ ```
+
+ We can construct a CsvDataset from it as follows:
+ ```python
+ dataset = tf.data.experimental.CsvDataset(
+ "my_file*.csv",
+ [tf.float32, # Required field, use dtype or empty tensor
+ tf.constant([0.0], dtype=tf.float32), # Optional field, default to 0.0
+ tf.int32, # Required field, use dtype or empty tensor
+ ],
+ select_cols=[1,2,3] # Only parse last three columns
+ )
+ ```
+
+ The expected output of its iterations is:
+ ```python
+ next_element = dataset.make_one_shot_iterator().get_next()
+ with tf.Session() as sess:
+ while True:
+ try:
+ print(sess.run(next_element))
+ except tf.errors.OutOfRangeError:
+ break
+
+ >> (4.28e10, 5.55e6, 12)
+ >> (-5.3e14, 0.0, 2)
+ ```
+
+ Args:
+ filenames: A `tf.string` tensor containing one or more filenames.
+ record_defaults: A list of default values for the CSV fields. Each item in
+ the list is either a valid CSV `DType` (float32, float64, int32, int64,
+ string), or a `Tensor` object with one of the above types. One per
+ column of CSV data, with either a scalar `Tensor` default value for the
+ column if it is optional, or `DType` or empty `Tensor` if required. If
+ both this and `select_columns` are specified, these must have the same
+ lengths, and `column_defaults` is assumed to be sorted in order of
+ increasing column index.
+ compression_type: (Optional.) A `tf.string` scalar evaluating to one of
+ `""` (no compression), `"ZLIB"`, or `"GZIP"`. Defaults to no
+ compression.
+ buffer_size: (Optional.) A `tf.int64` scalar denoting the number of bytes
+ to buffer while reading files. Defaults to 4MB.
+ header: (Optional.) A `tf.bool` scalar indicating whether the CSV file(s)
+ have header line(s) that should be skipped when parsing. Defaults to
+ `False`.
+ field_delim: (Optional.) A `tf.string` scalar containing the delimiter
+ character that separates fields in a record. Defaults to `","`.
+ use_quote_delim: (Optional.) A `tf.bool` scalar. If `False`, treats
+ double quotation marks as regular characters inside of string fields
+ (ignoring RFC 4180, Section 2, Bullet 5). Defaults to `True`.
+ na_value: (Optional.) A `tf.string` scalar indicating a value that will
+ be treated as NA/NaN.
+ select_cols: (Optional.) A sorted list of column indices to select from
+ the input data. If specified, only this subset of columns will be
+ parsed. Defaults to parsing all columns.
+ """
+ super(CsvDataset, self).__init__()
+ self._filenames = ops.convert_to_tensor(
+ filenames, dtype=dtypes.string, name="filenames")
+ self._compression_type = convert.optional_param_to_tensor(
+ "compression_type",
+ compression_type,
+ argument_default="",
+ argument_dtype=dtypes.string)
+ record_defaults = [
+ constant_op.constant([], dtype=x) if x in _ACCEPTABLE_CSV_TYPES else x
+ for x in record_defaults
+ ]
+ self._record_defaults = ops.convert_n_to_tensor(
+ record_defaults, name="record_defaults")
+ self._buffer_size = convert.optional_param_to_tensor(
+ "buffer_size", buffer_size, _DEFAULT_READER_BUFFER_SIZE_BYTES)
+ self._header = ops.convert_to_tensor(
+ header, dtype=dtypes.bool, name="header")
+ self._field_delim = ops.convert_to_tensor(
+ field_delim, dtype=dtypes.string, name="field_delim")
+ self._use_quote_delim = ops.convert_to_tensor(
+ use_quote_delim, dtype=dtypes.bool, name="use_quote_delim")
+ self._na_value = ops.convert_to_tensor(
+ na_value, dtype=dtypes.string, name="na_value")
+ self._select_cols = convert.optional_param_to_tensor(
+ "select_cols",
+ select_cols,
+ argument_default=[],
+ argument_dtype=dtypes.int64,
+ )
+ self._output_shapes = tuple(
+ tensor_shape.scalar() for _ in range(len(record_defaults)))
+ self._output_types = tuple(d.dtype for d in self._record_defaults)
+ self._output_classes = tuple(
+ ops.Tensor for _ in range(len(record_defaults)))
+
+ def _as_variant_tensor(self):
+ # Constructs graph node for the dataset op.
+ return gen_experimental_dataset_ops.experimental_csv_dataset(
+ filenames=self._filenames,
+ record_defaults=self._record_defaults,
+ buffer_size=self._buffer_size,
+ header=self._header,
+ output_shapes=self._output_shapes,
+ field_delim=self._field_delim,
+ use_quote_delim=self._use_quote_delim,
+ na_value=self._na_value,
+ select_cols=self._select_cols,
+ compression_type=self._compression_type,
+ )
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+
+@tf_export("data.experimental.make_batched_features_dataset")
+def make_batched_features_dataset(file_pattern,
+ batch_size,
+ features,
+ reader=core_readers.TFRecordDataset,
+ label_key=None,
+ reader_args=None,
+ num_epochs=None,
+ shuffle=True,
+ shuffle_buffer_size=10000,
+ shuffle_seed=None,
+ prefetch_buffer_size=optimization.AUTOTUNE,
+ reader_num_threads=1,
+ parser_num_threads=2,
+ sloppy_ordering=False,
+ drop_final_batch=False):
+ """Returns a `Dataset` of feature dictionaries from `Example` protos.
+
+ If label_key argument is provided, returns a `Dataset` of tuple
+ comprising of feature dictionaries and label.
+
+ Example:
+
+ ```
+ serialized_examples = [
+ features {
+ feature { key: "age" value { int64_list { value: [ 0 ] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
+ feature { key: "kws" value { bytes_list { value: [ "code", "art" ] } } }
+ },
+ features {
+ feature { key: "age" value { int64_list { value: [] } } }
+ feature { key: "gender" value { bytes_list { value: [ "f" ] } } }
+ feature { key: "kws" value { bytes_list { value: [ "sports" ] } } }
+ }
+ ]
+ ```
+
+ We can use arguments:
+
+ ```
+ features: {
+ "age": FixedLenFeature([], dtype=tf.int64, default_value=-1),
+ "gender": FixedLenFeature([], dtype=tf.string),
+ "kws": VarLenFeature(dtype=tf.string),
+ }
+ ```
+
+ And the expected output is:
+
+ ```python
+ {
+ "age": [[0], [-1]],
+ "gender": [["f"], ["f"]],
+ "kws": SparseTensor(
+ indices=[[0, 0], [0, 1], [1, 0]],
+ values=["code", "art", "sports"]
+ dense_shape=[2, 2]),
+ }
+ ```
+
+ Args:
+ file_pattern: List of files or patterns of file paths containing
+ `Example` records. See `tf.gfile.Glob` for pattern rules.
+ batch_size: An int representing the number of records to combine
+ in a single batch.
+ features: A `dict` mapping feature keys to `FixedLenFeature` or
+ `VarLenFeature` values. See `tf.parse_example`.
+ reader: A function or class that can be
+ called with a `filenames` tensor and (optional) `reader_args` and returns
+ a `Dataset` of `Example` tensors. Defaults to `tf.data.TFRecordDataset`.
+ label_key: (Optional) A string corresponding to the key labels are stored in
+ `tf.Examples`. If provided, it must be one of the `features` key,
+ otherwise results in `ValueError`.
+ reader_args: Additional arguments to pass to the reader class.
+ num_epochs: Integer specifying the number of times to read through the
+ dataset. If None, cycles through the dataset forever. Defaults to `None`.
+ shuffle: A boolean, indicates whether the input should be shuffled. Defaults
+ to `True`.
+ shuffle_buffer_size: Buffer size of the ShuffleDataset. A large capacity
+ ensures better shuffling but would increase memory usage and startup time.
+ shuffle_seed: Randomization seed to use for shuffling.
+ prefetch_buffer_size: Number of feature batches to prefetch in order to
+ improve performance. Recommended value is the number of batches consumed
+ per training step. Defaults to auto-tune.
+ reader_num_threads: Number of threads used to read `Example` records. If >1,
+ the results will be interleaved.
+ parser_num_threads: Number of threads to use for parsing `Example` tensors
+ into a dictionary of `Feature` tensors.
+ sloppy_ordering: If `True`, reading performance will be improved at
+ the cost of non-deterministic ordering. If `False`, the order of elements
+ produced is deterministic prior to shuffling (elements are still
+ randomized if `shuffle=True`. Note that if the seed is set, then order
+ of elements after shuffling is deterministic). Defaults to `False`.
+ drop_final_batch: If `True`, and the batch size does not evenly divide the
+ input dataset size, the final smaller batch will be dropped. Defaults to
+ `False`.
+
+ Returns:
+ A dataset of `dict` elements, (or a tuple of `dict` elements and label).
+ Each `dict` maps feature keys to `Tensor` or `SparseTensor` objects.
+
+ Raises:
+ ValueError: If `label_key` is not one of the `features` keys.
+ """
+ # Create dataset of all matching filenames
+ filenames = _get_file_names(file_pattern, False)
+ dataset = dataset_ops.Dataset.from_tensor_slices(filenames)
+ if shuffle:
+ dataset = dataset.shuffle(len(filenames), shuffle_seed)
+
+ # Read `Example` records from files as tensor objects.
+ if reader_args is None:
+ reader_args = []
+
+ # Read files sequentially (if reader_num_threads=1) or in parallel
+ dataset = dataset.apply(
+ interleave_ops.parallel_interleave(
+ lambda filename: reader(filename, *reader_args),
+ cycle_length=reader_num_threads,
+ sloppy=sloppy_ordering))
+
+ # Extract values if the `Example` tensors are stored as key-value tuples.
+ if dataset.output_types == (dtypes.string, dtypes.string):
+ dataset = dataset_ops.MapDataset(
+ dataset, lambda _, v: v, use_inter_op_parallelism=False)
+
+ # Apply dataset repeat and shuffle transformations.
+ dataset = _maybe_shuffle_and_repeat(
+ dataset, num_epochs, shuffle, shuffle_buffer_size, shuffle_seed)
+
+ # NOTE(mrry): We set `drop_remainder=True` when `num_epochs is None` to
+ # improve the shape inference, because it makes the batch dimension static.
+ # It is safe to do this because in that case we are repeating the input
+ # indefinitely, and all batches will be full-sized.
+ dataset = dataset.batch(
+ batch_size, drop_remainder=drop_final_batch or num_epochs is None)
+
+ # Parse `Example` tensors to a dictionary of `Feature` tensors.
+ dataset = dataset.apply(
+ parsing_ops.parse_example_dataset(
+ features, num_parallel_calls=parser_num_threads))
+
+ if label_key:
+ if label_key not in features:
+ raise ValueError(
+ "The `label_key` provided (%r) must be one of the `features` keys." %
+ label_key)
+ dataset = dataset.map(lambda x: (x, x.pop(label_key)))
+
+ dataset = dataset.prefetch(prefetch_buffer_size)
+ return dataset
+
+
+def _get_file_names(file_pattern, shuffle):
+ """Parse list of file names from pattern, optionally shuffled.
+
+ Args:
+ file_pattern: File glob pattern, or list of glob patterns.
+ shuffle: Whether to shuffle the order of file names.
+
+ Returns:
+ List of file names matching `file_pattern`.
+
+ Raises:
+ ValueError: If `file_pattern` is empty, or pattern matches no files.
+ """
+ if isinstance(file_pattern, list):
+ if not file_pattern:
+ raise ValueError("File pattern is empty.")
+ file_names = []
+ for entry in file_pattern:
+ file_names.extend(gfile.Glob(entry))
+ else:
+ file_names = list(gfile.Glob(file_pattern))
+
+ if not file_names:
+ raise ValueError("No files match %s." % file_pattern)
+
+ # Sort files so it will be deterministic for unit tests.
+ if not shuffle:
+ file_names = sorted(file_names)
+ return file_names
+
+
+@tf_export("data.experimental.SqlDataset")
+class SqlDataset(dataset_ops.DatasetSource):
+ """A `Dataset` consisting of the results from a SQL query."""
+
+ def __init__(self, driver_name, data_source_name, query, output_types):
+ """Creates a `SqlDataset`.
+
+ `SqlDataset` allows a user to read data from the result set of a SQL query.
+ For example:
+
+ ```python
+ dataset = tf.data.experimental.SqlDataset("sqlite", "/foo/bar.sqlite3",
+ "SELECT name, age FROM people",
+ (tf.string, tf.int32))
+ iterator = dataset.make_one_shot_iterator()
+ next_element = iterator.get_next()
+ # Prints the rows of the result set of the above query.
+ while True:
+ try:
+ print(sess.run(next_element))
+ except tf.errors.OutOfRangeError:
+ break
+ ```
+
+ Args:
+ driver_name: A 0-D `tf.string` tensor containing the database type.
+ Currently, the only supported value is 'sqlite'.
+ data_source_name: A 0-D `tf.string` tensor containing a connection string
+ to connect to the database.
+ query: A 0-D `tf.string` tensor containing the SQL query to execute.
+ output_types: A tuple of `tf.DType` objects representing the types of the
+ columns returned by `query`.
+ """
+ super(SqlDataset, self).__init__()
+ self._driver_name = ops.convert_to_tensor(
+ driver_name, dtype=dtypes.string, name="driver_name")
+ self._data_source_name = ops.convert_to_tensor(
+ data_source_name, dtype=dtypes.string, name="data_source_name")
+ self._query = ops.convert_to_tensor(
+ query, dtype=dtypes.string, name="query")
+ self._output_types = output_types
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.sql_dataset(self._driver_name,
+ self._data_source_name, self._query,
+ nest.flatten(self.output_types),
+ nest.flatten(self.output_shapes))
+
+ @property
+ def output_classes(self):
+ return nest.map_structure(lambda _: ops.Tensor, self._output_types)
+
+ @property
+ def output_shapes(self):
+ return nest.map_structure(lambda _: tensor_shape.TensorShape([]),
+ self._output_types)
+
+ @property
+ def output_types(self):
+ return self._output_types
diff --git a/tensorflow/python/data/experimental/ops/resampling.py b/tensorflow/python/data/experimental/ops/resampling.py
new file mode 100644
index 0000000000..3a3040ae9a
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/resampling.py
@@ -0,0 +1,296 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Resampling dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.data.experimental.ops import batching
+from tensorflow.python.data.experimental.ops import interleave_ops
+from tensorflow.python.data.experimental.ops import scan_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_util
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import logging_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.rejection_resample")
+def rejection_resample(class_func, target_dist, initial_dist=None, seed=None):
+ """A transformation that resamples a dataset to achieve a target distribution.
+
+ **NOTE** Resampling is performed via rejection sampling; some fraction
+ of the input values will be dropped.
+
+ Args:
+ class_func: A function mapping an element of the input dataset to a scalar
+ `tf.int32` tensor. Values should be in `[0, num_classes)`.
+ target_dist: A floating point type tensor, shaped `[num_classes]`.
+ initial_dist: (Optional.) A floating point type tensor, shaped
+ `[num_classes]`. If not provided, the true class distribution is
+ estimated live in a streaming fashion.
+ seed: (Optional.) Python integer seed for the resampler.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ """Function from `Dataset` to `Dataset` that applies the transformation."""
+ target_dist_t = ops.convert_to_tensor(target_dist, name="target_dist")
+ class_values_ds = dataset.map(class_func)
+
+ # Get initial distribution.
+ if initial_dist is not None:
+ initial_dist_t = ops.convert_to_tensor(initial_dist, name="initial_dist")
+ acceptance_dist, prob_of_original = (
+ _calculate_acceptance_probs_with_mixing(initial_dist_t,
+ target_dist_t))
+ initial_dist_ds = dataset_ops.Dataset.from_tensors(
+ initial_dist_t).repeat()
+ acceptance_dist_ds = dataset_ops.Dataset.from_tensors(
+ acceptance_dist).repeat()
+ prob_of_original_ds = dataset_ops.Dataset.from_tensors(
+ prob_of_original).repeat()
+ else:
+ initial_dist_ds = _estimate_initial_dist_ds(
+ target_dist_t, class_values_ds)
+ acceptance_and_original_prob_ds = initial_dist_ds.map(
+ lambda initial: _calculate_acceptance_probs_with_mixing( # pylint: disable=g-long-lambda
+ initial, target_dist_t))
+ acceptance_dist_ds = acceptance_and_original_prob_ds.map(
+ lambda accept_prob, _: accept_prob)
+ prob_of_original_ds = acceptance_and_original_prob_ds.map(
+ lambda _, prob_original: prob_original)
+ filtered_ds = _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds,
+ class_values_ds, seed)
+ # Prefetch filtered dataset for speed.
+ filtered_ds = filtered_ds.prefetch(3)
+
+ prob_original_static = _get_prob_original_static(
+ initial_dist_t, target_dist_t) if initial_dist is not None else None
+ if prob_original_static == 1:
+ return dataset_ops.Dataset.zip((class_values_ds, dataset))
+ elif prob_original_static == 0:
+ return filtered_ds
+ else:
+ return interleave_ops.sample_from_datasets(
+ [dataset_ops.Dataset.zip((class_values_ds, dataset)), filtered_ds],
+ weights=prob_of_original_ds.map(lambda prob: [(prob, 1.0 - prob)]),
+ seed=seed)
+
+ return _apply_fn
+
+
+def _get_prob_original_static(initial_dist_t, target_dist_t):
+ """Returns the static probability of sampling from the original.
+
+ `tensor_util.constant_value(prob_of_original)` returns `None` if it encounters
+ an Op that it isn't defined for. We have some custom logic to avoid this.
+
+ Args:
+ initial_dist_t: A tensor of the initial distribution.
+ target_dist_t: A tensor of the target distribution.
+
+ Returns:
+ The probability of sampling from the original distribution as a constant,
+ if it is a constant, or `None`.
+ """
+ init_static = tensor_util.constant_value(initial_dist_t)
+ target_static = tensor_util.constant_value(target_dist_t)
+
+ if init_static is None or target_static is None:
+ return None
+ else:
+ return np.min(target_static / init_static)
+
+
+def _filter_ds(dataset, acceptance_dist_ds, initial_dist_ds, class_values_ds,
+ seed):
+ """Filters a dataset based on per-class acceptance probabilities.
+
+ Args:
+ dataset: The dataset to be filtered.
+ acceptance_dist_ds: A dataset of acceptance probabilities.
+ initial_dist_ds: A dataset of the initial probability distribution, given or
+ estimated.
+ class_values_ds: A dataset of the corresponding classes.
+ seed: (Optional.) Python integer seed for the resampler.
+
+ Returns:
+ A dataset of (class value, data) after filtering.
+ """
+ def maybe_warn_on_large_rejection(accept_dist, initial_dist):
+ proportion_rejected = math_ops.reduce_sum((1 - accept_dist) * initial_dist)
+ return control_flow_ops.cond(
+ math_ops.less(proportion_rejected, .5),
+ lambda: accept_dist,
+ lambda: logging_ops.Print( # pylint: disable=g-long-lambda
+ accept_dist, [proportion_rejected, initial_dist, accept_dist],
+ message="Proportion of examples rejected by sampler is high: ",
+ summarize=100,
+ first_n=10))
+
+ acceptance_dist_ds = (dataset_ops.Dataset.zip((acceptance_dist_ds,
+ initial_dist_ds))
+ .map(maybe_warn_on_large_rejection))
+
+ def _gather_and_copy(class_val, acceptance_prob, data):
+ return class_val, array_ops.gather(acceptance_prob, class_val), data
+
+ current_probabilities_and_class_and_data_ds = dataset_ops.Dataset.zip(
+ (class_values_ds, acceptance_dist_ds, dataset)).map(_gather_and_copy)
+ filtered_ds = (
+ current_probabilities_and_class_and_data_ds
+ .filter(lambda _1, p, _2: random_ops.random_uniform([], seed=seed) < p))
+ return filtered_ds.map(lambda class_value, _, data: (class_value, data))
+
+
+def _estimate_initial_dist_ds(
+ target_dist_t, class_values_ds, dist_estimation_batch_size=32,
+ smoothing_constant=10):
+ num_classes = (target_dist_t.shape[0].value or
+ array_ops.shape(target_dist_t)[0])
+ initial_examples_per_class_seen = array_ops.fill(
+ [num_classes], np.int64(smoothing_constant))
+
+ def update_estimate_and_tile(num_examples_per_class_seen, c):
+ updated_examples_per_class_seen, dist = _estimate_data_distribution(
+ c, num_examples_per_class_seen)
+ tiled_dist = array_ops.tile(
+ array_ops.expand_dims(dist, 0), [dist_estimation_batch_size, 1])
+ return updated_examples_per_class_seen, tiled_dist
+
+ initial_dist_ds = (class_values_ds.batch(dist_estimation_batch_size)
+ .apply(scan_ops.scan(initial_examples_per_class_seen,
+ update_estimate_and_tile))
+ .apply(batching.unbatch()))
+
+ return initial_dist_ds
+
+
+def _get_target_to_initial_ratio(initial_probs, target_probs):
+ # Add tiny to initial_probs to avoid divide by zero.
+ denom = (initial_probs + np.finfo(initial_probs.dtype.as_numpy_dtype).tiny)
+ return target_probs / denom
+
+
+def _estimate_data_distribution(c, num_examples_per_class_seen):
+ """Estimate data distribution as labels are seen.
+
+ Args:
+ c: The class labels. Type `int32`, shape `[batch_size]`.
+ num_examples_per_class_seen: Type `int64`, shape `[num_classes]`,
+ containing counts.
+
+ Returns:
+ num_examples_per_lass_seen: Updated counts. Type `int64`, shape
+ `[num_classes]`.
+ dist: The updated distribution. Type `float32`, shape `[num_classes]`.
+ """
+ num_classes = num_examples_per_class_seen.get_shape()[0].value
+ # Update the class-count based on what labels are seen in batch.
+ num_examples_per_class_seen = math_ops.add(
+ num_examples_per_class_seen, math_ops.reduce_sum(
+ array_ops.one_hot(c, num_classes, dtype=dtypes.int64), 0))
+ init_prob_estimate = math_ops.truediv(
+ num_examples_per_class_seen,
+ math_ops.reduce_sum(num_examples_per_class_seen))
+ dist = math_ops.cast(init_prob_estimate, dtypes.float32)
+ return num_examples_per_class_seen, dist
+
+
+def _calculate_acceptance_probs_with_mixing(initial_probs, target_probs):
+ """Calculates the acceptance probabilities and mixing ratio.
+
+ In this case, we assume that we can *either* sample from the original data
+ distribution with probability `m`, or sample from a reshaped distribution
+ that comes from rejection sampling on the original distribution. This
+ rejection sampling is done on a per-class basis, with `a_i` representing the
+ probability of accepting data from class `i`.
+
+ This method is based on solving the following analysis for the reshaped
+ distribution:
+
+ Let F be the probability of a rejection (on any example).
+ Let p_i be the proportion of examples in the data in class i (init_probs)
+ Let a_i is the rate the rejection sampler should *accept* class i
+ Let t_i is the target proportion in the minibatches for class i (target_probs)
+
+ ```
+ F = sum_i(p_i * (1-a_i))
+ = 1 - sum_i(p_i * a_i) using sum_i(p_i) = 1
+ ```
+
+ An example with class `i` will be accepted if `k` rejections occur, then an
+ example with class `i` is seen by the rejector, and it is accepted. This can
+ be written as follows:
+
+ ```
+ t_i = sum_k=0^inf(F^k * p_i * a_i)
+ = p_i * a_j / (1 - F) using geometric series identity, since 0 <= F < 1
+ = p_i * a_i / sum_j(p_j * a_j) using F from above
+ ```
+
+ Note that the following constraints hold:
+ ```
+ 0 <= p_i <= 1, sum_i(p_i) = 1
+ 0 <= a_i <= 1
+ 0 <= t_i <= 1, sum_i(t_i) = 1
+ ```
+
+ A solution for a_i in terms of the other variables is the following:
+ ```a_i = (t_i / p_i) / max_i[t_i / p_i]```
+
+ If we try to minimize the amount of data rejected, we get the following:
+
+ M_max = max_i [ t_i / p_i ]
+ M_min = min_i [ t_i / p_i ]
+
+ The desired probability of accepting data if it comes from class `i`:
+
+ a_i = (t_i/p_i - m) / (M_max - m)
+
+ The desired probability of pulling a data element from the original dataset,
+ rather than the filtered one:
+
+ m = M_min
+
+ Args:
+ initial_probs: A Tensor of the initial probability distribution, given or
+ estimated.
+ target_probs: A Tensor of the corresponding classes.
+
+ Returns:
+ (A 1D Tensor with the per-class acceptance probabilities, the desired
+ probability of pull from the original distribution.)
+ """
+ ratio_l = _get_target_to_initial_ratio(initial_probs, target_probs)
+ max_ratio = math_ops.reduce_max(ratio_l)
+ min_ratio = math_ops.reduce_min(ratio_l)
+
+ # Target prob to sample from original distribution.
+ m = min_ratio
+
+ # TODO(joelshor): Simplify fraction, if possible.
+ a_i = (ratio_l - m) / (max_ratio - m)
+ return a_i, m
diff --git a/tensorflow/python/data/experimental/ops/scan_ops.py b/tensorflow/python/data/experimental/ops/scan_ops.py
new file mode 100644
index 0000000000..e05e7c5a18
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/scan_ops.py
@@ -0,0 +1,177 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Scan dataset transformation."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import sparse_tensor
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+class _ScanDataset(dataset_ops.UnaryDataset):
+ """A dataset that scans a function across its input."""
+
+ def __init__(self, input_dataset, initial_state, scan_func):
+ """See `scan()` for details."""
+ super(_ScanDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ with ops.name_scope("initial_state"):
+ # Convert any `SparseTensorValue`s to `SparseTensor`s and all other
+ # values to tensors.
+ self._initial_state = nest.pack_sequence_as(initial_state, [
+ sparse_tensor.SparseTensor.from_value(t)
+ if sparse_tensor.is_sparse(t) else ops.convert_to_tensor(
+ t, name="component_%d" % i)
+ for i, t in enumerate(nest.flatten(initial_state))
+ ])
+
+ # Compute initial values for the state classes, shapes and types based on
+ # the initial state. The shapes may be refined by running `tf_scan_func` one
+ # or more times below.
+ self._state_classes = sparse.get_classes(self._initial_state)
+ self._state_shapes = nest.pack_sequence_as(
+ self._initial_state,
+ [t.get_shape() for t in nest.flatten(self._initial_state)])
+ self._state_types = nest.pack_sequence_as(
+ self._initial_state,
+ [t.dtype for t in nest.flatten(self._initial_state)])
+
+ # Will be populated by calling `tf_scan_func`.
+ self._output_classes = None
+ self._output_shapes = None
+ self._output_types = None
+
+ # Iteratively rerun the scan function until reaching a fixed point on
+ # `self._state_shapes`.
+ need_to_rerun = True
+ while need_to_rerun:
+
+ wrapped_func = dataset_ops.StructuredFunctionWrapper(
+ scan_func,
+ "tf.data.experimental.scan()",
+ input_classes=(self._state_classes, input_dataset.output_classes),
+ input_shapes=(self._state_shapes, input_dataset.output_shapes),
+ input_types=(self._state_types, input_dataset.output_types),
+ add_to_graph=False)
+ if not (
+ isinstance(wrapped_func.output_types, collections.Sequence) and
+ len(wrapped_func.output_types) == 2):
+ raise TypeError("The scan function must return a pair comprising the "
+ "new state and the output value.")
+
+ new_state_classes, self._output_classes = wrapped_func.output_classes
+
+ # Extract and validate class information from the returned values.
+ for new_state_class, state_class in zip(
+ nest.flatten(new_state_classes),
+ nest.flatten(self._state_classes)):
+ if not issubclass(new_state_class, state_class):
+ raise TypeError(
+ "The element classes for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_classes, new_state_classes))
+
+ # Extract and validate type information from the returned values.
+ new_state_types, self._output_types = wrapped_func.output_types
+ for new_state_type, state_type in zip(
+ nest.flatten(new_state_types), nest.flatten(self._state_types)):
+ if new_state_type != state_type:
+ raise TypeError(
+ "The element types for the new state must match the initial "
+ "state. Expected %s; got %s." %
+ (self._state_types, new_state_types))
+
+ # Extract shape information from the returned values.
+ new_state_shapes, self._output_shapes = wrapped_func.output_shapes
+
+ flat_state_shapes = nest.flatten(self._state_shapes)
+ flat_new_state_shapes = nest.flatten(new_state_shapes)
+ weakened_state_shapes = [
+ original.most_specific_compatible_shape(new)
+ for original, new in zip(flat_state_shapes, flat_new_state_shapes)
+ ]
+
+ need_to_rerun = False
+ for original_shape, weakened_shape in zip(flat_state_shapes,
+ weakened_state_shapes):
+ if original_shape.ndims is not None and (
+ weakened_shape.ndims is None or
+ original_shape.as_list() != weakened_shape.as_list()):
+ need_to_rerun = True
+ break
+
+ if need_to_rerun:
+ self._state_shapes = nest.pack_sequence_as(self._state_shapes,
+ weakened_state_shapes)
+
+ self._scan_func = wrapped_func.function
+ self._scan_func.add_to_graph(ops.get_default_graph())
+
+ def _as_variant_tensor(self):
+ input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+ return gen_dataset_ops.scan_dataset(
+ input_t,
+ nest.flatten(sparse.serialize_sparse_tensors(self._initial_state)),
+ self._scan_func.captured_inputs,
+ f=self._scan_func,
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._output_classes
+
+ @property
+ def output_shapes(self):
+ return self._output_shapes
+
+ @property
+ def output_types(self):
+ return self._output_types
+
+
+@tf_export("data.experimental.scan")
+def scan(initial_state, scan_func):
+ """A transformation that scans a function across an input dataset.
+
+ This transformation is a stateful relative of `tf.data.Dataset.map`.
+ In addition to mapping `scan_func` across the elements of the input dataset,
+ `scan()` accumulates one or more state tensors, whose initial values are
+ `initial_state`.
+
+ Args:
+ initial_state: A nested structure of tensors, representing the initial state
+ of the accumulator.
+ scan_func: A function that maps `(old_state, input_element)` to
+ `(new_state, output_element). It must take two arguments and return a
+ pair of nested structures of tensors. The `new_state` must match the
+ structure of `initial_state`.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+ def _apply_fn(dataset):
+ return _ScanDataset(dataset, initial_state, scan_func)
+
+ return _apply_fn
diff --git a/tensorflow/python/data/experimental/ops/shuffle_ops.py b/tensorflow/python/data/experimental/ops/shuffle_ops.py
new file mode 100644
index 0000000000..a4307212da
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/shuffle_ops.py
@@ -0,0 +1,102 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental shuffle ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import random_seed
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+class _ShuffleAndRepeatDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that fuses `shuffle` and `repeat`."""
+
+ def __init__(self, input_dataset, buffer_size, count=None, seed=None):
+ super(_ShuffleAndRepeatDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._buffer_size = ops.convert_to_tensor(
+ buffer_size, dtype=dtypes.int64, name="buffer_size")
+ if count is None:
+ self._count = constant_op.constant(-1, dtype=dtypes.int64, name="count")
+ else:
+ self._count = ops.convert_to_tensor(
+ count, dtype=dtypes.int64, name="count")
+ self._seed, self._seed2 = random_seed.get_seed(seed)
+
+ def _as_variant_tensor(self):
+ # pylint: disable=protected-access
+ input_resource = self._input_dataset._as_variant_tensor()
+ return gen_dataset_ops.shuffle_and_repeat_dataset(
+ input_resource,
+ buffer_size=self._buffer_size,
+ count=self._count,
+ seed=self._seed,
+ seed2=self._seed2,
+ **dataset_ops.flat_structure(self))
+ # pylint: enable=protected-access
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+@tf_export("data.experimental.shuffle_and_repeat")
+def shuffle_and_repeat(buffer_size, count=None, seed=None):
+ """Shuffles and repeats a Dataset returning a new permutation for each epoch.
+
+ `dataset.apply(tf.data.experimental.shuffle_and_repeat(buffer_size, count))`
+
+ is equivalent to
+
+ `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)`
+
+ The difference is that the latter dataset is not serializable. So,
+ if you need to checkpoint an input pipeline with reshuffling you must use
+ this implementation.
+
+ Args:
+ buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
+ maximum number elements that will be buffered when prefetching.
+ count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ number of times the dataset should be repeated. The default behavior
+ (if `count` is `None` or `-1`) is for the dataset be repeated
+ indefinitely.
+ seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ random seed that will be used to create the distribution. See
+ `tf.set_random_seed` for behavior.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset): # pylint: disable=missing-docstring
+ return _ShuffleAndRepeatDataset(dataset, buffer_size, count, seed)
+
+ return _apply_fn
diff --git a/tensorflow/contrib/data/python/ops/stats_ops.py b/tensorflow/python/data/experimental/ops/stats_ops.py
index bc47c5989d..c918d223e8 100644
--- a/tensorflow/contrib/data/python/ops/stats_ops.py
+++ b/tensorflow/python/data/experimental/ops/stats_ops.py
@@ -21,8 +21,10 @@ from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+@tf_export("data.experimental.StatsAggregator")
class StatsAggregator(object):
"""A stateful resource that aggregates statistics from one or more iterators.
@@ -34,7 +36,7 @@ class StatsAggregator(object):
```python
dataset = ...
- dataset = dataset.apply(stats_ops.latency_stats("total_bytes"))
+ dataset = dataset.apply(tf.data.experimental.latency_stats("total_bytes"))
```
To associate a `StatsAggregator` with a `tf.data.Dataset` object, use
@@ -46,7 +48,7 @@ class StatsAggregator(object):
# Apply `set_stats_aggregator` to associate `dataset` with `stats_aggregator`.
dataset = dataset.apply(
- tf.contrib.data.set_stats_aggregator(stats_aggregator))
+ tf.data.experimental.set_stats_aggregator(stats_aggregator))
iterator = dataset.make_one_shot_iterator()
```
@@ -111,11 +113,12 @@ class _SetStatsAggregatorDataset(dataset_ops.UnaryDataset):
return self._input_dataset.output_classes
+@tf_export("data.experimental.set_stats_aggregator")
def set_stats_aggregator(stats_aggregator):
"""Set the given `stats_aggregator` for aggregating the input dataset stats.
Args:
- stats_aggregator: A `tf.contrib.data.StatsAggregator` object.
+ stats_aggregator: A `tf.data.experimental.StatsAggregator` object.
Returns:
A `Dataset` transformation function, which can be passed to
@@ -128,8 +131,8 @@ def set_stats_aggregator(stats_aggregator):
return _apply_fn
-# TODO(b/38416882): Properly export in the `tf.contrib.data` API when stable
-# or make private / remove.
+# TODO(b/38416882): Properly export in the `tf.data.experimental` API when
+# stable or make private / remove.
def bytes_produced_stats(tag):
"""Records the number of bytes produced by each element of the input dataset.
@@ -152,6 +155,7 @@ def bytes_produced_stats(tag):
return _apply_fn
+@tf_export("data.experimental.latency_stats")
def latency_stats(tag):
"""Records the latency of producing each element of the input dataset.
diff --git a/tensorflow/python/data/experimental/ops/threadpool.py b/tensorflow/python/data/experimental/ops/threadpool.py
new file mode 100644
index 0000000000..3ea017c6e8
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/threadpool.py
@@ -0,0 +1,104 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental API for controlling threading in `tf.data` pipelines."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import threading
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
+from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops
+from tensorflow.python.ops import resource_variable_ops
+
+_uid_counter = 0
+_uid_lock = threading.Lock()
+
+
+def _generate_shared_name(prefix):
+ with _uid_lock:
+ global _uid_counter
+ uid = _uid_counter
+ _uid_counter += 1
+ return "{}{}".format(prefix, uid)
+
+
+# TODO(b/73383364): Properly export in the `tf.data.experimental` API when
+# stable or make private / remove.
+class PrivateThreadPool(object):
+ """A stateful resource that represents a private thread pool."""
+
+ def __init__(self, num_threads, display_name=None,
+ max_intra_op_parallelism=1):
+ """Creates a `PrivateThreadPool` with the given number of threads."""
+ if context.executing_eagerly():
+ shared_name = _generate_shared_name("privatethreadpool")
+ self._resource = ged_ops.experimental_thread_pool_handle(
+ num_threads=num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name=display_name,
+ shared_name=shared_name)
+ self._resource_deleter = resource_variable_ops.EagerResourceDeleter(
+ handle=self._resource, handle_device=context.context().device_name)
+ else:
+ self._resource = ged_ops.experimental_thread_pool_handle(
+ num_threads=num_threads,
+ max_intra_op_parallelism=max_intra_op_parallelism,
+ display_name=display_name)
+
+
+class _ThreadPoolDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` that acts as an identity, and sets a custom threadpool."""
+
+ def __init__(self, input_dataset, thread_pool):
+ super(_ThreadPoolDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._thread_pool = thread_pool
+
+ def _as_variant_tensor(self):
+ return ged_ops.experimental_thread_pool_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._thread_pool._resource, # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+
+# TODO(b/73383364): Properly export in the `tf.data.experimental` API when
+# stable or make private / remove.
+def override_threadpool(dataset, thread_pool):
+ """Returns a new dataset that uses the given thread pool for its operations.
+
+ Args:
+ dataset: A `tf.data.Dataset` object.
+ thread_pool: A `PrivateThreadPool` object.
+
+ Returns:
+ A dataset containing the same values as `dataset`, but which uses
+ `thread_pool` to compute any of its parallel operations (such as
+ `tf.data.Dataset.map`).
+ """
+ return _ThreadPoolDataset(dataset, thread_pool)
diff --git a/tensorflow/python/data/experimental/ops/unique.py b/tensorflow/python/data/experimental/ops/unique.py
new file mode 100644
index 0000000000..2a7775c456
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/unique.py
@@ -0,0 +1,79 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Unique element dataset transformations."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import gen_experimental_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.unique")
+def unique():
+ """Creates a `Dataset` from another `Dataset`, discarding duplicates.
+
+ Use this transformation to produce a dataset that contains one instance of
+ each unique element in the input. For example:
+
+ ```python
+ dataset = tf.data.Dataset.from_tensor_slices([1, 37, 2, 37, 2, 1])
+
+ # Using `unique()` will drop the duplicate elements.
+ dataset = dataset.apply(tf.data.experimental.unique()) # ==> { 1, 37, 2 }
+ ```
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ `tf.data.Dataset.apply`.
+ """
+
+ def _apply_fn(dataset):
+ return _UniqueDataset(dataset)
+
+ return _apply_fn
+
+
+class _UniqueDataset(dataset_ops.UnaryDataset):
+ """A `Dataset` contains the unique elements from its input."""
+
+ def __init__(self, input_dataset):
+ """See `unique()` for details."""
+ super(_UniqueDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ if input_dataset.output_types not in (dtypes.int32, dtypes.int64,
+ dtypes.string):
+ raise TypeError(
+ "`tf.data.experimental.unique()` only supports inputs with a single "
+ "`tf.int32`, `tf.int64`, or `tf.string` component.")
+
+ def _as_variant_tensor(self):
+ return gen_experimental_dataset_ops.experimental_unique_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **dataset_ops.flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
diff --git a/tensorflow/python/data/experimental/ops/writers.py b/tensorflow/python/data/experimental/ops/writers.py
new file mode 100644
index 0000000000..994447cb4d
--- /dev/null
+++ b/tensorflow/python/data/experimental/ops/writers.py
@@ -0,0 +1,60 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python wrappers for tf.data writers."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import convert
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
+from tensorflow.python.util.tf_export import tf_export
+
+
+@tf_export("data.experimental.TFRecordWriter")
+class TFRecordWriter(object):
+ """Writes data to a TFRecord file."""
+
+ def __init__(self, filename, compression_type=None):
+ self._filename = ops.convert_to_tensor(
+ filename, dtypes.string, name="filename")
+ self._compression_type = convert.optional_param_to_tensor(
+ "compression_type",
+ compression_type,
+ argument_default="",
+ argument_dtype=dtypes.string)
+
+ def write(self, dataset):
+ """Returns a `tf.Operation` to write a dataset to a file.
+
+ Args:
+ dataset: a `tf.data.Dataset` whose elements are to be written to a file
+
+ Returns:
+ A `tf.Operation` that, when run, writes contents of `dataset` to a file.
+ """
+ if not isinstance(dataset, dataset_ops.Dataset):
+ raise TypeError("`dataset` must be a `tf.data.Dataset` object.")
+ if (dataset.output_types != dtypes.string or
+ dataset.output_shapes != tensor_shape.scalar()):
+ raise TypeError(
+ "`dataset` must produce scalar `DT_STRING` tensors whereas it "
+ "produces shape {0} and types {1}".format(dataset.output_shapes,
+ dataset.output_types))
+ return gen_dataset_ops.dataset_to_tf_record(
+ dataset._as_variant_tensor(), self._filename, self._compression_type) # pylint: disable=protected-access
diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD
index 5f9818566f..bf76860aa4 100644
--- a/tensorflow/python/data/kernel_tests/BUILD
+++ b/tensorflow/python/data/kernel_tests/BUILD
@@ -115,8 +115,10 @@ tf_py_test(
srcs = ["dataset_ops_test.py"],
additional_deps = [
":test_base",
- "//tensorflow/core:protos_all_py",
+ "@absl_py//absl/testing:parameterized",
+ "//third_party/py/numpy",
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:sparse_tensor",
"//tensorflow/python/data/ops:dataset_ops",
],
)
@@ -173,20 +175,6 @@ tf_py_test(
)
tf_py_test(
- name = "inputs_test",
- size = "small",
- srcs = ["inputs_test.py"],
- additional_deps = [
- ":test_base",
- "@absl_py//absl/testing:parameterized",
- "//third_party/py/numpy",
- "//tensorflow/python:client_testlib",
- "//tensorflow/python:sparse_tensor",
- "//tensorflow/python/data/ops:dataset_ops",
- ],
-)
-
-tf_py_test(
name = "interleave_dataset_op_test",
size = "small",
srcs = ["interleave_dataset_op_test.py"],
@@ -471,6 +459,9 @@ py_library(
srcs = ["test_base.py"],
deps = [
"//tensorflow/python:client_testlib",
+ "//tensorflow/python:errors",
+ "//tensorflow/python:sparse_tensor",
+ "//tensorflow/python/data/util:nest",
],
)
diff --git a/tensorflow/python/data/kernel_tests/dataset_ops_test.py b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
index f115f9d9c7..b9f8875b9f 100644
--- a/tensorflow/python/data/kernel_tests/dataset_ops_test.py
+++ b/tensorflow/python/data/kernel_tests/dataset_ops_test.py
@@ -18,13 +18,20 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from absl.testing import parameterized
+import numpy as np
+
from tensorflow.core.framework import graph_pb2
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.ops import readers
+from tensorflow.python.data.util import nest
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
-class DatasetOpsTest(test_base.DatasetTestBase):
+class DatasetOpsTest(test_base.DatasetTestBase, parameterized.TestCase):
def testAsSerializedGraph(self):
dataset = dataset_ops.Dataset.range(10)
@@ -33,6 +40,155 @@ class DatasetOpsTest(test_base.DatasetTestBase):
sess.run(dataset._as_serialized_graph()))
self.assertTrue(any([node.op != "RangeDataset" for node in graph.node]))
+ @staticmethod
+ def make_apply_fn(dataset):
+
+ def apply_fn(dataset):
+
+ def _apply_fn(dataset):
+ return dataset.cache()
+
+ return dataset.apply(_apply_fn)
+
+ return apply_fn
+
+ @staticmethod
+ def make_gen():
+
+ def gen():
+ yield 42
+
+ return gen
+
+ @staticmethod
+ def make_interleave_fn(dataset, num_parallel_calls=None):
+
+ def interleave_fn(dataset):
+ return dataset.interleave(
+ lambda x: dataset_ops.Dataset.range(0),
+ cycle_length=2,
+ num_parallel_calls=num_parallel_calls)
+
+ return interleave_fn
+
+ @parameterized.named_parameters(
+ ("FixedLengthRecord", readers.FixedLengthRecordDataset("", 42)),
+ ("FromGenerator",
+ dataset_ops.Dataset.from_generator(make_gen.__func__(), dtypes.int32),
+ 1),
+ ("FromSparseTensorSlices",
+ dataset_ops.Dataset.from_sparse_tensor_slices(
+ sparse_tensor.SparseTensor(
+ indices=np.array([[0, 0], [1, 0], [2, 0]]),
+ values=np.array([0, 0, 0]),
+ dense_shape=np.array([3, 1])))),
+ ("FromTensors", dataset_ops.Dataset.from_tensors([42])),
+ ("FromTensorSlices", dataset_ops.Dataset.from_tensors([42])),
+ ("Range", dataset_ops.Dataset.range(10)),
+ ("TextLine", readers.TextLineDataset("")),
+ ("TFRecord", readers.TFRecordDataset(""), 1),
+ )
+ def testDatasetSourceInputs(self, dataset, num_inputs=0):
+ self.assertEqual(num_inputs, len(dataset._inputs()))
+
+ @parameterized.named_parameters(
+ ("Apply", make_apply_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Batch", lambda x: x.batch(10), dataset_ops.Dataset.range(0)),
+ ("Cache", lambda x: x.cache(), dataset_ops.Dataset.range(0)),
+ ("Filter", lambda x: x.filter(lambda x: True),
+ dataset_ops.Dataset.range(0)),
+ ("FlatMap", lambda x: x.flat_map(lambda x: dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Interleave", make_interleave_fn.__func__(dataset_ops.Dataset.range(0)),
+ dataset_ops.Dataset.range(0)),
+ ("Map", lambda x: x.map(lambda x: x), dataset_ops.Dataset.range(0)),
+ ("PaddedBatch", lambda x: x.padded_batch(10, []),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelInterleave",
+ make_interleave_fn.__func__(dataset_ops.Dataset.range(0), 2),
+ dataset_ops.Dataset.range(0)),
+ ("ParallelMap", lambda x: x.map(lambda x: x, num_parallel_calls=2),
+ dataset_ops.Dataset.range(0)),
+ ("Repeat", lambda x: x.repeat(), dataset_ops.Dataset.range(0)),
+ ("Shuffle", lambda x: x.shuffle(10), dataset_ops.Dataset.range(0)),
+ ("Skip", lambda x: x.skip(1), dataset_ops.Dataset.range(0)),
+ ("Take", lambda x: x.take(1), dataset_ops.Dataset.range(0)),
+ ("Window", lambda x: x.window(10), dataset_ops.Dataset.range(0)),
+ )
+ def testUnaryTransformationInputs(self, dataset_fn, input_dataset):
+ self.assertEqual([input_dataset], dataset_fn(input_dataset)._inputs())
+
+ @parameterized.named_parameters(
+ ("Concatenate", lambda x, y: x.concatenate(y),
+ dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1)))
+ def testBinaryTransformationInputs(self, dataset_fn, input1, input2):
+ self.assertEqual([input1, input2], dataset_fn(input1, input2)._inputs())
+
+ @parameterized.named_parameters(
+ ("ZipOne", dataset_ops.Dataset.zip, (dataset_ops.Dataset.range(0))),
+ ("ZipNest", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0),
+ (dataset_ops.Dataset.range(1), dataset_ops.Dataset.range(2)))),
+ ("ZipTuple", dataset_ops.Dataset.zip,
+ (dataset_ops.Dataset.range(0), dataset_ops.Dataset.range(1))))
+ def testVariadicTransformationInputs(self, dataset_fn, input_datasets):
+ self.assertEqual(
+ nest.flatten(input_datasets),
+ dataset_fn(input_datasets)._inputs())
+
+ def testCollectInputs(self):
+ ds1 = dataset_ops.Dataset.range(0)
+ ds2 = ds1.concatenate(ds1)
+ ds3 = dataset_ops.Dataset.zip((ds2, ds1, ds2))
+
+ inputs = []
+ queue = [ds3]
+ while queue:
+ ds = queue[0]
+ queue = queue[1:]
+ queue.extend(ds._inputs())
+ inputs.append(ds)
+
+ self.assertEqual(5, inputs.count(ds1))
+ self.assertEqual(2, inputs.count(ds2))
+ self.assertEqual(1, inputs.count(ds3))
+
+ def testOptionsDefault(self):
+ ds = dataset_ops.Dataset.range(0)
+ self.assertEqual(dataset_ops.Options(), ds.options())
+
+ def testOptionsOnce(self):
+ options = dataset_ops.Options()
+ ds = dataset_ops.Dataset.range(0).with_options(options).cache()
+ self.assertEqual(options, ds.options())
+
+ def testOptionsTwiceSame(self):
+ options = dataset_ops.Options()
+ options.experimental_autotune = True
+ ds = dataset_ops.Dataset.range(0).with_options(options).with_options(
+ options)
+ self.assertEqual(options, ds.options())
+
+ def testOptionsTwiceDifferent(self):
+ options1 = dataset_ops.Options()
+ options1.experimental_autotune = True
+ options2 = dataset_ops.Options()
+ options2.experimental_filter_fusion = False
+ ds = dataset_ops.Dataset.range(0).with_options(options1).with_options(
+ options2)
+ self.assertTrue(ds.options().experimental_autotune)
+ self.assertFalse(ds.options().experimental_filter_fusion)
+
+ def testOptionsTwiceDifferentError(self):
+ options1 = dataset_ops.Options()
+ options1.experimental_autotune = True
+ options2 = dataset_ops.Options()
+ options2.experimental_autotune = False
+ with self.assertRaisesRegexp(ValueError,
+ "Cannot merge incompatible values of option"):
+ dataset_ops.Dataset.range(0).with_options(options1).with_options(options2)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/data/kernel_tests/test_base.py b/tensorflow/python/data/kernel_tests/test_base.py
index b4f64115b7..b730e10949 100644
--- a/tensorflow/python/data/kernel_tests/test_base.py
+++ b/tensorflow/python/data/kernel_tests/test_base.py
@@ -17,6 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import re
+
+from tensorflow.python.data.util import nest
+from tensorflow.python.eager import context
+from tensorflow.python.framework import errors
+from tensorflow.python.framework import sparse_tensor
from tensorflow.python.platform import test
@@ -24,6 +30,80 @@ class DatasetTestBase(test.TestCase):
"""Base class for dataset tests."""
def assertSparseValuesEqual(self, a, b):
+ """Asserts that two SparseTensors/SparseTensorValues are equal."""
self.assertAllEqual(a.indices, b.indices)
self.assertAllEqual(a.values, b.values)
self.assertAllEqual(a.dense_shape, b.dense_shape)
+
+ def getNext(self, dataset):
+ """Returns a callable that returns the next element of the dataset.
+
+ Example use:
+ ```python
+ # In both graph and eager modes
+ dataset = ...
+ nxt = self.getNext(dataset)
+ result = self.evaluate(nxt())
+ ```
+
+ Args:
+ dataset: A dataset whose next element is returned
+
+ Returns:
+ A callable that returns the next element of `dataset`
+ """
+ it = dataset.make_one_shot_iterator()
+ if context.executing_eagerly():
+ return it.get_next
+ else:
+ nxt = it.get_next()
+ return lambda: nxt
+
+ def assertDatasetsEqual(self, dataset1, dataset2):
+ """Checks that datasets are equal. Supports both graph and eager mode."""
+ self.assertEqual(dataset1.output_types, dataset2.output_types)
+ self.assertEqual(dataset1.output_classes, dataset2.output_classes)
+
+ next1 = self.getNext(dataset1)
+ next2 = self.getNext(dataset2)
+ while True:
+ try:
+ op1 = self.evaluate(next1())
+ except errors.OutOfRangeError:
+ with self.assertRaises(errors.OutOfRangeError):
+ self.evaluate(next2())
+ break
+ op2 = self.evaluate(next2())
+
+ op1 = nest.flatten(op1)
+ op2 = nest.flatten(op2)
+ assert len(op1) == len(op2)
+ for i in range(len(op1)):
+ if isinstance(
+ op1[i],
+ (sparse_tensor.SparseTensor, sparse_tensor.SparseTensorValue)):
+ self.assertSparseValuesEqual(op1[i], op2[i])
+ else:
+ self.assertAllEqual(op1[i], op2[i])
+
+ def assertDatasetsRaiseSameError(self,
+ dataset1,
+ dataset2,
+ exception_class,
+ replacements=None):
+ """Checks that datasets raise the same error on the first get_next call."""
+ next1 = self.getNext(dataset1)
+ next2 = self.getNext(dataset2)
+ try:
+ self.evaluate(next1())
+ raise ValueError(
+ 'Expected dataset to raise an error of type %s, but it did not.' %
+ repr(exception_class))
+ except exception_class as e:
+ expected_message = e.message
+ for old, new, count in replacements:
+ expected_message = expected_message.replace(old, new, count)
+ # Check that the first segment of the error messages are the same.
+ with self.assertRaisesRegexp(exception_class,
+ re.escape(expected_message)):
+ self.evaluate(next2())
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index 6bba72a8e9..46ce191f7b 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -86,6 +86,18 @@ class Dataset(object):
raise NotImplementedError("Dataset._inputs")
+ def options(self):
+ """Returns the options for this dataset.
+
+ Returns:
+ A `tf.data.Options` object representing the dataset options.
+ """
+ for input_dataset in self._inputs():
+ options = input_dataset.options()
+ if options is not None:
+ return options
+ return Options()
+
def make_initializable_iterator(self, shared_name=None):
"""Creates an `Iterator` for enumerating the elements of this dataset.
@@ -114,6 +126,13 @@ class Dataset(object):
raise RuntimeError(
"dataset.make_initializable_iterator is not supported when eager "
"execution is enabled.")
+ dataset = self
+ options = self.options()
+ static_optimizations = options._static_optimizations() # pylint: disable=protected-access
+ if static_optimizations:
+ dataset = _OptimizeDataset(dataset, static_optimizations)
+ if options.experimental_autotune:
+ dataset = _ModelDataset(dataset)
if shared_name is None:
shared_name = ""
if compat.forward_compatible(2018, 8, 3):
@@ -123,11 +142,12 @@ class Dataset(object):
iterator_resource = gen_dataset_ops.iterator(
container="", shared_name=shared_name, **flat_structure(self))
with ops.colocate_with(iterator_resource):
- initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(),
- iterator_resource)
+ initializer = gen_dataset_ops.make_iterator(
+ dataset._as_variant_tensor(), # pylint: disable=protected-access
+ iterator_resource)
return iterator_ops.Iterator(iterator_resource, initializer,
- self.output_types, self.output_shapes,
- self.output_classes)
+ dataset.output_types, dataset.output_shapes,
+ dataset.output_classes)
def __iter__(self):
"""Creates an `Iterator` for enumerating the elements of this dataset.
@@ -162,7 +182,14 @@ class Dataset(object):
# a 0-argument function.
@function.Defun(capture_by_value=True)
def _make_dataset():
- return self._as_variant_tensor() # pylint: disable=protected-access
+ dataset = self
+ options = self.options()
+ static_optimizations = options._static_optimizations() # pylint: disable=protected-access
+ if static_optimizations:
+ dataset = _OptimizeDataset(dataset, static_optimizations)
+ if options.experimental_autotune:
+ dataset = _ModelDataset(dataset)
+ return dataset._as_variant_tensor() # pylint: disable=protected-access
try:
_make_dataset.add_to_graph(ops.get_default_graph())
@@ -889,8 +916,8 @@ class Dataset(object):
will be padded out to the maximum length of all elements in that
dimension.
- See also `tf.contrib.data.dense_to_sparse_batch`, which combines elements
- that may have different shapes into a `tf.SparseTensor`.
+ See also `tf.data.experimental.dense_to_sparse_batch`, which combines
+ elements that may have different shapes into a `tf.SparseTensor`.
Args:
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
@@ -1325,6 +1352,146 @@ class Dataset(object):
output_shapes,
output_classes)
+ def with_options(self, options):
+ """Returns a new `tf.data.Dataset` with the given options set.
+
+ The options are "global" in the sense they apply to the entire input
+ pipeline in which the `with_options` transformation is used. If options are
+ set multiple times, they are merged if possible (see
+ `tf.data.Options.merge()` for details).
+
+ Args:
+ options: A `tf.data.Options` that identifies the options the use.
+
+ Returns:
+ Dataset: A `Dataset` with the given options.
+
+ Raises:
+ ValueError: if options are set more than once
+ """
+ return _OptionsDataset(self, options)
+
+
+@tf_export("data.Options")
+class Options(object):
+ """Represents options for tf.data.Dataset.
+
+ An `Options` object can be for instance used to control which static
+ optimizations to apply or whether to use performance modeling to dynamically
+ tune the parallelism of operations such as `tf.data.Dataset.map` or
+ `tf.data.Dataset.interleave`.
+ """
+ for _name, _ty, _docstring in [
+ ("experimental_autotune", bool,
+ "Whether to dynamically adjust the values of tunable parameters (e.g. "
+ "degrees of parallelism)."),
+ ("experimental_filter_fusion", bool,
+ "Whether to fuse filter transformations."),
+ ("experimental_hoist_random_uniform", bool,
+ "Whether to hoist `tf.random_uniform()` ops out of map transformations."
+ ),
+ ("experimental_latency_all_edges", bool,
+ "Whether to add latency measurements on all edges."),
+ ("experimental_map_and_batch_fusion", bool,
+ "Whether to fuse map and batch transformations."),
+ ("experimental_map_and_filter_fusion", bool,
+ "Whether to fuse map and filter transformations."),
+ ("experimental_map_fusion", bool, "Whether to fuse map transformations."),
+ ("experimental_map_parallelization", bool,
+ "Whether to parallelize stateless map transformations."),
+ ("experimental_map_vectorization", bool,
+ "Whether to vectorize map transformations."),
+ ("experimental_noop_elimination", bool,
+ "Whether to eliminate no-op transformations."),
+ ("experimental_shuffle_and_repeat_fusion", bool,
+ "Whether to fuse shuffle and repeat transformations."),
+ ]:
+
+ def _make_getter(name): # pylint: disable=no-self-argument
+
+ def getter(self):
+ return getattr(self, "_" + name)
+
+ return getter
+
+ def _make_setter(name, ty): # pylint: disable=no-self-argument
+
+ def setter(self, value):
+ if not isinstance(value, ty):
+ raise TypeError(
+ "Attempting to set the option %s to incompatible value: %r" %
+ (name, value))
+ setattr(self, "_" + name, value)
+
+ return setter
+
+ vars()["_" + _name] = None
+ vars()[_name] = property(
+ _make_getter(_name), _make_setter(_name, _ty), None, _docstring)
+
+ def __init__(self):
+ pass
+
+ def __eq__(self, other):
+ if isinstance(other, self.__class__):
+ return self.__dict__ == other.__dict__
+ else:
+ return False
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+ def _static_optimizations(self):
+ """Produces the list of enabled static optimizations."""
+ experimental_optimizations = [
+ "filter_fusion", "hoist_random_uniform", "latency_all_edges",
+ "map_and_batch_fusion", "map_and_filter_fusion", "map_fusion",
+ "map_parallelization", "map_vectorization", "noop_elimination",
+ "shuffle_and_repeat_fusion"
+ ]
+ result = []
+ for exp_opt in experimental_optimizations:
+ if getattr(self, "experimental_" + exp_opt):
+ result.append(exp_opt)
+ return result
+
+ def merge(self, options):
+ """Merges itself with the given `tf.data.Options`.
+
+ The given `tf.data.Options` can be merged as long as there does not exist an
+ attribute that is set to different values in `self` and `options`.
+
+ Args:
+ options: a `tf.data.Options` to merge with
+
+ Raises:
+ ValueError: if the given `tf.data.Options` cannot be merged
+
+ Returns:
+ New `tf.data.Options()` object which is the result of merging self with
+ the input `tf.data.Options`.
+ """
+ result = Options()
+ for other in [self, options]:
+ for name in [
+ "experimental_autotune", "experimental_filter_fusion",
+ "experimental_hoist_random_uniform", "experimental_latency_all_edges",
+ "experimental_map_and_batch_fusion",
+ "experimental_map_and_filter_fusion", "experimental_map_fusion",
+ "experimental_map_parallelization", "experimental_map_vectorization",
+ "experimental_noop_elimination",
+ "experimental_shuffle_and_repeat_fusion"
+ ]:
+ this = getattr(result, name)
+ that = getattr(other, name)
+ if that is not None:
+ if this is None:
+ setattr(result, name, that)
+ elif this != that:
+ raise ValueError(
+ "Cannot merge incompatible values of option: %s" % (name))
+ return result
+
class DatasetSource(Dataset):
"""Abstract class representing a dataset with no inputs."""
@@ -1664,6 +1831,9 @@ class StructuredFunctionWrapper(object):
flat_classes.append(component)
flat_shapes.append(component)
flat_types.append(component)
+ if t.options() is not None: # pylint: disable=protected-access
+ warnings.warn("Encountered a nested dataset with options. These "
+ "options will not be applied to the outer dataset.")
else:
try:
t = ops.convert_to_tensor(t)
@@ -2703,3 +2873,91 @@ class WindowDataset(UnaryDataset):
@property
def output_types(self):
return self._output_types
+
+
+class _OptionsDataset(UnaryDataset):
+ """An identity `Dataset` that stores options."""
+
+ def __init__(self, input_dataset, options):
+ super(_OptionsDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ self._options = input_dataset.options()
+ if self._options:
+ self._options = self._options.merge(options)
+ else:
+ self._options = options
+
+ def _as_variant_tensor(self):
+ return self._input_dataset._as_variant_tensor() # pylint: disable=protected-access
+
+ def options(self):
+ return self._options
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _ModelDataset(UnaryDataset):
+ """A `Dataset` that acts as an identity, and models performance."""
+
+ def __init__(self, input_dataset):
+ """See `optimize()` for details."""
+ super(_ModelDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.model_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ **flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
+
+
+class _OptimizeDataset(UnaryDataset):
+ """A `Dataset` that acts as an identity, and applies optimizations."""
+
+ def __init__(self, input_dataset, optimizations):
+ """See `optimize()` for details."""
+ super(_OptimizeDataset, self).__init__(input_dataset)
+ self._input_dataset = input_dataset
+ if optimizations is None:
+ optimizations = []
+ self._optimizations = ops.convert_to_tensor(
+ optimizations, dtype=dtypes.string, name="optimizations")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.optimize_dataset(
+ self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access
+ self._optimizations,
+ **flat_structure(self))
+
+ @property
+ def output_classes(self):
+ return self._input_dataset.output_classes
+
+ @property
+ def output_shapes(self):
+ return self._input_dataset.output_shapes
+
+ @property
+ def output_types(self):
+ return self._input_dataset.output_types
diff --git a/tensorflow/python/data/ops/optional_ops.py b/tensorflow/python/data/ops/optional_ops.py
index 3bbebd7878..aca989e03a 100644
--- a/tensorflow/python/data/ops/optional_ops.py
+++ b/tensorflow/python/data/ops/optional_ops.py
@@ -31,7 +31,7 @@ class Optional(object):
An `Optional` can represent the result of an operation that may fail as a
value, rather than raising an exception and halting execution. For example,
- `tf.contrib.data.get_next_as_optional` returns an `Optional` that either
+ `tf.data.experimental.get_next_as_optional` returns an `Optional` that either
contains the next value from a `tf.data.Iterator` if one exists, or a "none"
value that indicates the end of the sequence has been reached.
"""
@@ -111,7 +111,7 @@ class Optional(object):
class _OptionalImpl(Optional):
- """Concrete implementation of `tf.contrib.data.Optional`.
+ """Concrete implementation of `tf.data.experimental.Optional`.
NOTE(mrry): This implementation is kept private, to avoid defining
`Optional.__init__()` in the public API.
diff --git a/tensorflow/python/data/ops/readers.py b/tensorflow/python/data/ops/readers.py
index b0f26631f9..d08da6704c 100644
--- a/tensorflow/python/data/ops/readers.py
+++ b/tensorflow/python/data/ops/readers.py
@@ -129,7 +129,7 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
def __init__(self, input_dataset, map_func, cycle_length, block_length,
sloppy, buffer_output_elements, prefetch_input_elements):
- """See `tf.contrib.data.parallel_interleave()` for details."""
+ """See `tf.data.experimental.parallel_interleave()` for details."""
super(ParallelInterleaveDataset, self).__init__(input_dataset, map_func,
cycle_length, block_length)
self._sloppy = ops.convert_to_tensor(
@@ -158,7 +158,7 @@ class ParallelInterleaveDataset(dataset_ops.InterleaveDataset):
# pylint: enable=protected-access
def _transformation_name(self):
- return "tf.contrib.data.parallel_interleave()"
+ return "tf.data.experimental.parallel_interleave()"
@tf_export("data.TFRecordDataset")
diff --git a/tensorflow/python/debug/examples/debug_tflearn_iris.py b/tensorflow/python/debug/examples/debug_tflearn_iris.py
index 019f13c450..f9bb3148fb 100644
--- a/tensorflow/python/debug/examples/debug_tflearn_iris.py
+++ b/tensorflow/python/debug/examples/debug_tflearn_iris.py
@@ -94,13 +94,15 @@ def main(_):
"sepal_length", "sepal_width", "petal_length", "petal_width", "label"]
batch_size = 32
def training_input_fn():
- return tf.contrib.data.make_csv_dataset(
- [training_data_path], batch_size,
- column_names=column_names, label_name="label")
+ return tf.data.experimental.make_csv_dataset([training_data_path],
+ batch_size,
+ column_names=column_names,
+ label_name="label")
def test_input_fn():
- return tf.contrib.data.make_csv_dataset(
- [test_data_path], batch_size,
- column_names=column_names, label_name="label")
+ return tf.data.experimental.make_csv_dataset([test_data_path],
+ batch_size,
+ column_names=column_names,
+ label_name="label")
feature_columns = [tf.feature_column.numeric_column(feature)
for feature in column_names[:-1]]
diff --git a/tensorflow/python/debug/examples/examples_test.sh b/tensorflow/python/debug/examples/examples_test.sh
index f7d597c8c0..89dc918616 100755
--- a/tensorflow/python/debug/examples/examples_test.sh
+++ b/tensorflow/python/debug/examples/examples_test.sh
@@ -115,7 +115,7 @@ OUTPUT=$(${OFFLINE_ANALYZER_BIN} 2>&1)
set -e
EXPECTED_OUTPUT="ERROR: dump_dir flag is empty."
-if [[ "${OUTPUT}" != "${EXPECTED_OUTPUT}" ]]; then
+if ! echo "${OUTPUT}" | grep -q "${EXPECTED_OUTPUT}"; then
echo "ERROR: offline_analyzer output didn't match expectation: ${OUTPUT}" 1>&2
echo "Expected output: ${EXPECTED_OUTPUT}"
exit 1
diff --git a/tensorflow/python/distribute/distribute_coordinator.py b/tensorflow/python/distribute/distribute_coordinator.py
index bd3562f1ff..b9b77d4a5b 100644
--- a/tensorflow/python/distribute/distribute_coordinator.py
+++ b/tensorflow/python/distribute/distribute_coordinator.py
@@ -126,7 +126,7 @@ class _WorkerContext(object):
replicated training.
task_id: an integer indicating id of the corresponding task. It can be
None if it is local training or in-graph replicated training.
- session_config: an optional @{tf.ConfigProto} object.
+ session_config: an optional `tf.ConfigProto` object.
rpc_layer: optional string specifying the RPC protocol for communication
with worker masters. If None or empty, hosts in the `cluster_spec` will
be used directly.
@@ -685,7 +685,7 @@ def run_distribute_coordinator(worker_fn,
in a cluster. If not set or empty, fall back to local training.
task_type: the current task type, optional if this is a client.
task_id: the current task id, optional if this is a client.
- session_config: an optional @{tf.ConfigProto} object which will be passed
+ session_config: an optional `tf.ConfigProto` object which will be passed
to `strategy`'s `configure` method and used to create a session.
rpc_layer: optional string, the protocol for RPC, e.g. "grpc".
diff --git a/tensorflow/python/distribute/estimator_training.py b/tensorflow/python/distribute/estimator_training.py
index 8daa34c885..0289689134 100644
--- a/tensorflow/python/distribute/estimator_training.py
+++ b/tensorflow/python/distribute/estimator_training.py
@@ -62,7 +62,7 @@ def _get_global_id(cluster_spec, task_type, task_id, chief_task_type):
# Sort task names in cluster by "chief"/"master", "evaluator", "worker"
# and "ps". More details can be found at the documentation of
- # @{tf.estimator.RunConfig.global_id_in_cluster}.
+ # `tf.estimator.RunConfig.global_id_in_cluster`.
task_type_ordered_list = []
if chief_task_type in cluster_spec.jobs:
task_type_ordered_list = [chief_task_type]
diff --git a/tensorflow/python/eager/BUILD b/tensorflow/python/eager/BUILD
index d3d997e6df..d0c1a93118 100644
--- a/tensorflow/python/eager/BUILD
+++ b/tensorflow/python/eager/BUILD
@@ -37,6 +37,7 @@ cc_library(
"//tensorflow/python:safe_ptr",
"//third_party/py/numpy:headers",
"//third_party/python_runtime:headers",
+ "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:variant",
],
)
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 78f3198011..deac29111f 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -619,7 +619,7 @@ pywrap_tensorflow.TFE_Py_RegisterVSpace(_default_vspace)
def _handle_or_self(x):
"""If x is ResourceVariable, return its handle, else x."""
- if isinstance(x, resource_variable_ops.ResourceVariable):
+ if resource_variable_ops.is_resource_variable(x):
x = x.handle
return x
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index dd3e1a3723..f261d92d64 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import collections
import functools
+import re
import sys
import threading
import weakref
@@ -61,9 +62,15 @@ cond_v2_impl._function = sys.modules[__name__] # pylint: disable=protected-acce
# This is to avoid a circular dependency with gradients_impl
gradients_impl._function = sys.modules[__name__] # pylint: disable=protected-access
+FORWARD_FUNCTION_ATTRIBUTE_NAME = "forward_function_name"
+BACKWARD_FUNCTION_ATTRIBUTE_NAME = "backward_function_name"
# TODO(scottzhu): Update this to allow arbitrary attribute names in future.
-WHITELIST_FUNCTION_ATTRIBUTE_PREFIX = "experimental_"
+WHITELIST_FUNCTION_ATTRIBUTE_REGEX = [
+ "experimental_.*",
+ FORWARD_FUNCTION_ATTRIBUTE_NAME,
+ BACKWARD_FUNCTION_ATTRIBUTE_NAME
+]
def _create_substitute_placeholder(value, name=None, dtype=None):
@@ -140,10 +147,11 @@ def _parse_func_attrs(attributes):
"""
attrs = {}
for key, value in attributes.items():
- if not key.startswith(WHITELIST_FUNCTION_ATTRIBUTE_PREFIX):
+ if not any([re.match(reg, key)
+ for reg in WHITELIST_FUNCTION_ATTRIBUTE_REGEX]):
raise ValueError("Attribute name is not whitelisted. "
"Whitelisted: prefix %s, got: %s" %
- (WHITELIST_FUNCTION_ATTRIBUTE_PREFIX, key))
+ (WHITELIST_FUNCTION_ATTRIBUTE_REGEX, key))
if isinstance(value, attr_value_pb2.AttrValue):
attrs[key] = value
@@ -154,7 +162,7 @@ def _parse_func_attrs(attributes):
attrs[key] = attr_value_pb2.AttrValue(i=value)
elif isinstance(value, float):
attrs[key] = attr_value_pb2.AttrValue(f=value)
- elif isinstance(value, str):
+ elif isinstance(value, (str, bytes)):
attrs[key] = attr_value_pb2.AttrValue(s=compat.as_bytes(value))
else:
raise ValueError("Unsupported attribute type for %s with type %s" %
@@ -261,6 +269,15 @@ class FuncGraph(ops.Graph):
def variables(self, var_list):
self._weak_variables = [weakref.ref(v) for v in var_list]
+ def control_dependencies(self, control_inputs):
+ # Drop control dependencies to outside of the graph. TODO(b/117109273)
+ # unclear how to capture an op, not a tensor.
+ if not control_inputs:
+ return super(FuncGraph, self).control_dependencies(control_inputs)
+ return super(FuncGraph, self).control_dependencies(
+ [c for c in control_inputs
+ if getattr(c, "graph", None) is self])
+
def create_op(
self,
op_type,
@@ -705,6 +722,7 @@ class Function(object):
def _construct_backprop_function(self):
"""Constructs the backprop function object for this function."""
backwards_graph = FuncGraph(_backward_name(self._func_graph.name))
+ forward_function_name = _forward_name(self._func_graph.name)
with backwards_graph.as_default():
gradients_wrt_outputs = [
graph_placeholder(x.dtype, x.shape) for x in self._func_graph.outputs
@@ -715,11 +733,11 @@ class Function(object):
grad_ys=gradients_wrt_outputs,
src_graph=self._func_graph)
- self._forward_function = _EagerDefinedFunction(
- _forward_name(
- self._func_graph.name), self._func_graph, self._func_graph.inputs,
- self._func_graph.outputs + list(backwards_graph.captures.keys()),
- self._attrs)
+ backwards_graph_captures = list(backwards_graph.captures.keys())
+
+ backward_function_attr = _parse_func_attrs(
+ {FORWARD_FUNCTION_ATTRIBUTE_NAME: forward_function_name})
+ backward_function_attr.update(self._attrs)
# The ordering of `backwards_graph.inputs` is important: inputs of
# `self._backward_graph_function` correspond to outputs of
@@ -732,7 +750,17 @@ class Function(object):
grad for grad in _flatten(gradients_wrt_inputs) if grad is not None)
backwards_graph.structured_outputs = gradients_wrt_inputs
self._backward_graph_function = Function(
- backwards_graph, attrs=self._attrs)
+ backwards_graph, attrs=backward_function_attr)
+
+ forward_function_attr = _parse_func_attrs({
+ BACKWARD_FUNCTION_ATTRIBUTE_NAME:
+ self._backward_graph_function._inference_function.name}) # pylint: disable=protected-access
+ forward_function_attr.update(self._attrs)
+
+ self._forward_function = _EagerDefinedFunction(
+ forward_function_name, self._func_graph, self._func_graph.inputs,
+ self._func_graph.outputs + backwards_graph_captures,
+ forward_function_attr)
def _backprop_call(self, args):
"""Calls the forward function and records the result on a tape.
@@ -986,52 +1014,8 @@ def func_graph_from_py_func(name,
return func_graph
-_TensorType = collections.namedtuple("_TensorType", ["dtype", "shape"])
-
-
-def _encode_arg(arg):
- """A canonical representation for this argument, for use in a cache key."""
-
- # `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
- # are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
- # are used for both performance reasons, as much TensorFlow code specializes
- # on known shapes to produce slimmer graphs, and correctness, as some
- # high-level APIs require shapes to be fully-known.
- #
- # TODO(akshayka): Add support for sparse tensors.
- #
- # pylint: disable=protected-access
- if isinstance(arg, ops.Tensor):
- return _TensorType(arg.dtype, arg._shape_tuple())
- elif isinstance(arg, ops.IndexedSlices):
- if arg.dense_shape is not None:
- return tuple([
- _TensorType(arg.values.dtype, arg.values._shape_tuple()),
- _TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
- _TensorType(arg.dense_shape.dtype, arg.dense_shape._shape_tuple()),
- ])
- else:
- return tuple([
- _TensorType(arg.values.dtype, arg.values._shape_tuple()),
- _TensorType(arg.indices.dtype, arg.indices._shape_tuple()),
- ])
- # pylint: enable=protected-access
- elif isinstance(arg, (list, tuple)):
- return tuple([_encode_arg(elem) for elem in arg])
- elif isinstance(arg, dict):
- return tuple(
- (_encode_arg(key), _encode_arg(arg[key])) for key in sorted(arg))
- else:
- try:
- # If possible, keep only a weak reference to Python objects. Weak
- # references hash to the same value as the original object.
- # TODO(allenl): Clean up dead functions and their cache keys if the cache
- # gets large. Right now creating objects with a defunned method, calling
- # the method, and losing a reference to the object in a loop will leak
- # memory here.
- return weakref.ref(arg)
- except TypeError:
- return arg
+pywrap_tensorflow.RegisterType("Tensor", ops.Tensor)
+pywrap_tensorflow.RegisterType("IndexedSlices", ops.IndexedSlices)
def _deterministic_dict_values(dictionary):
@@ -1101,6 +1085,8 @@ class PolymorphicFunction(object):
offset + index: default
for index, default in enumerate(fullargspec.defaults or [])
}
+ self._default_values = fullargspec.defaults
+ self._default_values_start_index = offset
if input_signature is None:
self._input_signature = None
else:
@@ -1161,7 +1147,7 @@ class PolymorphicFunction(object):
"""Computes the cache key given inputs and execution context."""
if self._input_signature is None:
inputs = (args, kwargs) if kwargs else args
- cache_key = tuple(_encode_arg(arg) for arg in inputs)
+ cache_key = pywrap_tensorflow.TFE_Py_EncodeArg(inputs)
else:
del args, kwargs
cache_key = self._flat_input_signature
@@ -1184,7 +1170,7 @@ class PolymorphicFunction(object):
colocation_stack = (() if executing_eagerly else
tuple(default_graph._colocation_stack.peek_objs())) # pylint: disable=protected-access
- return cache_key + (execution_context, device_functions, colocation_stack)
+ return (cache_key, execution_context, device_functions, colocation_stack)
def _canonicalize_function_inputs(self, *args, **kwargs):
"""Canonicalizes `args` and `kwargs`.
@@ -1212,26 +1198,32 @@ class PolymorphicFunction(object):
# Maps from index of arg to its corresponding value, according to `args`
# and `kwargs`; seeded with the default values for the named args that
# aren't in `args`.
- arg_indices_to_values = {
- index: default
- for index, default in six.iteritems(self._arg_indices_to_default_values)
- if index >= len(args)
- }
- consumed_args = []
- for arg, value in six.iteritems(kwargs):
- index = self._args_to_indices.get(arg, None)
- if index is not None:
- arg_indices_to_values[index] = value
- consumed_args.append(arg)
- elif self._input_signature is not None:
- raise ValueError("Cannot define a TensorFlow function from a Python "
- "function with keyword arguments when "
- "input_signature is provided.")
- for arg in consumed_args:
- # After this loop, `kwargs` will only contain true keyword arguments, as
- # opposed to named arguments called in a keyword-like fashion.
- kwargs.pop(arg)
- inputs = args + _deterministic_dict_values(arg_indices_to_values)
+ if not kwargs:
+ if self._default_values:
+ inputs = args + self._default_values[len(args) -
+ self._default_values_start_index:]
+ else:
+ inputs = args
+ else:
+ arg_indices_to_values = {
+ index: default for index, default in six.iteritems(
+ self._arg_indices_to_default_values) if index >= len(args)
+ }
+ consumed_args = []
+ for arg, value in six.iteritems(kwargs):
+ index = self._args_to_indices.get(arg, None)
+ if index is not None:
+ arg_indices_to_values[index] = value
+ consumed_args.append(arg)
+ elif self._input_signature is not None:
+ raise ValueError("Cannot define a TensorFlow function from a Python "
+ "function with keyword arguments when "
+ "input_signature is provided.")
+ for arg in consumed_args:
+ # After this loop, `kwargs` will only contain true keyword arguments, as
+ # opposed to named arguments called in a keyword-like fashion.
+ kwargs.pop(arg)
+ inputs = args + _deterministic_dict_values(arg_indices_to_values)
flat_inputs = nest.flatten(inputs)
# Check for NumPy arrays in arguments and convert them to Tensors.
diff --git a/tensorflow/python/eager/function_test.py b/tensorflow/python/eager/function_test.py
index 34a2648e26..9ce367a837 100644
--- a/tensorflow/python/eager/function_test.py
+++ b/tensorflow/python/eager/function_test.py
@@ -1237,6 +1237,24 @@ class FunctionTest(test.TestCase):
x = constant_op.constant([1.0, 2.0])
self.assertAllEqual([2., 4.], self.evaluate(defined(x)))
+ def testCacheObjectHashCollisions(self):
+
+ class Foo(object):
+
+ def __hash__(self):
+ return 42
+
+ def func(foo):
+ del foo
+ return
+
+ defined = function.defun(func)
+ defined(Foo())
+ self.assertEqual(len(defined._function_cache), 1)
+
+ defined(Foo())
+ self.assertEqual(len(defined._function_cache), 2)
+
def testPythonFunctionWithDefaultArgs(self):
def func(foo, bar=1, baz=2):
@@ -1250,20 +1268,20 @@ class FunctionTest(test.TestCase):
def cache_keys():
"""Sanitizes cache keys of non-input metadata."""
- return tuple(key[:3] for key in defined._function_cache)
+ return tuple(key[0] for key in defined._function_cache)
# `True` corresponds to the fact that we're executing eagerly
- self.assertIn((0, 1, 20), cache_keys())
+ self.assertIn(('tRRR', (0, 1, 20)), cache_keys())
defined(1) # bar=1, baz=2
- self.assertIn((1, 1, 2), cache_keys())
+ self.assertIn(('tRRR', (1, 1, 2)), cache_keys())
# This matches the previous call.
defined(foo=1)
self.assertEqual(len(defined._function_cache), 2)
defined(1, 2, 3)
- self.assertIn((1, 2, 3), cache_keys())
+ self.assertIn(('tRRR', (1, 2, 3)), cache_keys())
# This matches the previous call.
defined(1, bar=2, baz=3)
@@ -1687,6 +1705,21 @@ class FunctionTest(test.TestCase):
self.assertRegexpMatches(captured_function_names[i],
expected_func_name_regex[i])
+ # Check the forward and backward function has the correct attributes.
+ self.assertEquals(
+ functions[1].definition.attr['backward_function_name'].s,
+ functions[2].name)
+ self.assertEquals(
+ functions[2].definition.attr['forward_function_name'].s,
+ functions[1].name)
+
+ self.assertEquals(
+ functions[4].definition.attr['backward_function_name'].s,
+ functions[5].name)
+ self.assertEquals(
+ functions[5].definition.attr['forward_function_name'].s,
+ functions[4].name)
+
sq = defun_matmul(t, t)
double = add(t, t)
self.assertAllEqual(sq.eval().reshape(-1), [7, 10, 15, 22])
diff --git a/tensorflow/python/eager/pywrap_tfe.h b/tensorflow/python/eager/pywrap_tfe.h
index f1b4042ec9..decd635b58 100755
--- a/tensorflow/python/eager/pywrap_tfe.h
+++ b/tensorflow/python/eager/pywrap_tfe.h
@@ -224,4 +224,8 @@ PyObject* TFE_Py_TensorShapeSlice(PyObject* tensors, int slice_dim);
// The shape is represented as a Python tuple of integers.
PyObject* TFE_Py_TensorShapeOnDevice(PyObject* tensor);
+// Encodes the object as a tuple that is meant to be used as part of the key
+// for the defun function cache.
+PyObject* TFE_Py_EncodeArg(PyObject*);
+
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_TFE_H_
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index 196e20e4d7..ae1e12f9c3 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/python/eager/pywrap_tfe.h"
+#include "absl/strings/str_cat.h"
#include "absl/types/variant.h"
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/c_api_internal.h"
@@ -567,11 +568,8 @@ bool SetOpAttrScalar(
return false;
}
}
- TFE_Op* func = TFE_NewOp(
- ctx, string(func_name.data(), func_name.size()).c_str(), status);
- if (TF_GetCode(status) != TF_OK) return false;
- TFE_OpSetAttrFunction(op, key, func);
- TFE_DeleteOp(func);
+ TF_SetStatus(status, TF_OK, "");
+ TFE_OpSetAttrFunctionName(op, key, func_name.data(), func_name.size());
} else {
TF_SetStatus(
status, TF_UNIMPLEMENTED,
@@ -1569,9 +1567,8 @@ void TapeSetRecordOperation(
}
for (TFE_Py_Tape* tape : SafeTapeSet()) {
- auto* function = backward_function_getter();
tape->tape->RecordOperation(op_type_str, output_info, input_ids,
- input_dtypes, function,
+ input_dtypes, backward_function_getter,
backward_function_killer);
}
}
@@ -2748,3 +2745,218 @@ PyObject* TFE_Py_RecordGradient(PyObject* op_name, PyObject* inputs,
return RecordGradient(op_name, inputs, attrs, results, name);
}
+
+namespace {
+
+tensorflow::int64 GetPyNoneHash() {
+ tensorflow::int64 py_none_hash = PyObject_Hash(Py_None);
+ return py_none_hash;
+}
+
+struct EncodeResult {
+ string str;
+ std::vector<PyObject*> objects;
+
+ PyObject* ToPyTuple() {
+ PyObject* result = PyTuple_New(2);
+
+ PyTuple_SET_ITEM(result, 0, GetPythonObjectFromString(str.c_str()));
+
+ if (objects.empty()) {
+ Py_INCREF(Py_None);
+ PyTuple_SET_ITEM(result, 1, Py_None);
+ } else {
+ PyObject* objects_tuple = PyTuple_New(objects.size());
+
+ for (int i = 0; i < objects.size(); i++) {
+ PyTuple_SET_ITEM(objects_tuple, i, objects[i]);
+ }
+
+ PyTuple_SET_ITEM(result, 1, objects_tuple);
+ }
+
+ return result;
+ }
+};
+
+tensorflow::Status TFE_Py_EncodeTensor(PyObject* arg, EncodeResult* result) {
+ if (EagerTensor_CheckExact(arg)) {
+ TFE_TensorHandle* t = EagerTensor_Handle(arg);
+ tensorflow::TensorShape tensor_shape;
+ TF_RETURN_IF_ERROR(t->handle->Shape(&tensor_shape));
+ absl::StrAppend(&result->str, t->handle->dtype);
+
+ for (tensorflow::int64 dim_size : tensor_shape.dim_sizes()) {
+ absl::StrAppend(&result->str, dim_size);
+ }
+
+ return tensorflow::Status::OK();
+ }
+
+ tensorflow::Safe_PyObjectPtr dtype_object(
+ PyObject_GetAttrString(arg, "dtype"));
+
+ if (dtype_object == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "ops.Tensor object doesn't have dtype() attr.");
+ }
+
+ tensorflow::Safe_PyObjectPtr dtype_enum(
+ PyObject_GetAttrString(dtype_object.get(), "_type_enum"));
+
+ if (dtype_enum == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "ops.Tensor's dtype object doesn't have _type_enum() attr.");
+ }
+
+ tensorflow::DataType dtype =
+ static_cast<tensorflow::DataType>(MakeInt(dtype_enum.get()));
+
+ absl::StrAppend(&result->str, dtype);
+ static char _shape_tuple[] = "_shape_tuple";
+ tensorflow::Safe_PyObjectPtr shape_tuple(
+ PyObject_CallMethod(arg, _shape_tuple, nullptr));
+
+ if (shape_tuple == nullptr) {
+ return tensorflow::errors::InvalidArgument(
+ "ops.Tensor object doesn't have _shape_tuple() method.");
+ }
+
+ if (shape_tuple.get() == Py_None) {
+ // Unknown shape, encode that directly.
+ absl::StrAppend(&result->str, GetPyNoneHash());
+ return tensorflow::Status::OK();
+ }
+
+ tensorflow::Safe_PyObjectPtr shape_seq(PySequence_Fast(
+ shape_tuple.get(), "shape_tuple didn't return a sequence"));
+
+ int len = PySequence_Fast_GET_SIZE(shape_seq.get());
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(shape_seq.get(), i);
+ if (item == Py_None) {
+ absl::StrAppend(&result->str, GetPyNoneHash());
+ } else {
+ absl::StrAppend(&result->str, MakeInt(item));
+ }
+ }
+
+ return tensorflow::Status::OK();
+}
+
+const char kTensor[] = "T";
+const char kIndexedSlices[] = "I";
+const char kList[] = "L";
+const char kTuple[] = "t";
+const char kDict[] = "D";
+const char kRaw[] = "R";
+
+tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result);
+
+// This function doesn't set the type of sequence before
+tensorflow::Status TFE_Py_EncodeSequence(PyObject* arg, const char* type,
+ EncodeResult* result) {
+ tensorflow::Safe_PyObjectPtr arg_seq(
+ PySequence_Fast(arg, "unable to create seq from list/tuple"));
+
+ absl::StrAppend(&result->str, type);
+ int len = PySequence_Fast_GET_SIZE(arg_seq.get());
+ for (int i = 0; i < len; ++i) {
+ PyObject* item = PySequence_Fast_GET_ITEM(arg_seq.get(), i);
+ if (item == Py_None) {
+ absl::StrAppend(&result->str, GetPyNoneHash());
+ } else {
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(item, result));
+ }
+ }
+
+ return tensorflow::Status::OK();
+}
+
+tensorflow::Status TFE_Py_EncodeArgHelper(PyObject* arg, EncodeResult* result) {
+ if (tensorflow::swig::IsTensor(arg)) {
+ absl::StrAppend(&result->str, kTensor);
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(arg, result));
+ } else if (tensorflow::swig::IsIndexedSlices(arg)) {
+ absl::StrAppend(&result->str, kIndexedSlices);
+ tensorflow::Safe_PyObjectPtr values(PyObject_GetAttrString(arg, "values"));
+ if (values == nullptr) {
+ PyErr_Clear();
+ return tensorflow::errors::InvalidArgument(
+ "IndexedSlices does not have a values attr");
+ }
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(values.get(), result));
+
+ tensorflow::Safe_PyObjectPtr indices(
+ PyObject_GetAttrString(arg, "indices"));
+ if (indices == nullptr) {
+ PyErr_Clear();
+ return tensorflow::errors::InvalidArgument(
+ "IndexedSlices does not have a indices attr");
+ }
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(indices.get(), result));
+
+ tensorflow::Safe_PyObjectPtr dense_shape(
+ PyObject_GetAttrString(arg, "dense_shape"));
+ if (dense_shape == nullptr) {
+ PyErr_Clear();
+ return tensorflow::errors::InvalidArgument(
+ "IndexedSlices does not have a dense_shape attr");
+ }
+ if (dense_shape.get() != Py_None) {
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeTensor(dense_shape.get(), result));
+ }
+ } else if (PyList_Check(arg)) {
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kList, result));
+ } else if (PyTuple_Check(arg)) {
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeSequence(arg, kTuple, result));
+ } else if (PyDict_Check(arg)) {
+ tensorflow::Safe_PyObjectPtr keys(PyDict_Keys(arg));
+ if (PyList_Sort(keys.get()) == -1) {
+ return tensorflow::errors::Internal("Unable to sort keys");
+ }
+
+ absl::StrAppend(&result->str, kDict);
+ int len = PyList_Size(keys.get());
+
+ for (int i = 0; i < len; i++) {
+ PyObject* key = PyList_GetItem(keys.get(), i);
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(key, result));
+ PyObject* value = PyDict_GetItem(arg, key);
+ TF_RETURN_IF_ERROR(TFE_Py_EncodeArgHelper(value, result));
+ }
+ } else {
+ PyObject* object = PyWeakref_NewRef(arg, nullptr);
+
+ if (object == nullptr) {
+ PyErr_Clear();
+
+ object = arg;
+ Py_INCREF(object);
+ }
+
+ absl::StrAppend(&result->str, kRaw);
+ result->objects.push_back(object);
+ }
+
+ return tensorflow::Status::OK();
+}
+
+} // namespace
+
+// `defun` uses dtypes and shapes instead of `Tensors` as cache keys. Dtypes
+// are used because TensorFlow graphs are not parametric w.r.t. dtypes. Shapes
+// are used for both performance reasons, as much TensorFlow code specializes
+// on known shapes to produce slimmer graphs, and correctness, as some
+// high-level APIs require shapes to be fully-known.
+//
+// TODO(nareshmodi): Add support for sparse tensors.
+PyObject* TFE_Py_EncodeArg(PyObject* arg) {
+ EncodeResult result;
+ const auto status = TFE_Py_EncodeArgHelper(arg, &result);
+ if (MaybeRaiseExceptionFromStatus(status, nullptr)) {
+ return nullptr;
+ }
+
+ return result.ToPyTuple();
+}
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index ba1b7ec2b5..1c4c5951df 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -344,6 +344,7 @@ py_test(
":pandas_io",
":prediction_keys",
"//tensorflow:tensorflow_py_no_contrib",
+ "@absl_py//absl/testing:parameterized",
"@six_archive//:six",
],
)
diff --git a/tensorflow/python/estimator/canned/dnn.py b/tensorflow/python/estimator/canned/dnn.py
index 97971f9561..a6c2aaa7d9 100644
--- a/tensorflow/python/estimator/canned/dnn.py
+++ b/tensorflow/python/estimator/canned/dnn.py
@@ -131,9 +131,7 @@ class _DNNModel(training.Model):
name=None,
**kwargs):
super(_DNNModel, self).__init__(name=name, **kwargs)
- self._is_v2 = False
if feature_column_v2.is_feature_column_v2(feature_columns):
- self._is_v2 = True
self._input_layer = feature_column_v2.FeatureLayer(
feature_columns=feature_columns,
name='input_layer',
@@ -190,7 +188,6 @@ class _DNNModel(training.Model):
_scope=logits_scope)
self._add_layer(self._logits_layer, logits_scope.name)
self._logits_scope_name = logits_scope.name
- self._logits_layer._use_resource_variables = False # pylint: disable=protected-access
self._input_layer_partitioner = input_layer_partitioner
def call(self, features, mode):
diff --git a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
index d16318659b..ae968e717a 100644
--- a/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
+++ b/tensorflow/python/estimator/canned/dnn_linear_combined_test.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import shutil
import tempfile
+from absl.testing import parameterized
import numpy as np
import six
@@ -35,6 +36,7 @@ from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.estimator.inputs import pandas_io
from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import nn
@@ -119,7 +121,16 @@ class LinearOnlyRegressorPartitionerTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorPartitionerV2Test(
+ linear_testing_utils.BaseLinearRegressorPartitionerTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorEvaluationTest(
@@ -128,7 +139,16 @@ class LinearOnlyRegressorEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorEvaluationV2Test(
+ linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorPredictTest(
@@ -137,7 +157,16 @@ class LinearOnlyRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorPredictV2Test(
+ linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorIntegrationTest(
@@ -146,7 +175,16 @@ class LinearOnlyRegressorIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorIntegrationV2Test(
+ linear_testing_utils.BaseLinearRegressorIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearOnlyRegressorTrainingTest(
@@ -155,7 +193,16 @@ class LinearOnlyRegressorTrainingTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearOnlyRegressorTrainingV2Test(
+ linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
def _linear_classifier_fn(feature_columns,
@@ -185,7 +232,18 @@ class LinearOnlyClassifierTrainingTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierTrainingV2Test(
+ linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearOnlyClassifierClassesEvaluationTest(
@@ -194,7 +252,18 @@ class LinearOnlyClassifierClassesEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierClassesEvaluationV2Test(
+ linear_testing_utils.BaseLinearClassifierEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearOnlyClassifierPredictTest(
@@ -203,7 +272,18 @@ class LinearOnlyClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierPredictV2Test(
+ linear_testing_utils.BaseLinearClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearOnlyClassifierIntegrationTest(
@@ -212,9 +292,21 @@ class LinearOnlyClassifierIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearOnlyClassifierIntegrationV2Test(
+ linear_testing_utils.BaseLinearClassifierIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
def setUp(self):
@@ -225,13 +317,15 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._model_dir)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- label_dimension, batch_size):
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, label_dimension, batch_size,
+ fc_impl):
linear_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
dnn_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
feature_columns = linear_feature_columns + dnn_feature_columns
est = dnn_linear_combined.DNNLinearCombinedRegressor(
linear_feature_columns=linear_feature_columns,
@@ -257,14 +351,14 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, label_dimension), predictions.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
label_dimension = 2
batch_size = 10
@@ -293,9 +387,10 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -326,9 +421,10 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
label_dimension = 2
batch_size = 10
@@ -376,7 +472,8 @@ class DNNLinearCombinedRegressorIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=label_dimension,
label_dimension=label_dimension,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
# A function to mimic dnn-classifier init reuse same tests.
@@ -407,7 +504,16 @@ class DNNOnlyClassifierEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNOnlyClassifierEvaluateV2Test(
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierEvaluateTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNOnlyClassifierPredictTest(
@@ -416,7 +522,16 @@ class DNNOnlyClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNOnlyClassifierPredictV2Test(
+ dnn_testing_utils.BaseDNNClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierPredictTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
class DNNOnlyClassifierTrainTest(
@@ -425,7 +540,16 @@ class DNNOnlyClassifierTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
- self, _dnn_classifier_fn)
+ self, _dnn_classifier_fn, fc_impl=feature_column)
+
+
+class DNNOnlyClassifierTrainV2Test(dnn_testing_utils.BaseDNNClassifierTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNClassifierTrainTest.__init__(
+ self, _dnn_classifier_fn, fc_impl=feature_column_v2)
# A function to mimic dnn-regressor init reuse same tests.
@@ -454,7 +578,16 @@ class DNNOnlyRegressorEvaluateTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNOnlyRegressorEvaluateV2Test(
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorEvaluateTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNOnlyRegressorPredictTest(
@@ -463,7 +596,16 @@ class DNNOnlyRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+
+
+class DNNOnlyRegressorPredictV2Test(
+ dnn_testing_utils.BaseDNNRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorPredictTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
class DNNOnlyRegressorTrainTest(
@@ -472,9 +614,19 @@ class DNNOnlyRegressorTrainTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
- self, _dnn_regressor_fn)
+ self, _dnn_regressor_fn, fc_impl=feature_column)
+class DNNOnlyRegressorTrainV2Test(dnn_testing_utils.BaseDNNRegressorTrainTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ dnn_testing_utils.BaseDNNRegressorTrainTest.__init__(
+ self, _dnn_regressor_fn, fc_impl=feature_column_v2)
+
+
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def setUp(self):
@@ -488,13 +640,14 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
def _as_label(self, data_in_float):
return np.rint(data_in_float).astype(np.int64)
- def _test_complete_flow(
- self, train_input_fn, eval_input_fn, predict_input_fn, input_dimension,
- n_classes, batch_size):
+ def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
+ input_dimension, n_classes, batch_size, fc_impl):
linear_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
dnn_feature_columns = [
- feature_column.numeric_column('x', shape=(input_dimension,))]
+ fc_impl.numeric_column('x', shape=(input_dimension,))
+ ]
feature_columns = linear_feature_columns + dnn_feature_columns
est = dnn_linear_combined.DNNLinearCombinedClassifier(
linear_feature_columns=linear_feature_columns,
@@ -520,14 +673,14 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
self.assertAllEqual((batch_size, n_classes), predicted_proba.shape)
# EXPORT
- feature_spec = feature_column.make_parse_example_spec(feature_columns)
+ feature_spec = fc_impl.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
serving_input_receiver_fn)
self.assertTrue(gfile.Exists(export_dir))
- def test_numpy_input_fn(self):
+ def test_numpy_input_fn(self, fc_impl):
"""Tests complete flow with numpy_input_fn."""
n_classes = 3
input_dimension = 2
@@ -559,9 +712,10 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_pandas_input_fn(self):
+ def test_pandas_input_fn(self, fc_impl):
"""Tests complete flow with pandas_input_fn."""
if not HAS_PANDAS:
return
@@ -593,9 +747,10 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
predict_input_fn=predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
- def test_input_fn_from_parse_example(self):
+ def test_input_fn_from_parse_example(self, fc_impl):
"""Tests complete flow with input_fn constructed from parse_example."""
input_dimension = 2
n_classes = 3
@@ -647,9 +802,11 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase):
predict_input_fn=_predict_input_fn,
input_dimension=input_dimension,
n_classes=n_classes,
- batch_size=batch_size)
+ batch_size=batch_size,
+ fc_impl=fc_impl)
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedTests(test.TestCase):
def setUp(self):
@@ -681,9 +838,9 @@ class DNNLinearCombinedTests(test.TestCase):
return optimizer_mock
- def test_train_op_calls_both_dnn_and_linear(self):
+ def test_train_op_calls_both_dnn_and_linear(self, fc_impl):
opt = gradient_descent.GradientDescentOptimizer(1.)
- x_column = feature_column.numeric_column('x')
+ x_column = fc_impl.numeric_column('x')
input_fn = numpy_io.numpy_input_fn(
x={'x': np.array([[0.], [1.]])},
y=np.array([[0.], [1.]]),
@@ -708,7 +865,7 @@ class DNNLinearCombinedTests(test.TestCase):
checkpoint_utils.load_variable(
self._model_dir, 'dnn_called'))
- def test_dnn_and_linear_logits_are_added(self):
+ def test_dnn_and_linear_logits_are_added(self, fc_impl):
with ops.Graph().as_default():
variables_lib.Variable([[1.0]], name='linear/linear_model/x/weights')
variables_lib.Variable([2.0], name='linear/linear_model/bias_weights')
@@ -719,7 +876,7 @@ class DNNLinearCombinedTests(test.TestCase):
variables_lib.Variable(1, name='global_step', dtype=dtypes.int64)
linear_testing_utils.save_variables_to_ckpt(self._model_dir)
- x_column = feature_column.numeric_column('x')
+ x_column = fc_impl.numeric_column('x')
est = dnn_linear_combined.DNNLinearCombinedRegressor(
linear_feature_columns=[x_column],
dnn_hidden_units=[1],
@@ -737,6 +894,7 @@ class DNNLinearCombinedTests(test.TestCase):
next(est.predict(input_fn=input_fn)))
+@parameterized.parameters((feature_column,), (feature_column_v2,))
class DNNLinearCombinedWarmStartingTest(test.TestCase):
def setUp(self):
@@ -758,11 +916,11 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase):
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._ckpt_and_vocab_dir)
- def test_classifier_basic_warm_starting(self):
+ def test_classifier_basic_warm_starting(self, fc_impl):
"""Tests correctness of DNNLinearCombinedClassifier default warm-start."""
- age = feature_column.numeric_column('age')
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ age = fc_impl.numeric_column('age')
+ city = fc_impl.embedding_column(
+ fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -798,11 +956,11 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase):
dnn_lc_classifier.get_variable_value(variable_name),
warm_started_dnn_lc_classifier.get_variable_value(variable_name))
- def test_regressor_basic_warm_starting(self):
+ def test_regressor_basic_warm_starting(self, fc_impl):
"""Tests correctness of DNNLinearCombinedRegressor default warm-start."""
- age = feature_column.numeric_column('age')
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ age = fc_impl.numeric_column('age')
+ city = fc_impl.embedding_column(
+ fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
@@ -836,11 +994,11 @@ class DNNLinearCombinedWarmStartingTest(test.TestCase):
dnn_lc_regressor.get_variable_value(variable_name),
warm_started_dnn_lc_regressor.get_variable_value(variable_name))
- def test_warm_starting_selective_variables(self):
+ def test_warm_starting_selective_variables(self, fc_impl):
"""Tests selecting variables to warm-start."""
- age = feature_column.numeric_column('age')
- city = feature_column.embedding_column(
- feature_column.categorical_column_with_vocabulary_list(
+ age = fc_impl.numeric_column('age')
+ city = fc_impl.embedding_column(
+ fc_impl.categorical_column_with_vocabulary_list(
'city', vocabulary_list=['Mountain View', 'Palo Alto']),
dimension=5)
diff --git a/tensorflow/python/estimator/canned/linear.py b/tensorflow/python/estimator/canned/linear.py
index 115dd18518..8b96284bd3 100644
--- a/tensorflow/python/estimator/canned/linear.py
+++ b/tensorflow/python/estimator/canned/linear.py
@@ -25,14 +25,18 @@ import six
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.canned import optimizers
-from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import partitioned_variables
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables as variable_ops
from tensorflow.python.ops.losses import losses
from tensorflow.python.summary import summary
from tensorflow.python.training import ftrl
+from tensorflow.python.util import nest
from tensorflow.python.util.tf_export import estimator_export
@@ -46,23 +50,42 @@ def _get_default_optimizer(feature_columns):
return ftrl.FtrlOptimizer(learning_rate=learning_rate)
-def _compute_fraction_of_zero(cols_to_vars):
- """Given a linear cols_to_vars dict, compute the fraction of zero weights.
+def _get_expanded_variable_list(var_list):
+ """Given a list of variables, expands them if they are partitioned.
Args:
- cols_to_vars: A dictionary mapping FeatureColumns to lists of tf.Variables
- like one returned from feature_column_lib.linear_model.
+ var_list: A list of variables.
+
+ Returns:
+ A list of variables where each partitioned variable is expanded to its
+ components.
+ """
+ returned_list = []
+ for variable in var_list:
+ if (isinstance(variable, variable_ops.Variable) or
+ resource_variable_ops.is_resource_variable(variable)):
+ returned_list.append(variable) # Single variable case.
+ else: # Must be a PartitionedVariable, so convert into a list.
+ returned_list.extend(list(variable))
+ return returned_list
+
+
+# TODO(rohanj): Consider making this a public utility method.
+def _compute_fraction_of_zero(variables):
+ """Given a linear variables list, compute the fraction of zero weights.
+
+ Args:
+ variables: A list or list of list of variables
Returns:
The fraction of zeros (sparsity) in the linear model.
"""
all_weight_vars = []
- for var_or_var_list in cols_to_vars.values():
+ for var_or_var_list in variables:
+ var_list = nest.flatten(var_or_var_list)
# Skip empty-lists associated with columns that created no Variables.
- if var_or_var_list:
- all_weight_vars += [
- array_ops.reshape(var, [-1]) for var in var_or_var_list
- ]
+ if var_list:
+ all_weight_vars += [array_ops.reshape(var, [-1]) for var in var_list]
return nn.zero_fraction(array_ops.concat(all_weight_vars, axis=0))
@@ -92,14 +115,36 @@ def _linear_logit_fn_builder(units, feature_columns, sparse_combiner='sum'):
Returns:
A `Tensor` representing the logits.
"""
- cols_to_vars = {}
- logits = feature_column_lib.linear_model(
- features=features,
- feature_columns=feature_columns,
- units=units,
- sparse_combiner=sparse_combiner,
- cols_to_vars=cols_to_vars)
- bias = cols_to_vars.pop('bias')
+ if feature_column_v2.is_feature_column_v2(feature_columns):
+ shared_state_manager = feature_column_v2.SharedEmbeddingStateManager()
+ linear_model = feature_column_v2.LinearModel(
+ feature_columns=feature_columns,
+ units=units,
+ sparse_combiner=sparse_combiner,
+ shared_state_manager=shared_state_manager)
+ logits = linear_model(features)
+ bias = linear_model.bias_variable
+
+ # We'd like to get all the non-bias variables associated with this
+ # LinearModel. This includes the shared embedding variables as well.
+ variables = linear_model.variables
+ variables.remove(bias)
+ variables.extend(shared_state_manager.variables)
+
+ # Expand (potential) Partitioned variables
+ bias = _get_expanded_variable_list([bias])
+ variables = _get_expanded_variable_list(variables)
+ else:
+ linear_model = feature_column._LinearModel( # pylint: disable=protected-access
+ feature_columns=feature_columns,
+ units=units,
+ sparse_combiner=sparse_combiner,
+ name='linear_model')
+ logits = linear_model(features)
+ cols_to_vars = linear_model.cols_to_vars()
+ bias = cols_to_vars.pop('bias')
+ variables = cols_to_vars.values()
+
if units > 1:
summary.histogram('bias', bias)
else:
@@ -107,7 +152,7 @@ def _linear_logit_fn_builder(units, feature_columns, sparse_combiner='sum'):
# so we should provide a scalar summary.
summary.scalar('bias', bias[0][0])
summary.scalar('fraction_of_zero_weights',
- _compute_fraction_of_zero(cols_to_vars))
+ _compute_fraction_of_zero(variables))
return logits
return linear_logit_fn
diff --git a/tensorflow/python/estimator/canned/linear_test.py b/tensorflow/python/estimator/canned/linear_test.py
index 59a230417d..3e6da5de22 100644
--- a/tensorflow/python/estimator/canned/linear_test.py
+++ b/tensorflow/python/estimator/canned/linear_test.py
@@ -20,6 +20,8 @@ from __future__ import print_function
from tensorflow.python.estimator.canned import linear
from tensorflow.python.estimator.canned import linear_testing_utils
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.platform import test
@@ -40,7 +42,16 @@ class LinearRegressorPartitionerTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorPartitionerV2Test(
+ linear_testing_utils.BaseLinearRegressorPartitionerTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPartitionerTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorEvaluationTest(
@@ -49,7 +60,16 @@ class LinearRegressorEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorEvaluationV2Test(
+ linear_testing_utils.BaseLinearRegressorEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorEvaluationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorPredictTest(
@@ -58,7 +78,16 @@ class LinearRegressorPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorPredictV2Test(
+ linear_testing_utils.BaseLinearRegressorPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorPredictTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorIntegrationTest(
@@ -67,7 +96,16 @@ class LinearRegressorIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
+
+
+class LinearRegressorIntegrationV2Test(
+ linear_testing_utils.BaseLinearRegressorIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorIntegrationTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
class LinearRegressorTrainingTest(
@@ -76,19 +114,37 @@ class LinearRegressorTrainingTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
- self, _linear_regressor_fn)
+ self, _linear_regressor_fn, fc_lib=feature_column)
-# Tests for Linear Classifier.
+class LinearRegressorTrainingV2Test(
+ linear_testing_utils.BaseLinearRegressorTrainingTest, test.TestCase):
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearRegressorTrainingTest.__init__(
+ self, _linear_regressor_fn, fc_lib=feature_column_v2)
+
+# Tests for Linear Classifier.
class LinearClassifierTrainingTest(
linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierTrainingV2Test(
+ linear_testing_utils.BaseLinearClassifierTrainingTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierTrainingTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearClassifierEvaluationTest(
@@ -97,7 +153,18 @@ class LinearClassifierEvaluationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierEvaluationV2Test(
+ linear_testing_utils.BaseLinearClassifierEvaluationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierEvaluationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearClassifierPredictTest(
@@ -106,7 +173,18 @@ class LinearClassifierPredictTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierPredictV2Test(
+ linear_testing_utils.BaseLinearClassifierPredictTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierPredictTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
class LinearClassifierIntegrationTest(
@@ -115,7 +193,18 @@ class LinearClassifierIntegrationTest(
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
- self, linear_classifier_fn=_linear_classifier_fn)
+ self, linear_classifier_fn=_linear_classifier_fn, fc_lib=feature_column)
+
+
+class LinearClassifierIntegrationV2Test(
+ linear_testing_utils.BaseLinearClassifierIntegrationTest, test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearClassifierIntegrationTest.__init__(
+ self,
+ linear_classifier_fn=_linear_classifier_fn,
+ fc_lib=feature_column_v2)
# Tests for Linear logit_fn.
@@ -124,7 +213,17 @@ class LinearLogitFnTest(linear_testing_utils.BaseLinearLogitFnTest,
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
- linear_testing_utils.BaseLinearLogitFnTest.__init__(self)
+ linear_testing_utils.BaseLinearLogitFnTest.__init__(
+ self, fc_lib=feature_column)
+
+
+class LinearLogitFnV2Test(linear_testing_utils.BaseLinearLogitFnTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearLogitFnTest.__init__(
+ self, fc_lib=feature_column_v2)
# Tests for warm-starting with Linear logit_fn.
@@ -134,7 +233,22 @@ class LinearWarmStartingTest(linear_testing_utils.BaseLinearWarmStartingTest,
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
test.TestCase.__init__(self, methodName)
linear_testing_utils.BaseLinearWarmStartingTest.__init__(
- self, _linear_classifier_fn, _linear_regressor_fn)
+ self,
+ _linear_classifier_fn,
+ _linear_regressor_fn,
+ fc_lib=feature_column)
+
+
+class LinearWarmStartingV2Test(linear_testing_utils.BaseLinearWarmStartingTest,
+ test.TestCase):
+
+ def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
+ test.TestCase.__init__(self, methodName)
+ linear_testing_utils.BaseLinearWarmStartingTest.__init__(
+ self,
+ _linear_classifier_fn,
+ _linear_regressor_fn,
+ fc_lib=feature_column_v2)
if __name__ == '__main__':
diff --git a/tensorflow/python/estimator/canned/linear_testing_utils.py b/tensorflow/python/estimator/canned/linear_testing_utils.py
index 65cdd50061..827352a70b 100644
--- a/tensorflow/python/estimator/canned/linear_testing_utils.py
+++ b/tensorflow/python/estimator/canned/linear_testing_utils.py
@@ -37,7 +37,8 @@ from tensorflow.python.estimator.canned import metric_keys
from tensorflow.python.estimator.export import export
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.estimator.inputs import pandas_io
-from tensorflow.python.feature_column import feature_column as feature_column_lib
+from tensorflow.python.feature_column import feature_column
+from tensorflow.python.feature_column import feature_column_v2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
@@ -152,8 +153,9 @@ class CheckPartitionerVarHook(session_run_hook.SessionRunHook):
class BaseLinearRegressorPartitionerTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -173,7 +175,7 @@ class BaseLinearRegressorPartitionerTest(object):
return [partitions, 1] if shape[0] == x_dim else [1]
regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.categorical_column_with_hash_bucket(
+ feature_columns=(self._fc_lib.categorical_column_with_hash_bucket(
'language', hash_bucket_size=x_dim),),
partitioner=_partitioner,
model_dir=self._model_dir)
@@ -209,9 +211,8 @@ class BaseLinearRegressorPartitionerTest(object):
'_get_replica_device_setter',
return_value=lambda _: '/cpu:0'):
linear_regressor = self._linear_regressor_fn(
- feature_columns=(
- feature_column_lib.categorical_column_with_hash_bucket(
- 'language', hash_bucket_size=x_dim),),
+ feature_columns=(self._fc_lib.categorical_column_with_hash_bucket(
+ 'language', hash_bucket_size=x_dim),),
config=FakeRunConfig(),
model_dir=self._model_dir)
@@ -232,8 +233,9 @@ class BaseLinearRegressorPartitionerTest(object):
# TODO(b/36813849): Add tests with dynamic shape inputs using placeholders.
class BaseLinearRegressorEvaluationTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -252,7 +254,7 @@ class BaseLinearRegressorEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir)
eval_metrics = linear_regressor.evaluate(
input_fn=lambda: ({'age': ((1,),)}, ((10.,),)), steps=1)
@@ -276,7 +278,7 @@ class BaseLinearRegressorEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir)
eval_metrics = linear_regressor.evaluate(
input_fn=lambda: ({'age': ((1,), (1,))}, ((10.,), (10.,))), steps=1)
@@ -308,7 +310,7 @@ class BaseLinearRegressorEvaluationTest(object):
return features, labels
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
weight_column='weights',
model_dir=self._model_dir)
eval_metrics = linear_regressor.evaluate(input_fn=_input_fn, steps=1)
@@ -336,8 +338,7 @@ class BaseLinearRegressorEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column(
- 'age', shape=(x_dim,)),),
+ feature_columns=(self._fc_lib.numeric_column('age', shape=(x_dim,)),),
label_dimension=label_dim,
model_dir=self._model_dir)
input_fn = numpy_io.numpy_input_fn(
@@ -374,8 +375,8 @@ class BaseLinearRegressorEvaluationTest(object):
batch_size = 2
feature_columns = [
- feature_column_lib.numeric_column('age'),
- feature_column_lib.numeric_column('height')
+ self._fc_lib.numeric_column('age'),
+ self._fc_lib.numeric_column('height')
]
input_fn = numpy_io.numpy_input_fn(
x={'age': np.array([20, 40]),
@@ -402,8 +403,9 @@ class BaseLinearRegressorEvaluationTest(object):
class BaseLinearRegressorPredictTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -422,7 +424,7 @@ class BaseLinearRegressorPredictTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('x'),),
+ feature_columns=(self._fc_lib.numeric_column('x'),),
model_dir=self._model_dir)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -441,7 +443,7 @@ class BaseLinearRegressorPredictTest(object):
batch_size = 2
label_dimension = 3
x_dim = 4
- feature_columns = (feature_column_lib.numeric_column('x', shape=(x_dim,)),)
+ feature_columns = (self._fc_lib.numeric_column('x', shape=(x_dim,)),)
with ops.Graph().as_default():
variables_lib.Variable( # shape=[x_dim, label_dimension]
[[1., 2., 3.], [2., 3., 4.], [3., 4., 5.], [4., 5., 6.]],
@@ -479,8 +481,8 @@ class BaseLinearRegressorPredictTest(object):
save_variables_to_ckpt(self._model_dir)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('x0'),
- feature_column_lib.numeric_column('x1')),
+ feature_columns=(self._fc_lib.numeric_column('x0'),
+ self._fc_lib.numeric_column('x1')),
model_dir=self._model_dir)
predict_input_fn = numpy_io.numpy_input_fn(
@@ -515,9 +517,8 @@ class BaseLinearRegressorPredictTest(object):
dense_shape=[2, 2]),
})
- feature_columns = (
- feature_column_lib.categorical_column_with_vocabulary_list(
- 'language', vocabulary_list=['a', 'b', 'c']),)
+ feature_columns = (self._fc_lib.categorical_column_with_vocabulary_list(
+ 'language', vocabulary_list=['a', 'b', 'c']),)
# Check prediction for each sparse_combiner.
# With sparse_combiner = 'sum', we have
@@ -561,8 +562,9 @@ class BaseLinearRegressorPredictTest(object):
class BaseLinearRegressorIntegrationTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -575,7 +577,7 @@ class BaseLinearRegressorIntegrationTest(object):
def _test_complete_flow(self, train_input_fn, eval_input_fn, predict_input_fn,
input_dimension, label_dimension, prediction_length):
feature_columns = [
- feature_column_lib.numeric_column('x', shape=(input_dimension,))
+ self._fc_lib.numeric_column('x', shape=(input_dimension,))
]
est = self._linear_regressor_fn(
feature_columns=feature_columns,
@@ -597,7 +599,7 @@ class BaseLinearRegressorIntegrationTest(object):
self.assertAllEqual((prediction_length, label_dimension), predictions.shape)
# EXPORT
- feature_spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ feature_spec = self._fc_lib.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
@@ -729,8 +731,9 @@ class BaseLinearRegressorIntegrationTest(object):
class BaseLinearRegressorTrainingTest(object):
- def __init__(self, linear_regressor_fn):
+ def __init__(self, linear_regressor_fn, fc_lib=feature_column):
self._linear_regressor_fn = linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -808,7 +811,7 @@ class BaseLinearRegressorTrainingTest(object):
label = 5.
age = 17
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir)
# Train for a few steps, and validate final checkpoint.
@@ -820,7 +823,7 @@ class BaseLinearRegressorTrainingTest(object):
def testTrainWithOneDimLabel(self):
label_dimension = 1
batch_size = 20
- feature_columns = [feature_column_lib.numeric_column('age', shape=(1,))]
+ feature_columns = [self._fc_lib.numeric_column('age', shape=(1,))]
est = self._linear_regressor_fn(
feature_columns=feature_columns,
label_dimension=label_dimension,
@@ -840,7 +843,7 @@ class BaseLinearRegressorTrainingTest(object):
def testTrainWithOneDimWeight(self):
label_dimension = 1
batch_size = 20
- feature_columns = [feature_column_lib.numeric_column('age', shape=(1,))]
+ feature_columns = [self._fc_lib.numeric_column('age', shape=(1,))]
est = self._linear_regressor_fn(
feature_columns=feature_columns,
label_dimension=label_dimension,
@@ -867,7 +870,7 @@ class BaseLinearRegressorTrainingTest(object):
# loss = (logits - label)^2 = (0 - 5.)^2 = 25.
mock_optimizer = self._mock_optimizer(expected_loss=25.)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir,
optimizer=mock_optimizer)
self.assertEqual(0, mock_optimizer.minimize.call_count)
@@ -900,7 +903,7 @@ class BaseLinearRegressorTrainingTest(object):
# loss = (logits - label)^2 = (175 - 5)^2 = 28900
mock_optimizer = self._mock_optimizer(expected_loss=28900.)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir,
optimizer=mock_optimizer)
self.assertEqual(0, mock_optimizer.minimize.call_count)
@@ -935,7 +938,7 @@ class BaseLinearRegressorTrainingTest(object):
# loss = sum(logits - label)^2 = (175 - 5)^2 + (155 - 3)^2 = 52004
mock_optimizer = self._mock_optimizer(expected_loss=52004.)
linear_regressor = self._linear_regressor_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
model_dir=self._model_dir,
optimizer=mock_optimizer)
self.assertEqual(0, mock_optimizer.minimize.call_count)
@@ -954,8 +957,9 @@ class BaseLinearRegressorTrainingTest(object):
class BaseLinearClassifierTrainingTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1031,7 +1035,7 @@ class BaseLinearClassifierTrainingTest(object):
label = 0
age = 17
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1051,7 +1055,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
data_rank_1 = np.array([0, 1])
@@ -1078,7 +1082,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
data_rank_1 = np.array([0, 1])
@@ -1103,7 +1107,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
weight_column='w',
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1129,7 +1133,7 @@ class BaseLinearClassifierTrainingTest(object):
batch_size = 20
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
weight_column='w',
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1166,7 +1170,7 @@ class BaseLinearClassifierTrainingTest(object):
expected_loss=-1 * math.log(1.0/n_classes))
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1229,7 +1233,7 @@ class BaseLinearClassifierTrainingTest(object):
mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1277,7 +1281,7 @@ class BaseLinearClassifierTrainingTest(object):
mock_optimizer = self._mock_optimizer(expected_loss=1.1132617)
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1341,7 +1345,7 @@ class BaseLinearClassifierTrainingTest(object):
mock_optimizer = self._mock_optimizer(expected_loss=expected_loss)
est = linear.LinearClassifier(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
optimizer=mock_optimizer,
model_dir=self._model_dir)
@@ -1368,8 +1372,9 @@ class BaseLinearClassifierTrainingTest(object):
class BaseLinearClassifierEvaluationTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1398,7 +1403,7 @@ class BaseLinearClassifierEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
eval_metrics = est.evaluate(
@@ -1464,7 +1469,7 @@ class BaseLinearClassifierEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
model_dir=self._model_dir)
eval_metrics = est.evaluate(
@@ -1540,7 +1545,7 @@ class BaseLinearClassifierEvaluationTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
n_classes=n_classes,
weight_column='w',
model_dir=self._model_dir)
@@ -1605,8 +1610,9 @@ class BaseLinearClassifierEvaluationTest(object):
class BaseLinearClassifierPredictTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1634,7 +1640,7 @@ class BaseLinearClassifierPredictTest(object):
save_variables_to_ckpt(self._model_dir)
est = self._linear_classifier_fn(
- feature_columns=(feature_column_lib.numeric_column('age'),),
+ feature_columns=(self._fc_lib.numeric_column('age'),),
label_vocabulary=label_vocabulary,
n_classes=n_classes,
model_dir=self._model_dir)
@@ -1730,9 +1736,8 @@ class BaseLinearClassifierPredictTest(object):
dense_shape=[2, 2]),
})
- feature_columns = (
- feature_column_lib.categorical_column_with_vocabulary_list(
- 'language', vocabulary_list=['a', 'b', 'c']),)
+ feature_columns = (self._fc_lib.categorical_column_with_vocabulary_list(
+ 'language', vocabulary_list=['a', 'b', 'c']),)
# Check prediction for each sparse_combiner.
# With sparse_combiner = 'sum', we have
@@ -1776,8 +1781,9 @@ class BaseLinearClassifierPredictTest(object):
class BaseLinearClassifierIntegrationTest(object):
- def __init__(self, linear_classifier_fn):
+ def __init__(self, linear_classifier_fn, fc_lib=feature_column):
self._linear_classifier_fn = linear_classifier_fn
+ self._fc_lib = fc_lib
def setUp(self):
self._model_dir = tempfile.mkdtemp()
@@ -1789,7 +1795,7 @@ class BaseLinearClassifierIntegrationTest(object):
def _test_complete_flow(self, n_classes, train_input_fn, eval_input_fn,
predict_input_fn, input_dimension, prediction_length):
feature_columns = [
- feature_column_lib.numeric_column('x', shape=(input_dimension,))
+ self._fc_lib.numeric_column('x', shape=(input_dimension,))
]
est = self._linear_classifier_fn(
feature_columns=feature_columns,
@@ -1811,7 +1817,7 @@ class BaseLinearClassifierIntegrationTest(object):
self.assertAllEqual((prediction_length, 1), predictions.shape)
# EXPORT
- feature_spec = feature_column_lib.make_parse_example_spec(feature_columns)
+ feature_spec = self._fc_lib.make_parse_example_spec(feature_columns)
serving_input_receiver_fn = export.build_parsing_serving_input_receiver_fn(
feature_spec)
export_dir = est.export_savedmodel(tempfile.mkdtemp(),
@@ -1961,9 +1967,12 @@ class BaseLinearClassifierIntegrationTest(object):
class BaseLinearLogitFnTest(object):
+ def __init__(self, fc_lib=feature_column):
+ self._fc_lib = fc_lib
+
def test_basic_logit_correctness(self):
"""linear_logit_fn simply wraps feature_column_lib.linear_model."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
with ops.Graph().as_default():
logit_fn = linear._linear_logit_fn_builder(units=2, feature_columns=[age])
logits = logit_fn(features={'age': [[23.], [31.]]})
@@ -1983,12 +1992,14 @@ class BaseLinearLogitFnTest(object):
def test_compute_fraction_of_zero(self):
"""Tests the calculation of sparsity."""
- age = feature_column_lib.numeric_column('age')
- occupation = feature_column_lib.categorical_column_with_hash_bucket(
+ if self._fc_lib != feature_column:
+ return
+ age = feature_column.numeric_column('age')
+ occupation = feature_column.categorical_column_with_hash_bucket(
'occupation', hash_bucket_size=5)
with ops.Graph().as_default():
cols_to_vars = {}
- feature_column_lib.linear_model(
+ feature_column.linear_model(
features={
'age': [[23.], [31.]],
'occupation': [['doctor'], ['engineer']]
@@ -1997,7 +2008,42 @@ class BaseLinearLogitFnTest(object):
units=3,
cols_to_vars=cols_to_vars)
cols_to_vars.pop('bias')
- fraction_zero = linear._compute_fraction_of_zero(cols_to_vars)
+ fraction_zero = linear._compute_fraction_of_zero(cols_to_vars.values())
+ age_var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
+ 'linear_model/age')[0]
+ with tf_session.Session() as sess:
+ sess.run([variables_lib.global_variables_initializer()])
+ # Upon initialization, all variables will be zero.
+ self.assertAllClose(1, fraction_zero.eval())
+
+ sess.run(age_var.assign([[2.0, 0.0, -1.0]]))
+ # 1 of the 3 age weights are zero, and all of the 15 (5 hash buckets
+ # x 3-dim output) are zero.
+ self.assertAllClose(16. / 18., fraction_zero.eval())
+
+ def test_compute_fraction_of_zero_v2(self):
+ """Tests the calculation of sparsity."""
+ if self._fc_lib != feature_column_v2:
+ return
+
+ age = feature_column_v2.numeric_column('age')
+ occupation = feature_column_v2.categorical_column_with_hash_bucket(
+ 'occupation', hash_bucket_size=5)
+ shared_state_manager = feature_column_v2.SharedEmbeddingStateManager()
+ with ops.Graph().as_default():
+ model = feature_column_v2.LinearModel(
+ feature_columns=[age, occupation],
+ units=3,
+ shared_state_manager=shared_state_manager)
+ features = {
+ 'age': [[23.], [31.]],
+ 'occupation': [['doctor'], ['engineer']]
+ }
+ model(features)
+ variables = model.variables
+ variables.remove(model.bias_variable)
+ variables.extend(shared_state_manager.variables)
+ fraction_zero = linear._compute_fraction_of_zero(variables)
age_var = ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES,
'linear_model/age')[0]
with tf_session.Session() as sess:
@@ -2013,9 +2059,13 @@ class BaseLinearLogitFnTest(object):
class BaseLinearWarmStartingTest(object):
- def __init__(self, _linear_classifier_fn, _linear_regressor_fn):
+ def __init__(self,
+ _linear_classifier_fn,
+ _linear_regressor_fn,
+ fc_lib=feature_column):
self._linear_classifier_fn = _linear_classifier_fn
self._linear_regressor_fn = _linear_regressor_fn
+ self._fc_lib = fc_lib
def setUp(self):
# Create a directory to save our old checkpoint and vocabularies to.
@@ -2039,7 +2089,7 @@ class BaseLinearWarmStartingTest(object):
def test_classifier_basic_warm_starting(self):
"""Tests correctness of LinearClassifier default warm-start."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
# Create a LinearClassifier and train to save a checkpoint.
linear_classifier = self._linear_classifier_fn(
@@ -2066,7 +2116,7 @@ class BaseLinearWarmStartingTest(object):
def test_regressor_basic_warm_starting(self):
"""Tests correctness of LinearRegressor default warm-start."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
# Create a LinearRegressor and train to save a checkpoint.
linear_regressor = self._linear_regressor_fn(
@@ -2091,7 +2141,7 @@ class BaseLinearWarmStartingTest(object):
def test_warm_starting_selective_variables(self):
"""Tests selecting variables to warm-start."""
- age = feature_column_lib.numeric_column('age')
+ age = self._fc_lib.numeric_column('age')
# Create a LinearClassifier and train to save a checkpoint.
linear_classifier = self._linear_classifier_fn(
@@ -2128,7 +2178,7 @@ class BaseLinearWarmStartingTest(object):
vocab_file = os.path.join(self._ckpt_and_vocab_dir, 'occupation_vocab')
with open(vocab_file, 'w') as f:
f.write('\n'.join(vocab_list))
- occupation = feature_column_lib.categorical_column_with_vocabulary_file(
+ occupation = self._fc_lib.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=vocab_file,
vocabulary_size=len(vocab_list))
@@ -2152,7 +2202,7 @@ class BaseLinearWarmStartingTest(object):
'new_occupation_vocab')
with open(new_vocab_file, 'w') as f:
f.write('\n'.join(new_vocab_list))
- new_occupation = feature_column_lib.categorical_column_with_vocabulary_file(
+ new_occupation = self._fc_lib.categorical_column_with_vocabulary_file(
'occupation',
vocabulary_file=new_vocab_file,
vocabulary_size=len(new_vocab_list))
@@ -2205,7 +2255,7 @@ class BaseLinearWarmStartingTest(object):
def test_warm_starting_with_naming_change(self):
"""Tests warm-starting with a Tensor name remapping."""
- age_in_years = feature_column_lib.numeric_column('age_in_years')
+ age_in_years = self._fc_lib.numeric_column('age_in_years')
# Create a LinearClassifier and train to save a checkpoint.
linear_classifier = self._linear_classifier_fn(
@@ -2219,7 +2269,7 @@ class BaseLinearWarmStartingTest(object):
# learning_rate = 0.0 optimizer to check values (use SGD so we don't have
# accumulator values that change).
warm_started_linear_classifier = self._linear_classifier_fn(
- feature_columns=[feature_column_lib.numeric_column('age')],
+ feature_columns=[self._fc_lib.numeric_column('age')],
n_classes=4,
optimizer=gradient_descent.GradientDescentOptimizer(learning_rate=0.0),
# The 'age' variable correspond to the 'age_in_years' variable in the
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 827b405e51..e6d82f0db7 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -144,7 +144,7 @@ class Estimator(object):
* `labels`: This is the second item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
single `tf.Tensor` or `dict` of same (for multi-head models).
- If mode is @{tf.estimator.ModeKeys.PREDICT}, `labels=None` will
+ If mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will
be passed. If the `model_fn`'s signature does not accept
`mode`, the `model_fn` must still be able to handle
`labels=None`.
@@ -468,6 +468,10 @@ class Estimator(object):
with ops.Graph().as_default():
if self._eval_distribution:
+ # We want to create the iterations variable outside the distribution
+ # scope as that is just stored on the host and mainly used to drive
+ # the loop and doesn't need to be a Mirrored/Device variable.
+ training.get_or_create_steps_per_run_variable()
with self._eval_distribution.scope():
return _evaluate()
else:
@@ -803,9 +807,9 @@ class Estimator(object):
those features and labels, and restores the given checkpoint
(or, lacking that, the most recent checkpoint) into the graph.
Only one of the modes is used for saving variables to the `SavedModel`
- (order of preference: @{tf.estimator.ModeKeys#TRAIN$TRAIN},
- @{tf.estimator.ModeKeys#EVAL$EVAL}, then
- @{tf.estimator.ModeKeys#PREDICT$PREDICT}), such that up to three
+ (order of preference: `tf.estimator.ModeKeys.TRAIN`,
+ `tf.estimator.ModeKeys.EVAL`, then
+ `tf.estimator.ModeKeys.PREDICT`), such that up to three
`tf.MetaGraphDefs` are saved with a single set of variables in a single
`SavedModel` directory.
@@ -1101,7 +1105,7 @@ class Estimator(object):
"""Creates the global step tensor in graph.
The global step tensor must be an integer type with name 'global_step' and
- be added to the collection @{tf.GraphKeys#GLOBAL_STEP$GLOBAL_STEP}.
+ be added to the collection `tf.GraphKeys.GLOBAL_STEP`.
Args:
graph: The graph in which to create the global step tensor.
@@ -1414,6 +1418,36 @@ class Estimator(object):
# It is expected to have one CheckpointSaverHook. If multiple, we pick
# up the first one to add listener.
saver_hooks[0]._listeners.extend(saving_listeners) # pylint: disable=protected-access
+
+ # Add summary hooks to worker 0 if we are running with a master, to ensure
+ # that summaries are written at correct intervals even with long-running
+ # evaluations.
+ save_summary_steps = self._config.save_summary_steps
+ log_step_count_steps = self._config.log_step_count_steps
+ if (self._config.cluster_spec and self._config.cluster_spec.jobs and
+ (run_config.TaskType.MASTER in self._config.cluster_spec.jobs)):
+ # Update config values to prevent the default hooks from being created on
+ # the master or other workers.
+ save_summary_steps = 0
+ log_step_count_steps = None
+
+ if (self._config.task_type == run_config.TaskType.WORKER and
+ self._config.task_id == 0):
+ if (self._config.save_summary_steps and
+ self._config.save_summary_steps > 0):
+ worker_hooks.append(
+ training.SummarySaverHook(
+ save_steps=self._config.save_summary_steps,
+ output_dir=self._config.model_dir,
+ scaffold=estimator_spec.scaffold))
+
+ if (self._config.log_step_count_steps and
+ self._config.log_step_count_steps > 0):
+ worker_hooks.append(
+ training.StepCounterHook(
+ every_n_steps=self._config.log_step_count_steps,
+ output_dir=self._config.model_dir))
+
with training.MonitoredTrainingSession(
master=self._config.master,
is_chief=self._config.is_chief,
@@ -1423,9 +1457,9 @@ class Estimator(object):
chief_only_hooks=(
tuple(chief_hooks) + tuple(estimator_spec.training_chief_hooks)),
save_checkpoint_secs=0, # Saving is handled by a hook.
- save_summaries_steps=self._config.save_summary_steps,
+ save_summaries_steps=save_summary_steps,
config=self._session_config,
- log_step_count_steps=self._config.log_step_count_steps) as mon_sess:
+ log_step_count_steps=log_step_count_steps) as mon_sess:
loss = None
while not mon_sess.should_stop():
_, loss = mon_sess.run([estimator_spec.train_op, estimator_spec.loss])
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index bc2504ca19..246dfb1a4b 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import functools
import glob
+import json
import os
import tempfile
@@ -969,6 +970,99 @@ class EstimatorTrainTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, 'train_and_evaluate'):
est.train(dummy_input_fn, steps=1)
+ def test_master_distributed_hooks(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.MASTER,
+ 'index': 0
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
+ def test_master_distributed_hooks_for_worker_0(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.WORKER,
+ 'index': 0
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertTrue(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertTrue(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
+ def test_master_distributed_hooks_for_worker_nonzero(self):
+ tf_config = json.dumps({
+ 'cluster': {
+ run_config.TaskType.PS: ['localhost:1234'],
+ run_config.TaskType.WORKER: ['localhost:1235', 'localhost:1237'],
+ run_config.TaskType.MASTER: ['localhost:1236']
+ },
+ 'task': {
+ 'type': run_config.TaskType.WORKER,
+ 'index': 1
+ }
+ })
+ with test.mock.patch.dict('os.environ', {'TF_CONFIG': tf_config}):
+ est = estimator.Estimator(
+ model_fn=model_fn_global_step_incrementer,
+ config=run_config.RunConfig())
+
+ with test.mock.patch.object(training,
+ 'MonitoredTrainingSession') as mock_sess:
+ est.train(dummy_input_fn, steps=1)
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.SummarySaverHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertFalse(
+ any(
+ isinstance(hook, basic_session_run_hooks.StepCounterHook)
+ for hook in mock_sess.call_args[1]['hooks']))
+ self.assertEqual(0, mock_sess.call_args[1]['save_summaries_steps'])
+ self.assertIsNone(mock_sess.call_args[1]['log_step_count_steps'])
+
def _model_fn_with_eval_metric_ops(features, labels, mode, params):
_, _ = features, labels
diff --git a/tensorflow/python/estimator/keras.py b/tensorflow/python/estimator/keras.py
index 7546771ed3..5d5ed81fbb 100644
--- a/tensorflow/python/estimator/keras.py
+++ b/tensorflow/python/estimator/keras.py
@@ -368,6 +368,44 @@ def _save_first_checkpoint(keras_model, custom_objects, config):
return latest_path
+def _get_file_from_google_storage(keras_model_path, model_dir):
+ """Get file from google storage and download to local file.
+
+ Args:
+ keras_model_path: a google storage path for compiled keras model.
+ model_dir: the directory from estimator config.
+
+ Returns:
+ The path where keras model is saved.
+
+ Raises:
+ ValueError: if storage object name does not end with .h5.
+ """
+ try:
+ from google.cloud import storage # pylint:disable=g-import-not-at-top
+ except ImportError:
+ raise TypeError('Could not save model to Google cloud storage; please '
+ 'install `google-cloud-storage` via '
+ '`pip install google-cloud-storage`.')
+ storage_client = storage.Client()
+ path, blob_name = os.path.split(keras_model_path)
+ _, bucket_name = os.path.split(path)
+ keras_model_dir = os.path.join(model_dir, 'keras')
+ if not gfile.Exists(keras_model_dir):
+ gfile.MakeDirs(keras_model_dir)
+ file_name = os.path.join(keras_model_dir, 'keras_model.h5')
+ try:
+ blob = storage_client.get_bucket(bucket_name).blob(blob_name)
+ blob.download_to_filename(file_name)
+ except:
+ raise ValueError('Failed to download keras model, please check '
+ 'environment variable GOOGLE_APPLICATION_CREDENTIALS '
+ 'and model path storage.googleapis.com/{bucket}/{object}.')
+ logging.info('Saving model to {}'.format(file_name))
+ del storage_client
+ return file_name
+
+
def model_to_estimator(keras_model=None,
keras_model_path=None,
custom_objects=None,
@@ -407,12 +445,13 @@ def model_to_estimator(keras_model=None,
'Please specity either `keras_model` or `keras_model_path`, '
'but not both.')
+ config = estimator_lib.maybe_overwrite_model_dir_and_session_config(
+ config, model_dir)
if not keras_model:
if keras_model_path.startswith(
'gs://') or 'storage.googleapis.com' in keras_model_path:
- raise ValueError(
- '%s is not a local path. Please copy the model locally first.' %
- keras_model_path)
+ keras_model_path = _get_file_from_google_storage(keras_model_path,
+ config.model_dir)
logging.info('Loading models from %s', keras_model_path)
keras_model = models.load_model(keras_model_path)
else:
@@ -425,9 +464,6 @@ def model_to_estimator(keras_model=None,
'Please compile the model with `model.compile()` '
'before calling `model_to_estimator()`.')
- config = estimator_lib.maybe_overwrite_model_dir_and_session_config(config,
- model_dir)
-
keras_model_fn = _create_keras_model_fn(keras_model, custom_objects)
if _any_weight_initialized(keras_model):
# Warn if config passed to estimator tries to update GPUOptions. If a
diff --git a/tensorflow/python/estimator/keras_test.py b/tensorflow/python/estimator/keras_test.py
index 288f9b8906..4e285fa25a 100644
--- a/tensorflow/python/estimator/keras_test.py
+++ b/tensorflow/python/estimator/keras_test.py
@@ -581,12 +581,6 @@ class TestKerasEstimator(test_util.TensorFlowTestCase):
with self.assertRaisesRegexp(ValueError, 'compiled'):
keras_lib.model_to_estimator(keras_model=keras_model)
- with self.cached_session():
- keras_model = simple_sequential_model()
- with self.assertRaisesRegexp(ValueError, 'not a local path'):
- keras_lib.model_to_estimator(
- keras_model_path='gs://bucket/object')
-
def test_invalid_ionames_error(self):
(x_train, y_train), (_, _) = testing_utils.get_test_data(
train_samples=_TRAIN_SIZE,
diff --git a/tensorflow/python/estimator/util.py b/tensorflow/python/estimator/util.py
index 31e4778e72..fb110c4b7b 100644
--- a/tensorflow/python/estimator/util.py
+++ b/tensorflow/python/estimator/util.py
@@ -22,7 +22,6 @@ from __future__ import print_function
import os
import time
-from tensorflow.core.protobuf import config_pb2
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import training
@@ -144,14 +143,11 @@ class StrategyInitFinalizeHook(training.SessionRunHook):
self._finalize_fn = finalize_fn
def begin(self):
+ # We only create the init ops, but don't run it. We rely on SessionManager
+ # to run it for us.
self._init_ops = self._initialization_fn()
self._finalize_ops = self._finalize_fn()
- def after_create_session(self, session, coord):
- logging.info('Initialize system')
- session.run(self._init_ops,
- options=config_pb2.RunOptions(timeout_in_ms=5 * 60 * 1000))
-
def end(self, session):
logging.info('Finalize system.')
session.run(self._finalize_ops)
diff --git a/tensorflow/python/feature_column/BUILD b/tensorflow/python/feature_column/BUILD
index 5800b693b4..ac53a84eef 100644
--- a/tensorflow/python/feature_column/BUILD
+++ b/tensorflow/python/feature_column/BUILD
@@ -156,7 +156,7 @@ py_test(
"//tensorflow/python:variables",
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
- "//tensorflow/python/estimator:estimator_py",
+ "//tensorflow/python/estimator:numpy_io",
"//third_party/py/numpy",
],
)
diff --git a/tensorflow/python/feature_column/feature_column_v2.py b/tensorflow/python/feature_column/feature_column_v2.py
index a8d5bfb437..b79373c475 100644
--- a/tensorflow/python/feature_column/feature_column_v2.py
+++ b/tensorflow/python/feature_column/feature_column_v2.py
@@ -271,6 +271,7 @@ class _StateManagerImpl(StateManager):
dtype=dtype,
initializer=initializer,
trainable=self._trainable and trainable,
+ use_resource=True,
# TODO(rohanj): Get rid of this hack once we have a mechanism for
# specifying a default partitioner for an entire layer. In that case,
# the default getter for Layers should work.
@@ -383,8 +384,8 @@ class FeatureLayer(Layer):
if isinstance(column, SharedEmbeddingColumn):
column.create_state(self._shared_state_manager)
else:
- with variable_scope.variable_scope(None, default_name=self.name):
- with variable_scope.variable_scope(None, default_name=column.name):
+ with variable_scope._pure_variable_scope(self.name): # pylint: disable=protected-access
+ with variable_scope._pure_variable_scope(column.name): # pylint: disable=protected-access
column.create_state(self._state_manager)
super(FeatureLayer, self).build(None)
@@ -414,19 +415,20 @@ class FeatureLayer(Layer):
output_tensors = []
ordered_columns = []
for column in sorted(self._feature_columns, key=lambda x: x.name):
- ordered_columns.append(column)
- if isinstance(column, SharedEmbeddingColumn):
- tensor = column.get_dense_tensor(transformation_cache,
- self._shared_state_manager)
- else:
- tensor = column.get_dense_tensor(transformation_cache,
- self._state_manager)
- num_elements = column.variable_shape.num_elements()
- batch_size = array_ops.shape(tensor)[0]
- tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
- output_tensors.append(tensor)
- if cols_to_output_tensors is not None:
- cols_to_output_tensors[column] = tensor
+ with ops.name_scope(column.name):
+ ordered_columns.append(column)
+ if isinstance(column, SharedEmbeddingColumn):
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._shared_state_manager)
+ else:
+ tensor = column.get_dense_tensor(transformation_cache,
+ self._state_manager)
+ num_elements = column.variable_shape.num_elements()
+ batch_size = array_ops.shape(tensor)[0]
+ tensor = array_ops.reshape(tensor, shape=(batch_size, num_elements))
+ output_tensors.append(tensor)
+ if cols_to_output_tensors is not None:
+ cols_to_output_tensors[column] = tensor
_verify_static_batch_size_equality(output_tensors, ordered_columns)
return array_ops.concat(output_tensors, 1)
@@ -601,6 +603,7 @@ class LinearModel(Layer):
shape=[self._units],
initializer=init_ops.zeros_initializer(),
trainable=self.trainable,
+ use_resource=True,
# TODO(rohanj): Get rid of this hack once we have a mechanism for
# specifying a default partitioner for an entire layer. In that case,
# the default getter for Layers should work.
@@ -627,36 +630,41 @@ class LinearModel(Layer):
if not isinstance(features, dict):
raise ValueError('We expected a dictionary here. Instead we got: ',
features)
- transformation_cache = FeatureTransformationCache(features)
- weighted_sums = []
- for column in self._feature_columns:
- with ops.name_scope(column.name):
- # All the weights used in the linear model are owned by the state
- # manager associated with this Linear Model.
- weight_var = self._state_manager.get_variable(column, 'weights')
-
- # The embedding weights for the SharedEmbeddingColumn are owned by
- # the shared_state_manager and so we need to pass that in while
- # creating the weighted sum. For all other columns, the state is owned
- # by the Linear Model's state manager.
- if isinstance(column, SharedEmbeddingColumn):
- state_manager = self._shared_state_manager
- else:
- state_manager = self._state_manager
- weighted_sum = _create_weighted_sum(
- column=column,
- transformation_cache=transformation_cache,
- state_manager=state_manager,
- sparse_combiner=self._sparse_combiner,
- weight_var=weight_var)
- weighted_sums.append(weighted_sum)
-
- _verify_static_batch_size_equality(weighted_sums, self._feature_columns)
- predictions_no_bias = math_ops.add_n(
- weighted_sums, name='weighted_sum_no_bias')
- predictions = nn_ops.bias_add(
- predictions_no_bias, self._bias_variable, name='weighted_sum')
- return predictions
+ with ops.name_scope(self.name):
+ transformation_cache = FeatureTransformationCache(features)
+ weighted_sums = []
+ for column in self._feature_columns:
+ with ops.name_scope(column.name):
+ # All the weights used in the linear model are owned by the state
+ # manager associated with this Linear Model.
+ weight_var = self._state_manager.get_variable(column, 'weights')
+
+ # The embedding weights for the SharedEmbeddingColumn are owned by
+ # the shared_state_manager and so we need to pass that in while
+ # creating the weighted sum. For all other columns, the state is owned
+ # by the Linear Model's state manager.
+ if isinstance(column, SharedEmbeddingColumn):
+ state_manager = self._shared_state_manager
+ else:
+ state_manager = self._state_manager
+ weighted_sum = _create_weighted_sum(
+ column=column,
+ transformation_cache=transformation_cache,
+ state_manager=state_manager,
+ sparse_combiner=self._sparse_combiner,
+ weight_var=weight_var)
+ weighted_sums.append(weighted_sum)
+
+ _verify_static_batch_size_equality(weighted_sums, self._feature_columns)
+ predictions_no_bias = math_ops.add_n(
+ weighted_sums, name='weighted_sum_no_bias')
+ predictions = nn_ops.bias_add(
+ predictions_no_bias, self._bias_variable, name='weighted_sum')
+ return predictions
+
+ @property
+ def bias_variable(self):
+ return self._bias_variable
def _transform_features(features, feature_columns, state_manager):
@@ -2605,6 +2613,7 @@ class SharedEmbeddingStateManager(Layer):
dtype=dtype,
trainable=self.trainable and trainable,
initializer=initializer,
+ use_resource=True,
# TODO(rohanj): Get rid of this hack once we have a mechanism for
# specifying a default partitioner for an entire layer. In that case,
# the default getter for Layers should work.
@@ -3279,6 +3288,7 @@ def _safe_embedding_lookup_sparse(embedding_weights,
raise ValueError('Missing embedding_weights %s.' % embedding_weights)
dtype = sparse_weights.dtype if sparse_weights is not None else None
+ # TODO(rohanj): Look into removing this convert_to_tensor call.
embedding_weights = [
ops.convert_to_tensor(w, dtype=dtype) for w in embedding_weights
]
diff --git a/tensorflow/python/feature_column/feature_column_v2_test.py b/tensorflow/python/feature_column/feature_column_v2_test.py
index a13a5010e1..d3787146ed 100644
--- a/tensorflow/python/feature_column/feature_column_v2_test.py
+++ b/tensorflow/python/feature_column/feature_column_v2_test.py
@@ -5170,8 +5170,8 @@ class WeightedCategoricalColumnTest(test.TestCase):
key='ids', num_buckets=3),
weight_feature_key='values')
with ops.Graph().as_default():
- with self.assertRaisesRegexp(
- ValueError, r'Dimensions.*are not compatible'):
+ with self.assertRaisesRegexp(ValueError,
+ r'Dimensions.*are not compatible'):
model = fc.LinearModel((column,))
model({
'ids':
diff --git a/tensorflow/python/framework/device.py b/tensorflow/python/framework/device.py
index 06c653097a..7f6e0a75a5 100644
--- a/tensorflow/python/framework/device.py
+++ b/tensorflow/python/framework/device.py
@@ -87,6 +87,7 @@ class DeviceSpec(object):
else:
self.device_type = device_type
self.device_index = device_index
+ self._hash = hash(self.to_string())
def _clear(self):
self._job = None
@@ -234,7 +235,7 @@ class DeviceSpec(object):
return self.to_string() == other.to_string()
def __hash__(self):
- return hash(self.to_string())
+ return self._hash
def check_valid(spec):
@@ -266,6 +267,7 @@ def canonical_name(device):
# possible to compare the device function stacks belonging to different
# graphs in a meaningful way.
_cached_device_functions = {}
+_cached_device_specs = {}
_cache_lock = threading.Lock()
@@ -297,7 +299,13 @@ def merge_device(spec):
"""
with _cache_lock:
if not isinstance(spec, DeviceSpec):
- spec = DeviceSpec.from_string(spec or "")
+ cached_device_spec = _cached_device_specs.get(spec, None)
+ if cached_device_spec is None:
+ device_spec = DeviceSpec.from_string(spec or "")
+ _cached_device_specs[spec] = device_spec
+ spec = device_spec
+ else:
+ spec = cached_device_spec
cached_function = _cached_device_functions.get(spec, None)
if cached_function is not None:
return cached_function
diff --git a/tensorflow/python/framework/dtypes.py b/tensorflow/python/framework/dtypes.py
index c3f70df7d8..64d3b42d89 100644
--- a/tensorflow/python/framework/dtypes.py
+++ b/tensorflow/python/framework/dtypes.py
@@ -26,7 +26,7 @@ from tensorflow.python.util.tf_export import tf_export
_np_bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
-@tf_export("DType")
+@tf_export("dtypes.DType", "DType")
class DType(object):
"""Represents the type of the elements in a `Tensor`.
@@ -658,7 +658,7 @@ _PYTHON_TO_TF = {
}
-@tf_export("as_dtype")
+@tf_export("dtypes.as_dtype", "as_dtype")
def as_dtype(type_value):
"""Converts the given `type_value` to a `DType`.
diff --git a/tensorflow/python/framework/errors_impl.py b/tensorflow/python/framework/errors_impl.py
index 5af71f2cfb..8b303fa8a9 100644
--- a/tensorflow/python/framework/errors_impl.py
+++ b/tensorflow/python/framework/errors_impl.py
@@ -25,11 +25,13 @@ from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.framework import c_api_util
from tensorflow.python.util import compat
+from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
-@tf_export("OpError", "errors.OpError")
+@tf_export("errors.OpError", "OpError")
+@deprecation.deprecated_endpoints("OpError")
class OpError(Exception):
"""A generic error that is raised when TensorFlow execution fails.
@@ -72,7 +74,7 @@ class OpError(Exception):
or `Recv` op, there will be no corresponding
`tf.Operation`
object. In that case, this will return `None`, and you should
- instead use the `tf.OpError.node_def` to
+ instead use the `tf.errors.OpError.node_def` to
discover information about the op.
Returns:
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index f287289bd0..225208944e 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -134,7 +134,7 @@ class Defun(object):
# Func should not use kwargs and defaults.
argspec = tf_inspect.getargspec(func)
if argspec.keywords or argspec.defaults:
- raise ValueError("Functions with argument defaults or keyword "
+ raise ValueError("Functions with argument defaults or keywords "
"arguments are not supported.")
# Computes how many arguments 'func' has.
diff --git a/tensorflow/python/framework/graph_io.py b/tensorflow/python/framework/graph_io.py
index be30b16f5f..47e1344eae 100644
--- a/tensorflow/python/framework/graph_io.py
+++ b/tensorflow/python/framework/graph_io.py
@@ -27,7 +27,7 @@ from tensorflow.python.lib.io import file_io
from tensorflow.python.util.tf_export import tf_export
-@tf_export('train.write_graph')
+@tf_export('io.write_graph', 'train.write_graph')
def write_graph(graph_or_graph_def, logdir, name, as_text=True):
"""Writes a graph proto to a file.
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index e48e67c8a1..c6595918ae 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -329,7 +329,7 @@ def _SetDefaultAttrValues(node_def, op_def):
node_def.attr[key].CopyFrom(attr_def.default_value)
-@tf_export('import_graph_def')
+@tf_export('graph_util.import_graph_def', 'import_graph_def')
@deprecated_args(None, 'Please file an issue at '
'https://github.com/tensorflow/tensorflow/issues if you depend'
' on this feature.', 'op_dict')
diff --git a/tensorflow/python/framework/random_seed.py b/tensorflow/python/framework/random_seed.py
index 2f9504889a..6f9f347a99 100644
--- a/tensorflow/python/framework/random_seed.py
+++ b/tensorflow/python/framework/random_seed.py
@@ -22,6 +22,7 @@ from __future__ import print_function
from tensorflow.python.eager import context
from tensorflow.python.framework import ops
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -33,7 +34,8 @@ def _truncate_seed(seed):
return seed % _MAXINT32 # Truncate to fit into 32-bit integer
-@tf_export('get_seed')
+@tf_export('random.get_seed', 'get_seed')
+@deprecation.deprecated_endpoints('get_seed')
def get_seed(op_seed):
"""Returns the local seeds an operation should use given an op-specific seed.
@@ -80,7 +82,7 @@ def get_seed(op_seed):
return seeds
-@tf_export('set_random_seed')
+@tf_export('random.set_random_seed', 'set_random_seed')
def set_random_seed(seed):
"""Sets the graph-level random seed.
diff --git a/tensorflow/python/framework/sparse_tensor.py b/tensorflow/python/framework/sparse_tensor.py
index d1bdd9b80a..440e3a0968 100644
--- a/tensorflow/python/framework/sparse_tensor.py
+++ b/tensorflow/python/framework/sparse_tensor.py
@@ -33,7 +33,7 @@ _override_helper = ops._override_helper
# pylint: enable=protected-access
-@tf_export("SparseTensor")
+@tf_export("sparse.SparseTensor", "SparseTensor")
class SparseTensor(_TensorLike):
"""Represents a sparse tensor.
@@ -245,7 +245,7 @@ class SparseTensor(_TensorLike):
SparseTensorValue = collections.namedtuple(
"SparseTensorValue", ["indices", "values", "dense_shape"])
tf_export("SparseTensorValue")(SparseTensorValue)
-pywrap_tensorflow.RegisterSparseTensorValueClass(SparseTensorValue)
+pywrap_tensorflow.RegisterType("SparseTensorValue", SparseTensorValue)
@tf_export("convert_to_tensor_or_sparse_tensor")
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index cd0b03be43..6673bc5561 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -24,8 +24,8 @@ from collections import OrderedDict
import contextlib
import gc
import itertools
-import os
import math
+import os
import random
import re
import tempfile
@@ -402,11 +402,14 @@ def with_c_shapes(cls):
return cls
-def enable_cond_v2(fn):
- """Decorator for enabling CondV2 on a test.
+def enable_control_flow_v2(fn):
+ """Decorator for enabling CondV2 and WhileV2 on a test.
- Note this enables using CondV2 after running the test class's setup/teardown
- methods.
+ Note this enables using CondV2 and WhileV2 after running the test class's
+ setup/teardown methods.
+
+ In addition to this, callers must import the while_v2 module in order to set
+ the _while_v2 module in control_flow_ops.
Args:
fn: the function to be wrapped
@@ -416,21 +419,56 @@ def enable_cond_v2(fn):
"""
def wrapper(*args, **kwargs):
- prev_value = control_flow_ops.ENABLE_COND_V2
+ enable_cond_v2_old = control_flow_ops.ENABLE_COND_V2
+ enable_while_v2_old = control_flow_ops.ENABLE_WHILE_V2
control_flow_ops.ENABLE_COND_V2 = True
+ control_flow_ops.ENABLE_WHILE_V2 = True
try:
fn(*args, **kwargs)
finally:
- control_flow_ops.ENABLE_COND_V2 = prev_value
+ control_flow_ops.ENABLE_COND_V2 = enable_cond_v2_old
+ control_flow_ops.ENABLE_WHILE_V2 = enable_while_v2_old
return wrapper
-def with_cond_v2(cls):
- """Adds methods that call original methods but with CondV2 enabled.
+def with_control_flow_v2(cls):
+ """Adds methods that call original methods with WhileV2 and CondV2 enabled.
- Note this enables CondV2 in new methods after running the test class's
- setup method.
+ Note this enables CondV2 and WhileV2 in new methods after running the test
+ class's setup method.
+
+ In addition to this, callers must import the while_v2 module in order to set
+ the _while_v2 module in control_flow_ops.
+
+ If a test function has _disable_control_flow_v2 attr set to True (using the
+ @disable_control_flow_v2 decorator), the v2 function is not generated for it.
+
+ Example:
+
+ @test_util.with_control_flow_v2
+ class ControlFlowTest(test.TestCase):
+
+ def testEnabledForV2(self):
+ ...
+
+ @test_util.disable_control_flow_v2("b/xyzabc")
+ def testDisabledForV2(self):
+ ...
+
+ Generated class:
+ class ControlFlowTest(test.TestCase):
+
+ def testEnabledForV2(self):
+ ...
+
+ def testEnabledForV2WithControlFlowV2(self):
+ // Enable V2 flags.
+ testEnabledForV2(self)
+ // Restore V2 flags.
+
+ def testDisabledForV2(self):
+ ...
Args:
cls: class to decorate
@@ -438,15 +476,33 @@ def with_cond_v2(cls):
Returns:
cls with new test methods added
"""
- if control_flow_ops.ENABLE_COND_V2:
+ if control_flow_ops.ENABLE_WHILE_V2 and control_flow_ops.ENABLE_COND_V2:
return cls
for name, value in cls.__dict__.copy().items():
- if callable(value) and name.startswith("test"):
- setattr(cls, name + "WithCondV2", enable_cond_v2(value))
+ if (callable(value) and name.startswith("test") and
+ not getattr(value, "_disable_control_flow_v2", False)):
+ setattr(cls, name + "WithControlFlowV2", enable_control_flow_v2(value))
return cls
+def disable_control_flow_v2(unused_msg):
+ """Decorator for a function in a with_control_flow_v2 enabled test class.
+
+ Blocks the function from being run with v2 control flow ops.
+
+ Args:
+ unused_msg: Reason for disabling.
+
+ Returns:
+ The wrapped function with _disable_control_flow_v2 attr set to True.
+ """
+ def wrapper(func):
+ func._disable_control_flow_v2 = True
+ return func
+ return wrapper
+
+
def assert_no_new_pyobjects_executing_eagerly(f):
"""Decorator for asserting that no new Python objects persist after a test.
diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py
index 4589c821e5..0d6877e4a1 100644
--- a/tensorflow/python/keras/backend.py
+++ b/tensorflow/python/keras/backend.py
@@ -1511,12 +1511,8 @@ def batch_dot(x, y, axes=None):
out = math_ops.reduce_sum(
math_ops.multiply(array_ops.transpose(x, [1, 0]), y), axes[1])
else:
- if axes is not None:
- adj_x = None if axes[0] == ndim(x) - 1 else True
- adj_y = True if axes[1] == ndim(y) - 1 else None
- else:
- adj_x = None
- adj_y = None
+ adj_x = None if axes[0] == ndim(x) - 1 else True
+ adj_y = True if axes[1] == ndim(y) - 1 else None
out = math_ops.matmul(x, y, adjoint_a=adj_x, adjoint_b=adj_y)
if diff:
if x_ndim > y_ndim:
@@ -3062,7 +3058,8 @@ def rnn(step_function,
mask=None,
constants=None,
unroll=False,
- input_length=None):
+ input_length=None,
+ time_major=False):
"""Iterates over the time dimension of a tensor.
Arguments:
@@ -3091,6 +3088,13 @@ def rnn(step_function,
constants: List of constant values passed at each step.
unroll: Whether to unroll the RNN or to use a symbolic `while_loop`.
input_length: If specified, assume time dimension is of this length.
+ time_major: Boolean. If true, the inputs and outputs will be in shape
+ `(timesteps, batch, ...)`, whereas in the False case, it will be
+ `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
+ efficient because it avoids transposes at the beginning and end of the
+ RNN calculation. However, most TensorFlow data is batch-major, so by
+ default this function accepts input and emits output in batch-major
+ form.
Returns:
A tuple, `(last_output, outputs, new_states)`.
@@ -3112,15 +3116,17 @@ def rnn(step_function,
if ndim < 3:
raise ValueError('Input should be at least 3D.')
inputs_shape = inputs.shape
- axes = [1, 0] + list(range(2, ndim))
- inputs = array_ops.transpose(inputs, (axes))
+ if not time_major:
+ axes = [1, 0] + list(range(2, ndim))
+ inputs = array_ops.transpose(inputs, axes)
if mask is not None:
if mask.dtype != dtypes_module.bool:
mask = math_ops.cast(mask, dtypes_module.bool)
if len(mask.shape) == ndim - 1:
mask = expand_dims(mask)
- mask = array_ops.transpose(mask, axes)
+ if not time_major:
+ mask = array_ops.transpose(mask, axes)
if constants is None:
constants = []
@@ -3301,10 +3307,11 @@ def rnn(step_function,
outputs = output_ta.stack()
last_output = output_ta.read(last_time - 1)
- axes = [1, 0] + list(range(2, len(outputs.shape)))
- outputs = array_ops.transpose(outputs, axes)
+ if not time_major:
+ axes = [1, 0] + list(range(2, len(outputs.shape)))
+ outputs = array_ops.transpose(outputs, axes)
- # Static shape inference: (samples, time, ...)
+ # Static shape inference: (samples, time, ...) or (time, sample, ...)
outputs_shape = outputs.shape.as_list()
outputs_shape[0] = inputs_shape[0]
outputs_shape[1] = inputs_shape[1]
diff --git a/tensorflow/python/keras/engine/base_layer.py b/tensorflow/python/keras/engine/base_layer.py
index e98b131ae6..a75ce30d31 100644
--- a/tensorflow/python/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/engine/base_layer.py
@@ -20,6 +20,7 @@ from __future__ import print_function
import collections as collections_lib
import enum # pylint: disable=g-bad-import-order
+import functools
import inspect # Necessary supplement to tf_inspect to deal with variadic args.
import numpy as np
@@ -160,9 +161,13 @@ class Layer(checkpointable.CheckpointableBase):
self._trainable_weights = []
self._non_trainable_weights = []
self._updates = []
- # When executing eagerly, _losses is a list of zero-argument lambdas which
- # return tensors. When using graph execution, _losses is a list of ops.
+ # A list of zero-argument lambdas which return Tensors, used for variable
+ # regularizers.
+ self._callable_losses = []
+ # A list of Tensors containing activity regularizers and losses manually
+ # added through `add_loss`. Empty when executing eagerly.
self._losses = []
+ self._in_call = False # Flag for error checking in add_loss
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
self._call_fn_args = function_utils.fn_args(self.call)
self._compute_previous_mask = ('mask' in self._call_fn_args or
@@ -359,20 +364,20 @@ class Layer(checkpointable.CheckpointableBase):
def losses(self):
"""Losses which are associated with this `Layer`.
- Note that when executing eagerly, getting this property evaluates
- regularizers. When using graph execution, variable regularization ops have
- already been created and are simply returned here.
+ Variable regularization tensors are created when this property is accessed,
+ so it is eager safe: accessing `losses` under a `tf.GradientTape` will
+ propagate gradients back to the corresponding variables.
Returns:
A list of tensors.
"""
- if context.executing_eagerly():
- # _losses may only contain variable regularization losses when executing
- # eagerly, and they have been saved as lambdas to be executed when
- # requested.
- return [regularizer() for regularizer in self._losses]
- else:
- return self._losses
+ collected_losses = []
+ collected_losses.extend(self._losses)
+ for regularizer in self._callable_losses:
+ loss_tensor = regularizer()
+ if loss_tensor is not None:
+ collected_losses.append(loss_tensor)
+ return collected_losses
@doc_controls.for_subclass_implementers
def add_loss(self, losses, inputs=None):
@@ -393,7 +398,9 @@ class Layer(checkpointable.CheckpointableBase):
from `Layer.call()`).
Arguments:
- losses: Loss tensor, or list/tuple of tensors.
+ losses: Loss tensor, or list/tuple of tensors. Rather than tensors, losses
+ may also be zero-argument callables which create a loss tensor. Only
+ callable losses are supported when executing eagerly.
inputs: If anything other than None is passed, it signals the losses
are conditional on some of the layer's inputs,
and thus they should only be run where these inputs are available.
@@ -403,29 +410,45 @@ class Layer(checkpointable.CheckpointableBase):
(e.g. weight regularization losses).
Raises:
- RuntimeError: If called in Eager mode.
+ RuntimeError: If called in Eager mode with a `Tensor` rather than a
+ callable, or if `inputs` is not None.
"""
- if context.executing_eagerly():
- # TODO(fchollet): it should be possible (and highly desirable) to support
- # `add_loss` in eager mode. This allows great convenience and flexibility
- # in defining custom losses on the fly (e.g. in VAEs).
- # Simply appending the loss value to `self._losses`
- # is the correct behavior.
- # The only caveat is that we need to force the user to only call
- # `add_loss` from inside a model or Layer's `call` method
- # (otherwise the loss computation cannot be backproped through).
- raise RuntimeError('Layer.add_loss not supported in Eager mode.')
-
+ executing_eagerly = context.executing_eagerly()
+ if executing_eagerly:
+ if inputs is not None:
+ raise RuntimeError(
+ 'Activity regularization (via the "inputs" argument to '
+ 'Layer.add_loss) is not supported when executing eagerly. Consider '
+ 'returning activity regularization losses from a Model\'s call() '
+ 'method.')
+ if getattr(self, '_in_call', False):
+ # TODO(psv): Support activity regularization and a way to reset losses.
+ raise RuntimeError(
+ 'Adding losses inside a Layer\'s call() method is not currently '
+ 'supported when executing eagerly. Please file a feature request '
+ 'if you need this limitation lifted.')
losses = generic_utils.to_list(losses)
- losses = [ops.convert_to_tensor(loss, dtype=backend.floatx())
- if not tensor_util.is_tensor(loss) else loss for loss in losses]
- self._losses += losses
- if inputs is None:
- for loss in losses:
- loss._unconditional_loss = True # pylint: disable=protected-access
- else:
- for loss in losses:
- loss._unconditional_loss = False # pylint: disable=protected-access
+
+ def _tag_unconditional(loss):
+ if callable(loss):
+ loss = loss()
+ if loss is None:
+ return None # Will be filtered out when computing the .losses property
+ if not tensor_util.is_tensor(loss):
+ loss = ops.convert_to_tensor(loss, dtype=backend.floatx())
+ loss._unconditional_loss = (inputs is None) # pylint: disable=protected-access
+ return loss
+
+ for loss in losses:
+ if callable(loss):
+ self._callable_losses.append(
+ functools.partial(_tag_unconditional, loss))
+ else:
+ if executing_eagerly:
+ raise RuntimeError(
+ 'Layer.add_loss only supported for zero-argument lambdas when '
+ 'executing eagerly.')
+ self._losses.append(_tag_unconditional(loss))
def get_losses_for(self, inputs):
"""Retrieves losses relevant to a specific set of inputs.
@@ -599,56 +622,20 @@ class Layer(checkpointable.CheckpointableBase):
return variable
def _handle_weight_regularization(self, name, variable, regularizer):
- # `init_graph` should point to the graph in which variable initialization
- # will occur; it should be None if and only if initialization will take
- # place in the eager context.
- init_graph = None
- if not context.executing_eagerly():
- default_graph = ops.get_default_graph()
- if default_graph.building_function:
- with ops.init_scope():
- # Retrieve the variables from the graph into which variables
- # will be lifted; if initialization ops will be lifted into
- # the eager context, then there is nothing to retrieve, since variable
- # collections are not supported when eager execution is enabled.
- if not context.executing_eagerly():
- init_graph = ops.get_default_graph()
- else:
- # Initialization ops will not be lifted out of the default graph.
- init_graph = default_graph
-
- if init_graph is not None: # pylint: disable=protected-access
- # The variable was created and initialized in a graph.
- if regularizer:
- if isinstance(variable, tf_variables.PartitionedVariable):
- for v in variable:
- with ops.colocate_with(v.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(v)
- if regularization is not None:
- self.add_loss(regularization)
- else:
- with ops.colocate_with(variable.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(variable)
- if regularization is not None:
- self.add_loss(regularization)
- elif regularizer: # initialization took place in an eager context
- if isinstance(variable, tf_variables.PartitionedVariable):
- raise RuntimeError(
- 'Partitioned variable regularization is not yet '
- 'supported when executing eagerly. File a feature request'
- 'if this is important to you.')
- # Save a zero-argument lambda which runs the regularizer on the
- # variable, to be executed when `Layer.losses` is requested.
- # This makes losses responsive to variable updates when executing
- # eagerly.
- #
- # TODO(akshayka): Do the same for graphs as well, so that losses
- # collected in a while_loop can be run outside its control flow
- # context and so that losses won't be swallowed up by graph functions
- # (i.e., `.losses()` should always create regularizers).
- self._losses.append(lambda: regularizer(variable))
+ """Create lambdas which compute regularization losses."""
+
+ def _loss_for_variable(v):
+ """Creates a regularization loss `Tensor` for variable `v`."""
+ with ops.colocate_with(v):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(v)
+ return regularization
+
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for v in variable:
+ self.add_loss(functools.partial(_loss_for_variable, v))
+ else:
+ self.add_loss(functools.partial(_loss_for_variable, variable))
def _handle_activity_regularization(self, inputs, outputs):
# Apply activity regularization.
@@ -766,7 +753,9 @@ class Layer(checkpointable.CheckpointableBase):
self._assert_input_compatibility(inputs)
if not in_deferred_mode:
+ self._in_call = True
outputs = self.call(inputs, *args, **kwargs)
+ self._in_call = False
if outputs is None:
raise ValueError('A layer\'s `call` method should return a Tensor '
'or a list of Tensors, not None (layer: ' +
diff --git a/tensorflow/python/keras/engine/distributed_training_utils.py b/tensorflow/python/keras/engine/distributed_training_utils.py
index 39341a931b..050602868a 100644
--- a/tensorflow/python/keras/engine/distributed_training_utils.py
+++ b/tensorflow/python/keras/engine/distributed_training_utils.py
@@ -17,12 +17,18 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import numpy as np
+
from tensorflow.python.client import session as session_module
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend as K
from tensorflow.python.keras import callbacks
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
from tensorflow.python.util import nest
@@ -304,23 +310,19 @@ def validate_inputs(x, y, distribution_strategy):
compiled.
Raises:
- ValueError: if input is not a Dataset or a numpy array.
+ ValueError: if input is not a Dataset or a numpy array(when we use
+ MirroredStrategy).
"""
- if isinstance(x, list) or isinstance(y, list):
- raise ValueError('DistributionStrategy does not support lists of numpy'
- 'arrays. You must pass a Dataset object or a numpy array '
- 'as input.')
-
if isinstance(x, dict) or isinstance(y, dict):
- raise ValueError('DistributionStrategy does not support inputs of type '
- 'dict. You must pass a Dataset object or a numpy array as '
- 'input.')
+ raise ValueError('`DistributionStrategy` does not support inputs of type '
+ 'dict. You must pass a `tf.data.Dataset` object or a '
+ 'numpy array as input.')
- if isinstance(x, iterator_ops.Iterator) or \
- isinstance(y, iterator_ops.Iterator):
- raise ValueError('DistributionStrategy does not support inputs of type '
- 'Iterator. You must pass a Dataset object or a numpy '
- 'array as input.')
+ if (isinstance(x, iterator_ops.Iterator) or
+ isinstance(y, iterator_ops.Iterator)):
+ raise ValueError('`DistributionStrategy` does not support inputs of type '
+ 'Iterator. You must pass a `tf.data.Dataset` object or a '
+ 'numpy array as input.')
if distribution_strategy.__class__.__name__ == 'TPUStrategy':
for i in [x, y]:
@@ -334,14 +336,14 @@ def validate_inputs(x, y, distribution_strategy):
'Found unknown shape {} in input {}.'.format(s, i))
-def get_input_batch_params(first_x_value, batch_size, current_strategy):
+def get_input_batch_params(first_x_value, batch_size, distribution_strategy):
"""Calculate the number of batches and steps/steps_per_epoch.
Args:
first_x_value: This is the first input numpy array that is passed in as the
model input.
batch_size: The specified batch_size or the default batch_size of 32.
- current_strategy: The current DistributionStrategy used to compile the
+ distribution_strategy: The current DistributionStrategy used to compile the
model.
Returns:
@@ -359,14 +361,14 @@ def get_input_batch_params(first_x_value, batch_size, current_strategy):
# TODO(anjalisridhar): TPU currently supports using the num_towers property.
# We might want to look into implementing worker_devices. In multi worker
# strategy, perhaps num_towers works better?
- steps = num_batches // current_strategy.num_towers
+ steps = num_batches // distribution_strategy.num_towers
if not steps:
# TODO(anjalisridhar): Number of towers in the error message may not convey
# what we want to the user. Is there another terminology that we can use
# that is consistent across different strategies.
raise ValueError('The number of batches %d is smaller than the number '
'of towers %d used for DistributionStrategy. ' %
- num_batches, current_strategy.num_towers)
+ (num_batches, distribution_strategy.num_towers))
return steps
@@ -376,3 +378,99 @@ def get_batch_dimension(iterator):
# all.
dims = shapes[0].dims
return dims[0] if dims else None
+
+
+def get_cpu_device(distribution_strategy):
+ """Returns the CPU device of the TPU host or the default CPU device string.
+
+ Args:
+ distribution_strategy: The DistributionStrategy used to compile the model.
+
+ Returns:
+ A device string which is the TPU host's CPU device in case of
+ TPUDistributionStrategy or the default CPU device string in all other
+ cases.
+
+ Raises:
+ NotImplementedError: We currently don't support copying numpy data to
+ multiple hosts in the case of Cloud TPU pods.
+ """
+ if distribution_strategy.__class__.__name__ == 'TPUStrategy':
+ if distribution_strategy.num_hosts > 1:
+ raise NotImplementedError('TPUDistributionStrategy does not '
+ 'support numpy inputs when running on Cloud'
+ 'TPU pods.')
+ return distribution_strategy.get_host_cpu_device(0)
+ else:
+ # For all strategies except TPUDistributionStrategy
+ # TODO(anjalisridhar): We may need to modify this when we add support for
+ # multi-worker strategy.
+ return '/CPU:0'
+
+
+def get_var_for_numpy(distribution_strategy, x):
+ if isinstance(x, list):
+ var_x = tuple([_get_var_for_numpy(distribution_strategy, single_input)
+ for single_input in x])
+ else:
+ var_x = _get_var_for_numpy(distribution_strategy, x)
+ return var_x
+
+
+def _get_var_for_numpy(distribution_strategy, input_array):
+ """Creates a variable and assigns the value of the numpy array to it.
+
+ Args:
+ distribution_strategy: The DistributionStrategy used to compile the model.
+ input_array: The input numpy array whose value will be assigned to the
+ variable we create.
+
+ Returns:
+ The variable to which we will copy the value of the input numpy array.
+
+ """
+ with ops.device(get_cpu_device(distribution_strategy)):
+ # Create and initialize a variable on the CPU device. This is the CPU
+ # device of the host in the case of TPUDistributionStrategy.
+ input_var = variables.VariableV1(array_ops.zeros(input_array.shape,
+ input_array.dtype),
+ trainable=False, use_resource=True)
+ K.get_session().run(input_var.initializer)
+
+ # Create a placeholder for the numpy array input slices. We copy the value
+ # of the input numpy array to the variable in slices of size 64 MB to avoid
+ # running into memory issues or RPC message limits.
+ start_placeholder = array_ops.placeholder(dtypes.int64, ())
+ end_placeholder = array_ops.placeholder(dtypes.int64, ())
+ slice_placeholder = array_ops.placeholder(input_var.dtype)
+ assign_slice_op = input_var[start_placeholder:end_placeholder].assign(
+ slice_placeholder)
+
+ # If each batch element is > 64 MB, then we copy each batch element
+ # individually. Otherwise, the slices will be < 128 MB. There might be padding
+ # which might mean that the slices are 128 MB even if the size of the
+ # tensor allocated is less than 128 MB.
+ # This formula gives slices with size:
+ # ceil(64 MB / byte size per batch element) bytes.
+ # Using ceil() guarantees we get a number >= 1.
+
+ # Calculate the size of each batch element.
+ byte_size_per_batch_element = np.prod(input_array.shape[1:]) * \
+ input_var.dtype.size
+
+ # Calculate number of elements we want to copy per slice.
+ batch_size_per_slice = np.ceil((64 << 20) / byte_size_per_batch_element)
+
+ # Copy slices of the above size starting at 0, except the last slice will be
+ # smaller.
+ start = 0
+ limit = input_array.shape[0]
+ while start < limit:
+ end = min(start + batch_size_per_slice, limit)
+ K.get_session().run(assign_slice_op, feed_dict={
+ start_placeholder: start,
+ end_placeholder: end,
+ slice_placeholder: input_array[start:end]})
+ start = end
+
+ return input_var
diff --git a/tensorflow/python/keras/engine/training.py b/tensorflow/python/keras/engine/training.py
index 5091cac836..c842b8192e 100644
--- a/tensorflow/python/keras/engine/training.py
+++ b/tensorflow/python/keras/engine/training.py
@@ -20,11 +20,9 @@ from __future__ import print_function
import weakref
import numpy as np
-import six
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
-from tensorflow.python.data.ops.dataset_ops import Dataset
from tensorflow.python.eager import context
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
@@ -814,19 +812,21 @@ class Model(Network):
first_x_value = nest.flatten(x)[0]
if isinstance(first_x_value, np.ndarray):
x_shape = first_x_value.shape
- x_dtype = first_x_value.dtype
if batch_size is None:
batch_size = x_shape[0] // steps
if y is not None:
- first_y_value = nest.flatten(y)[0]
- x = Dataset.from_generator(lambda x=x, y=y: six.moves.zip(x, y),
- output_types=(x_dtype, first_y_value.dtype),
- output_shapes=(x_shape[1:],
- first_y_value.shape[1:]))
+ var_x = distributed_training_utils.get_var_for_numpy(
+ self._distribution_strategy, x)
+ var_y = distributed_training_utils.get_var_for_numpy(
+ self._distribution_strategy, y)
+
+ x = dataset_ops.Dataset.from_tensor_slices((var_x, var_y))
# TODO(anjalisridhar): What should the buffer size be?
x = x.shuffle(10000)
x = x.repeat()
- x = x.batch(batch_size)
+ # We need to use the drop_remainder argument to allow for a static
+ # input shape which is required for TPUs.
+ x = x.batch(batch_size, drop_remainder=True)
y = None
else:
# This case is for the predict call where the dataset only contains
@@ -834,11 +834,13 @@ class Model(Network):
# TODO(anjalisridhar): Raise an error if we are not able to process
# all the predict samples. This can happen if the number of batches is
# not evenly divisible by the number of worker devices.
- x = Dataset.from_generator(lambda x=x: x,
- output_types=x_dtype,
- output_shapes=x_shape[1:])
+ var_x = distributed_training_utils.get_var_for_numpy(
+ self._distribution_strategy, x)
+ x = dataset_ops.Dataset.from_tensor_slices(var_x)
x = x.repeat()
- x = x.batch(batch_size)
+ # We need to use the drop_remainder argument to allow for a static
+ # input shape which is required for TPUs.
+ x = x.batch(batch_size, drop_remainder=True)
# TODO(anjalisridhar): Can we use the iterator and getnext op cache?
# We require users to pass Datasets since we distribute the dataset across
@@ -978,16 +980,18 @@ class Model(Network):
'Make sure that your dataset can generate '
'required number of samples.')
- if (not isinstance(next_element, (list, tuple)) or
- len(next_element) not in [2, 3]):
- raise ValueError(
- 'Please provide model inputs as a list or tuple of 2 or 3'
- 'elements: (input, target) or (input, target, sample_weights)'
- 'Received %s' % next_element)
- if len(next_element) == 2:
- x, y = next_element
+ if isinstance(next_element, (list, tuple)):
+ if len(next_element) not in [2, 3]:
+ raise ValueError(
+ 'Please provide model inputs as a list or tuple of 2 or 3'
+ 'elements: (input, target) or (input, target, sample_weights)'
+ 'Received %s' % next_element)
+ if len(next_element) == 2:
+ x, y = next_element
+ else:
+ x, y, sample_weight = next_element
else:
- x, y, sample_weight = next_element
+ x = next_element
x, y, sample_weights = self._standardize_weights(x, y, sample_weight,
class_weight, batch_size)
return x, y, sample_weights
diff --git a/tensorflow/python/keras/engine/training_distributed.py b/tensorflow/python/keras/engine/training_distributed.py
index a6470458d2..04e8d079c0 100644
--- a/tensorflow/python/keras/engine/training_distributed.py
+++ b/tensorflow/python/keras/engine/training_distributed.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import distribute as distribute_lib
+from tensorflow.python.util import nest
# TODO(priyag, sourabhbajaj): Refactor this file to address code duplication.
@@ -296,15 +297,16 @@ def _experimental_fit_loop(
initial_loop_values[name] = array_ops.zeros(tensor.shape, tensor.dtype)
if steps_per_epoch is None:
- raise ValueError('steps_per_epoch should be specified in the fit call.')
- steps_per_run_var = K.variable(
+ raise ValueError('`steps_per_epoch` should be specified when calling '
+ '`fit` on the model.')
+ steps_per_run = K.variable(
value=min(steps_per_epoch, current_strategy.steps_per_run),
dtype='int32',
- name='steps_per_run_var')
+ name='steps_per_run')
with current_strategy.scope():
ctx = current_strategy.run_steps_on_dataset(
- step_fn, iterator, iterations=steps_per_run_var,
+ step_fn, iterator, iterations=steps_per_run,
initial_loop_values=initial_loop_values)
train_op = ctx.run_op
@@ -344,7 +346,7 @@ def _experimental_fit_loop(
batch_logs = {'batch': step_index, 'size': 1, 'num_steps': step_count}
callbacks.on_batch_begin(step_index, batch_logs)
if prev_step_count is None or step_count != prev_step_count:
- steps_per_run_var.load(step_count, K.get_session())
+ steps_per_run.load(step_count, K.get_session())
prev_step_count = step_count
try:
_, outputs = K.get_session().run([train_op, output_tensors])
@@ -720,13 +722,9 @@ def _experimental_predict_loop(model, iterator, verbose=0, steps=None):
model.predict_function.updates_op,
model.predict_function.session_kwargs)
- def step_fn(ctx, inputs, targets):
+ def step_fn(ctx, *inputs):
"""Clones the model and calls make_predict_function."""
- # TODO(anjalisridhar): Support predict input correctly as it will not
- # contain targets, only inputs.
- del targets
-
# TODO(priyag, sourabhbajaj): The model gets cloned every time
# fit/test/predict is called. We should look into caching this keyed on
# input shapes.
@@ -824,9 +822,10 @@ def _clone_and_build_model(model, inputs=None, targets=None):
# TODO(priyag): Is there a cleaner way to do this? The API doc suggests a
# single tensor should be OK but it throws an error in that case.
- if (targets is not None and not isinstance(targets, list) and
- not isinstance(targets, dict)):
+ if targets is not None and not isinstance(targets, (list, dict, tuple)):
targets = [targets]
+ if isinstance(targets, tuple):
+ targets = nest.flatten(targets)
cloned_model.compile(
optimizer,
model.loss,
@@ -891,11 +890,12 @@ def _get_input_from_iterator(iterator, model):
"""Get elements from the iterator and verify the input shape and type."""
next_element = iterator.get_next()
- if isinstance(next_element, tuple):
- x, y = next_element
- else:
+ if len(nest.flatten(next_element)) == len(model.inputs):
x = next_element
y = None
+ else:
+ x, y = next_element
+
# Validate that all the elements in x and y are of the same type and shape.
# We can then pass the first element of x and y to `_standardize_weights`
# below and be confident of the output.
diff --git a/tensorflow/python/keras/engine/training_eager_test.py b/tensorflow/python/keras/engine/training_eager_test.py
index db7ccb181f..1f5176c4d7 100644
--- a/tensorflow/python/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/engine/training_eager_test.py
@@ -192,6 +192,20 @@ class CorrectnessTest(test.TestCase):
history = model.fit(iterator, epochs=1, steps_per_epoch=10)
self.assertEqual(np.around(history.history['loss'][-1], decimals=4), 0.6173)
+ def test_no_loss_in_call(self):
+
+ class HasLoss(keras.layers.Layer):
+
+ def call(self, x):
+ self.add_loss(x)
+ return x
+
+ layer = HasLoss()
+ with self.assertRaises(RuntimeError):
+ layer(1.)
+
+ with ops.Graph().as_default():
+ layer(1.)
if __name__ == '__main__':
ops.enable_eager_execution()
diff --git a/tensorflow/python/keras/engine/training_test.py b/tensorflow/python/keras/engine/training_test.py
index 30be4131a4..54ad74c08b 100644
--- a/tensorflow/python/keras/engine/training_test.py
+++ b/tensorflow/python/keras/engine/training_test.py
@@ -27,6 +27,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
+from tensorflow.python.eager import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import test_util as tf_test_util
@@ -2427,6 +2428,17 @@ class TestTrainingWithMetrics(test.TestCase):
scores = model.train_on_batch(x, y, sample_weight=w)
self.assertArrayNear(scores, [0.2, 0.8, 0.8], 0.1)
+ def test_losses_in_defun(self):
+ with context.eager_mode():
+ layer = keras.layers.Dense(1, kernel_regularizer='l1')
+ layer(array_ops.ones([1, 10]))
+
+ @function.defun
+ def get_losses():
+ return layer.losses
+
+ self.assertAllEqual(self.evaluate(layer.losses),
+ self.evaluate(get_losses()))
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent.py b/tensorflow/python/keras/layers/cudnn_recurrent.py
index cf2b0c476c..29a09a3d71 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent.py
@@ -47,6 +47,9 @@ class _CuDNNRNN(RNN):
stateful: Boolean (default False). If True, the last state
for each sample at index i in a batch will be used as initial
state for the sample of index i in the following batch.
+ time_major: Boolean (default False). If true, the inputs and outputs will be
+ in shape `(timesteps, batch, ...)`, whereas in the False case, it will
+ be `(batch, timesteps, ...)`.
"""
def __init__(self,
@@ -54,6 +57,7 @@ class _CuDNNRNN(RNN):
return_state=False,
go_backwards=False,
stateful=False,
+ time_major=False,
**kwargs):
# We invoke the base layer's initializer directly here because we do not
# want to create RNN cell instance.
@@ -62,6 +66,7 @@ class _CuDNNRNN(RNN):
self.return_state = return_state
self.go_backwards = go_backwards
self.stateful = stateful
+ self.time_major = time_major
self.supports_masking = False
self.input_spec = [InputSpec(ndim=3)]
if hasattr(self.cell.state_size, '__len__'):
@@ -124,7 +129,8 @@ class _CuDNNRNN(RNN):
'return_sequences': self.return_sequences,
'return_state': self.return_state,
'go_backwards': self.go_backwards,
- 'stateful': self.stateful
+ 'stateful': self.stateful,
+ 'time_major': self.time_major,
}
base_config = super( # pylint: disable=bad-super-call
RNN, self).get_config()
@@ -267,7 +273,8 @@ class CuDNNGRU(_CuDNNRNN):
self.built = True
def _process_batch(self, inputs, initial_state):
- inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
+ if not self.time_major:
+ inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
input_h = initial_state[0]
input_h = array_ops.expand_dims(input_h, axis=0)
@@ -301,7 +308,10 @@ class CuDNNGRU(_CuDNNRNN):
if self.stateful or self.return_state:
h = h[0]
if self.return_sequences:
- output = array_ops.transpose(outputs, perm=(1, 0, 2))
+ if self.time_major:
+ output = outputs
+ else:
+ output = array_ops.transpose(outputs, perm=(1, 0, 2))
else:
output = outputs[-1]
return output, [h]
@@ -456,7 +466,8 @@ class CuDNNLSTM(_CuDNNRNN):
self.built = True
def _process_batch(self, inputs, initial_state):
- inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
+ if not self.time_major:
+ inputs = array_ops.transpose(inputs, perm=(1, 0, 2))
input_h = initial_state[0]
input_c = initial_state[1]
input_h = array_ops.expand_dims(input_h, axis=0)
@@ -496,7 +507,10 @@ class CuDNNLSTM(_CuDNNRNN):
h = h[0]
c = c[0]
if self.return_sequences:
- output = array_ops.transpose(outputs, perm=(1, 0, 2))
+ if self.time_major:
+ output = outputs
+ else:
+ output = array_ops.transpose(outputs, perm=(1, 0, 2))
else:
output = outputs[-1]
return output, [h, c]
diff --git a/tensorflow/python/keras/layers/cudnn_recurrent_test.py b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
index 2ed0aa8f26..7becbfede1 100644
--- a/tensorflow/python/keras/layers/cudnn_recurrent_test.py
+++ b/tensorflow/python/keras/layers/cudnn_recurrent_test.py
@@ -26,6 +26,7 @@ import numpy as np
from tensorflow.python import keras
from tensorflow.python.framework import test_util
from tensorflow.python.keras import testing_utils
+from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
from tensorflow.python.training.rmsprop import RMSPropOptimizer
@@ -142,6 +143,32 @@ class CuDNNTest(test.TestCase, parameterized.TestCase):
('cudnngru', keras.layers.CuDNNGRU),
('cudnnlstm', keras.layers.CuDNNLSTM),
)
+ def test_time_major_input(self, layer_class):
+ if test.is_gpu_available(cuda_only=True):
+ with self.test_session(use_gpu=True):
+ input_size = 10
+ timesteps = 6
+ units = 2
+ num_samples = 32
+
+ model = keras.models.Sequential()
+ model.add(
+ keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2])))
+ layer = layer_class(units, time_major=True, return_sequences=True)
+ model.add(layer)
+ model.add(
+ keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2])))
+ model.compile(loss='categorical_crossentropy', optimizer='adam')
+ model.fit(
+ np.ones((num_samples, timesteps, input_size)),
+ np.ones((num_samples, timesteps, units)))
+ out = model.predict(np.ones((num_samples, timesteps, input_size)))
+ self.assertEqual(out.shape, (num_samples, timesteps, units))
+
+ @parameterized.named_parameters(
+ ('cudnngru', keras.layers.CuDNNGRU),
+ ('cudnnlstm', keras.layers.CuDNNLSTM),
+ )
def test_specify_initial_state_keras_tensor(self, layer_class):
if test.is_gpu_available(cuda_only=True):
with self.test_session(use_gpu=True):
diff --git a/tensorflow/python/keras/layers/embeddings.py b/tensorflow/python/keras/layers/embeddings.py
index c6df5f2e26..824a0b069e 100644
--- a/tensorflow/python/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/layers/embeddings.py
@@ -159,13 +159,15 @@ class Embedding(Layer):
else:
in_lens = [self.input_length]
if len(in_lens) != len(input_shape) - 1:
- ValueError('"input_length" is %s, but received input has shape %s' %
- (str(self.input_length), str(input_shape)))
+ raise ValueError('"input_length" is %s, '
+ 'but received input has shape %s' % (str(
+ self.input_length), str(input_shape)))
else:
for i, (s1, s2) in enumerate(zip(in_lens, input_shape[1:])):
if s1 is not None and s2 is not None and s1 != s2:
- ValueError('"input_length" is %s, but received input has shape %s' %
- (str(self.input_length), str(input_shape)))
+ raise ValueError('"input_length" is %s, '
+ 'but received input has shape %s' % (str(
+ self.input_length), str(input_shape)))
elif s1 is None:
in_lens[i] = s2
return (input_shape[0],) + tuple(in_lens) + (self.output_dim,)
diff --git a/tensorflow/python/keras/layers/recurrent.py b/tensorflow/python/keras/layers/recurrent.py
index ba7498e7e6..b07ec71178 100644
--- a/tensorflow/python/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/layers/recurrent.py
@@ -336,9 +336,18 @@ class RNN(Layer):
in your model, you would need to specify the input length
at the level of the first layer
(e.g. via the `input_shape` argument)
+ time_major: The shape format of the `inputs` and `outputs` tensors.
+ If True, the inputs and outputs will be in shape
+ `(timesteps, batch, ...)`, whereas in the False case, it will be
+ `(batch, timesteps, ...)`. Using `time_major = True` is a bit more
+ efficient because it avoids transposes at the beginning and end of the
+ RNN calculation. However, most TensorFlow data is batch-major, so by
+ default this function accepts input and emits output in batch-major
+ form.
Input shape:
- N-D tensor with shape `(batch_size, timesteps, ...)`.
+ N-D tensor with shape `(batch_size, timesteps, ...)` or
+ `(timesteps, batch_size, ...)` when time_major is True.
Output shape:
- if `return_state`: a list of tensors. The first tensor is
@@ -347,7 +356,8 @@ class RNN(Layer):
be a high dimension tensor shape.
- if `return_sequences`: N-D tensor with shape
`(batch_size, timesteps, output_size)`, where `output_size` could
- be a high dimension tensor shape.
+ be a high dimension tensor shape, or
+ `(timesteps, batch_size, output_size)` when `time_major` is True.
- else, N-D tensor with shape `(batch_size, output_size)`, where
`output_size` could be a high dimension tensor shape.
@@ -448,6 +458,7 @@ class RNN(Layer):
go_backwards=False,
stateful=False,
unroll=False,
+ time_major=False,
**kwargs):
if isinstance(cell, (list, tuple)):
cell = StackedRNNCells(cell)
@@ -468,6 +479,7 @@ class RNN(Layer):
self.go_backwards = go_backwards
self.stateful = stateful
self.unroll = unroll
+ self.time_major = time_major
self.supports_masking = True
self.input_spec = [None] # The input shape is unknown yet, at least rank 3.
@@ -503,14 +515,21 @@ class RNN(Layer):
# Note that state_size[0] could be a tensor_shape or int.
output_dim = tensor_shape.as_shape(state_size[0]).as_list()
+ batch = input_shape[0]
+ time_step = input_shape[1]
+ if self.time_major:
+ batch, time_step = time_step, batch
if self.return_sequences:
- output_shape = tuple([input_shape[0], input_shape[1]] + output_dim)
+ if self.time_major:
+ output_shape = tuple([time_step, batch] + output_dim)
+ else:
+ output_shape = tuple([batch, time_step] + output_dim)
else:
- output_shape = tuple([input_shape[0]] + output_dim)
+ output_shape = tuple([batch] + output_dim)
if self.return_state:
state_shape = [
- tuple([input_shape[0]] + tensor_shape.as_shape(dim).as_list())
+ tuple([batch] + tensor_shape.as_shape(dim).as_list())
for dim in state_size
]
return [output_shape] + state_shape
@@ -539,13 +558,18 @@ class RNN(Layer):
if isinstance(input_shape, list):
input_shape = input_shape[0]
- batch_size = input_shape[0] if self.stateful else None
- input_dim = input_shape[2:]
- self.input_spec[0] = InputSpec(shape=(batch_size, None) + input_dim)
+ input_spec_shape = list(input_shape)
+ batch_index, time_step_index = (1, 0) if self.time_major else (0, 1)
+ if not self.stateful:
+ input_spec_shape[batch_index] = None
+ input_spec_shape[time_step_index] = None
+ self.input_spec[0] = InputSpec(shape=tuple(input_spec_shape))
+ batch = input_shape[batch_index]
+ input_dim = input_shape[2:]
+ step_input_shape = (batch,) + input_dim
# allow cell (if layer) to build before we set or validate state_spec
if isinstance(self.cell, Layer):
- step_input_shape = (input_shape[0],) + input_dim
if constants_shape is not None:
self.cell.build([step_input_shape] + constants_shape)
else:
@@ -598,12 +622,16 @@ class RNN(Layer):
def get_initial_state(self, inputs):
get_initial_state_fn = getattr(self.cell, 'get_initial_state', None)
+
+ input_shape = array_ops.shape(inputs)
+ batch_size = input_shape[1] if self.time_major else input_shape[0]
+ dtype = inputs.dtype
if get_initial_state_fn:
init_state = get_initial_state_fn(
- inputs=inputs, batch_size=None, dtype=None)
+ inputs=None, batch_size=batch_size, dtype=dtype)
else:
- init_state = _generate_zero_filled_state(
- array_ops.shape(inputs)[0], self.cell.state_size, inputs.dtype)
+ init_state = _generate_zero_filled_state(batch_size, self.cell.state_size,
+ dtype)
# Keras RNN expect the states in a list, even if it's a single state tensor.
if not nest.is_sequence(init_state):
init_state = [init_state]
@@ -696,7 +724,7 @@ class RNN(Layer):
'Layer has ' + str(len(self.states)) + ' states but was passed ' +
str(len(initial_state)) + ' initial states.')
input_shape = K.int_shape(inputs)
- timesteps = input_shape[1]
+ timesteps = input_shape[0] if self.time_major else input_shape[1]
if self.unroll and timesteps in [None, 1]:
raise ValueError('Cannot unroll a RNN if the '
'time dimension is undefined or equal to 1. \n'
@@ -747,7 +775,8 @@ class RNN(Layer):
go_backwards=self.go_backwards,
mask=mask,
unroll=self.unroll,
- input_length=timesteps)
+ input_length=timesteps,
+ time_major=self.time_major)
if self.stateful:
updates = []
for i in range(len(states)):
@@ -777,7 +806,10 @@ class RNN(Layer):
def reset_states(self, states=None):
if not self.stateful:
raise AttributeError('Layer must be stateful.')
- batch_size = self.input_spec[0].shape[0]
+ if self.time_major:
+ batch_size = self.input_spec[0].shape[1]
+ else:
+ batch_size = self.input_spec[0].shape[0]
if not batch_size:
raise ValueError('If a RNN is stateful, it needs to know '
'its batch size. Specify the batch size '
@@ -839,7 +871,8 @@ class RNN(Layer):
'return_state': self.return_state,
'go_backwards': self.go_backwards,
'stateful': self.stateful,
- 'unroll': self.unroll
+ 'unroll': self.unroll,
+ 'time_major': self.time_major
}
if self._num_constants is not None:
config['num_constants'] = self._num_constants
diff --git a/tensorflow/python/keras/layers/recurrent_test.py b/tensorflow/python/keras/layers/recurrent_test.py
index b9e90095e4..d246be6b45 100644
--- a/tensorflow/python/keras/layers/recurrent_test.py
+++ b/tensorflow/python/keras/layers/recurrent_test.py
@@ -186,6 +186,96 @@ class RNNTest(test.TestCase):
y_np_2 = model.predict(x_np)
self.assertAllClose(y_np, y_np_2, atol=1e-4)
+ def test_rnn_with_time_major(self):
+ batch = 10
+ time_step = 5
+ embedding_dim = 4
+ units = 3
+
+ with self.cached_session():
+ # Test basic case.
+ x = keras.Input((time_step, embedding_dim))
+ time_major_x = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(x)
+ layer = keras.layers.SimpleRNN(
+ units, time_major=True, return_sequences=True)
+ self.assertEqual(
+ layer.compute_output_shape((time_step, None,
+ embedding_dim)).as_list(),
+ [time_step, None, units])
+ y = layer(time_major_x)
+ self.assertEqual(layer.output_shape, (time_step, None, units))
+
+ y = keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))(y)
+
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ np.zeros((batch, time_step, embedding_dim)),
+ np.zeros((batch, time_step, units)))
+
+ with self.cached_session():
+ # Test stacking.
+ x = keras.Input((time_step, embedding_dim))
+ time_major_x = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(x)
+ cell_units = [10, 8, 6]
+ cells = [keras.layers.SimpleRNNCell(cell_units[i]) for i in range(3)]
+ layer = keras.layers.RNN(cells, time_major=True, return_sequences=True)
+ y = layer(time_major_x)
+ self.assertEqual(layer.output_shape, (time_step, None, cell_units[-1]))
+
+ y = keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))(y)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ np.zeros((batch, time_step, embedding_dim)),
+ np.zeros((batch, time_step, cell_units[-1])))
+
+ with self.cached_session():
+ # Test masking.
+ x = keras.Input((time_step, embedding_dim))
+ time_major = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(x)
+ mask = keras.layers.Masking()(time_major)
+ rnn = keras.layers.SimpleRNN(
+ units, time_major=True, return_sequences=True)(mask)
+ y = keras.layers.Lambda(lambda t: array_ops.transpose(t, [1, 0, 2]))(rnn)
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ np.zeros((batch, time_step, embedding_dim)),
+ np.zeros((batch, time_step, units)))
+
+ with self.cached_session():
+ # Test layer output
+ x = keras.Input((time_step, embedding_dim))
+ rnn_1 = keras.layers.SimpleRNN(units, return_sequences=True)
+ y = rnn_1(x)
+
+ model = keras.models.Model(x, y)
+ model.compile(optimizer='rmsprop', loss='mse')
+ model.train_on_batch(
+ np.zeros((batch, time_step, embedding_dim)),
+ np.zeros((batch, time_step, units)))
+
+ x_np = np.random.random((batch, time_step, embedding_dim))
+ y_np_1 = model.predict(x_np)
+
+ time_major = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(x)
+ rnn_2 = keras.layers.SimpleRNN(
+ units, time_major=True, return_sequences=True)
+ y_2 = rnn_2(time_major)
+ y_2 = keras.layers.Lambda(
+ lambda t: array_ops.transpose(t, [1, 0, 2]))(y_2)
+
+ model_2 = keras.models.Model(x, y_2)
+ rnn_2.set_weights(rnn_1.get_weights())
+
+ y_np_2 = model_2.predict(x_np)
+ self.assertAllClose(y_np_1, y_np_2, atol=1e-4)
+
def test_rnn_cell_with_constants_layer(self):
class RNNCellWithConstants(keras.layers.Layer):
diff --git a/tensorflow/python/keras/models.py b/tensorflow/python/keras/models.py
index b04b4df257..2883c9ad74 100644
--- a/tensorflow/python/keras/models.py
+++ b/tensorflow/python/keras/models.py
@@ -96,6 +96,8 @@ def _clone_functional_model(model, input_tensors=None):
else:
# Make sure that all input tensors come from a Keras layer.
# If tensor comes from an input layer: cache the input layer.
+ if isinstance(input_tensors, tuple):
+ input_tensors = list(input_tensors)
input_tensors = generic_utils.to_list(input_tensors)
input_tensors_ = []
for i, x in enumerate(input_tensors):
@@ -212,6 +214,9 @@ def _clone_sequential_model(model, input_tensors=None):
raise ValueError('To clone a `Sequential` model, we expect '
' at most one tensor '
'as part of `input_tensors`.')
+
+ if isinstance(input_tensors, tuple):
+ input_tensors = list(input_tensors)
x = generic_utils.to_list(input_tensors)[0]
if K.is_keras_tensor(x):
origin_layer = x._keras_history[0]
diff --git a/tensorflow/python/keras/preprocessing/image_test.py b/tensorflow/python/keras/preprocessing/image_test.py
index 362cbc1dc9..4abaadfcd3 100644
--- a/tensorflow/python/keras/preprocessing/image_test.py
+++ b/tensorflow/python/keras/preprocessing/image_test.py
@@ -94,43 +94,6 @@ class TestImage(test.TestCase):
self.assertEqual(x.shape[1:], images.shape[1:])
break
- def test_image_data_generator_with_validation_split(self):
- if PIL is None:
- return # Skip test if PIL is not available.
-
- for test_images in _generate_test_images():
- img_list = []
- for im in test_images:
- img_list.append(keras.preprocessing.image.img_to_array(im)[None, ...])
-
- images = np.vstack(img_list)
- generator = keras.preprocessing.image.ImageDataGenerator(
- validation_split=0.5)
- seq = generator.flow(
- images,
- np.arange(images.shape[0]),
- shuffle=False,
- batch_size=3,
- subset='validation')
- _, y = seq[0]
- self.assertEqual(list(y), [0, 1, 2])
- seq = generator.flow(
- images,
- np.arange(images.shape[0]),
- shuffle=False,
- batch_size=3,
- subset='training')
- _, y2 = seq[0]
- self.assertEqual(list(y2), [4, 5, 6])
-
- with self.assertRaises(ValueError):
- generator.flow(
- images,
- np.arange(images.shape[0]),
- shuffle=False,
- batch_size=3,
- subset='foo')
-
def test_image_data_generator_with_split_value_error(self):
with self.assertRaises(ValueError):
keras.preprocessing.image.ImageDataGenerator(validation_split=5)
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 280c18ec00..9490746fd9 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1480,7 +1480,7 @@ cuda_py_test(
name = "control_flow_ops_py_test",
# TODO(b/70473603): change this back to "small" once the C API is
# permanently enabled
- size = "medium",
+ size = "large",
srcs = ["control_flow_ops_py_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -1512,6 +1512,7 @@ cuda_py_test(
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
+ "//tensorflow/python:while_v2",
],
)
@@ -2358,7 +2359,7 @@ cuda_py_test(
cuda_py_test(
name = "transpose_op_test",
- size = "large",
+ size = "medium",
srcs = ["transpose_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2366,10 +2367,11 @@ cuda_py_test(
"//tensorflow/python:client_testlib",
"//tensorflow/python:framework_for_generated_wrappers",
],
- shard_count = 2,
+ shard_count = 4,
tags = [
"no_gpu",
"no_oss",
+ "optonly", # times out
],
)
@@ -2488,6 +2490,7 @@ cuda_py_test(
"//tensorflow/python:nn_grad",
"//tensorflow/python:nn_ops",
],
+ shard_count = 2,
tags = [
"optonly", # flaky timeouts unless optimized
],
@@ -2508,7 +2511,7 @@ cuda_py_test(
cuda_py_test(
name = "conv_ops_test",
- size = "large",
+ size = "medium",
srcs = ["conv_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2527,6 +2530,9 @@ cuda_py_test(
"//tensorflow/python:variables",
],
shard_count = 4,
+ tags = [
+ "optonly", # times out
+ ],
)
cuda_py_test(
@@ -2586,7 +2592,7 @@ cuda_py_test(
cuda_py_test(
name = "fft_ops_test",
- size = "large",
+ size = "medium",
srcs = ["fft_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2596,7 +2602,8 @@ cuda_py_test(
"//tensorflow/python:spectral_ops",
"//tensorflow/python:spectral_ops_test_util",
],
- shard_count = 3,
+ shard_count = 4,
+ tags = ["optonly"],
)
cuda_py_test(
@@ -2661,7 +2668,7 @@ cuda_py_test(
cuda_py_test(
name = "scatter_ops_test",
- size = "large", # NOTE: This is not run by default.
+ size = "medium", # NOTE: This is not run by default.
srcs = ["scatter_ops_test.py"],
additional_deps = [
"//third_party/py/numpy",
@@ -2670,11 +2677,13 @@ cuda_py_test(
"//tensorflow/python:state_ops",
"//tensorflow/python:variables",
],
+ shard_count = 2,
+ tags = ["optonly"],
)
cuda_py_test(
name = "slice_op_test",
- size = "large",
+ size = "medium",
srcs = ["slice_op_test.py"],
additional_deps = [
"//third_party/py/numpy",
diff --git a/tensorflow/python/kernel_tests/array_ops_test.py b/tensorflow/python/kernel_tests/array_ops_test.py
index c5547b19be..dcc594789e 100644
--- a/tensorflow/python/kernel_tests/array_ops_test.py
+++ b/tensorflow/python/kernel_tests/array_ops_test.py
@@ -615,6 +615,14 @@ class StridedSliceTest(test_util.TensorFlowTestCase):
_ = checker[:, 0]
_ = checker[:, :, 0]
+ def testBothNewAxisAndShrink(self):
+ with self.test_session(use_gpu=True):
+ ones = array_ops.placeholder(shape=[2, 2], dtype=dtypes.int16)
+ self.assertAllEqual(
+ ones[array_ops.newaxis, :, 0].eval(
+ feed_dict={ones: [[1, 1], [1, 1]]}),
+ [[1, 1]])
+
def testTensorIndexing(self):
with self.test_session(use_gpu=True):
raw = [[[[[1, 2, 4, 5], [5, 6, 7, 8], [9, 10, 11, 12]]],
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index 083de84775..655fece5ff 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -23,7 +23,6 @@ from __future__ import print_function
import collections
import math
import time
-import unittest
import numpy as np
from six.moves import xrange # pylint: disable=redefined-builtin
@@ -32,6 +31,7 @@ from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
from tensorflow.python.eager import context
+from tensorflow.python.eager import function as eager_function
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors_impl
@@ -63,6 +63,7 @@ from tensorflow.python.ops import script_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
+from tensorflow.python.ops import while_v2 # pylint: disable=unused-import
# pylint: disable=unused-import
import tensorflow.python.ops.tensor_array_grad
# pylint: enable=unused-import
@@ -125,7 +126,7 @@ def isum(s, maximum_iterations=None):
return r_s
-@test_util.with_cond_v2
+@test_util.with_control_flow_v2
class ControlFlowTest(test.TestCase):
def testRefIdentity(self):
@@ -332,10 +333,8 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesOpError("has inputs from different frames"):
res.eval(feed_dict={data: 1.0})
+ @test_util.disable_control_flow_v2("b/113294340")
def testCondBool(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296297")
-
values = constant_op.constant(10)
fn1 = lambda: math_ops.add(values, 1)
fn2 = lambda: math_ops.subtract(values, 1)
@@ -366,6 +365,7 @@ class ControlFlowTest(test.TestCase):
"has been marked as not fetchable"):
sess.run(t, feed_dict={x: 3})
+ @test_util.disable_control_flow_v2("Not relevant")
def testFeedable(self):
with self.cached_session() as sess:
c = constant_op.constant(2)
@@ -383,10 +383,8 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesRegexp(ValueError, "may not be fed"):
sess.run(r, feed_dict={t: 3})
+ @test_util.disable_control_flow_v2("b/113296180 (IndexedSlices)")
def testCondIndexedSlices(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296180")
-
with self.cached_session():
values = constant_op.constant(10)
indices = constant_op.constant(0)
@@ -401,10 +399,8 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(11, val)
self.assertAllEqual(0, ind)
+ @test_util.disable_control_flow_v2("b/113296161 (SparseTensors)")
def testCondSparseTensor(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113296161 (SparseTensors)")
-
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
indices = constant_op.constant(
@@ -435,10 +431,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(1.0, control_flow_ops.cond(rv, case, lambda: t).eval())
+ @test_util.disable_control_flow_v2("b/113293074")
def testCondIndexedSlicesDifferentTypes(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113293074")
-
with self.cached_session():
values = constant_op.constant(10)
i_32 = ops.convert_to_tensor(0, name="one", dtype=dtypes.int32)
@@ -510,10 +504,8 @@ class ControlFlowTest(test.TestCase):
result = r.eval()
self.assertAllEqual(12, result)
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testCond_4(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v1 = variables.Variable(7)
v2 = variables.Variable(7)
@@ -587,10 +579,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(constant_op.constant(False), true_fn, false_fn)
self.assertAllEqual([2.0], r.eval())
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testCondWithControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/79881896")
-
with self.cached_session():
control_holder = array_ops.placeholder(dtypes.float32, shape=())
a = constant_op.constant(3)
@@ -629,10 +619,9 @@ class ControlFlowTest(test.TestCase):
merged_op = control_flow_ops.merge([assign_v, orig_v])
self.assertAllEqual([1.0], sess.run(merged_op.output))
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCondSwitchIdentity(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
# Make sure the recv identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
pred = constant_op.constant(True)
@@ -646,10 +635,9 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCondRecvIdentity(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
# Make sure the switch identity is not removed by optimization.
with session.Session(config=opt_cfg()) as sess:
with ops.device(test.gpu_device_name()):
@@ -665,10 +653,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(pred, fn1, fn2)
sess.run(r)
+ @test_util.disable_control_flow_v2("b/113346829 (gpu failure)")
def testCondGrad_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113346829 (gpu failure)")
-
graph = ops.Graph()
with graph.as_default():
x = constant_op.constant(10.0, name="x")
@@ -694,10 +680,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(42.0, grad.eval(feed_dict={c: 1}))
self.assertAllEqual(3.0, grad.eval(feed_dict={c: 3}))
+ @test_util.disable_control_flow_v2(
+ "b/110550782 (gradient w.r.t external variable)")
def testCondGrad_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/110550782 (gradient w.r.t external variable)")
-
with self.cached_session():
c = array_ops.placeholder(dtypes.int32, shape=[])
ox = constant_op.constant(10.0)
@@ -729,10 +714,8 @@ class ControlFlowTest(test.TestCase):
result = gradients_impl.gradients(z, x)[0]
self.assertEqual(1.0, result.eval())
+ @test_util.disable_control_flow_v2("b/113327884")
def testCondGrad_Gather(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113327884")
-
with self.cached_session() as sess:
v1 = variables.Variable([1.0, 42.0])
c = array_ops.placeholder(dtypes.int32, shape=[])
@@ -756,6 +739,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(dense_gv, [0.0, 2.0])
# Microbenchmark: 256,000 iterations/s.
+ @test_util.disable_control_flow_v2("b/116630618 (Times out)")
def testWhile_1(self):
with self.cached_session():
n = constant_op.constant(0)
@@ -764,6 +748,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testWhileExternalControlDependencies(self):
with self.cached_session():
v = variables.Variable(0.0)
@@ -779,6 +764,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(result.eval(), 2)
self.assertAllEqual(v.eval(), 1.0)
+ @test_util.disable_control_flow_v2("b/79881896 (control deps)")
def testWhileExternalControlDependenciesNoInput(self):
with self.cached_session():
v = variables.Variable(0.0)
@@ -794,6 +780,7 @@ class ControlFlowTest(test.TestCase):
result.eval()
self.assertAllEqual(v.eval(), 1.0)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileWithRefs_1(self):
with self.cached_session() as sess:
x = variables.VariableV1(0)._ref() # pylint: disable=protected-access
@@ -824,18 +811,22 @@ class ControlFlowTest(test.TestCase):
r = isum(s)
self.assertAllEqual(45, r.eval())
+ @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
def testWhileWithMaximumIterations(self):
with self.cached_session():
s = constant_op.constant([1, 2, 3, 4, 5])
r = isum(s, maximum_iterations=3)
self.assertAllEqual([1 + 3, 2 + 3, 3 + 3, 4 + 3, 5 + 3], r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithMaximumIterationsAndSingleArgument(self):
with self.cached_session():
r = control_flow_ops.while_loop(
lambda i: i < 3, lambda i: i + 1, [0], maximum_iterations=1)
self.assertEqual(1, r.eval())
+ @test_util.disable_control_flow_v2(
+ "b/116248044 (nested), b/115920078 (gradients)")
def testSingleNestedMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
@@ -861,6 +852,7 @@ class ControlFlowTest(test.TestCase):
# Should execute without issue.
self.assertEqual(3, self.evaluate(loop_execute))
+ @test_util.disable_control_flow_v2("b/116248044 (nested while_loop)")
def testInvalidMaximumIterationsWhileLoopGradientInXLAContext(self):
v = constant_op.constant(1.0)
@@ -904,10 +896,8 @@ class ControlFlowTest(test.TestCase):
r"context '.*' \(currently defined in '.*'\)"):
_ = gradients_impl.gradients(loop_with_maxiter, v)
+ @test_util.disable_control_flow_v2("b/115776323 (max_iters)")
def testInvalidMaximumIterationsFromSiblingContextWhileLoopInXLAContext(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
-
v = constant_op.constant(1.0)
def create_while_loop():
@@ -939,6 +929,8 @@ class ControlFlowTest(test.TestCase):
r"while loop context '' \(currently defined in 'cond/.+'\)"):
_ = gradients_impl.gradients(loop, v)
+ @test_util.disable_control_flow_v2(
+ "b/116248044 (nesting), b/115776323 (max_iters)")
def testNestedWhileLoopWithMaxItersFromOuterContextInXLAContext(self):
v = constant_op.constant(1.0)
@@ -1072,6 +1064,7 @@ class ControlFlowTest(test.TestCase):
result = r[2].eval()
self.assertAllEqual(np.array([0, 1, 2, 3, 4, 5, 6]), result)
+ @test_util.disable_control_flow_v2("b/116338794 (buffer_reuse)")
def testBufferForwarding(self):
run_options = config_pb2.RunOptions(
trace_level=config_pb2.RunOptions.FULL_TRACE)
@@ -1139,6 +1132,7 @@ class ControlFlowTest(test.TestCase):
r = r[1] * array_ops.ones([8, 8])
self.assertAllEqual(np.ones((8, 8)), r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithNonTensorInput_Scalar(self):
with self.cached_session():
n = 0
@@ -1147,6 +1141,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n], parallel_iterations=20)
self.assertEqual(10000, r.eval())
+ @test_util.disable_control_flow_v2("b/116339888 (non-tensor loop var)")
def testWhileWithNonTensorInput_Vector(self):
with self.cached_session():
n = np.array([0]) # Note, [0] would not work here; that is a list
@@ -1169,7 +1164,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(
c, b, [i, m],
[i.get_shape(), tensor_shape.TensorShape([None, 2])])
- self.assertTrue(r[1].get_shape()[0].value is None)
+ self.assertIsNone(r[1].get_shape()[0].value)
self.assertEqual(r[1].get_shape()[1], tensor_shape.Dimension(2))
with self.assertRaisesRegexp(
@@ -1180,6 +1175,7 @@ class ControlFlowTest(test.TestCase):
r"tf.while_loop to specify a less-specific shape."):
r = control_flow_ops.while_loop(c, b, [i, m])
+ @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
def testWhileShapeInferenceSparseTensor(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -1211,6 +1207,7 @@ class ControlFlowTest(test.TestCase):
c, b, [i, x],
[i.get_shape(), tensor_shape.TensorShape([5])])
+ @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
def testWhileShapeInferenceIndexedSlices(self):
with self.cached_session():
values = constant_op.constant([[2.0, 4.0], [3.0, 5.0]], name="values")
@@ -1265,6 +1262,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n])
self.assertEqual(225, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhile_1(self):
self._testNestedWhile_1(use_gpu=False)
self._testNestedWhile_1(use_gpu=True)
@@ -1297,6 +1295,7 @@ class ControlFlowTest(test.TestCase):
outer_c, outer_b, [s0], parallel_iterations=1)
self.assertEqual(1048576.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhile_2(self):
self._testNestedWhile_2(use_gpu=False)
self._testNestedWhile_2(use_gpu=True)
@@ -1350,6 +1349,7 @@ class ControlFlowTest(test.TestCase):
lambda x: x < 10, lambda x: x + array_ops.identity(c), [x0])
self.assertEqual(10, sess.run(r, {b: True}))
+ @test_util.disable_control_flow_v2("b/79881896 (control_deps)")
def testWhileWithControl_5(self):
with self.cached_session() as sess:
b = array_ops.placeholder(dtypes.bool)
@@ -1363,10 +1363,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(lambda x: x < 10, body, [x0])
self.assertEqual(10, sess.run(r, {b: True}))
+ @test_util.disable_control_flow_v2("b/116134862 (cond output shape)")
def testWhileCondWithControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
-
# Ensure that no control edges by an outer control dependency context are
# added to nodes inside cond/while contexts.
with self.cached_session() as sess:
@@ -1380,10 +1378,8 @@ class ControlFlowTest(test.TestCase):
(constant_op.constant(5),))
self.assertEqual(0, sess.run(loop))
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testWhileCondWithControl_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v = variable_scope.get_variable(
"v", [], initializer=init_ops.constant_initializer(2))
@@ -1405,9 +1401,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(4, r.eval())
self.assertAllClose(65536.0, v.eval())
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testWhileCondExitControl(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
v = variables.Variable(1)
@@ -1432,8 +1427,6 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(99, v.eval())
def testCondWhile_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1445,8 +1438,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def testCondWhile_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -1458,9 +1449,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
def _testCondWhile_3(self, use_gpu):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294340 (enable while_v2)")
-
with self.test_session(use_gpu=use_gpu) as sess:
p = array_ops.placeholder(dtypes.bool)
n = constant_op.constant(0.0)
@@ -1477,18 +1465,18 @@ class ControlFlowTest(test.TestCase):
lambda: control_flow_ops.while_loop(c, b, [n]),
lambda: math_ops.multiply(n, 2.0))
r1 = gradients_impl.gradients(r, [n])
- self.assertEqual(10, sess.run(r, {p: True}))
+ self.assertEqual(10., sess.run(r, {p: True}))
self.assertEqual([1.0], sess.run(r1, {p: True}))
self.assertEqual(0.0, sess.run(r, {p: False}))
self.assertEqual([2.0], sess.run(r1, {p: False}))
+ @test_util.disable_control_flow_v2("b/116743589")
def testCondWhile_3(self):
self._testCondWhile_3(use_gpu=False)
self._testCondWhile_3(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116134862 (cond output shape)")
def testWhileCond_1(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
i = ops.convert_to_tensor(0, name="i")
@@ -1504,9 +1492,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [i])
self.assertAllEqual(10, r.eval())
+ @test_util.disable_control_flow_v2("b/116134862 (cond output shape)")
def testWhileCond_2(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
n = ops.convert_to_tensor(0, name="n")
@@ -1515,9 +1502,8 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.while_loop(c, b, [n])
self.assertAllEqual(10, r.eval())
+ @test_util.disable_control_flow_v2("b/116134862 (cond output shape)")
def testWhileCond_3(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.cached_session():
n = ops.convert_to_tensor(0)
@@ -1532,6 +1518,7 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(10, r.eval())
# NOTE: It is ok to have parallel_iterations > 1
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_1(self):
with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
@@ -1554,6 +1541,7 @@ class ControlFlowTest(test.TestCase):
result = select.eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_2(self):
with self.cached_session():
select1 = variables.Variable([3.0, 4.0, 5.0])
@@ -1580,6 +1568,7 @@ class ControlFlowTest(test.TestCase):
result2 = select2.eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result2)
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_3(self):
with self.cached_session():
select = variables.Variable([3.0, 4.0, 5.0])
@@ -1601,7 +1590,7 @@ class ControlFlowTest(test.TestCase):
result = r[1].eval()
self.assertAllClose(np.array([10.0, 10.0, 10.0]), result)
- # b/24814703
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_4(self):
with self.cached_session():
var_a = variables.Variable(0, name="a")
@@ -1629,7 +1618,7 @@ class ControlFlowTest(test.TestCase):
lpa.eval() # Run the loop
self.assertEqual(10, var_b.eval())
- # b/24736492
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_5(self):
with self.cached_session():
# Create some variables.
@@ -1659,7 +1648,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(10, var_a.eval())
self.assertEqual(10, var_b.eval())
- # b/24814668
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileUpdateVariable_6(self):
with self.cached_session():
# Create some variables.
@@ -1689,6 +1678,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(55, var_b.eval())
self.assertEqual(10, var_a.eval())
+ @test_util.disable_control_flow_v2("b/116742472 (resource accumulator)")
def testWhileQueue_1(self):
with self.cached_session():
q = data_flow_ops.FIFOQueue(-1, dtypes.int32)
@@ -1707,6 +1697,7 @@ class ControlFlowTest(test.TestCase):
for i in xrange(10):
self.assertEqual([i], q.dequeue().eval())
+ @test_util.disable_control_flow_v2("b/117119329 (stack)")
def testWhileStack_1(self):
with self.cached_session():
s = gen_data_flow_ops.stack_v2(-1, dtypes.int32, stack_name="foo")
@@ -1775,6 +1766,7 @@ class ControlFlowTest(test.TestCase):
with self.session(graph=graph) as sess:
self.assertAllClose(1024.0, sess.run(r))
+ @test_util.disable_control_flow_v2("b/116351701 (colocation)")
def testWhileGrad_ColocateGradients(self):
self._testWhileGrad_ColocateGradients(colocate=False)
self._testWhileGrad_ColocateGradients(colocate=True)
@@ -1861,8 +1853,6 @@ class ControlFlowTest(test.TestCase):
self._testWhileGrad_Mul(use_gpu=True, p_iters=10)
def _testNestedWhileCondWhileGrad(self, use_gpu):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
with self.test_session(use_gpu=use_gpu):
v = constant_op.constant(1.0)
@@ -1885,10 +1875,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testNestedWhileCondWhileGrad(self):
self._testNestedWhileCondWhileGrad(use_gpu=False)
self._testNestedWhileCondWhileGrad(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116823782")
def testWhileGrad_Variable(self):
with self.cached_session():
a = variables.Variable(3.0)
@@ -1902,8 +1894,6 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(216.0, r[0].eval())
def testWhileGradInCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/110550782 (gradient w.r.t external variable)")
with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
@@ -1919,6 +1909,7 @@ class ControlFlowTest(test.TestCase):
r = control_flow_ops.cond(math_ops.less(1, 2), fn1, lambda: x)
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ @test_util.disable_control_flow_v2("b/116340060")
def testGradInWhileWrtInitialLoopVal(self):
with self.cached_session():
x = array_ops.placeholder(dtypes.float32, shape=(), name="x")
@@ -1936,6 +1927,7 @@ class ControlFlowTest(test.TestCase):
"loop invariants or wrt the input parameters to the loop body."):
control_flow_ops.while_loop(lambda i, x: i < 3, body, [0, y])
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testWhileGradInWhile(self):
with self.cached_session():
n = ops.convert_to_tensor(1.0, name="n")
@@ -1952,9 +1944,8 @@ class ControlFlowTest(test.TestCase):
[tensor_shape.unknown_shape()])
self.assertAllClose(9.0, r.eval(feed_dict={x: 1.0}))
+ @test_util.disable_control_flow_v2("b/116248044 (nested while)")
def testCondGradInNestedWhiles(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113346829 (gpu failure)")
def outer_body(i, x):
_, x = control_flow_ops.while_loop(
@@ -1972,6 +1963,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(i_val, 3)
self.assertAllClose(x_val, 1.0)
+ @test_util.disable_control_flow_v2("b/116255781 (flat_args)")
def testWhile_NestedInput(self):
with self.cached_session() as sess:
named = collections.namedtuple("named", ("a", "b"))
@@ -1999,6 +1991,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual([100.0, 1.0, 102.0, 3.0, 4.0 + 100 * 2.0],
sess.run(r_flattened))
+ @test_util.disable_control_flow_v2("b/116255781(flat_args)")
def testWhile_NestedBadArityFails(self):
with self.cached_session():
named = collections.namedtuple("named", ("a", "b"))
@@ -2057,6 +2050,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients([rx], x)
self.assertAllClose(1024.0, r[0].eval())
+ @test_util.disable_control_flow_v2("b/116355153 (back_prop flag)")
def testWhileGrad_NoGradient(self):
with self.cached_session():
v = constant_op.constant(2.0, name="v")
@@ -2067,6 +2061,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)
self.assertAllClose(1.0, r[0].eval())
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGrad_NoDependency(self):
with self.cached_session() as sess:
variable = variables.Variable(array_ops.ones([2, 3]))
@@ -2180,10 +2175,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(8.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_Simple(self):
self._testNestedWhileGrad_Simple(use_gpu=False)
self._testNestedWhileGrad_Simple(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_SerialInner(self):
with self.cached_session():
v = constant_op.constant(1.0)
@@ -2207,6 +2204,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(256.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116248044 (nested)")
def testNestedWhileGrad_ParallelInner(self):
with self.cached_session():
v = constant_op.constant(1.0)
@@ -2230,6 +2228,8 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(512.0, r.eval())
+ @test_util.disable_control_flow_v2(
+ "Nested loops and TensorArrays not supported")
def testNestedWhileGrad_ParallelIterations(self):
# Make sure the stack pushes and pops of an inner loop are executed in
# the sequential order of the iterations of its outer loop.
@@ -2268,13 +2268,12 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, v)[0]
self.assertAllClose(1024.0, r.eval())
+ @test_util.disable_control_flow_v2("b/116272044 (cond_in_while)")
def testWhileCondGrad_Simple(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113294377 (unknown shape)")
-
self._testWhileCondGrad_Simple(use_gpu=False)
self._testWhileCondGrad_Simple(use_gpu=True)
+ @test_util.disable_control_flow_v2("b/116272044 (cond_in_while)")
def testWhileCondGrad_UnknownShape(self):
with self.cached_session() as sess:
v = array_ops.placeholder(dtypes.float32)
@@ -2315,6 +2314,7 @@ class ControlFlowTest(test.TestCase):
sess.run(op)
self.assertAllClose([[0.98000002, 1.98000002]], sess.run(x))
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileWithRefsWithGradients_1(self):
with self.cached_session() as sess:
x = variables.VariableV1(0.)._ref() # pylint: disable=protected-access
@@ -2343,6 +2343,7 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, value_x)
self.assertEqual(73, value_x_grad)
+ @test_util.disable_control_flow_v2("b/116282023 (IndexedSlices)")
def testWhileGrad_IndexedSlices(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -2364,6 +2365,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r.values, values)[0]
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+ @test_util.disable_control_flow_v2("b/116328420 (SparseTensor)")
def testWhileGrad_SparseTensor(self):
with self.cached_session():
values = constant_op.constant([2.0, 4.0], name="values")
@@ -2386,6 +2388,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r.values, values)[0]
self.assertAllClose(np.array([1024.0, 1024.0]), r.eval())
+ @test_util.disable_control_flow_v2("b/115920078 (gradients)")
def testCallGradInLoop(self):
with self.cached_session() as sess:
i0 = constant_op.constant(0)
@@ -2405,6 +2408,8 @@ class ControlFlowTest(test.TestCase):
c, b, [i0, constant_op.constant(0.0)])
self.assertAllClose(600.0, sess.run(output_grad)[1])
+ @test_util.disable_control_flow_v2(
+ "b/116255781 (flat_args), b/115660901 (TensorArray)")
def testWhileAndTensorArray(self):
with self.cached_session() as sess:
param = constant_op.constant(2.0)
@@ -2509,6 +2514,7 @@ class ControlFlowTest(test.TestCase):
all_ops = x.graph.get_operations()
self.assertFalse(any([name in op.name for op in all_ops]))
+ @test_util.disable_control_flow_v2("b/116255781 (flat args)")
def testWhileGradGradFail(self):
theta = variables.Variable(initial_value=1.)
@@ -2538,6 +2544,7 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, y)[0]
self.assertEqual(388.0, r.eval())
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGradientWithNontrainablePath1(self):
q = variables.Variable([7., 8.])
@@ -2555,6 +2562,7 @@ class ControlFlowTest(test.TestCase):
sess.run(q.initializer)
self.assertAllClose([0., 0.], sess.run(dy_dq))
+ @test_util.disable_control_flow_v2("b/113324949 (RefVariable)")
def testWhileGradientWithNontrainablePath2(self):
q = variables.Variable([7., 8.])
@@ -2572,6 +2580,7 @@ class ControlFlowTest(test.TestCase):
sess.run(q.initializer)
self.assertAllClose([1., 1.], sess.run(dy_dq))
+ @test_util.disable_control_flow_v2("b/115920078 (gradients)")
def testIssue16504(self):
c = constant_op.constant(np.arange(100), dtype=dtypes.float32)
w = variables.Variable(
@@ -2595,6 +2604,7 @@ class ControlFlowTest(test.TestCase):
grad, = gradients_impl.gradients(w, c)
self.assertIsNotNone(grad)
+ @test_util.disable_control_flow_v2("b/116270461 (resource)")
def testStopGradMultiFlows(self):
with self.cached_session():
@@ -2653,10 +2663,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllClose(4.0, i.eval(feed_dict={d: 1}))
self.assertAllClose(2.0 * math.sqrt(2), i.eval(feed_dict={d: 2}))
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCase(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
with self.cached_session():
x = constant_op.constant(1)
y = constant_op.constant(2)
@@ -2708,10 +2717,9 @@ class ControlFlowTest(test.TestCase):
self.assertAllEqual(r6.eval(), 0)
+ @test_util.disable_control_flow_v2(
+ "b/112477618 (Operation returned from cond)")
def testCaseSideEffects(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/112477618 (Operation returned from cond)")
-
with self.cached_session() as sess:
v0 = variables.Variable(-1)
v1 = variables.Variable(-1)
@@ -2746,10 +2754,8 @@ class ControlFlowTest(test.TestCase):
self.assertEqual(0, r0.eval())
self.assertAllEqual(sess.run([v0, v1, v2]), [0, -1, -1])
+ @test_util.disable_control_flow_v2("b/113324949 (ref vars)")
def testOneOpCond(self):
- if control_flow_ops.ENABLE_COND_V2:
- return unittest.skip("b/113324949 (ref vars)")
-
with self.cached_session():
v = variables.Variable(0)
c = ops.convert_to_tensor(0)
@@ -3031,9 +3037,11 @@ class ControlFlowTest(test.TestCase):
r = gradients_impl.gradients(r, x)[0]
self.assertEqual(r.eval(), 524288.0)
- self.assertEqual(
- len([op for op in x.graph.get_operations() if op.type == "StackV2"]),
- 1)
+ # while_v2 does not have stacks.
+ if not control_flow_ops.ENABLE_WHILE_V2:
+ self.assertEqual(
+ len([op for op in x.graph.get_operations() if op.type == "StackV2"
+ ]), 1)
class ControlFlowContextCheckTest(test.TestCase):
@@ -3393,7 +3401,7 @@ class WhileOpBenchmark(test.Benchmark):
name="unroll_same_device", iters=iters, wall_time=duration)
-@test_util.with_cond_v2
+@test_util.with_control_flow_v2
class EagerTest(test.TestCase):
def testCond(self):
@@ -3406,6 +3414,27 @@ class EagerTest(test.TestCase):
self.assertAllEqual(r.numpy(), 10)
self.assertFalse(isinstance(r, list))
+ def testCondInDefun(self):
+ if "GPU" in [d.device_type for d in device_lib.list_local_devices()]:
+ return unittest.skip("b/113346829 (gpu failure)")
+
+ with context.eager_mode():
+
+ @eager_function.defun
+ def foo(pred):
+ # TODO(b/111124878): this only needs to output one element.
+ fn1 = lambda: (constant_op.constant(10), constant_op.constant(100))
+ fn2 = lambda: (constant_op.constant(20), constant_op.constant(200))
+ return control_flow_ops.cond(constant_op.constant(pred), fn1, fn2)
+
+ r = foo(True)
+ self.assertAllEqual(r[0].numpy(), 10)
+ self.assertNotIsInstance(r, list)
+
+ r = foo(False)
+ self.assertAllEqual(r[0].numpy(), 20)
+ self.assertFalse(isinstance(r, list))
+
def testWhileLoop(self):
with context.eager_mode():
tensor = constant_op.constant([1, 2, 3, 4, 5])
diff --git a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
index 6d1ead20be..9c02b69180 100644
--- a/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
+++ b/tensorflow/python/kernel_tests/depthwise_conv_op_test.py
@@ -131,8 +131,8 @@ class DepthwiseConv2DTest(test.TestCase):
with self.session(graph=graph, use_gpu=use_gpu) as sess:
tolerance = {
dtypes.float16: 4e-2,
- dtypes.float32: 1e-8,
- dtypes.float64: 1e-13,
+ dtypes.float32: 1e-7,
+ dtypes.float64: 1e-12,
}[data_type]
t1 = constant_op.constant(x1, shape=tensor_in_sizes, dtype=data_type)
diff --git a/tensorflow/python/kernel_tests/distributions/beta_test.py b/tensorflow/python/kernel_tests/distributions/beta_test.py
index d580a415dd..42e81bd658 100644
--- a/tensorflow/python/kernel_tests/distributions/beta_test.py
+++ b/tensorflow/python/kernel_tests/distributions/beta_test.py
@@ -167,6 +167,11 @@ class BetaTest(test.TestCase):
self.assertAllClose([[1., 3. / 2], [3. / 2, 15. / 8]], self.evaluate(pdf))
self.assertEqual((2, 2), pdf.get_shape())
+ def testLogPdfOnBoundaryIsFiniteWhenAlphaIsOne(self):
+ b = [[0.01, 0.1, 1., 2], [5., 10., 2., 3]]
+ pdf = self.evaluate(beta_lib.Beta(1., b).prob(0.))
+ self.assertAllEqual(np.ones_like(pdf, dtype=np.bool), np.isfinite(pdf))
+
def testBetaMean(self):
a = [1., 2, 3]
b = [2., 4, 1.2]
diff --git a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
index cace5b3ba2..0f96382453 100644
--- a/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
+++ b/tensorflow/python/kernel_tests/distributions/dirichlet_test.py
@@ -83,6 +83,23 @@ class DirichletTest(test.TestCase):
with self.assertRaisesOpError("sample last-dimension must sum to `1`"):
self.evaluate(dist.prob([.1, .2, .8]))
+ def testLogPdfOnBoundaryIsFiniteWhenAlphaIsOne(self):
+ # Test concentration = 1. for each dimension.
+ concentration = 3 * np.ones((10, 10)).astype(np.float32)
+ concentration[range(10), range(10)] = 1.
+ x = 1 / 9. * np.ones((10, 10)).astype(np.float32)
+ x[range(10), range(10)] = 0.
+ dist = dirichlet_lib.Dirichlet(concentration)
+ log_prob = self.evaluate(dist.log_prob(x))
+ self.assertAllEqual(
+ np.ones_like(log_prob, dtype=np.bool), np.isfinite(log_prob))
+
+ # Test when concentration[k] = 1., and x is zero at various dimensions.
+ dist = dirichlet_lib.Dirichlet(10 * [1.])
+ log_prob = self.evaluate(dist.log_prob(x))
+ self.assertAllEqual(
+ np.ones_like(log_prob, dtype=np.bool), np.isfinite(log_prob))
+
def testPdfZeroBatches(self):
alpha = [1., 2]
x = [.5, .5]
diff --git a/tensorflow/python/kernel_tests/distributions/exponential_test.py b/tensorflow/python/kernel_tests/distributions/exponential_test.py
index 27d1291912..1600387585 100644
--- a/tensorflow/python/kernel_tests/distributions/exponential_test.py
+++ b/tensorflow/python/kernel_tests/distributions/exponential_test.py
@@ -65,6 +65,13 @@ class ExponentialTest(test.TestCase):
self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+ def testExponentialLogPDFBoundary(self):
+ # Check that Log PDF is finite at 0.
+ rate = np.array([0.1, 0.5, 1., 2., 5., 10.], dtype=np.float32)
+ exponential = exponential_lib.Exponential(rate=rate)
+ log_pdf = exponential.log_prob(0.)
+ self.assertAllClose(np.log(rate), self.evaluate(log_pdf))
+
def testExponentialCDF(self):
batch_size = 6
lam = constant_op.constant([2.0] * batch_size)
@@ -81,6 +88,22 @@ class ExponentialTest(test.TestCase):
expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)
self.assertAllClose(self.evaluate(cdf), expected_cdf)
+ def testExponentialLogSurvival(self):
+ batch_size = 7
+ lam = constant_op.constant([2.0] * batch_size)
+ lam_v = 2.0
+ x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0, 10.0], dtype=np.float32)
+
+ exponential = exponential_lib.Exponential(rate=lam)
+
+ log_survival = exponential.log_survival_function(x)
+ self.assertEqual(log_survival.get_shape(), (7,))
+
+ if not stats:
+ return
+ expected_log_survival = stats.expon.logsf(x, scale=1 / lam_v)
+ self.assertAllClose(self.evaluate(log_survival), expected_log_survival)
+
def testExponentialMean(self):
lam_v = np.array([1.0, 4.0, 2.5])
exponential = exponential_lib.Exponential(rate=lam_v)
diff --git a/tensorflow/python/kernel_tests/distributions/gamma_test.py b/tensorflow/python/kernel_tests/distributions/gamma_test.py
index 4eff40b029..4c5b9c3ea3 100644
--- a/tensorflow/python/kernel_tests/distributions/gamma_test.py
+++ b/tensorflow/python/kernel_tests/distributions/gamma_test.py
@@ -77,6 +77,14 @@ class GammaTest(test.TestCase):
self.assertAllClose(self.evaluate(log_pdf), expected_log_pdf)
self.assertAllClose(self.evaluate(pdf), np.exp(expected_log_pdf))
+ def testGammaLogPDFBoundary(self):
+ # When concentration = 1, we have an exponential distribution. Check that at
+ # 0 we have finite log prob.
+ rate = np.array([0.1, 0.5, 1., 2., 5., 10.], dtype=np.float32)
+ gamma = gamma_lib.Gamma(concentration=1., rate=rate)
+ log_pdf = gamma.log_prob(0.)
+ self.assertAllClose(np.log(rate), self.evaluate(log_pdf))
+
def testGammaLogPDFMultidimensional(self):
batch_size = 6
alpha = constant_op.constant([[2.0, 4.0]] * batch_size)
diff --git a/tensorflow/python/kernel_tests/rnn_test.py b/tensorflow/python/kernel_tests/rnn_test.py
index 05ad9f6336..2f6963f6b8 100644
--- a/tensorflow/python/kernel_tests/rnn_test.py
+++ b/tensorflow/python/kernel_tests/rnn_test.py
@@ -535,6 +535,45 @@ class RNNTest(test.TestCase):
self.assertAllClose(tf_out, k_out)
self.assertAllClose(tf_state, k_state)
+ def testSimpleRNNCellAndBasicRNNCellComparison(self):
+ input_shape = 10
+ output_shape = 5
+ timestep = 4
+ batch = 20
+ (x_train, _), _ = testing_utils.get_test_data(
+ train_samples=batch,
+ test_samples=0,
+ input_shape=(timestep, input_shape),
+ num_classes=output_shape)
+ fix_weights_generator = keras.layers.SimpleRNNCell(output_shape)
+ fix_weights_generator.build((None, input_shape))
+ # The SimpleRNNCell contains 3 weights: kernel, recurrent_kernel, and bias
+ # The BasicRNNCell contains 2 weight: kernel and bias, where kernel is
+ # zipped [kernel, recurrent_kernel] in SimpleRNNCell.
+ keras_weights = fix_weights_generator.get_weights()
+ kernel, recurrent_kernel, bias = keras_weights
+ tf_weights = [np.concatenate((kernel, recurrent_kernel)), bias]
+
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ cell = keras.layers.SimpleRNNCell(output_shape)
+ k_out, k_state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ cell.set_weights(keras_weights)
+ [k_out, k_state] = sess.run([k_out, k_state], {inputs: x_train})
+ with self.test_session(graph=ops_lib.Graph()) as sess:
+ inputs = array_ops.placeholder(
+ dtypes.float32, shape=(None, timestep, input_shape))
+ cell = rnn_cell_impl.BasicRNNCell(output_shape)
+ tf_out, tf_state = rnn.dynamic_rnn(
+ cell, inputs, dtype=dtypes.float32)
+ cell.set_weights(tf_weights)
+ [tf_out, tf_state] = sess.run([tf_out, tf_state], {inputs: x_train})
+
+ self.assertAllClose(tf_out, k_out)
+ self.assertAllClose(tf_state, k_state)
+
def testBasicLSTMCellInterchangeWithLSTMCell(self):
with self.session(graph=ops_lib.Graph()) as sess:
basic_cell = rnn_cell_impl.BasicLSTMCell(1)
diff --git a/tensorflow/python/kernel_tests/variables_test.py b/tensorflow/python/kernel_tests/variables_test.py
index 942ceedc8b..c2b86089f4 100644
--- a/tensorflow/python/kernel_tests/variables_test.py
+++ b/tensorflow/python/kernel_tests/variables_test.py
@@ -696,6 +696,48 @@ class PartitionedVariableTest(test.TestCase):
variable_list=[v0],
partitions=partitions)
+ def testPartitionedVariableAssignments(self):
+ with ops.Graph().as_default(), self.cached_session() as sess:
+ v0 = variables.Variable(initial_value=[0.0])
+ v1 = variables.Variable(initial_value=[1.0])
+ v0._set_save_slice_info(
+ variables.Variable.SaveSliceInfo(v0.name, [2], [0], [1]))
+ v1._set_save_slice_info(
+ variables.Variable.SaveSliceInfo(v0.name, [2], [1], [1]))
+ partitions = [2]
+
+ # Pass variable_list as [v1, v0] to ensure they are properly
+ # re-sorted to [v0, v1] based on their slice info offsets.
+ partitioned_variable = variables.PartitionedVariable(
+ name="two_vars",
+ shape=[2],
+ dtype=v0.dtype,
+ variable_list=[v0, v1],
+ partitions=partitions)
+
+ deltas_a = constant_op.constant([1.0, 2.0])
+ deltas_b = constant_op.constant([3.0, 4.0])
+ ones = array_ops.ones([2])
+ plus_delta = partitioned_variable.assign_add(deltas_a)
+ minus_delta = partitioned_variable.assign_sub(deltas_b)
+ assign_ones = partitioned_variable.assign(ones)
+ variables.global_variables_initializer().run()
+
+ self.assertEqual([1.0], plus_delta[0].eval())
+ self.assertEqual([1.0], v0.eval())
+ self.assertEqual([3.0], plus_delta[1].eval())
+ self.assertEqual([3.0], v1.eval())
+
+ self.assertEqual([-2.0], minus_delta[0].eval())
+ self.assertEqual([-2.0], v0.eval())
+ self.assertEqual([-1.0], minus_delta[1].eval())
+ self.assertEqual([-1.0], v1.eval())
+
+ self.assertEqual([1.0], assign_ones[0].eval())
+ self.assertEqual([1.0], v0.eval())
+ self.assertEqual([1.0], assign_ones[1].eval())
+ self.assertEqual([1.0], v1.eval())
+
class VariableContainerTest(test.TestCase):
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 3ba880d7a1..e399ece232 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -131,10 +131,20 @@ class Layer(base_layer.Layer):
def add_loss(self, losses, inputs=None):
previous_losses_length = len(self._losses)
+ previous_callable_losses_length = len(self._callable_losses)
super(Layer, self).add_loss(losses, inputs=inputs)
- # TODO(fchollet): deprecate collection below.
- new_losses = self._losses[previous_losses_length:]
- _add_elements_to_collection(new_losses, ops.GraphKeys.REGULARIZATION_LOSSES)
+ if not context.executing_eagerly():
+ # TODO(fchollet): deprecate collection below.
+ new_losses = self._losses[previous_losses_length:]
+ new_callable_losses = self._callable_losses[
+ previous_callable_losses_length:]
+ for regularizer in new_callable_losses:
+ loss_tensor = regularizer()
+ if loss_tensor is not None:
+ new_losses.append(loss_tensor)
+ _add_elements_to_collection(
+ new_losses,
+ ops.GraphKeys.REGULARIZATION_LOSSES)
def _name_scope(self):
"""Determines op naming for the Layer."""
diff --git a/tensorflow/python/layers/convolutional_test.py b/tensorflow/python/layers/convolutional_test.py
index d61d3b6dba..257fa27156 100644
--- a/tensorflow/python/layers/convolutional_test.py
+++ b/tensorflow/python/layers/convolutional_test.py
@@ -207,7 +207,8 @@ class ConvTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DBiasRegularizer(self):
height, width = 7, 9
@@ -217,7 +218,8 @@ class ConvTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DNoBias(self):
height, width = 7, 9
@@ -445,7 +447,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DPointwiseRegularizer(self):
length = 9
@@ -455,7 +458,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DBiasRegularizer(self):
length = 9
@@ -465,7 +469,8 @@ class SeparableConv1DTest(test.TestCase):
layer.apply(data)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv1DNoBias(self):
length = 9
@@ -682,7 +687,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DPointwiseRegularizer(self):
height, width = 7, 9
@@ -692,7 +698,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DBiasRegularizer(self):
height, width = 7, 9
@@ -702,7 +709,8 @@ class SeparableConv2DTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testSeparableConv2DNoBias(self):
height, width = 7, 9
@@ -839,7 +847,8 @@ class Conv2DTransposeTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DTransposeBiasRegularizer(self):
height, width = 7, 9
@@ -849,7 +858,8 @@ class Conv2DTransposeTest(test.TestCase):
layer.apply(images)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv2DTransposeNoBias(self):
height, width = 7, 9
@@ -1017,7 +1027,8 @@ class Conv3DTransposeTest(test.TestCase):
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv3DTransposeBiasRegularizer(self):
depth, height, width = 5, 7, 9
@@ -1027,7 +1038,8 @@ class Conv3DTransposeTest(test.TestCase):
layer.apply(volumes)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(layer.losses, loss_keys)
+ self.evaluate([v.initializer for v in layer.variables])
+ self.assertListEqual(self.evaluate(layer.losses), self.evaluate(loss_keys))
def testConv3DTransposeNoBias(self):
depth, height, width = 5, 7, 9
diff --git a/tensorflow/python/layers/core.py b/tensorflow/python/layers/core.py
index 9879e5020f..e06e9aba4a 100644
--- a/tensorflow/python/layers/core.py
+++ b/tensorflow/python/layers/core.py
@@ -269,6 +269,13 @@ def dropout(inputs,
class Flatten(keras_layers.Flatten, base.Layer):
"""Flattens an input tensor while preserving the batch axis (axis 0).
+ Arguments:
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, ..., channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, ...)`.
+
Examples:
```
@@ -285,12 +292,17 @@ class Flatten(keras_layers.Flatten, base.Layer):
@tf_export('layers.flatten')
-def flatten(inputs, name=None):
+def flatten(inputs, name=None, data_format='channels_last'):
"""Flattens an input tensor while preserving the batch axis (axis 0).
Arguments:
inputs: Tensor input.
name: The name of the layer (string).
+ data_format: A string, one of `channels_last` (default) or `channels_first`.
+ The ordering of the dimensions in the inputs.
+ `channels_last` corresponds to inputs with shape
+ `(batch, height, width, channels)` while `channels_first` corresponds to
+ inputs with shape `(batch, channels, height, width)`.
Returns:
Reshaped tensor.
@@ -307,7 +319,7 @@ def flatten(inputs, name=None):
# now `y` has shape `(None, None)`
```
"""
- layer = Flatten(name=name)
+ layer = Flatten(name=name, data_format=data_format)
return layer.apply(inputs)
diff --git a/tensorflow/python/layers/core_test.py b/tensorflow/python/layers/core_test.py
index 46009a30ac..0343bfa8bd 100644
--- a/tensorflow/python/layers/core_test.py
+++ b/tensorflow/python/layers/core_test.py
@@ -197,7 +197,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testKernelRegularizerWithReuse(self):
regularizer = lambda x: math_ops.reduce_sum(x) * 1e-3
@@ -218,7 +219,8 @@ class DenseTest(test.TestCase):
_ = dense(inputs)
loss_keys = ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES)
self.assertEqual(len(loss_keys), 1)
- self.assertListEqual(dense.losses, loss_keys)
+ self.evaluate([v.initializer for v in dense.variables])
+ self.assertAllEqual(self.evaluate(dense.losses), self.evaluate(loss_keys))
def testFunctionalDense(self):
with self.cached_session():
@@ -474,6 +476,40 @@ class FlattenTest(test.TestCase):
shape = core_layers.Flatten().compute_output_shape((None, 3, None))
self.assertEqual(shape.as_list(), [None, None])
+ def testDataFormat5d(self):
+ np_input_channels_last = np.arange(
+ 120, dtype='float32').reshape([1, 5, 4, 3, 2])
+
+ with self.test_session() as sess:
+ x = array_ops.placeholder(shape=(1, 5, 4, 3, 2), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_last')(x)
+ np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last})
+
+ x = array_ops.placeholder(shape=(1, 2, 5, 4, 3), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_first')(x)
+ np_input_channels_first = np.transpose(np_input_channels_last,
+ [0, 4, 1, 2, 3])
+ np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first})
+
+ self.assertAllEqual(np_output_cl, np_output_cf)
+
+ def testDataFormat4d(self):
+ np_input_channels_last = np.arange(
+ 24, dtype='float32').reshape([1, 4, 3, 2])
+
+ with self.test_session() as sess:
+ x = array_ops.placeholder(shape=(1, 4, 3, 2), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_last')(x)
+ np_output_cl = sess.run(y, feed_dict={x: np_input_channels_last})
+
+ x = array_ops.placeholder(shape=(1, 2, 4, 3), dtype='float32')
+ y = core_layers.Flatten(data_format='channels_first')(x)
+ np_input_channels_first = np.transpose(np_input_channels_last,
+ [0, 3, 1, 2])
+ np_output_cf = sess.run(y, feed_dict={x: np_input_channels_first})
+
+ self.assertAllEqual(np_output_cl, np_output_cf)
+
def testFunctionalFlatten(self):
x = array_ops.placeholder(shape=(None, 2, 3), dtype='float32')
y = core_layers.flatten(x, name='flatten')
diff --git a/tensorflow/python/lib/io/tf_record.py b/tensorflow/python/lib/io/tf_record.py
index cce71a2bab..9ab683d96a 100644
--- a/tensorflow/python/lib/io/tf_record.py
+++ b/tensorflow/python/lib/io/tf_record.py
@@ -22,10 +22,12 @@ from __future__ import print_function
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.framework import errors
from tensorflow.python.util import compat
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
-@tf_export("python_io.TFRecordCompressionType")
+@tf_export("io.TFRecordCompressionType", "python_io.TFRecordCompressionType")
+@deprecation.deprecated_endpoints("python_io.TFRecordCompressionType")
class TFRecordCompressionType(object):
"""The type of compression for the record."""
NONE = 0
@@ -33,7 +35,8 @@ class TFRecordCompressionType(object):
GZIP = 2
-@tf_export("python_io.TFRecordOptions")
+@tf_export("io.TFRecordOptions", "python_io.TFRecordOptions")
+@deprecation.deprecated_endpoints("python_io.TFRecordOptions")
class TFRecordOptions(object):
"""Options used for manipulating TFRecord files."""
compression_type_map = {
@@ -143,7 +146,8 @@ class TFRecordOptions(object):
return options
-@tf_export("python_io.tf_record_iterator")
+@tf_export("io.tf_record_iterator", "python_io.tf_record_iterator")
+@deprecation.deprecated_endpoints("python_io.tf_record_iterator")
def tf_record_iterator(path, options=None):
"""An iterator that read the records from a TFRecords file.
@@ -175,7 +179,8 @@ def tf_record_iterator(path, options=None):
reader.Close()
-@tf_export("python_io.TFRecordWriter")
+@tf_export("io.TFRecordWriter", "python_io.TFRecordWriter")
+@deprecation.deprecated_endpoints("python_io.TFRecordWriter")
class TFRecordWriter(object):
"""A class to write records to a TFRecords file.
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index a7f57e94e3..9f5149d5ac 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -1204,7 +1204,8 @@ def boolean_mask(tensor, mask, name="boolean_mask", axis=None):
return _apply_mask_1d(tensor, mask, axis)
-@tf_export("sparse_mask")
+@tf_export("sparse.mask", "sparse_mask")
+@deprecation.deprecated_endpoints("sparse_mask")
def sparse_mask(a, mask_indices, name=None):
"""Masks elements of `IndexedSlices`.
@@ -1226,7 +1227,7 @@ def sparse_mask(a, mask_indices, name=None):
# `b` will be the subset of `a` slices at its second and third indices, so
# we want to mask its first and last indices (which are at absolute
# indices 12, 45)
- b = tf.sparse_mask(a, [12, 45])
+ b = tf.sparse.mask(a, [12, 45])
b.indices # [26, 37]
tf.shape(b.values) # [2, 10]
@@ -1382,7 +1383,7 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
[10, 11, 12]]])
# Take the transpose of the matrices in dimension-0
- # (this common operation has a shorthand `matrix_transpose`)
+ # (this common operation has a shorthand `linalg.transpose`)
tf.transpose(x, perm=[0, 2, 1]) # [[[1, 4],
# [2, 5],
# [3, 6]],
@@ -1421,7 +1422,8 @@ def transpose(a, perm=None, name="transpose", conjugate=False):
# pylint: disable=invalid-name
-@tf_export("matrix_transpose", "linalg.transpose")
+@tf_export("linalg.transpose", "matrix_transpose")
+@deprecation.deprecated_endpoints("matrix_transpose")
def matrix_transpose(a, name="matrix_transpose", conjugate=False):
"""Transposes last two dimensions of tensor `a`.
@@ -1429,19 +1431,19 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False):
```python
x = tf.constant([[1, 2, 3], [4, 5, 6]])
- tf.matrix_transpose(x) # [[1, 4],
+ tf.linalg.transpose(x) # [[1, 4],
# [2, 5],
# [3, 6]]
x = tf.constant([[1 + 1j, 2 + 2j, 3 + 3j],
[4 + 4j, 5 + 5j, 6 + 6j]])
- tf.matrix_transpose(x, conjugate=True) # [[1 - 1j, 4 - 4j],
+ tf.linalg.transpose(x, conjugate=True) # [[1 - 1j, 4 - 4j],
# [2 - 2j, 5 - 5j],
# [3 - 3j, 6 - 6j]]
# Matrix with two batch dimensions.
# x.shape is [1, 2, 3, 4]
- # tf.matrix_transpose(x) is shape [1, 2, 4, 3]
+ # tf.linalg.transpose(x) is shape [1, 2, 4, 3]
```
Note that `tf.matmul` provides kwargs allowing for transpose of arguments.
@@ -1452,14 +1454,14 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False):
tf.matmul(matrix, b, transpose_b=True)
# Inefficient!
- tf.matmul(matrix, tf.matrix_transpose(b))
+ tf.matmul(matrix, tf.linalg.transpose(b))
```
@compatibility(numpy)
In `numpy` transposes are memory-efficient constant time operations as they
simply return a new view of the same data with adjusted `strides`.
- TensorFlow does not support strides, `matrix_transposes` return a new tensor
+ TensorFlow does not support strides, `linalg.transposes` return a new tensor
with the items permuted.
@end_compatibility
@@ -1467,7 +1469,7 @@ def matrix_transpose(a, name="matrix_transpose", conjugate=False):
a: A `Tensor` with `rank >= 2`.
name: A name for the operation (optional).
conjugate: Optional bool. Setting it to `True` is mathematically equivalent
- to tf.conj(tf.matrix_transpose(input)).
+ to tf.conj(tf.linalg.transpose(input)).
Returns:
A transposed batch matrix `Tensor`.
@@ -1756,7 +1758,8 @@ def _normalize_sparse_shape(shape, name):
return (ops.convert_to_tensor(shape, dtype=dtypes.int64, name=name), rank)
-@tf_export("sparse_placeholder")
+@tf_export("sparse.placeholder", "sparse_placeholder")
+@deprecation.deprecated_endpoints("sparse_placeholder")
def sparse_placeholder(dtype, shape=None, name=None):
"""Inserts a placeholder for a sparse tensor that will be always fed.
@@ -1767,8 +1770,8 @@ def sparse_placeholder(dtype, shape=None, name=None):
For example:
```python
- x = tf.sparse_placeholder(tf.float32)
- y = tf.sparse_reduce_sum(x)
+ x = tf.sparse.placeholder(tf.float32)
+ y = tf.sparse.reduce_sum(x)
with tf.Session() as sess:
print(sess.run(y)) # ERROR: will fail because x was not fed.
@@ -2250,7 +2253,8 @@ def required_space_to_batch_paddings(input_shape,
return result_paddings, result_crops
-@tf_export("space_to_batch")
+@tf_export("nn.space_to_batch", "space_to_batch")
+@deprecation.deprecated_endpoints("space_to_batch")
def space_to_batch(input, paddings, block_size, name=None): # pylint: disable=redefined-builtin
result = space_to_batch_nd(
input,
@@ -2264,7 +2268,8 @@ def space_to_batch(input, paddings, block_size, name=None): # pylint: disable=r
space_to_batch.__doc__ = gen_array_ops.space_to_batch.__doc__
-@tf_export("space_to_depth")
+@tf_export("nn.space_to_depth", "space_to_depth")
+@deprecation.deprecated_endpoints("space_to_depth")
def space_to_depth(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin
return gen_array_ops.space_to_depth(input, block_size, data_format, name=name)
@@ -2272,7 +2277,8 @@ def space_to_depth(input, block_size, name=None, data_format="NHWC"): # pylint:
space_to_depth.__doc__ = gen_array_ops.space_to_depth.__doc__
-@tf_export("depth_to_space")
+@tf_export("nn.depth_to_space", "depth_to_space")
+@deprecation.deprecated_endpoints("depth_to_space")
def depth_to_space(input, block_size, name=None, data_format="NHWC"): # pylint: disable=redefined-builtin
return gen_array_ops.depth_to_space(input, block_size, data_format, name=name)
@@ -2747,7 +2753,8 @@ def batch_gather(params, indices, name=None):
@tf_export("quantize_v2")
@deprecation.deprecated(
"2017-10-25",
- "`tf.quantize_v2` is deprecated, please use `tf.quantize` instead.")
+ "`tf.quantize_v2` is deprecated, please use `tf.quantization.quantize` "
+ "instead.") # pylint: disable=missing-docstring
def quantize_v2(input, # pylint: disable=redefined-builtin
min_range,
max_range,
@@ -2769,7 +2776,8 @@ quantize_v2.__doc__ = """Please use `tf.quantize` instead."""
# We want to expose tf.quantize instead of tf.quantize_v2; we can deprecate
# tf.quantize_v2 in next version of TensorFlow.
-@tf_export("quantize")
+@tf_export("quantization.quantize", "quantize")
+@deprecation.deprecated_endpoints("quantize")
def quantize(input, # pylint: disable=redefined-builtin
min_range,
max_range,
diff --git a/tensorflow/python/ops/candidate_sampling_ops.py b/tensorflow/python/ops/candidate_sampling_ops.py
index 9ea1ea9c92..98dde995c9 100644
--- a/tensorflow/python/ops/candidate_sampling_ops.py
+++ b/tensorflow/python/ops/candidate_sampling_ops.py
@@ -23,10 +23,12 @@ from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops # pylint: disable=unused-import
from tensorflow.python.ops import gen_candidate_sampling_ops
from tensorflow.python.ops import math_ops # pylint: disable=unused-import
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
-@tf_export('nn.uniform_candidate_sampler')
+@tf_export('random.uniform_candidate_sampler', 'nn.uniform_candidate_sampler')
+@deprecation.deprecated_endpoints('nn.uniform_candidate_sampler')
def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
range_max, seed=None, name=None):
"""Samples a set of classes using a uniform base distribution.
@@ -82,7 +84,9 @@ def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
seed2=seed2, name=name)
-@tf_export('nn.log_uniform_candidate_sampler')
+@tf_export('random.log_uniform_candidate_sampler',
+ 'nn.log_uniform_candidate_sampler')
+@deprecation.deprecated_endpoints('nn.log_uniform_candidate_sampler')
def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
range_max, seed=None, name=None):
"""Samples a set of classes using a log-uniform (Zipfian) base distribution.
diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py
index c3cf6e61f2..d607f1d9fb 100644
--- a/tensorflow/python/ops/check_ops.py
+++ b/tensorflow/python/ops/check_ops.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util import compat
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
NUMERIC_TYPES = frozenset(
@@ -91,7 +92,8 @@ def _shape_and_dtype_str(tensor):
return 'shape=%s dtype=%s' % (tensor.shape, tensor.dtype.name)
-@tf_export('assert_proper_iterable')
+@tf_export('debugging.assert_proper_iterable', 'assert_proper_iterable')
+@deprecation.deprecated_endpoints('assert_proper_iterable')
def assert_proper_iterable(values):
"""Static assert that values is a "proper" iterable.
@@ -119,7 +121,8 @@ def assert_proper_iterable(values):
'Expected argument "values" to be iterable. Found: %s' % type(values))
-@tf_export('assert_negative')
+@tf_export('debugging.assert_negative', 'assert_negative')
+@deprecation.deprecated_endpoints('assert_negative')
def assert_negative(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x < 0` holds element-wise.
@@ -160,7 +163,8 @@ def assert_negative(x, data=None, summarize=None, message=None, name=None):
return assert_less(x, zero, data=data, summarize=summarize)
-@tf_export('assert_positive')
+@tf_export('debugging.assert_positive', 'assert_positive')
+@deprecation.deprecated_endpoints('assert_positive')
def assert_positive(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x > 0` holds element-wise.
@@ -200,7 +204,8 @@ def assert_positive(x, data=None, summarize=None, message=None, name=None):
return assert_less(zero, x, data=data, summarize=summarize)
-@tf_export('assert_non_negative')
+@tf_export('debugging.assert_non_negative', 'assert_non_negative')
+@deprecation.deprecated_endpoints('assert_non_negative')
def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x >= 0` holds element-wise.
@@ -242,7 +247,8 @@ def assert_non_negative(x, data=None, summarize=None, message=None, name=None):
return assert_less_equal(zero, x, data=data, summarize=summarize)
-@tf_export('assert_non_positive')
+@tf_export('debugging.assert_non_positive', 'assert_non_positive')
+@deprecation.deprecated_endpoints('assert_non_positive')
def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x <= 0` holds element-wise.
@@ -284,7 +290,7 @@ def assert_non_positive(x, data=None, summarize=None, message=None, name=None):
return assert_less_equal(x, zero, data=data, summarize=summarize)
-@tf_export('assert_equal')
+@tf_export('debugging.assert_equal', 'assert_equal')
def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x == y` holds element-wise.
@@ -384,7 +390,8 @@ def assert_equal(x, y, data=None, summarize=None, message=None, name=None):
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_none_equal')
+@tf_export('debugging.assert_none_equal', 'assert_none_equal')
+@deprecation.deprecated_endpoints('assert_none_equal')
def assert_none_equal(
x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x != y` holds for all elements.
@@ -435,7 +442,8 @@ def assert_none_equal(
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_near')
+@tf_export('debugging.assert_near', 'assert_near')
+@deprecation.deprecated_endpoints('assert_near')
def assert_near(
x, y, rtol=None, atol=None, data=None, summarize=None, message=None,
name=None):
@@ -513,7 +521,7 @@ def assert_near(
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_less')
+@tf_export('debugging.assert_less', 'assert_less')
def assert_less(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x < y` holds element-wise.
@@ -561,7 +569,8 @@ def assert_less(x, y, data=None, summarize=None, message=None, name=None):
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_less_equal')
+@tf_export('debugging.assert_less_equal', 'assert_less_equal')
+@deprecation.deprecated_endpoints('assert_less_equal')
def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x <= y` holds element-wise.
@@ -609,7 +618,7 @@ def assert_less_equal(x, y, data=None, summarize=None, message=None, name=None):
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_greater')
+@tf_export('debugging.assert_greater', 'assert_greater')
def assert_greater(x, y, data=None, summarize=None, message=None, name=None):
"""Assert the condition `x > y` holds element-wise.
@@ -657,7 +666,8 @@ def assert_greater(x, y, data=None, summarize=None, message=None, name=None):
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_greater_equal')
+@tf_export('debugging.assert_greater_equal', 'assert_greater_equal')
+@deprecation.deprecated_endpoints('assert_greater_equal')
def assert_greater_equal(x, y, data=None, summarize=None, message=None,
name=None):
"""Assert the condition `x >= y` holds element-wise.
@@ -755,7 +765,7 @@ def _assert_rank_condition(
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_rank')
+@tf_export('debugging.assert_rank', 'assert_rank')
def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank equal to `rank`.
@@ -817,7 +827,8 @@ def assert_rank(x, rank, data=None, summarize=None, message=None, name=None):
return assert_op
-@tf_export('assert_rank_at_least')
+@tf_export('debugging.assert_rank_at_least', 'assert_rank_at_least')
+@deprecation.deprecated_endpoints('assert_rank_at_least')
def assert_rank_at_least(
x, rank, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank equal to `rank` or higher.
@@ -948,7 +959,8 @@ def _assert_ranks_condition(
return control_flow_ops.Assert(condition, data, summarize=summarize)
-@tf_export('assert_rank_in')
+@tf_export('debugging.assert_rank_in', 'assert_rank_in')
+@deprecation.deprecated_endpoints('assert_rank_in')
def assert_rank_in(
x, ranks, data=None, summarize=None, message=None, name=None):
"""Assert `x` has rank in `ranks`.
@@ -1010,7 +1022,8 @@ def assert_rank_in(
return assert_op
-@tf_export('assert_integer')
+@tf_export('debugging.assert_integer', 'assert_integer')
+@deprecation.deprecated_endpoints('assert_integer')
def assert_integer(x, message=None, name=None):
"""Assert that `x` is of integer dtype.
@@ -1048,7 +1061,8 @@ def assert_integer(x, message=None, name=None):
return control_flow_ops.no_op('statically_determined_was_integer')
-@tf_export('assert_type')
+@tf_export('debugging.assert_type', 'assert_type')
+@deprecation.deprecated_endpoints('assert_type')
def assert_type(tensor, tf_type, message=None, name=None):
"""Statically asserts that the given `Tensor` is of the specified type.
@@ -1095,12 +1109,14 @@ def _get_diff_for_monotonic_comparison(x):
return control_flow_ops.cond(is_shorter_than_two, short_result, diff)
-@tf_export('is_numeric_tensor')
+@tf_export('debugging.is_numeric_tensor', 'is_numeric_tensor')
+@deprecation.deprecated_endpoints('is_numeric_tensor')
def is_numeric_tensor(tensor):
return isinstance(tensor, ops.Tensor) and tensor.dtype in NUMERIC_TYPES
-@tf_export('is_non_decreasing')
+@tf_export('debugging.is_non_decreasing', 'is_non_decreasing')
+@deprecation.deprecated_endpoints('is_non_decreasing')
def is_non_decreasing(x, name=None):
"""Returns `True` if `x` is non-decreasing.
@@ -1127,7 +1143,8 @@ def is_non_decreasing(x, name=None):
return math_ops.reduce_all(math_ops.less_equal(zero, diff))
-@tf_export('is_strictly_increasing')
+@tf_export('debugging.is_strictly_increasing', 'is_strictly_increasing')
+@deprecation.deprecated_endpoints('is_strictly_increasing')
def is_strictly_increasing(x, name=None):
"""Returns `True` if `x` is strictly increasing.
@@ -1202,7 +1219,8 @@ def _assert_same_base_type(items, expected_type=None):
return expected_type
-@tf_export('assert_same_float_dtype')
+@tf_export('debugging.assert_same_float_dtype', 'assert_same_float_dtype')
+@deprecation.deprecated_endpoints('assert_same_float_dtype')
def assert_same_float_dtype(tensors=None, dtype=None):
"""Validate and return float type based on `tensors` and `dtype`.
@@ -1231,7 +1249,8 @@ def assert_same_float_dtype(tensors=None, dtype=None):
return dtype
-@tf_export('assert_scalar')
+@tf_export('debugging.assert_scalar', 'assert_scalar')
+@deprecation.deprecated_endpoints('assert_scalar')
def assert_scalar(tensor, name=None):
with ops.name_scope(name, 'assert_scalar', [tensor]) as name_scope:
tensor = ops.convert_to_tensor(tensor, name=name_scope)
diff --git a/tensorflow/python/ops/clip_ops.py b/tensorflow/python/ops/clip_ops.py
index 29468431b3..45516068f4 100644
--- a/tensorflow/python/ops/clip_ops.py
+++ b/tensorflow/python/ops/clip_ops.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_nn_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import numerics
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -76,8 +77,8 @@ def clip_by_value(t, clip_value_min, clip_value_max,
return t_max
# TODO(scottzhu): switch to use new implmentation in 2 weeks.
- # return gen_math_ops.clip_by_value(
- # t, clip_value_min, clip_value_max, name=name)
+ # return gen_math_ops.clip_by_value(
+ # t, clip_value_min, clip_value_max, name=name)
# TODO(scottzhu): switch to use new implmentation in 2 weeks.
@@ -159,7 +160,8 @@ def clip_by_norm(t, clip_norm, axes=None, name=None):
return tclip
-@tf_export("global_norm")
+@tf_export("linalg.global_norm", "global_norm")
+@deprecation.deprecated_endpoints("global_norm")
def global_norm(t_list, name=None):
"""Computes the global norm of multiple tensors.
diff --git a/tensorflow/python/ops/confusion_matrix.py b/tensorflow/python/ops/confusion_matrix.py
index c09154129f..8259142456 100644
--- a/tensorflow/python/ops/confusion_matrix.py
+++ b/tensorflow/python/ops/confusion_matrix.py
@@ -26,6 +26,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import sparse_ops
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -89,7 +90,8 @@ def remove_squeezable_dimensions(
return labels, predictions
-@tf_export('confusion_matrix')
+@tf_export('train.confusion_matrix', 'confusion_matrix')
+@deprecation.deprecated_endpoints('confusion_matrix')
def confusion_matrix(labels, predictions, num_classes=None, dtype=dtypes.int32,
name=None, weights=None):
"""Computes the confusion matrix from predictions and labels.
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index 87f8bd85a5..f779c3d273 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -60,8 +60,17 @@ from tensorflow.python.util import nest
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.tf_export import tf_export
+# The while_v2 module.
+_while_v2 = None
ENABLE_COND_V2 = os.getenv("TF_ENABLE_COND_V2", "0") != "0"
+# Note: Setting this to True is not sufficient to switch to the v2 while_loop.
+# Users must also import the while_v2 module to set the _while_v2 module
+# variable above. We do this to avoid a circular dependency:
+# control_flow_ops -> while_v2 -> gradients_impl -> control_flow_ops
+# A ValueError is raised in tf.while_loop if this is set to True and the
+# `_while_v2` module is not set.
+ENABLE_WHILE_V2 = os.getenv("TF_ENABLE_WHILE_V2", "0") != "0"
# We override the 'tuple' for a control flow op, so we keep python's
@@ -97,7 +106,7 @@ def _summarize_eager(tensor, summarize=None):
# Assert and Print are special symbols in python, so we must
# use an upper-case version of them.
-@tf_export("Assert")
+@tf_export("debugging.Assert", "Assert")
@tf_should_use.should_use_result
def Assert(condition, data, summarize=None, name=None):
"""Asserts that the given condition is true.
@@ -3211,6 +3220,14 @@ def while_loop(cond,
```
"""
+ if ENABLE_WHILE_V2 and not context.executing_eagerly():
+ if not _while_v2:
+ raise ValueError("The while_v2 module is not set. Did you forget to "
+ "import tensorflow.python.ops."
+ "while_v2?")
+ return _while_v2.while_loop(
+ cond, body, loop_vars, shape_invariants=shape_invariants, name=name)
+
with ops.name_scope(name, "while", loop_vars):
if not loop_vars:
raise ValueError("No loop variables provided")
diff --git a/tensorflow/python/ops/conv2d_benchmark.py b/tensorflow/python/ops/conv2d_benchmark.py
index 28111c2730..f40488afbe 100644
--- a/tensorflow/python/ops/conv2d_benchmark.py
+++ b/tensorflow/python/ops/conv2d_benchmark.py
@@ -63,9 +63,9 @@ def build_graph(device, dtype, data_format, input_shape, filter_shape, strides,
An array of tensors to run()
"""
with ops.device("/%s:0" % device):
- inp = variables.Variable(
+ inp = variables.VariableV1(
random_ops.truncated_normal(input_shape, dtype=dtype))
- filt = variables.Variable(
+ filt = variables.VariableV1(
random_ops.truncated_normal(filter_shape, dtype=dtype))
outputs = []
diff --git a/tensorflow/python/ops/data_flow_ops.py b/tensorflow/python/ops/data_flow_ops.py
index 69c0fcbbee..97b6f3bd9c 100644
--- a/tensorflow/python/ops/data_flow_ops.py
+++ b/tensorflow/python/ops/data_flow_ops.py
@@ -39,6 +39,7 @@ from tensorflow.python.ops import resource_variable_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_data_flow_ops import *
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -112,7 +113,8 @@ def _shape_common(s1, s2):
# pylint: disable=protected-access
-@tf_export("QueueBase")
+@tf_export("io.QueueBase", "QueueBase")
+@deprecation.deprecated_endpoints("QueueBase")
class QueueBase(object):
"""Base class for queue implementations.
@@ -604,7 +606,8 @@ def _shared_name(shared_name):
return shared_name
-@tf_export("RandomShuffleQueue")
+@tf_export("io.RandomShuffleQueue", "RandomShuffleQueue")
+@deprecation.deprecated_endpoints("RandomShuffleQueue")
class RandomShuffleQueue(QueueBase):
"""A queue implementation that dequeues elements in a random order.
@@ -746,7 +749,8 @@ class FIFOQueue(QueueBase):
super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
-@tf_export("PaddingFIFOQueue")
+@tf_export("io.PaddingFIFOQueue", "PaddingFIFOQueue")
+@deprecation.deprecated_endpoints("PaddingFIFOQueue")
class PaddingFIFOQueue(QueueBase):
"""A FIFOQueue that supports batching variable-sized tensors by padding.
@@ -820,7 +824,8 @@ class PaddingFIFOQueue(QueueBase):
super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
-@tf_export("PriorityQueue")
+@tf_export("io.PriorityQueue", "PriorityQueue")
+@deprecation.deprecated_endpoints("PriorityQueue")
class PriorityQueue(QueueBase):
"""A queue implementation that dequeues elements in prioritized order.
@@ -1300,7 +1305,9 @@ class ConditionalAccumulator(ConditionalAccumulatorBase):
return out
-@tf_export("SparseConditionalAccumulator")
+@tf_export("sparse.SparseConditionalAccumulator",
+ "SparseConditionalAccumulator")
+@deprecation.deprecated_endpoints("SparseConditionalAccumulator")
class SparseConditionalAccumulator(ConditionalAccumulatorBase):
"""A conditional accumulator for aggregating sparse gradients.
diff --git a/tensorflow/python/ops/distributions/BUILD b/tensorflow/python/ops/distributions/BUILD
index e7ad028376..59ba9aee59 100644
--- a/tensorflow/python/ops/distributions/BUILD
+++ b/tensorflow/python/ops/distributions/BUILD
@@ -12,6 +12,13 @@ py_library(
["*.py"],
exclude = ["util.py"],
),
+ deprecation = ("TensorFlow Distributions has migrated to " +
+ "TensorFlow Probability " +
+ "(https://github.com/tensorflow/probability). " +
+ "Deprecated copies remaining in tf.distributions " +
+ "will not receive new features, and will be removed by " +
+ "early 2019. You should update all usage of " +
+ "`tf.distributions` to `tfp.distributions`."),
srcs_version = "PY2AND3",
deps = [
":util",
diff --git a/tensorflow/python/ops/distributions/bernoulli.py b/tensorflow/python/ops/distributions/bernoulli.py
index 84d9d40a35..baecc321d3 100644
--- a/tensorflow/python/ops/distributions/bernoulli.py
+++ b/tensorflow/python/ops/distributions/bernoulli.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -39,6 +40,14 @@ class Bernoulli(distribution.Distribution):
`1` outcome (vs a `0` outcome).
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
logits=None,
probs=None,
diff --git a/tensorflow/python/ops/distributions/beta.py b/tensorflow/python/ops/distributions/beta.py
index 2ba1ea6744..51c4f6eb3d 100644
--- a/tensorflow/python/ops/distributions/beta.py
+++ b/tensorflow/python/ops/distributions/beta.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -150,6 +151,14 @@ class Beta(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
concentration1=None,
concentration0=None,
@@ -267,8 +276,8 @@ class Beta(distribution.Distribution):
def _log_unnormalized_prob(self, x):
x = self._maybe_assert_valid_sample(x)
- return ((self.concentration1 - 1.) * math_ops.log(x)
- + (self.concentration0 - 1.) * math_ops.log1p(-x))
+ return (math_ops.xlogy(self.concentration1 - 1., x) +
+ (self.concentration0 - 1.) * math_ops.log1p(-x))
def _log_normalization(self):
return (math_ops.lgamma(self.concentration1)
@@ -341,6 +350,11 @@ class Beta(distribution.Distribution):
class BetaWithSoftplusConcentration(Beta):
"""Beta with softplus transform of `concentration1` and `concentration0`."""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "Use `tfd.Beta(tf.nn.softplus(concentration1), "
+ "tf.nn.softplus(concentration2))` instead.",
+ warn_once=True)
def __init__(self,
concentration1,
concentration0,
diff --git a/tensorflow/python/ops/distributions/categorical.py b/tensorflow/python/ops/distributions/categorical.py
index fbbacf2521..26a3da2fb6 100644
--- a/tensorflow/python/ops/distributions/categorical.py
+++ b/tensorflow/python/ops/distributions/categorical.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -149,6 +150,14 @@ class Categorical(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(
self,
logits=None,
diff --git a/tensorflow/python/ops/distributions/dirichlet.py b/tensorflow/python/ops/distributions/dirichlet.py
index 415249a958..675c30b383 100644
--- a/tensorflow/python/ops/distributions/dirichlet.py
+++ b/tensorflow/python/ops/distributions/dirichlet.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -156,6 +157,14 @@ class Dirichlet(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
concentration,
validate_args=False,
@@ -236,7 +245,7 @@ class Dirichlet(distribution.Distribution):
def _log_unnormalized_prob(self, x):
x = self._maybe_assert_valid_sample(x)
- return math_ops.reduce_sum((self.concentration - 1.) * math_ops.log(x), -1)
+ return math_ops.reduce_sum(math_ops.xlogy(self.concentration - 1., x), -1)
def _log_normalization(self):
return special_math_ops.lbeta(self.concentration)
diff --git a/tensorflow/python/ops/distributions/dirichlet_multinomial.py b/tensorflow/python/ops/distributions/dirichlet_multinomial.py
index 5350c82847..2e3151a5ab 100644
--- a/tensorflow/python/ops/distributions/dirichlet_multinomial.py
+++ b/tensorflow/python/ops/distributions/dirichlet_multinomial.py
@@ -28,6 +28,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -163,6 +164,14 @@ class DirichletMultinomial(distribution.Distribution):
# TODO(b/27419586) Change docstring for dtype of concentration once int
# allowed.
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
total_count,
concentration,
diff --git a/tensorflow/python/ops/distributions/distribution.py b/tensorflow/python/ops/distributions/distribution.py
index 12fd039392..4741370cd8 100644
--- a/tensorflow/python/ops/distributions/distribution.py
+++ b/tensorflow/python/ops/distributions/distribution.py
@@ -34,6 +34,7 @@ from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util
+from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -229,6 +230,14 @@ class ReparameterizationType(object):
gradients / surrogate loss instead.
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self, rep_type):
self._rep_type = rep_type
@@ -405,6 +414,14 @@ class Distribution(_BaseDistribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
dtype,
reparameterization_type,
diff --git a/tensorflow/python/ops/distributions/exponential.py b/tensorflow/python/ops/distributions/exponential.py
index 4325a14449..6a52af8c33 100644
--- a/tensorflow/python/ops/distributions/exponential.py
+++ b/tensorflow/python/ops/distributions/exponential.py
@@ -27,6 +27,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import gamma
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -70,6 +71,14 @@ class Exponential(gamma.Gamma):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
rate,
validate_args=False,
@@ -114,6 +123,9 @@ class Exponential(gamma.Gamma):
def rate(self):
return self._rate
+ def _log_survival_function(self, value):
+ return self._log_prob(value) - math_ops.log(self._rate)
+
def _sample_n(self, n, seed=None):
shape = array_ops.concat([[n], array_ops.shape(self._rate)], 0)
# Uniform variates must be sampled from the open-interval `(0, 1)` rather
@@ -135,6 +147,10 @@ class Exponential(gamma.Gamma):
class ExponentialWithSoftplusRate(Exponential):
"""Exponential with softplus transform on `rate`."""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "Use `tfd.Exponential(tf.nn.softplus(rate)).",
+ warn_once=True)
def __init__(self,
rate,
validate_args=False,
diff --git a/tensorflow/python/ops/distributions/gamma.py b/tensorflow/python/ops/distributions/gamma.py
index 3293cda874..4a2db208d4 100644
--- a/tensorflow/python/ops/distributions/gamma.py
+++ b/tensorflow/python/ops/distributions/gamma.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -121,6 +122,14 @@ class Gamma(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
concentration,
rate,
@@ -225,7 +234,7 @@ class Gamma(distribution.Distribution):
def _log_unnormalized_prob(self, x):
x = self._maybe_assert_valid_sample(x)
- return (self.concentration - 1.) * math_ops.log(x) - self.rate * x
+ return math_ops.xlogy(self.concentration - 1., x) - self.rate * x
def _log_normalization(self):
return (math_ops.lgamma(self.concentration)
@@ -279,6 +288,11 @@ class Gamma(distribution.Distribution):
class GammaWithSoftplusConcentrationRate(Gamma):
"""`Gamma` with softplus of `concentration` and `rate`."""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "Use `tfd.Gamma(tf.nn.softplus(concentration), "
+ "tf.nn.softplus(rate))` instead.",
+ warn_once=True)
def __init__(self,
concentration,
rate,
diff --git a/tensorflow/python/ops/distributions/identity_bijector.py b/tensorflow/python/ops/distributions/identity_bijector.py
index 8628e68f96..eded96f5bc 100644
--- a/tensorflow/python/ops/distributions/identity_bijector.py
+++ b/tensorflow/python/ops/distributions/identity_bijector.py
@@ -20,6 +20,7 @@ from __future__ import print_function
from tensorflow.python.framework import constant_op
from tensorflow.python.ops.distributions import bijector
+from tensorflow.python.util import deprecation
__all__ = [
@@ -43,6 +44,14 @@ class Identity(bijector.Bijector):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self, validate_args=False, name="identity"):
super(Identity, self).__init__(
forward_min_event_ndims=0,
diff --git a/tensorflow/python/ops/distributions/kullback_leibler.py b/tensorflow/python/ops/distributions/kullback_leibler.py
index fdeb97bf64..12743fa23d 100644
--- a/tensorflow/python/ops/distributions/kullback_leibler.py
+++ b/tensorflow/python/ops/distributions/kullback_leibler.py
@@ -22,6 +22,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.util import deprecation
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -51,6 +52,14 @@ def _registered_kl(type_a, type_b):
return kl_fn
+@deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
@tf_export("distributions.kl_divergence")
def kl_divergence(distribution_a, distribution_b,
allow_nan_stats=True, name=None):
@@ -112,6 +121,14 @@ def kl_divergence(distribution_a, distribution_b,
return array_ops.identity(kl_t, name="checked_kl")
+@deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def cross_entropy(ref, other,
allow_nan_stats=True, name=None):
"""Computes the (Shannon) cross entropy.
@@ -155,6 +172,14 @@ class RegisterKL(object):
# Return KL(norm_a || norm_b)
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self, dist_cls_a, dist_cls_b):
"""Initialize the KL registrar.
diff --git a/tensorflow/python/ops/distributions/laplace.py b/tensorflow/python/ops/distributions/laplace.py
index be17cf2527..4f6a8f587d 100644
--- a/tensorflow/python/ops/distributions/laplace.py
+++ b/tensorflow/python/ops/distributions/laplace.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import special_math
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -71,6 +72,14 @@ class Laplace(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
loc,
scale,
@@ -211,6 +220,11 @@ class Laplace(distribution.Distribution):
class LaplaceWithSoftplusScale(Laplace):
"""Laplace with softplus applied to `scale`."""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "Use `tfd.Laplace(loc, tf.nn.softplus(scale)) "
+ "instead.",
+ warn_once=True)
def __init__(self,
loc,
scale,
diff --git a/tensorflow/python/ops/distributions/multinomial.py b/tensorflow/python/ops/distributions/multinomial.py
index d0943e8eee..8397353cd5 100644
--- a/tensorflow/python/ops/distributions/multinomial.py
+++ b/tensorflow/python/ops/distributions/multinomial.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -148,6 +149,14 @@ class Multinomial(distribution.Distribution):
```
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
total_count,
logits=None,
diff --git a/tensorflow/python/ops/distributions/normal.py b/tensorflow/python/ops/distributions/normal.py
index 2feaf806c0..9f511709b9 100644
--- a/tensorflow/python/ops/distributions/normal.py
+++ b/tensorflow/python/ops/distributions/normal.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import kullback_leibler
from tensorflow.python.ops.distributions import special_math
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -106,6 +107,14 @@ class Normal(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
loc,
scale,
@@ -240,6 +249,11 @@ class Normal(distribution.Distribution):
class NormalWithSoftplusScale(Normal):
"""Normal with softplus applied to `scale`."""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "Use `tfd.Normal(loc, tf.nn.softplus(scale)) "
+ "instead.",
+ warn_once=True)
def __init__(self,
loc,
scale,
diff --git a/tensorflow/python/ops/distributions/special_math.py b/tensorflow/python/ops/distributions/special_math.py
index 31b7a36fd3..ccc667cae3 100644
--- a/tensorflow/python/ops/distributions/special_math.py
+++ b/tensorflow/python/ops/distributions/special_math.py
@@ -12,6 +12,62 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
+
+# Functions "ndtr" and "ndtri" are derived from calculations made in:
+# https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
+# In the following email exchange, the author gives his consent to redistribute
+# derived works under an Apache 2.0 license.
+#
+# From: Stephen Moshier <steve@moshier.net>
+# Date: Sat, Jun 9, 2018 at 2:36 PM
+# Subject: Re: Licensing cephes under Apache (BSD-like) license.
+# To: rif <rif@google.com>
+#
+#
+#
+# Hello Rif,
+#
+# Yes, Google may distribute Cephes files under the Apache 2 license.
+#
+# If clarification is needed, I do not favor BSD over other free licenses.
+# I would agree that Apache 2 seems to cover the concern you mentioned
+# about sublicensees.
+#
+# Best wishes for good luck with your projects!
+# Steve Moshier
+#
+#
+#
+# On Thu, 31 May 2018, rif wrote:
+#
+# > Hello Steve.
+# > My name is Rif. I work on machine learning software at Google.
+# >
+# > Your cephes software continues to be incredibly useful and widely used. I
+# > was wondering whether it would be permissible for us to use the Cephes code
+# > under the Apache 2.0 license, which is extremely similar in permissions to
+# > the BSD license (Wikipedia comparisons). This would be quite helpful to us
+# > in terms of avoiding multiple licenses on software.
+# >
+# > I'm sorry to bother you with this (I can imagine you're sick of hearing
+# > about this by now), but I want to be absolutely clear we're on the level and
+# > not misusing your important software. In former conversation with Eugene
+# > Brevdo (ebrevdo@google.com), you wrote "If your licensing is similar to BSD,
+# > the formal way that has been handled is simply to add a statement to the
+# > effect that you are incorporating the Cephes software by permission of the
+# > author." I wanted to confirm that (a) we could use the Apache license, (b)
+# > that we don't need to (and probably you don't want to) keep getting
+# > contacted about individual uses, because your intent is generally to allow
+# > this software to be reused under "BSD-like" license, and (c) you're OK
+# > letting incorporators decide whether a license is sufficiently BSD-like?
+# >
+# > Best,
+# >
+# > rif
+# >
+# >
+# >
+
"""Special Math Ops."""
from __future__ import absolute_import
@@ -135,7 +191,7 @@ def _ndtri(p):
# Constants used in piece-wise rational approximations. Taken from the cephes
# library:
- # https://github.com/scipy/scipy/blob/master/scipy/special/cephes/ndtri.c
+ # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
p0 = list(reversed([-5.99633501014107895267E1,
9.80010754185999661536E1,
-5.66762857469070293439E1,
@@ -305,7 +361,8 @@ def log_ndtr(x, series_order=3, name="log_ndtr"):
else:
raise TypeError("x.dtype=%s is not supported." % x.dtype)
- # The basic idea here was ported from py/scipy/special/cephes/ndtr.c.
+ # The basic idea here was ported from:
+ # https://root.cern.ch/doc/v608/SpecFuncCephesInv_8cxx_source.html
# We copy the main idea, with a few changes
# * For x >> 1, and X ~ Normal(0, 1),
# Log[P[X < x]] = Log[1 - P[X < -x]] approx -P[X < -x],
diff --git a/tensorflow/python/ops/distributions/student_t.py b/tensorflow/python/ops/distributions/student_t.py
index e8d214bbe0..b69e61925c 100644
--- a/tensorflow/python/ops/distributions/student_t.py
+++ b/tensorflow/python/ops/distributions/student_t.py
@@ -33,6 +33,7 @@ from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops.distributions import distribution
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -140,6 +141,14 @@ class StudentT(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
df,
loc,
@@ -361,6 +370,11 @@ class StudentT(distribution.Distribution):
class StudentTWithAbsDfSoftplusScale(StudentT):
"""StudentT with `df = floor(abs(df))` and `scale = softplus(scale)`."""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "Use `tfd.StudentT(tf.floor(tf.abs(df)), loc, "
+ "tf.nn.softplus(scale)) instead.",
+ warn_once=True)
def __init__(self,
df,
loc,
diff --git a/tensorflow/python/ops/distributions/transformed_distribution.py b/tensorflow/python/ops/distributions/transformed_distribution.py
index e80bf9ee42..1becfc1877 100644
--- a/tensorflow/python/ops/distributions/transformed_distribution.py
+++ b/tensorflow/python/ops/distributions/transformed_distribution.py
@@ -30,6 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops.distributions import distribution as distribution_lib
from tensorflow.python.ops.distributions import identity_bijector
from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.util import deprecation
__all__ = [
"TransformedDistribution",
@@ -227,6 +228,14 @@ class TransformedDistribution(distribution_lib.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
distribution,
bijector=None,
diff --git a/tensorflow/python/ops/distributions/uniform.py b/tensorflow/python/ops/distributions/uniform.py
index e66c4a37e7..b6b24187cc 100644
--- a/tensorflow/python/ops/distributions/uniform.py
+++ b/tensorflow/python/ops/distributions/uniform.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import check_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.distributions import distribution
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -76,6 +77,14 @@ class Uniform(distribution.Distribution):
"""
+ @deprecation.deprecated(
+ "2019-01-01",
+ "The TensorFlow Distributions library has moved to "
+ "TensorFlow Probability "
+ "(https://github.com/tensorflow/probability). You "
+ "should update all references to use `tfp.distributions` "
+ "instead of `tf.distributions`.",
+ warn_once=True)
def __init__(self,
low=0.,
high=1.,
diff --git a/tensorflow/python/ops/init_ops.py b/tensorflow/python/ops/init_ops.py
index fff3d9b930..65bb77b474 100644
--- a/tensorflow/python/ops/init_ops.py
+++ b/tensorflow/python/ops/init_ops.py
@@ -43,6 +43,7 @@ from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import linalg_ops_impl
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
+from tensorflow.python.util import deprecation
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.deprecation import deprecated_arg_values
from tensorflow.python.util.tf_export import tf_export
@@ -341,6 +342,7 @@ class TruncatedNormal(Initializer):
@tf_export("initializers.uniform_unit_scaling",
"uniform_unit_scaling_initializer")
+@deprecation.deprecated_endpoints("uniform_unit_scaling_initializer")
class UniformUnitScaling(Initializer):
"""Initializer that generates tensors without scaling variance.
@@ -401,6 +403,7 @@ class UniformUnitScaling(Initializer):
@tf_export("keras.initializers.VarianceScaling",
"initializers.variance_scaling", "variance_scaling_initializer")
+@deprecation.deprecated_endpoints("variance_scaling_initializer")
class VarianceScaling(Initializer):
"""Initializer capable of adapting its scale to the shape of weights tensors.
@@ -494,6 +497,7 @@ class VarianceScaling(Initializer):
@tf_export("keras.initializers.Orthogonal", "initializers.orthogonal",
"orthogonal_initializer", "keras.initializers.orthogonal")
+@deprecation.deprecated_endpoints("orthogonal_initializer")
class Orthogonal(Initializer):
"""Initializer that generates an orthogonal matrix.
@@ -1149,6 +1153,7 @@ class GlorotUniform(VarianceScaling):
@tf_export("glorot_normal_initializer", "keras.initializers.glorot_normal",
"initializers.glorot_normal")
+@deprecation.deprecated_endpoints("glorot_normal_initializer")
class GlorotNormal(VarianceScaling):
"""The Glorot normal initializer, also called Xavier normal initializer.
diff --git a/tensorflow/python/ops/linalg_ops.py b/tensorflow/python/ops/linalg_ops.py
index f4a93560be..bf4354fa73 100644
--- a/tensorflow/python/ops/linalg_ops.py
+++ b/tensorflow/python/ops/linalg_ops.py
@@ -80,6 +80,7 @@ def _RegularizedGramianCholesky(matrix, l2_regularizer, first_kind):
@tf_export('cholesky_solve', 'linalg.cholesky_solve')
+@deprecation.deprecated_endpoints('cholesky_solve')
def cholesky_solve(chol, rhs, name=None):
"""Solves systems of linear eqns `A X = RHS`, given Cholesky factorizations.
@@ -167,7 +168,8 @@ def eye(num_rows,
name=name)
-@tf_export('matrix_solve_ls', 'linalg.lstsq')
+@tf_export('linalg.lstsq', 'matrix_solve_ls')
+@deprecation.deprecated_endpoints('matrix_solve_ls')
def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
r"""Solves one or more linear least-squares problems.
@@ -220,7 +222,7 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
squares sense.
Raises:
- NotImplementedError: matrix_solve_ls is currently disabled for complex128
+ NotImplementedError: linalg.lstsq is currently disabled for complex128
and l2_regularizer != 0 due to poor accuracy.
"""
@@ -303,7 +305,8 @@ def matrix_solve_ls(matrix, rhs, l2_regularizer=0.0, fast=True, name=None):
matrix, rhs, l2_regularizer, fast=fast, name=name)
-@tf_export('self_adjoint_eig', 'linalg.eigh')
+@tf_export('linalg.eigh', 'self_adjoint_eig')
+@deprecation.deprecated_endpoints('self_adjoint_eig')
def self_adjoint_eig(tensor, name=None):
"""Computes the eigen decomposition of a batch of self-adjoint matrices.
@@ -325,12 +328,13 @@ def self_adjoint_eig(tensor, name=None):
return e, v
-@tf_export('self_adjoint_eigvals', 'linalg.eigvalsh')
+@tf_export('linalg.eigvalsh', 'self_adjoint_eigvals')
+@deprecation.deprecated_endpoints('self_adjoint_eigvals')
def self_adjoint_eigvals(tensor, name=None):
"""Computes the eigenvalues of one or more self-adjoint matrices.
Note: If your program backpropagates through this function, you should replace
- it with a call to tf.self_adjoint_eig (possibly ignoring the second output) to
+ it with a call to tf.linalg.eigvalsh (possibly ignoring the second output) to
avoid computing the eigen decomposition twice. This is because the
eigenvectors are used to compute the gradient w.r.t. the eigenvalues. See
_SelfAdjointEigV2Grad in linalg_grad.py.
@@ -348,6 +352,7 @@ def self_adjoint_eigvals(tensor, name=None):
@tf_export('svd', 'linalg.svd')
+@deprecation.deprecated_endpoints('svd')
def svd(tensor, full_matrices=False, compute_uv=True, name=None):
r"""Computes the singular value decompositions of one or more matrices.
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index 5443699ddd..cffaa983d4 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -59,7 +59,7 @@ def initialize_all_tables(name="init_all_tables"):
return tables_initializer(name)
-@tf_export("tables_initializer")
+@tf_export("initializers.tables_initializer", "tables_initializer")
def tables_initializer(name="init_all_tables"):
"""Returns an Op that initializes all tables of the default graph.
diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py
index 6633565a64..d9d0728287 100644
--- a/tensorflow/python/ops/manip_ops.py
+++ b/tensorflow/python/ops/manip_ops.py
@@ -19,11 +19,13 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: disable=protected-access
-@tf_export('manip.roll')
+@tf_export('roll', 'manip.roll')
+@deprecation.deprecated_endpoints('manip.roll')
def roll(input, shift, axis): # pylint: disable=redefined-builtin
return _gen_manip_ops.roll(input, shift, axis)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index f57abf6704..83b8b5a3a4 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -70,7 +70,7 @@ def _set_doc(doc):
# pylint: disable=redefined-builtin
-@tf_export("argmax")
+@tf_export("math.argmax", "argmax")
@deprecation.deprecated_args(None, "Use the `axis` argument instead",
"dimension")
@_set_doc(
@@ -88,7 +88,7 @@ def argmax(input,
return gen_math_ops.arg_max(input, axis, name=name, output_type=output_type)
-@tf_export("argmin")
+@tf_export("math.argmin", "argmin")
@deprecation.deprecated_args(None, "Use the `axis` argument instead",
"dimension")
@_set_doc(
@@ -111,7 +111,7 @@ def argmin(input,
# pylint: disable=anomalous-backslash-in-string,protected-access
# pylint: disable=g-docstring-has-escape
-@tf_export("abs")
+@tf_export("math.abs", "abs")
def abs(x, name=None): # pylint: disable=redefined-builtin
r"""Computes the absolute value of a tensor.
@@ -186,7 +186,7 @@ class DivideDelegateWithName(object):
return _div_python2(self.x, y, self.name)
-@tf_export("divide")
+@tf_export("math.divide", "divide")
def divide(x, y, name=None):
"""Computes Python style division of `x` by `y`."""
@@ -198,7 +198,7 @@ def divide(x, y, name=None):
return x / y
-@tf_export("multiply")
+@tf_export("math.multiply", "multiply")
def multiply(x, y, name=None):
return gen_math_ops.mul(x, y, name)
@@ -218,7 +218,7 @@ _mul.__doc__ = (
gen_math_ops.mul.__doc__ + ("" if _mul.__doc__ is None else _mul.__doc__))
-@tf_export("subtract")
+@tf_export("math.subtract", "subtract")
def subtract(x, y, name=None):
return gen_math_ops.sub(x, y, name)
@@ -239,7 +239,7 @@ _sub.__doc__ = (
# pylint: disable=g-docstring-has-escape
-@tf_export("negative")
+@tf_export("math.negative", "negative")
def negative(x, name=None):
"""Computes numerical negative value element-wise.
@@ -288,7 +288,7 @@ def _neg(x, name=None):
# pylint: enable=g-docstring-has-escape
-@tf_export("sign")
+@tf_export("math.sign", "sign")
def sign(x, name=None):
"""Returns an element-wise indication of the sign of a number.
@@ -319,7 +319,7 @@ def sign(x, name=None):
return gen_math_ops.sign(x, name=name)
-@tf_export("square")
+@tf_export("math.square", "square")
def square(x, name=None):
r"""Computes square of x element-wise.
@@ -342,7 +342,7 @@ def square(x, name=None):
return gen_math_ops.square(x, name=name)
-@tf_export("sqrt")
+@tf_export("math.sqrt", "sqrt")
def sqrt(x, name=None):
r"""Computes square root of x element-wise.
@@ -365,7 +365,8 @@ def sqrt(x, name=None):
return gen_math_ops.sqrt(x, name=name)
-@tf_export("erf")
+@tf_export("math.erf", "erf")
+@deprecation.deprecated_endpoints("erf")
def erf(x, name=None):
"""Computes the Gauss error function of `x` element-wise.
@@ -386,7 +387,7 @@ def erf(x, name=None):
return gen_math_ops.erf(x, name=name)
-@tf_export("scalar_mul")
+@tf_export("math.scalar_mul", "scalar_mul")
def scalar_mul(scalar, x):
"""Multiplies a scalar times a `Tensor` or `IndexedSlices` object.
@@ -416,7 +417,7 @@ def scalar_mul(scalar, x):
raise ValueError("Only scalar multiply works, got shape %s" % shape)
-@tf_export("pow")
+@tf_export("math.pow", "pow")
def pow(x, y, name=None): # pylint: disable=redefined-builtin
r"""Computes the power of one value to another.
@@ -444,7 +445,7 @@ def pow(x, y, name=None): # pylint: disable=redefined-builtin
# pylint: disable=redefined-builtin,redefined-outer-name
-@tf_export("complex")
+@tf_export("dtypes.complex", "complex")
def complex(real, imag, name=None):
r"""Converts two real numbers to a complex number.
@@ -486,7 +487,8 @@ def complex(real, imag, name=None):
return gen_math_ops._complex(real, imag, Tout=Tout, name=name)
-@tf_export("real")
+@tf_export("math.real", "real")
+@deprecation.deprecated_endpoints("real")
def real(input, name=None):
r"""Returns the real part of a complex (or real) tensor.
@@ -517,7 +519,8 @@ def real(input, name=None):
return input
-@tf_export("imag")
+@tf_export("math.imag", "imag")
+@deprecation.deprecated_endpoints("imag")
def imag(input, name=None):
r"""Returns the imaginary part of a complex (or real) tensor.
@@ -547,7 +550,8 @@ def imag(input, name=None):
return array_ops.zeros_like(input)
-@tf_export("angle")
+@tf_export("math.angle", "angle")
+@deprecation.deprecated_endpoints("angle")
def angle(input, name=None):
r"""Returns the element-wise argument of a complex (or real) tensor.
@@ -586,7 +590,7 @@ def angle(input, name=None):
# pylint: enable=redefined-outer-name,redefined-builtin
-@tf_export("round")
+@tf_export("math.round", "round")
def round(x, name=None): # pylint: disable=redefined-builtin
"""Rounds the values of a tensor to the nearest integer, element-wise.
@@ -613,7 +617,7 @@ def round(x, name=None): # pylint: disable=redefined-builtin
return gen_math_ops.round(x, name=name)
-@tf_export("cast")
+@tf_export("dtypes.cast", "cast")
def cast(x, dtype, name=None):
"""Casts a tensor to a new type.
@@ -676,7 +680,7 @@ def cast(x, dtype, name=None):
return x
-@tf_export("saturate_cast")
+@tf_export("dtypes.saturate_cast", "saturate_cast")
def saturate_cast(value, dtype, name=None):
"""Performs a safe saturating cast of `value` to `dtype`.
@@ -995,7 +999,7 @@ def _div_python2(x, y, name=None):
return gen_math_ops.floor_div(x, y, name=name)
-@tf_export("truediv")
+@tf_export("math.truediv", "truediv")
def truediv(x, y, name=None):
"""Divides x / y elementwise (using Python 3 division operator semantics).
@@ -1006,7 +1010,7 @@ def truediv(x, y, name=None):
arguments are cast to floating types first. This op is generated by normal
`x / y` division in Python 3 and in Python 2.7 with
`from __future__ import division`. If you want integer division that rounds
- down, use `x // y` or `tf.floordiv`.
+ down, use `x // y` or `tf.math.floordiv`.
`x` and `y` must have the same numeric type. If the inputs are floating
point, the output will have the same type. If the inputs are integral, the
@@ -1078,7 +1082,8 @@ mod = gen_math_ops.floor_mod
# TODO(aselle): Deprecate this once all internal functionality uses
# tf.truncatediv
-@tf_export("floordiv")
+@tf_export("math.floordiv", "floordiv")
+@deprecation.deprecated_endpoints("floordiv")
def floordiv(x, y, name=None):
"""Divides `x / y` elementwise, rounding toward the most negative integer.
@@ -1151,7 +1156,8 @@ _OverrideBinaryOperatorHelper(gen_math_ops.floor_mod, "mod")
_OverrideBinaryOperatorHelper(pow, "pow")
-@tf_export("logical_xor")
+@tf_export("math.logical_xor", "logical_xor")
+@deprecation.deprecated_endpoints("logical_xor")
def logical_xor(x, y, name="LogicalXor"):
"""x ^ y = (x | y) & ~(x & y)."""
# TODO(alemi) Make this a cwise op if people end up relying on it.
@@ -1277,7 +1283,7 @@ def _may_reduce_to_scalar(keepdims, axis, reduction_indices, output):
return output
-@tf_export("reduce_sum")
+@tf_export("math.reduce_sum", "reduce_sum")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_sum(input_tensor,
@@ -1339,7 +1345,7 @@ def reduce_sum(input_tensor,
name=name))
-@tf_export("count_nonzero")
+@tf_export("math.count_nonzero", "count_nonzero")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def count_nonzero(input_tensor,
@@ -1417,7 +1423,7 @@ def count_nonzero(input_tensor,
dtype=dtype)
-@tf_export("reduce_mean")
+@tf_export("math.reduce_mean", "reduce_mean")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_mean(input_tensor,
@@ -1489,7 +1495,7 @@ def reduce_mean(input_tensor,
name=name))
-@tf_export("reduce_prod")
+@tf_export("math.reduce_prod", "reduce_prod")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_prod(input_tensor,
@@ -1539,7 +1545,7 @@ def reduce_prod(input_tensor,
name=name))
-@tf_export("reduce_min")
+@tf_export("math.reduce_min", "reduce_min")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_min(input_tensor,
@@ -1588,7 +1594,7 @@ def reduce_min(input_tensor,
name=name))
-@tf_export("reduce_max")
+@tf_export("math.reduce_max", "reduce_max")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_max(input_tensor,
@@ -1637,7 +1643,7 @@ def reduce_max(input_tensor,
name=name))
-@tf_export("reduce_all")
+@tf_export("math.reduce_all", "reduce_all")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_all(input_tensor,
@@ -1695,7 +1701,7 @@ def reduce_all(input_tensor,
name=name))
-@tf_export("reduce_any")
+@tf_export("math.reduce_any", "reduce_any")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_any(input_tensor,
@@ -1753,7 +1759,7 @@ def reduce_any(input_tensor,
name=name))
-@tf_export("reduce_logsumexp")
+@tf_export("math.reduce_logsumexp", "reduce_logsumexp")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def reduce_logsumexp(input_tensor,
@@ -1827,7 +1833,8 @@ def reduce_logsumexp(input_tensor,
return _may_reduce_to_scalar(keepdims, axis, reduction_indices, result)
-@tf_export("trace", "linalg.trace")
+@tf_export("linalg.trace", "trace")
+@deprecation.deprecated_endpoints("trace")
def trace(x, name=None):
"""Compute the trace of a tensor `x`.
@@ -1841,12 +1848,12 @@ def trace(x, name=None):
```python
x = tf.constant([[1, 2], [3, 4]])
- tf.trace(x) # 5
+ tf.linalg.trace(x) # 5
x = tf.constant([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
- tf.trace(x) # 15
+ tf.linalg.trace(x) # 15
x = tf.constant([[[1, 2, 3],
[4, 5, 6],
@@ -1854,7 +1861,7 @@ def trace(x, name=None):
[[-1, -2, -3],
[-4, -5, -6],
[-7, -8, -9]]])
- tf.trace(x) # [15, -15]
+ tf.linalg.trace(x) # [15, -15]
```
Args:
@@ -1869,7 +1876,7 @@ def trace(x, name=None):
return reduce_sum(array_ops.matrix_diag_part(x), [-1], name=name)
-@tf_export("matmul")
+@tf_export("linalg.matmul", "matmul")
def matmul(a,
b,
transpose_a=False,
@@ -2131,7 +2138,7 @@ def _as_indexed_slices_list(inputs, optimize=True):
return casted_outputs
-@tf_export("add_n")
+@tf_export("math.add_n", "add_n")
def add_n(inputs, name=None):
"""Adds all input tensors element-wise.
@@ -2166,14 +2173,15 @@ def add_n(inputs, name=None):
return gen_math_ops.add_n(inputs, name=name)
-@tf_export("accumulate_n")
+@tf_export("math.accumulate_n", "accumulate_n")
+@deprecation.deprecated_endpoints("accumulate_n")
def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
"""Returns the element-wise sum of a list of tensors.
Optionally, pass `shape` and `tensor_dtype` for shape and type checking,
otherwise, these are inferred.
- `tf.accumulate_n` performs the same operation as `tf.add_n`, but does not
+ `tf.math.accumulate_n` performs the same operation as `tf.add_n`, but does not
wait for all of its inputs to be ready before beginning to sum. This can
save memory if inputs are ready at different times, since minimum temporary
storage is proportional to the output size rather than the inputs size.
@@ -2185,10 +2193,10 @@ def accumulate_n(inputs, shape=None, tensor_dtype=None, name=None):
```python
a = tf.constant([[1, 2], [3, 4]])
b = tf.constant([[5, 0], [0, 6]])
- tf.accumulate_n([a, b, a]) # [[7, 4], [6, 14]]
+ tf.math.accumulate_n([a, b, a]) # [[7, 4], [6, 14]]
# Explicitly pass shape and type
- tf.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32)
+ tf.math.accumulate_n([a, b, a], shape=[2, 2], tensor_dtype=tf.int32)
# [[7, 4],
# [6, 14]]
```
@@ -2252,7 +2260,7 @@ def _accumulate_n_grad(op, grad):
return [grad] * len(op.inputs)
-@tf_export("nn.sigmoid", "sigmoid")
+@tf_export("math.sigmoid", "nn.sigmoid", "sigmoid")
def sigmoid(x, name=None):
"""Computes sigmoid of `x` element-wise.
@@ -2275,7 +2283,8 @@ def sigmoid(x, name=None):
return gen_math_ops.sigmoid(x, name=name)
-@tf_export("log_sigmoid")
+@tf_export("math.log_sigmoid", "log_sigmoid")
+@deprecation.deprecated_endpoints("log_sigmoid")
def log_sigmoid(x, name=None):
"""Computes log sigmoid of `x` element-wise.
@@ -2294,7 +2303,7 @@ def log_sigmoid(x, name=None):
return gen_math_ops.neg(gen_nn_ops.softplus(-x), name=name)
-@tf_export("nn.tanh", "tanh")
+@tf_export("math.tanh", "nn.tanh", "tanh")
def tanh(x, name=None):
"""Computes hyperbolic tangent of `x` element-wise.
@@ -2315,7 +2324,8 @@ def tanh(x, name=None):
return gen_math_ops.tanh(x, name=name)
-@tf_export("bincount")
+@tf_export("math.bincount", "bincount")
+@deprecation.deprecated_endpoints("bincount")
def bincount(arr,
weights=None,
minlength=None,
@@ -2362,7 +2372,7 @@ def bincount(arr,
return gen_math_ops.bincount(arr, output_size, weights)
-@tf_export("cumsum")
+@tf_export("math.cumsum", "cumsum")
def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
"""Compute the cumulative sum of the tensor `x` along `axis`.
@@ -2414,7 +2424,8 @@ def cumsum(x, axis=0, exclusive=False, reverse=False, name=None):
x, axis, exclusive=exclusive, reverse=reverse, name=name)
-@tf_export("cumprod")
+@tf_export("math.cumprod", "cumprod")
+@deprecation.deprecated_endpoints("cumprod")
def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
"""Compute the cumulative product of the tensor `x` along `axis`.
@@ -2422,7 +2433,7 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
first element of the input is identical to the first element of the output:
```python
- tf.cumprod([a, b, c]) # [a, a * b, a * b * c]
+ tf.math.cumprod([a, b, c]) # [a, a * b, a * b * c]
```
By setting the `exclusive` kwarg to `True`, an exclusive cumprod is
@@ -2430,21 +2441,21 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
instead:
```python
- tf.cumprod([a, b, c], exclusive=True) # [1, a, a * b]
+ tf.math.cumprod([a, b, c], exclusive=True) # [1, a, a * b]
```
By setting the `reverse` kwarg to `True`, the cumprod is performed in the
opposite direction:
```python
- tf.cumprod([a, b, c], reverse=True) # [a * b * c, b * c, c]
+ tf.math.cumprod([a, b, c], reverse=True) # [a * b * c, b * c, c]
```
This is more efficient than using separate `tf.reverse` ops.
The `reverse` and `exclusive` kwargs can also be combined:
```python
- tf.cumprod([a, b, c], exclusive=True, reverse=True) # [b * c, c, 1]
+ tf.math.cumprod([a, b, c], exclusive=True, reverse=True) # [b * c, c, 1]
```
Args:
@@ -2466,7 +2477,8 @@ def cumprod(x, axis=0, exclusive=False, reverse=False, name=None):
x, axis, exclusive=exclusive, reverse=reverse, name=name)
-@tf_export("conj")
+@tf_export("math.conj", "conj")
+@deprecation.deprecated_endpoints("conj")
def conj(x, name=None):
r"""Returns the complex conjugate of a complex number.
@@ -2480,7 +2492,7 @@ def conj(x, name=None):
For example:
# tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
- tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
+ tf.math.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
If `x` is real, it is returned unchanged.
@@ -2566,7 +2578,8 @@ def _unsorted_segment_N(data, segment_ids, num_segments):
return gen_math_ops.maximum(N, 1)
-@tf_export("unsorted_segment_mean")
+@tf_export("math.unsorted_segment_mean", "unsorted_segment_mean")
+@deprecation.deprecated_endpoints("unsorted_segment_mean")
def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
r"""Computes the mean along segments of a tensor.
@@ -2608,7 +2621,8 @@ def unsorted_segment_mean(data, segment_ids, num_segments, name=None):
return summed / N
-@tf_export("unsorted_segment_sqrt_n")
+@tf_export("math.unsorted_segment_sqrt_n", "unsorted_segment_sqrt_n")
+@deprecation.deprecated_endpoints("unsorted_segment_sqrt_n")
def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
r"""Computes the sum along segments of a tensor divided by the sqrt(N).
@@ -2653,7 +2667,8 @@ def unsorted_segment_sqrt_n(data, segment_ids, num_segments, name=None):
return summed / gen_math_ops.sqrt(N)
-@tf_export("sparse_segment_sum")
+@tf_export("sparse.segment_sum", "sparse_segment_sum")
+@deprecation.deprecated_endpoints("sparse_segment_sum")
def sparse_segment_sum(data, indices, segment_ids, name=None,
num_segments=None):
r"""Computes the sum along sparse segments of a tensor.
@@ -2674,16 +2689,16 @@ def sparse_segment_sum(data, indices, segment_ids, name=None,
c = tf.constant([[1,2,3,4], [-1,-2,-3,-4], [5,6,7,8]])
# Select two rows, one segment.
- tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
+ tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 0]))
# => [[0 0 0 0]]
# Select two rows, two segment.
- tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
+ tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 1]))
# => [[ 1 2 3 4]
# [-1 -2 -3 -4]]
# With missing segment ids.
- tf.sparse_segment_sum(c, tf.constant([0, 1]), tf.constant([0, 2]),
+ tf.sparse.segment_sum(c, tf.constant([0, 1]), tf.constant([0, 2]),
num_segments=4)
# => [[ 1 2 3 4]
# [ 0 0 0 0]
@@ -2691,7 +2706,7 @@ def sparse_segment_sum(data, indices, segment_ids, name=None,
# [ 0 0 0 0]]
# Select all rows, two segments.
- tf.sparse_segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
+ tf.sparse.segment_sum(c, tf.constant([0, 1, 2]), tf.constant([0, 0, 1]))
# => [[0 0 0 0]
# [5 6 7 8]]
@@ -2726,7 +2741,8 @@ def sparse_segment_sum(data, indices, segment_ids, name=None,
data=data, indices=indices, segment_ids=segment_ids, name=name)
-@tf_export("sparse_segment_mean")
+@tf_export("sparse.segment_mean", "sparse_segment_mean")
+@deprecation.deprecated_endpoints("sparse_segment_mean")
def sparse_segment_mean(data,
indices,
segment_ids,
@@ -2771,7 +2787,8 @@ def sparse_segment_mean(data,
data=data, indices=indices, segment_ids=segment_ids, name=name)
-@tf_export("sparse_segment_sqrt_n")
+@tf_export("sparse.segment_sqrt_n", "sparse_segment_sqrt_n")
+@deprecation.deprecated_endpoints("sparse_segment_sqrt_n")
def sparse_segment_sqrt_n(data,
indices,
segment_ids,
diff --git a/tensorflow/python/ops/nn_impl.py b/tensorflow/python/ops/nn_impl.py
index 2a1919e66f..453848fc00 100644
--- a/tensorflow/python/ops/nn_impl.py
+++ b/tensorflow/python/ops/nn_impl.py
@@ -328,7 +328,7 @@ def swish(features):
return features * math_ops.sigmoid(features)
-@tf_export("nn.l2_normalize")
+@tf_export("math.l2_normalize", "linalg.l2_normalize", "nn.l2_normalize")
@deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
"""Normalizes along dimension `axis` using an L2 norm.
@@ -360,7 +360,7 @@ def l2_normalize(x, axis=None, epsilon=1e-12, name=None, dim=None):
return math_ops.multiply(x, x_inv_norm, name=name)
-@tf_export("nn.zero_fraction")
+@tf_export("math.zero_fraction", "nn.zero_fraction")
def zero_fraction(value, name=None):
"""Returns the fraction of zeros in `value`.
@@ -689,7 +689,7 @@ def moments(
# Compute true mean while keeping the dims for proper broadcasting.
mean = math_ops.reduce_mean(y, axes, keepdims=True, name="mean")
# sample variance, not unbiased variance
- # Note: stop_gradient does not change the gradient that gets
+ # Note: stop_gradient does not change the gradient that gets
# backpropagated to the mean from the variance calculation,
# because that gradient is zero
variance = math_ops.reduce_mean(
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index 9ef177e97b..1fbe31a098 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -426,8 +426,8 @@ class _WithSpaceToBatch(object):
try:
input_shape.with_rank_at_least(expected_input_rank)
except ValueError:
- ValueError("input tensor must have rank %d at least" %
- (expected_input_rank))
+ raise ValueError(
+ "input tensor must have rank %d at least" % (expected_input_rank))
const_rate = tensor_util.constant_value(dilation_rate)
rate_or_const_rate = dilation_rate
@@ -817,12 +817,14 @@ class Convolution(object):
try:
input_shape.with_rank(num_spatial_dims + 2)
except ValueError:
- ValueError("input tensor must have rank %d" % (num_spatial_dims + 2))
+ raise ValueError(
+ "input tensor must have rank %d" % (num_spatial_dims + 2))
try:
filter_shape.with_rank(num_spatial_dims + 2)
except ValueError:
- ValueError("filter tensor must have rank %d" % (num_spatial_dims + 2))
+ raise ValueError(
+ "filter tensor must have rank %d" % (num_spatial_dims + 2))
if data_format is None or not data_format.startswith("NC"):
input_channels_dim = input_shape[num_spatial_dims + 1]
@@ -1692,7 +1694,7 @@ def _softmax(logits, compute_op, dim=-1, name=None):
return output
-@tf_export("nn.softmax")
+@tf_export("nn.softmax", "math.softmax")
@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def softmax(logits, axis=None, name=None, dim=None):
"""Computes softmax activations.
@@ -1722,7 +1724,7 @@ def softmax(logits, axis=None, name=None, dim=None):
return _softmax(logits, gen_nn_ops.softmax, axis, name)
-@tf_export("nn.log_softmax")
+@tf_export("nn.log_softmax", "math.log_softmax")
@deprecation.deprecated_args(None, "dim is deprecated, use axis instead", "dim")
def log_softmax(logits, axis=None, name=None, dim=None):
"""Computes log softmax activations.
@@ -2329,7 +2331,7 @@ def dropout(x, keep_prob, noise_shape=None, seed=None, name=None): # pylint: di
return ret
-@tf_export("nn.top_k")
+@tf_export("math.top_k", "nn.top_k")
def top_k(input, k=1, sorted=True, name=None): # pylint: disable=redefined-builtin
"""Finds values and indices of the `k` largest entries for the last dimension.
@@ -2644,7 +2646,7 @@ def erosion2d(value, kernel, strides, rates, padding, name=None):
name=name))
-@tf_export("nn.in_top_k")
+@tf_export("math.in_top_k", "nn.in_top_k")
def in_top_k(predictions, targets, k, name=None):
r"""Says whether the targets are in the top `K` predictions.
diff --git a/tensorflow/python/ops/numerics.py b/tensorflow/python/ops/numerics.py
index 8fcbd7d834..002e87b411 100644
--- a/tensorflow/python/ops/numerics.py
+++ b/tensorflow/python/ops/numerics.py
@@ -24,10 +24,12 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
-@tf_export("verify_tensor_all_finite")
+@tf_export("debugging.assert_all_finite", "verify_tensor_all_finite")
+@deprecation.deprecated_endpoints("verify_tensor_all_finite")
def verify_tensor_all_finite(t, msg, name=None):
"""Assert that the tensor does not contain any NaN's or Inf's.
diff --git a/tensorflow/python/ops/parallel_for/pfor.py b/tensorflow/python/ops/parallel_for/pfor.py
index e0f6d51881..83cbe64ff2 100644
--- a/tensorflow/python/ops/parallel_for/pfor.py
+++ b/tensorflow/python/ops/parallel_for/pfor.py
@@ -1987,14 +1987,12 @@ def _convert_cast(pfor_input):
@RegisterPForWithArgs("Pow", math_ops.pow)
@RegisterPForWithArgs("RealDiv", math_ops.divide)
@RegisterPForWithArgs("Real", math_ops.real)
-@RegisterPForWithArgs("ReciprocalGrad", math_ops.reciprocal_grad)
@RegisterPForWithArgs("Reciprocal", math_ops.reciprocal)
@RegisterPForWithArgs("Relu6", nn_ops.relu6)
@RegisterPForWithArgs("Relu", nn_ops.relu)
@RegisterPForWithArgs("RightShift", bitwise_ops.right_shift)
@RegisterPForWithArgs("Rint", math_ops.rint)
@RegisterPForWithArgs("Round", math_ops.round)
-@RegisterPForWithArgs("RsqrtGrad", math_ops.rsqrt_grad)
@RegisterPForWithArgs("Rsqrt", math_ops.rsqrt)
@RegisterPForWithArgs("Selu", nn_ops.selu)
@RegisterPForWithArgs("Sigmoid", math_ops.sigmoid)
@@ -2003,7 +2001,6 @@ def _convert_cast(pfor_input):
@RegisterPForWithArgs("Sin", math_ops.sin)
@RegisterPForWithArgs("Softplus", nn_ops.softplus)
@RegisterPForWithArgs("Softsign", nn_ops.softsign)
-@RegisterPForWithArgs("SqrtGrad", math_ops.sqrt_grad)
@RegisterPForWithArgs("Sqrt", math_ops.sqrt)
@RegisterPForWithArgs("SquaredDifference", math_ops.squared_difference)
@RegisterPForWithArgs("Square", math_ops.square)
@@ -2095,6 +2092,9 @@ def _convert_biasaddgrad(pfor_input):
@RegisterPForWithArgs("SoftplusGrad")
@RegisterPForWithArgs("SoftsignGrad")
@RegisterPForWithArgs("TanhGrad")
+@RegisterPForWithArgs("SqrtGrad")
+@RegisterPForWithArgs("RsqrtGrad")
+@RegisterPForWithArgs("ReciprocalGrad")
def _convert_grads(pfor_input, op_type, *args, **kw_args):
del args
del kw_args
diff --git a/tensorflow/python/ops/parsing_ops.py b/tensorflow/python/ops/parsing_ops.py
index b3e03a0135..ff50fe0d09 100644
--- a/tensorflow/python/ops/parsing_ops.py
+++ b/tensorflow/python/ops/parsing_ops.py
@@ -36,6 +36,7 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops.gen_parsing_ops import *
# pylint: enable=wildcard-import,undefined-variable
from tensorflow.python.platform import tf_logging
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -45,7 +46,7 @@ ops.NotDifferentiable("SerializeTensor")
ops.NotDifferentiable("StringToNumber")
-@tf_export("VarLenFeature")
+@tf_export("io.VarLenFeature", "VarLenFeature")
class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])):
"""Configuration for parsing a variable-length input feature.
@@ -55,7 +56,7 @@ class VarLenFeature(collections.namedtuple("VarLenFeature", ["dtype"])):
pass
-@tf_export("SparseFeature")
+@tf_export("io.SparseFeature", "SparseFeature")
class SparseFeature(
collections.namedtuple(
"SparseFeature",
@@ -130,7 +131,7 @@ class SparseFeature(
cls, index_key, value_key, dtype, size, already_sorted)
-@tf_export("FixedLenFeature")
+@tf_export("io.FixedLenFeature", "FixedLenFeature")
class FixedLenFeature(collections.namedtuple(
"FixedLenFeature", ["shape", "dtype", "default_value"])):
"""Configuration for parsing a fixed-length input feature.
@@ -150,7 +151,7 @@ class FixedLenFeature(collections.namedtuple(
cls, shape, dtype, default_value)
-@tf_export("FixedLenSequenceFeature")
+@tf_export("io.FixedLenSequenceFeature", "FixedLenSequenceFeature")
class FixedLenSequenceFeature(collections.namedtuple(
"FixedLenSequenceFeature",
["shape", "dtype", "allow_missing", "default_value"])):
@@ -360,7 +361,7 @@ def _prepend_none_dimension(features):
return features
-@tf_export("parse_example")
+@tf_export("io.parse_example", "parse_example")
def parse_example(serialized, features, name=None, example_names=None):
# pylint: disable=line-too-long
"""Parses `Example` protos into a `dict` of tensors.
@@ -761,7 +762,7 @@ def _process_raw_parameters(names, dense_defaults, sparse_keys, sparse_types,
dense_shapes_as_proto, dense_shapes)
-@tf_export("parse_single_example")
+@tf_export("io.parse_single_example", "parse_single_example")
def parse_single_example(serialized, features, name=None, example_names=None):
"""Parses a single `Example` proto.
@@ -1244,7 +1245,7 @@ def _parse_sequence_example_raw(serialized,
# TODO(sundberg): rewrite this method to call the batch version, which is more
# efficient especially for large inputs.
-@tf_export("parse_single_sequence_example")
+@tf_export("io.parse_single_sequence_example", "parse_single_sequence_example")
def parse_single_sequence_example(
serialized, context_features=None, sequence_features=None,
example_name=None, name=None):
@@ -1564,7 +1565,8 @@ def _parse_single_sequence_example_raw(serialized,
# Swap `name` and `na_value` for backward compatibility.
-@tf_export("decode_csv")
+@tf_export("io.decode_csv", "decode_csv")
+@deprecation.deprecated_endpoints("decode_csv")
def decode_csv(records,
record_defaults,
field_delim=",",
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index 4baf506385..c2eb9dfc5d 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -29,6 +29,7 @@ from tensorflow.python.ops import math_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_random_ops import *
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# pylint: enable=wildcard-import
@@ -43,7 +44,7 @@ def _ShapeTensor(shape):
return ops.convert_to_tensor(shape, dtype=dtype, name="shape")
-@tf_export("random_normal")
+@tf_export("random.normal", "random_normal")
def random_normal(shape,
mean=0.0,
stddev=1.0,
@@ -136,7 +137,7 @@ def parameterized_truncated_normal(shape,
return rnd
-@tf_export("truncated_normal")
+@tf_export("random.truncated_normal", "truncated_normal")
def truncated_normal(shape,
mean=0.0,
stddev=1.0,
@@ -181,7 +182,7 @@ ops.NotDifferentiable("ParameterizedTruncatedNormal")
ops.NotDifferentiable("TruncatedNormal")
-@tf_export("random_uniform")
+@tf_export("random.uniform", "random_uniform")
def random_uniform(shape,
minval=0,
maxval=None,
@@ -246,7 +247,7 @@ def random_uniform(shape,
ops.NotDifferentiable("RandomUniform")
-@tf_export("random_shuffle")
+@tf_export("random.shuffle", "random_shuffle")
def random_shuffle(value, seed=None, name=None):
"""Randomly shuffles a tensor along its first dimension.
@@ -277,7 +278,7 @@ def random_shuffle(value, seed=None, name=None):
value, seed=seed1, seed2=seed2, name=name)
-@tf_export("random_crop")
+@tf_export("image.random_crop", "random_crop")
def random_crop(value, size, seed=None, name=None):
"""Randomly crops a tensor to a given size.
@@ -320,7 +321,7 @@ def random_crop(value, size, seed=None, name=None):
return array_ops.slice(value, offset, size, name=name)
-@tf_export("multinomial")
+@tf_export("random.multinomial", "multinomial")
def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None):
"""Draws samples from a multinomial distribution.
@@ -356,7 +357,8 @@ def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None):
ops.NotDifferentiable("Multinomial")
-@tf_export("random_gamma")
+@tf_export("random.gamma", "random_gamma")
+@deprecation.deprecated_endpoints("random_gamma")
def random_gamma(shape,
alpha,
beta=None,
@@ -439,7 +441,8 @@ def random_gamma(shape,
shape, alpha_broadcast, seed=seed1, seed2=seed2) / beta)
-@tf_export("random_poisson")
+@tf_export("random.poisson", "random_poisson")
+@deprecation.deprecated_endpoints("random_poisson")
def random_poisson(lam, shape, dtype=dtypes.float32, seed=None, name=None):
"""Draws `shape` samples from each of the given Poisson distribution(s).
diff --git a/tensorflow/python/ops/rnn_cell_impl.py b/tensorflow/python/ops/rnn_cell_impl.py
index 43cca1a498..dd4f3d7a99 100644
--- a/tensorflow/python/ops/rnn_cell_impl.py
+++ b/tensorflow/python/ops/rnn_cell_impl.py
@@ -370,7 +370,7 @@ class LayerRNNCell(RNNCell):
*args, **kwargs)
-@tf_export("nn.rnn_cell.BasicRNNCell")
+@tf_export(v1=["nn.rnn_cell.BasicRNNCell"])
class BasicRNNCell(LayerRNNCell):
"""The most basic RNN cell.
@@ -393,6 +393,8 @@ class BasicRNNCell(LayerRNNCell):
`trainable` etc when constructing the cell from configs of get_config().
"""
+ @deprecated(None, "This class is equivalent as tf.keras.layers.SimpleRNNCell,"
+ " and will be replaced by that in Tensorflow 2.0.")
def __init__(self,
num_units,
activation=None,
@@ -611,7 +613,7 @@ class LSTMStateTuple(_LSTMStateTuple):
# TODO(scottzhu): Stop exporting this class in TF 2.0.
@tf_export("nn.rnn_cell.BasicLSTMCell")
class BasicLSTMCell(LayerRNNCell):
- """DEPRECATED: Please use @{tf.nn.rnn_cell.LSTMCell} instead.
+ """DEPRECATED: Please use `tf.nn.rnn_cell.LSTMCell` instead.
Basic LSTM recurrent network cell.
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index 400a42a3c0..7e3dbdbad4 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -185,7 +185,8 @@ def sparse_eye(num_rows,
# pylint: disable=protected-access
-@tf_export("sparse_concat")
+@tf_export("sparse.concat", "sparse_concat")
+@deprecation.deprecated_endpoints("sparse_concat")
@deprecation.deprecated_args(
None, "concat_dim is deprecated, use axis instead", "concat_dim")
def sparse_concat(axis,
@@ -317,7 +318,8 @@ def sparse_concat(axis,
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
-@tf_export("sparse_add")
+@tf_export("sparse.add", "sparse_add")
+@deprecation.deprecated_endpoints("sparse_add")
def sparse_add(a, b, thresh=0):
"""Adds two tensors, at least one of each is a `SparseTensor`.
@@ -557,7 +559,8 @@ def sparse_dense_cwise_add(sp_t, dense_t):
return sparse_tensor.SparseTensor(sp_t.indices, result, sp_t.dense_shape)
-@tf_export("sparse_reorder")
+@tf_export("sparse.reorder", "sparse_reorder")
+@deprecation.deprecated_endpoints("sparse_reorder")
def sparse_reorder(sp_input, name=None):
"""Reorders a `SparseTensor` into the canonical, row-major ordering.
@@ -607,7 +610,8 @@ def sparse_reorder(sp_input, name=None):
return sparse_tensor.SparseTensor(reordered_ind, reordered_val, dense_shape)
-@tf_export("sparse_reshape")
+@tf_export("sparse.reshape", "sparse_reshape")
+@deprecation.deprecated_endpoints("sparse_reshape")
def sparse_reshape(sp_input, shape, name=None):
"""Reshapes a `SparseTensor` to represent values in a new dense shape.
@@ -700,7 +704,8 @@ class KeywordRequired(object):
return "KeywordRequired()"
-@tf_export("sparse_split")
+@tf_export("sparse.split", "sparse_split")
+@deprecation.deprecated_endpoints("sparse_split")
@deprecation.deprecated_args(
None, "split_dim is deprecated, use axis instead", "split_dim")
def sparse_split(keyword_required=KeywordRequired(),
@@ -773,7 +778,8 @@ def sparse_split(keyword_required=KeywordRequired(),
return sparse_tensors
-@tf_export("sparse_slice")
+@tf_export("sparse.slice", "sparse_slice")
+@deprecation.deprecated_endpoints("sparse_slice")
def sparse_slice(sp_input, start, size, name=None):
"""Slice a `SparseTensor` based on the `start` and `size.
@@ -785,11 +791,11 @@ def sparse_slice(sp_input, start, size, name=None):
Graphically the output tensors are:
- sparse_slice([0, 0], [2, 4]) = shape = [2, 4]
+ sparse.slice([0, 0], [2, 4]) = shape = [2, 4]
[ a ]
[b c ]
- sparse_slice([0, 4], [2, 3]) = shape = [2, 3]
+ sparse.slice([0, 4], [2, 3]) = shape = [2, 3]
[ d e ]
[ ]
@@ -823,6 +829,9 @@ def sparse_slice(sp_input, start, size, name=None):
@tf_export("sparse_to_dense")
+@deprecation.deprecated(
+ None,
+ "Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.")
def sparse_to_dense(sparse_indices,
output_shape,
sparse_values,
@@ -878,7 +887,8 @@ def sparse_to_dense(sparse_indices,
name=name)
-@tf_export("sparse_reduce_max")
+@tf_export("sparse.reduce_max", "sparse_reduce_max")
+@deprecation.deprecated_endpoints("sparse_reduce_max")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def sparse_reduce_max(sp_input, axis=None, keepdims=None,
@@ -912,16 +922,16 @@ def sparse_reduce_max(sp_input, axis=None, keepdims=None,
# 'x' represents [[1, ?, 2]
# [?, 3, ?]]
# where ? is implicitly-zero.
- tf.sparse_reduce_max(x) ==> 3
- tf.sparse_reduce_max(x, 0) ==> [1, 3, 2]
- tf.sparse_reduce_max(x, 1) ==> [2, 3] # Can also use -1 as the axis.
- tf.sparse_reduce_max(x, 1, keepdims=True) ==> [[2], [3]]
- tf.sparse_reduce_max(x, [0, 1]) ==> 3
+ tf.sparse.reduce_max(x) ==> 3
+ tf.sparse.reduce_max(x, 0) ==> [1, 3, 2]
+ tf.sparse.reduce_max(x, 1) ==> [2, 3] # Can also use -1 as the axis.
+ tf.sparse.reduce_max(x, 1, keepdims=True) ==> [[2], [3]]
+ tf.sparse.reduce_max(x, [0, 1]) ==> 3
# 'y' represents [[-7, ?]
# [ 4, 3]
# [ ?, ?]
- tf.sparse_reduce_max(x, 1) ==> [-7, 4, 0]
+ tf.sparse.reduce_max(x, 1) ==> [-7, 4, 0]
```
Args:
@@ -945,7 +955,8 @@ def sparse_reduce_max(sp_input, axis=None, keepdims=None,
math_ops._ReductionDims(sp_input, axis, reduction_axes), keepdims)
-@tf_export("sparse_reduce_max_sparse")
+@tf_export("sparse.reduce_max_sparse", "sparse_reduce_max_sparse")
+@deprecation.deprecated_endpoints("sparse_reduce_max_sparse")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def sparse_reduce_max_sparse(sp_input,
@@ -995,7 +1006,8 @@ def sparse_reduce_max_sparse(sp_input,
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
-@tf_export("sparse_reduce_sum")
+@tf_export("sparse.reduce_sum", "sparse_reduce_sum")
+@deprecation.deprecated_endpoints("sparse_reduce_sum")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def sparse_reduce_sum(sp_input, axis=None, keepdims=None,
@@ -1021,11 +1033,11 @@ def sparse_reduce_sum(sp_input, axis=None, keepdims=None,
# 'x' represents [[1, ?, 1]
# [?, 1, ?]]
# where ? is implicitly-zero.
- tf.sparse_reduce_sum(x) ==> 3
- tf.sparse_reduce_sum(x, 0) ==> [1, 1, 1]
- tf.sparse_reduce_sum(x, 1) ==> [2, 1] # Can also use -1 as the axis.
- tf.sparse_reduce_sum(x, 1, keepdims=True) ==> [[2], [1]]
- tf.sparse_reduce_sum(x, [0, 1]) ==> 3
+ tf.sparse.reduce_sum(x) ==> 3
+ tf.sparse.reduce_sum(x, 0) ==> [1, 1, 1]
+ tf.sparse.reduce_sum(x, 1) ==> [2, 1] # Can also use -1 as the axis.
+ tf.sparse.reduce_sum(x, 1, keepdims=True) ==> [[2], [1]]
+ tf.sparse.reduce_sum(x, [0, 1]) ==> 3
```
Args:
@@ -1049,7 +1061,8 @@ def sparse_reduce_sum(sp_input, axis=None, keepdims=None,
math_ops._ReductionDims(sp_input, axis, reduction_axes), keepdims)
-@tf_export("sparse_reduce_sum_sparse")
+@tf_export("sparse.reduce_sum_sparse", "sparse_reduce_sum_sparse")
+@deprecation.deprecated_endpoints("sparse_reduce_sum_sparse")
@deprecation.deprecated_args(
None, "keep_dims is deprecated, use keepdims instead", "keep_dims")
def sparse_reduce_sum_sparse(sp_input,
@@ -1099,7 +1112,8 @@ def sparse_reduce_sum_sparse(sp_input,
return sparse_tensor.SparseTensor(output_ind, output_val, output_shape)
-@tf_export("sparse_tensor_to_dense")
+@tf_export("sparse.to_dense", "sparse_tensor_to_dense")
+@deprecation.deprecated_endpoints("sparse_tensor_to_dense")
def sparse_tensor_to_dense(sp_input,
default_value=0,
validate_indices=True,
@@ -1151,7 +1165,8 @@ def sparse_tensor_to_dense(sp_input,
name=name)
-@tf_export("sparse_to_indicator")
+@tf_export("sparse.to_indicator", "sparse_to_indicator")
+@deprecation.deprecated_endpoints("sparse_to_indicator")
def sparse_to_indicator(sp_input, vocab_size, name=None):
"""Converts a `SparseTensor` of ids into a dense bool indicator tensor.
@@ -1214,7 +1229,8 @@ def sparse_to_indicator(sp_input, vocab_size, name=None):
sp_new, default_value=False, validate_indices=False, name=name)
-@tf_export("sparse_merge")
+@tf_export("sparse.merge", "sparse_merge")
+@deprecation.deprecated_endpoints("sparse_merge")
def sparse_merge(sp_ids, sp_values, vocab_size, name=None,
already_sorted=False):
"""Combines a batch of feature ids and values into a single `SparseTensor`.
@@ -1358,7 +1374,8 @@ def sparse_merge(sp_ids, sp_values, vocab_size, name=None,
sorted_result.indices, sorted_result.values, new_shape)
-@tf_export("sparse_retain")
+@tf_export("sparse.retain", "sparse_retain")
+@deprecation.deprecated_endpoints("sparse_retain")
def sparse_retain(sp_input, to_retain):
"""Retains specified non-empty values within a `SparseTensor`.
@@ -1402,7 +1419,8 @@ def sparse_retain(sp_input, to_retain):
array_ops.identity(sp_input.dense_shape))
-@tf_export("sparse_reset_shape")
+@tf_export("sparse.reset_shape", "sparse_reset_shape")
+@deprecation.deprecated_endpoints("sparse_reset_shape")
def sparse_reset_shape(sp_input, new_shape=None):
"""Resets the shape of a `SparseTensor` with indices and values unchanged.
@@ -1503,7 +1521,8 @@ def sparse_reset_shape(sp_input, new_shape=None):
return sparse_tensor.SparseTensor(in_indices, in_values, output_shape_tensor)
-@tf_export("sparse_fill_empty_rows")
+@tf_export("sparse.fill_empty_rows", "sparse_fill_empty_rows")
+@deprecation.deprecated_endpoints("sparse_fill_empty_rows")
def sparse_fill_empty_rows(sp_input, default_value, name=None):
"""Fills empty rows in the input 2-D `SparseTensor` with a default value.
@@ -1567,7 +1586,8 @@ def sparse_fill_empty_rows(sp_input, default_value, name=None):
dense_shape=sp_input.dense_shape), empty_row_indicator)
-@tf_export("serialize_sparse")
+@tf_export("io.serialize_sparse", "serialize_sparse")
+@deprecation.deprecated_endpoints("serialize_sparse")
def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize a `SparseTensor` into a 3-vector (1-D `Tensor`) object.
@@ -1593,7 +1613,8 @@ def serialize_sparse(sp_input, name=None, out_type=dtypes.string):
out_type=out_type)
-@tf_export("serialize_many_sparse")
+@tf_export("io.serialize_many_sparse", "serialize_many_sparse")
+@deprecation.deprecated_endpoints("serialize_many_sparse")
def serialize_many_sparse(sp_input, name=None, out_type=dtypes.string):
"""Serialize `N`-minibatch `SparseTensor` into an `[N, 3]` `Tensor`.
@@ -1694,7 +1715,8 @@ def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None):
return sparse_tensor.SparseTensor(output_indices, output_values, output_shape)
-@tf_export("deserialize_many_sparse")
+@tf_export("io.deserialize_many_sparse", "deserialize_many_sparse")
+@deprecation.deprecated_endpoints("deserialize_many_sparse")
def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None):
"""Deserialize and concatenate `SparseTensors` from a serialized minibatch.
@@ -1712,7 +1734,7 @@ def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None):
The input `SparseTensor` objects' indices are assumed ordered in
standard lexicographic order. If this is not the case, after this
- step run `sparse_reorder` to restore index ordering.
+ step run `sparse.reorder` to restore index ordering.
For example, if the serialized input is a `[2, 3]` matrix representing two
original `SparseTensor` objects:
@@ -1764,7 +1786,8 @@ def deserialize_many_sparse(serialized_sparse, dtype, rank=None, name=None):
return sparse_tensor.SparseTensor(output_indices, output_values, output_shape)
-@tf_export("sparse_tensor_dense_matmul")
+@tf_export("sparse.matmul", "sparse_tensor_dense_matmul")
+@deprecation.deprecated_endpoints("sparse_tensor_dense_matmul")
def sparse_tensor_dense_matmul(sp_a,
b,
adjoint_a=False,
@@ -1777,7 +1800,7 @@ def sparse_tensor_dense_matmul(sp_a,
following input format is recommended for optimal behavior:
* If `adjoint_a == false`: `A` should be sorted in lexicographically
- increasing order. Use `sparse_reorder` if you're not sure.
+ increasing order. Use `sparse.reorder` if you're not sure.
* If `adjoint_a == true`: `A` should be sorted in order of increasing
dimension 1 (i.e., "column major" order instead of "row major" order).
@@ -1981,7 +2004,8 @@ def sparse_tensor_dense_matmul(sp_a,
adjoint_b=adjoint_b)
-@tf_export("sparse_softmax")
+@tf_export("sparse.softmax", "sparse_softmax")
+@deprecation.deprecated_endpoints("sparse_softmax")
def sparse_softmax(sp_input, name=None):
"""Applies softmax to a batched N-D `SparseTensor`.
@@ -2036,7 +2060,8 @@ def sparse_softmax(sp_input, name=None):
sp_input.dense_shape)
-@tf_export("sparse_maximum")
+@tf_export("sparse.maximum", "sparse_maximum")
+@deprecation.deprecated_endpoints("sparse_maximum")
def sparse_maximum(sp_a, sp_b, name=None):
"""Returns the element-wise max of two SparseTensors.
@@ -2073,7 +2098,8 @@ def sparse_maximum(sp_a, sp_b, name=None):
return sparse_tensor.SparseTensor(out_indices, out_values, sp_a.dense_shape)
-@tf_export("sparse_minimum")
+@tf_export("sparse.minimum", "sparse_minimum")
+@deprecation.deprecated_endpoints("sparse_minimum")
def sparse_minimum(sp_a, sp_b, name=None):
"""Returns the element-wise min of two SparseTensors.
@@ -2110,7 +2136,8 @@ def sparse_minimum(sp_a, sp_b, name=None):
return sparse_tensor.SparseTensor(out_indices, out_values, sp_a.dense_shape)
-@tf_export("sparse_transpose")
+@tf_export("sparse.transpose", "sparse_transpose")
+@deprecation.deprecated_endpoints("sparse_transpose")
def sparse_transpose(sp_input, perm=None, name=None):
"""Transposes a `SparseTensor`
@@ -2259,7 +2286,7 @@ def _take_many_sparse_from_tensors_map(sparse_map_op,
The input `SparseTensor` objects' indices are assumed ordered in
standard lexicographic order. If this is not the case, after this
- step run `sparse_reorder` to restore index ordering.
+ step run `sparse.reorder` to restore index ordering.
For example, if the serialized input is a `[2, 3]` matrix representing two
original `SparseTensor` objects:
diff --git a/tensorflow/python/ops/special_math_ops.py b/tensorflow/python/ops/special_math_ops.py
index 9a10abfcf7..cfab943896 100644
--- a/tensorflow/python/ops/special_math_ops.py
+++ b/tensorflow/python/ops/special_math_ops.py
@@ -29,11 +29,13 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# TODO(b/27419586) Change docstring for required dtype of x once int allowed
-@tf_export('lbeta')
+@tf_export('math.lbeta', 'lbeta')
+@deprecation.deprecated_endpoints('lbeta')
def lbeta(x, name=None):
r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension.
diff --git a/tensorflow/python/ops/string_ops.py b/tensorflow/python/ops/string_ops.py
index 046a48d192..0812f901a2 100644
--- a/tensorflow/python/ops/string_ops.py
+++ b/tensorflow/python/ops/string_ops.py
@@ -46,6 +46,7 @@ from tensorflow.python.util.tf_export import tf_export
# pylint: disable=redefined-builtin
+@tf_export("strings.regex_full_match")
def regex_full_match(input, pattern, name=None):
r"""Match elements of `input` with regex `pattern`.
@@ -73,15 +74,14 @@ def regex_full_match(input, pattern, name=None):
regex_full_match.__doc__ = gen_string_ops.regex_full_match.__doc__
-# Expose regex_full_match in strings namespace
-tf_export("strings.regex_full_match")(regex_full_match)
-
-def regex_replace(source, pattern, rewrite, replace_global=True):
- r"""Replace elements of `source` matching regex `pattern` with `rewrite`.
+@tf_export("strings.regex_replace", "regex_replace")
+@deprecation.deprecated_endpoints("regex_replace")
+def regex_replace(input, pattern, rewrite, replace_global=True, name=None):
+ r"""Replace elements of `input` matching regex `pattern` with `rewrite`.
Args:
- source: string `Tensor`, the source strings to process.
+ input: string `Tensor`, the source strings to process.
pattern: string or scalar string `Tensor`, regular expression to use,
see more details at https://github.com/google/re2/wiki/Syntax
rewrite: string or scalar string `Tensor`, value to use in match
@@ -89,9 +89,10 @@ def regex_replace(source, pattern, rewrite, replace_global=True):
text matching corresponding parenthesized group.
replace_global: `bool`, if `True` replace all non-overlapping matches,
else replace only the first match.
+ name: A name for the operation (optional).
Returns:
- string `Tensor` of the same shape as `source` with specified replacements.
+ string `Tensor` of the same shape as `input` with specified replacements.
"""
if (isinstance(pattern, util_compat.bytes_or_text_types) and
isinstance(rewrite, util_compat.bytes_or_text_types)):
@@ -99,11 +100,13 @@ def regex_replace(source, pattern, rewrite, replace_global=True):
# use a version which performs the expensive regex compilation once at
# creation time.
return gen_string_ops.static_regex_replace(
- input=source, pattern=pattern,
- rewrite=rewrite, replace_global=replace_global)
+ input=input, pattern=pattern,
+ rewrite=rewrite, replace_global=replace_global,
+ name=name)
return gen_string_ops.regex_replace(
- input=source, pattern=pattern,
- rewrite=rewrite, replace_global=replace_global)
+ input=input, pattern=pattern,
+ rewrite=rewrite, replace_global=replace_global,
+ name=name)
@tf_export("strings.format")
@@ -310,8 +313,9 @@ def _reduce_join_reduction_dims(x, axis, reduction_indices):
return math_ops.range(array_ops.rank(x) - 1, -1, -1)
-@tf_export("reduce_join")
-def reduce_join(inputs, axis=None,
+@tf_export("strings.reduce_join", "reduce_join")
+@deprecation.deprecated_endpoints("reduce_join")
+def reduce_join(inputs, axis=None, # pylint: disable=missing-docstring
keep_dims=False,
separator="",
name=None,
@@ -329,6 +333,8 @@ def reduce_join(inputs, axis=None,
reduce_join.__doc__ = deprecation.rewrite_argument_docstring(
gen_string_ops.reduce_join.__doc__, "reduction_indices", "axis")
+reduce_join.__doc__ = reduce_join.__doc__.replace("tf.reduce_join(",
+ "tf.strings.reduce_join(")
# This wrapper provides backwards compatibility for code that predates the
diff --git a/tensorflow/python/ops/variable_scope.py b/tensorflow/python/ops/variable_scope.py
index af5c7d4050..5032ca79f9 100644
--- a/tensorflow/python/ops/variable_scope.py
+++ b/tensorflow/python/ops/variable_scope.py
@@ -939,7 +939,8 @@ class _VariableStore(object):
if regularizer:
with ops.colocate_with(v):
with ops.name_scope(name + "/Regularizer/"):
- loss = regularizer(v)
+ with ops.init_scope():
+ loss = regularizer(v)
if loss is not None:
if context.executing_eagerly():
v_name = "v_%s" % type(v)
diff --git a/tensorflow/python/ops/variables.py b/tensorflow/python/ops/variables.py
index 8da1e9fe56..45c8618610 100644
--- a/tensorflow/python/ops/variables.py
+++ b/tensorflow/python/ops/variables.py
@@ -2620,10 +2620,50 @@ class PartitionedVariable(object):
def _get_partitions(self):
return self._partitions
- def assign(self, value, use_locking=False):
- _ = value, use_locking
- raise NotImplementedError(
- "assign() has not been implemented for PartitionedVariable.")
+ def _apply_assign_fn(self, assign_fn, value):
+ partition_axes = self._partition_axes()
+ if len(partition_axes) > 1:
+ raise NotImplementedError(
+ "Cannot do assign action along more than one dimension: %s. "
+ "Multi-axis partition assign action is not supported " %
+ str(partition_axes))
+ partition_ix = partition_axes[0]
+ size_splits_list = [
+ var.shape[partition_ix].value for var in self._variable_list
+ ]
+ value_list = array_ops.split(value, size_splits_list, axis=partition_ix)
+ op_list = [
+ assign_fn(var, value_list[idx], idx)
+ for idx, var in enumerate(self._variable_list)
+ ]
+ return op_list
+
+ def assign(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
+
+ def assign_add(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign_add(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
+
+ def assign_sub(self, value, use_locking=False, name=None, read_value=True):
+ assign_fn = lambda var, r_value, idx: var.assign_sub(
+ r_value, use_locking=use_locking,
+ name="%s_%d" % (name, idx), read_value=read_value)
+ assign_list = self._apply_assign_fn(assign_fn, value)
+ if read_value:
+ return assign_list
+ return [assign.op for assign in assign_list]
@tf_export(v1=["global_variables"])
diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py
index 875be31602..8e88a84d60 100644
--- a/tensorflow/python/ops/while_v2.py
+++ b/tensorflow/python/ops/while_v2.py
@@ -24,6 +24,7 @@ from __future__ import division
from __future__ import print_function
import collections
+import sys
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.eager import function
@@ -31,8 +32,10 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import function_def_to_graph
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
+from tensorflow.python.framework import tensor_spec
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import cond_v2_impl as cond_v2
+from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import control_flow_util
from tensorflow.python.ops import gen_functional_ops
from tensorflow.python.ops import gradients_impl
@@ -41,6 +44,8 @@ from tensorflow.python.util import nest
# pylint: disable=protected-access
+control_flow_ops._while_v2 = sys.modules[__name__]
+
# TODO(b/79881896): Handle external control dependencies. tf.while_loop allows
# control dependencies on external nodes with at least 1 output.
# Another idea is to create const nodes outside the loop and add control edges
@@ -48,8 +53,17 @@ from tensorflow.python.util import nest
# handled in the CapturingGraph itself.
-def while_loop(cond, body, loop_vars, name=None):
+def while_loop(cond, body, loop_vars, shape_invariants=None, name=None):
"""Like tf.while_loop, except emits a single While op."""
+ flattened_loop_vars = nest.flatten(loop_vars)
+ if shape_invariants is not None:
+ nest.assert_same_structure(loop_vars, shape_invariants)
+ flattened_shapes = nest.flatten(shape_invariants)
+ else:
+ flattened_shapes = [t.shape for t in flattened_loop_vars]
+
+ del shape_invariants
+
if not name:
name = "while"
@@ -58,25 +72,33 @@ def while_loop(cond, body, loop_vars, name=None):
cond_name = _get_unique_name(("%scond" % scope).replace("/", "_"))
body_name = _get_unique_name(("%sbody" % scope).replace("/", "_"))
- flattened_loop_vars = nest.flatten(loop_vars)
num_outputs = len(flattened_loop_vars)
# Add loop counter needed for computing gradients.
flattened_loop_vars = [constant_op.constant(0., name="loop_counter")
] + flattened_loop_vars
+ flattened_shapes = [tensor_shape.scalar()] + flattened_shapes
+
# Build a `cond` wrapper that can handle the extra counter loop_var.
def wrapped_cond(unused_loop_counter, *loop_vars):
return cond(*loop_vars)
- cond_graph = function.func_graph_from_py_func(cond_name, wrapped_cond,
- flattened_loop_vars, {})
+ signature = [
+ tensor_spec.TensorSpec(shape, t.dtype)
+ for shape, t in zip(flattened_shapes, flattened_loop_vars)
+ ]
+ cond_graph = function.func_graph_from_py_func(
+ cond_name, wrapped_cond, flattened_loop_vars, {}, signature=signature)
# Add external_captures of cond to the list of loop vars.
# Note that external tensors will be treated as loop invariants, i.e.,
# the value of that tensor in each iteration is the same as it was at the
# beginning of the loop execution.
flattened_loop_vars = flattened_loop_vars + cond_graph.external_captures
+ flattened_shapes = flattened_shapes + [
+ t.shape for t in cond_graph.external_captures
+ ]
def wrapped_body(loop_counter, *args):
"""Loop body augmented with counter update.
@@ -101,8 +123,12 @@ def while_loop(cond, body, loop_vars, name=None):
# is_constant=True for inputs that are directly passed to outputs.
return [loop_counter + 1] + list(outputs) + list(args[num_outputs:])
- body_graph = function.func_graph_from_py_func(body_name, wrapped_body,
- flattened_loop_vars, {})
+ signature = [
+ tensor_spec.TensorSpec(shape, t.dtype)
+ for shape, t in zip(flattened_shapes, flattened_loop_vars)
+ ]
+ body_graph = function.func_graph_from_py_func(
+ body_name, wrapped_body, flattened_loop_vars, {}, signature=signature)
# Add external captures of body to the list of loop vars.
# Note that external tensors will be treated as loop invariants, i.e.,
# the value of that tensor in each iteration is the same as it was at the
@@ -145,10 +171,17 @@ def while_loop(cond, body, loop_vars, name=None):
# Add this modified tensor list to the list of outputs.
body_graph.outputs.append(appended_tensor_list)
+ # Make sure that the shapes of the loop outputs are compatible with the
+ # shape invariants, or the shapes of the loop vars if the invariants are not
+ # specified.
+ _check_shapes_compat(body_graph.outputs[1:1 + num_outputs],
+ flattened_shapes[1:1 + num_outputs],
+ flattened_loop_vars[1:1 + num_outputs])
outputs = gen_functional_ops._while(
flattened_loop_vars,
cond_v2._create_new_tf_function(cond_graph),
cond_v2._create_new_tf_function(body_graph),
+ output_shapes=[t.shape for t in body_graph.outputs],
name=scope)
_copy_handle_data(body_graph.outputs, outputs)
@@ -212,6 +245,7 @@ def _WhileGrad(op, *grads): # pylint: disable=invalid-name
loop_vars,
cond_v2._create_new_tf_function(cond_grad_graph),
cond_v2._create_new_tf_function(body_grad_graph),
+ output_shapes=[t.shape for t in body_grad_graph.outputs],
name=_get_unique_name("%s_grad" % op.name))
_copy_handle_data(body_grad_graph.outputs, outputs)
@@ -232,8 +266,10 @@ def _get_body_graph(while_op):
Returns:
`FuncGraph` for the while body.
"""
- extra_inputs = list(while_op.inputs)
- input_shapes = [t.shape for t in extra_inputs]
+ # TODO(srbs): Handle TensorShapeProto in function_def_to_graph.input_shapes.
+ input_shapes = [
+ tensor_shape.TensorShape(s) for s in while_op.get_attr("output_shapes")
+ ]
func_name = while_op.get_attr("body").name
fdef = while_op.graph._get_function(func_name).definition
func_graph = function_def_to_graph.function_def_to_graph(fdef, input_shapes)
@@ -531,6 +567,17 @@ class _WhileBodyGradFuncGraph(function.FuncGraph):
return captured_tensor
+def _check_shapes_compat(output_tensors, shape_invariants, input_tensors):
+ for (t, shape, input_t) in zip(output_tensors, shape_invariants,
+ input_tensors):
+ if not control_flow_ops._ShapeLessThanOrEqual(t.shape, shape):
+ raise ValueError(
+ "Input tensor '%s' enters the loop with shape %s, but has "
+ "shape %s after one iteration. To allow the shape to vary across "
+ "iterations, use the `shape_invariants` argument of tf.while_loop to "
+ "specify a less-specific shape." % (input_t.name, shape, t.shape))
+
+
def _copy_handle_data(src_tensors, tgt_tensors):
for src_t, tgt_t in zip(src_tensors, tgt_tensors):
function._copy_handle_data(src_t, tgt_t)
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index c411a58b70..61e0abbfcb 100755
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -67,6 +67,7 @@ limitations under the License.
%rename("%s") TFE_ContextStartStep;
%rename("%s") TFE_ContextEndStep;
%rename("%s") TFE_Py_RegisterVSpace;
+%rename("%s") TFE_Py_EncodeArg;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"
diff --git a/tensorflow/python/saved_model/builder_impl.py b/tensorflow/python/saved_model/builder_impl.py
index 8e7f123a85..8bf057f69d 100644
--- a/tensorflow/python/saved_model/builder_impl.py
+++ b/tensorflow/python/saved_model/builder_impl.py
@@ -36,10 +36,13 @@ from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat
from tensorflow.python.util.deprecation import deprecated_args
+from tensorflow.python.util.deprecation import deprecated_endpoints
from tensorflow.python.util.tf_export import tf_export
-@tf_export("saved_model.builder.SavedModelBuilder")
+@tf_export("saved_model.Builder",
+ "saved_model.builder.SavedModelBuilder")
+@deprecated_endpoints("saved_model.builder.SavedModelBuilder")
class SavedModelBuilder(object):
"""Builds the `SavedModel` protocol buffer and saves variables and assets.
@@ -61,7 +64,7 @@ class SavedModelBuilder(object):
Typical usage for the `SavedModelBuilder`:
```python
...
- builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
+ builder = tf.saved_model.Builder(export_dir)
with tf.Session(graph=tf.Graph()) as sess:
...
diff --git a/tensorflow/python/saved_model/loader_impl.py b/tensorflow/python/saved_model/loader_impl.py
index e8536108e8..895644a030 100644
--- a/tensorflow/python/saved_model/loader_impl.py
+++ b/tensorflow/python/saved_model/loader_impl.py
@@ -34,6 +34,7 @@ from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import utils_impl as saved_model_utils
from tensorflow.python.training import saver as tf_saver
from tensorflow.python.util import compat
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -144,7 +145,10 @@ def _get_main_op_tensor(
return main_op_tensor
-@tf_export("saved_model.loader.maybe_saved_model_directory")
+@tf_export("saved_model.maybe_saved_model_directory",
+ "saved_model.loader.maybe_saved_model_directory")
+@deprecation.deprecated_endpoints(
+ "saved_model.loader.maybe_saved_model_directory")
def maybe_saved_model_directory(export_dir):
"""Checks whether the provided export directory could contain a SavedModel.
@@ -165,7 +169,7 @@ def maybe_saved_model_directory(export_dir):
return file_io.file_exists(txt_path) or file_io.file_exists(pb_path)
-@tf_export("saved_model.loader.load")
+@tf_export("saved_model.load", "saved_model.loader.load")
def load(sess, tags, export_dir, import_scope=None, **saver_kwargs):
"""Loads the model from a SavedModel as specified by tags.
diff --git a/tensorflow/python/saved_model/main_op_impl.py b/tensorflow/python/saved_model/main_op_impl.py
index 631ee63729..ad4511b28e 100644
--- a/tensorflow/python/saved_model/main_op_impl.py
+++ b/tensorflow/python/saved_model/main_op_impl.py
@@ -22,6 +22,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import lookup_ops
from tensorflow.python.ops import variables
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@@ -42,7 +43,9 @@ def main_op():
# TODO(sukritiramesh): Integrate with Saver for complete restore functionality.
-@tf_export('saved_model.main_op.main_op_with_restore')
+@tf_export('saved_model.main_op_with_restore',
+ 'saved_model.main_op.main_op_with_restore')
+@deprecation.deprecated_endpoints('saved_model.main_op.main_op_with_restore')
def main_op_with_restore(restore_op_name):
"""Returns a main op to init variables, tables and restore the graph.
diff --git a/tensorflow/python/saved_model/signature_def_utils_impl.py b/tensorflow/python/saved_model/signature_def_utils_impl.py
index 37f927f381..a1034416e9 100644
--- a/tensorflow/python/saved_model/signature_def_utils_impl.py
+++ b/tensorflow/python/saved_model/signature_def_utils_impl.py
@@ -24,10 +24,14 @@ from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.framework import ops
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import utils
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
-@tf_export('saved_model.signature_def_utils.build_signature_def')
+@tf_export('saved_model.build_signature_def',
+ 'saved_model.signature_def_utils.build_signature_def')
+@deprecation.deprecated_endpoints(
+ 'saved_model.signature_def_utils.build_signature_def')
def build_signature_def(inputs=None, outputs=None, method_name=None):
"""Utility function to build a SignatureDef protocol buffer.
@@ -53,7 +57,10 @@ def build_signature_def(inputs=None, outputs=None, method_name=None):
return signature_def
-@tf_export('saved_model.signature_def_utils.regression_signature_def')
+@tf_export('saved_model.regression_signature_def',
+ 'saved_model.signature_def_utils.regression_signature_def')
+@deprecation.deprecated_endpoints(
+ 'saved_model.signature_def_utils.regression_signature_def')
def regression_signature_def(examples, predictions):
"""Creates regression signature from given examples and predictions.
@@ -95,7 +102,10 @@ def regression_signature_def(examples, predictions):
return signature_def
-@tf_export('saved_model.signature_def_utils.classification_signature_def')
+@tf_export('saved_model.classification_signature_def',
+ 'saved_model.signature_def_utils.classification_signature_def')
+@deprecation.deprecated_endpoints(
+ 'saved_model.signature_def_utils.classification_signature_def')
def classification_signature_def(examples, classes, scores):
"""Creates classification signature from given examples and predictions.
@@ -148,7 +158,10 @@ def classification_signature_def(examples, classes, scores):
return signature_def
-@tf_export('saved_model.signature_def_utils.predict_signature_def')
+@tf_export('saved_model.predict_signature_def',
+ 'saved_model.signature_def_utils.predict_signature_def')
+@deprecation.deprecated_endpoints(
+ 'saved_model.signature_def_utils.predict_signature_def')
def predict_signature_def(inputs, outputs):
"""Creates prediction signature from given inputs and outputs.
@@ -239,7 +252,10 @@ def _supervised_signature_def(
return signature_def
-@tf_export('saved_model.signature_def_utils.is_valid_signature')
+@tf_export('saved_model.is_valid_signature',
+ 'saved_model.signature_def_utils.is_valid_signature')
+@deprecation.deprecated_endpoints(
+ 'saved_model.signature_def_utils.is_valid_signature')
def is_valid_signature(signature_def):
"""Determine whether a SignatureDef can be served by TensorFlow Serving."""
if signature_def is None:
@@ -313,4 +329,3 @@ def _is_valid_classification_signature(signature_def):
return False
return True
-
diff --git a/tensorflow/python/saved_model/utils_impl.py b/tensorflow/python/saved_model/utils_impl.py
index 06d09325c8..0bba7b6fac 100644
--- a/tensorflow/python/saved_model/utils_impl.py
+++ b/tensorflow/python/saved_model/utils_impl.py
@@ -27,13 +27,16 @@ from tensorflow.python.framework import sparse_tensor
from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import constants
from tensorflow.python.util import compat
+from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
# TensorInfo helpers.
-@tf_export("saved_model.utils.build_tensor_info")
+@tf_export("saved_model.build_tensor_info",
+ "saved_model.utils.build_tensor_info")
+@deprecation.deprecated_endpoints("saved_model.utils.build_tensor_info")
def build_tensor_info(tensor):
"""Utility function to build TensorInfo proto.
@@ -57,7 +60,10 @@ def build_tensor_info(tensor):
return tensor_info
-@tf_export("saved_model.utils.get_tensor_from_tensor_info")
+@tf_export("saved_model.get_tensor_from_tensor_info",
+ "saved_model.utils.get_tensor_from_tensor_info")
+@deprecation.deprecated_endpoints(
+ "saved_model.utils.get_tensor_from_tensor_info")
def get_tensor_from_tensor_info(tensor_info, graph=None, import_scope=None):
"""Returns the Tensor or SparseTensor described by a TensorInfo proto.
diff --git a/tensorflow/python/tools/api/generator/api_init_files.bzl b/tensorflow/python/tools/api/generator/api_init_files.bzl
index 92446e2f8f..533a138a39 100644
--- a/tensorflow/python/tools/api/generator/api_init_files.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files.bzl
@@ -8,6 +8,7 @@ TENSORFLOW_API_INIT_FILES = [
"bitwise/__init__.py",
"compat/__init__.py",
"data/__init__.py",
+ "data/experimental/__init__.py",
"debugging/__init__.py",
"distributions/__init__.py",
"dtypes/__init__.py",
@@ -69,6 +70,7 @@ TENSORFLOW_API_INIT_FILES = [
"profiler/__init__.py",
"python_io/__init__.py",
"quantization/__init__.py",
+ "random/__init__.py",
"resource_loader/__init__.py",
"strings/__init__.py",
"saved_model/__init__.py",
diff --git a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
index bc2f3516d1..0747424eab 100644
--- a/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
+++ b/tensorflow/python/tools/api/generator/api_init_files_v1.bzl
@@ -8,6 +8,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"bitwise/__init__.py",
"compat/__init__.py",
"data/__init__.py",
+ "data/experimental/__init__.py",
"debugging/__init__.py",
"distributions/__init__.py",
"dtypes/__init__.py",
@@ -69,6 +70,7 @@ TENSORFLOW_API_INIT_FILES_V1 = [
"profiler/__init__.py",
"python_io/__init__.py",
"quantization/__init__.py",
+ "random/__init__.py",
"resource_loader/__init__.py",
"strings/__init__.py",
"saved_model/__init__.py",
diff --git a/tensorflow/python/tools/saved_model_cli.py b/tensorflow/python/tools/saved_model_cli.py
index 3dbccd1409..2fcb0fa029 100644
--- a/tensorflow/python/tools/saved_model_cli.py
+++ b/tensorflow/python/tools/saved_model_cli.py
@@ -267,7 +267,8 @@ def scan_meta_graph_def(meta_graph_def):
def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
input_tensor_key_feed_dict, outdir,
- overwrite_flag, worker=None, tf_debug=False):
+ overwrite_flag, worker=None, init_tpu=False,
+ tf_debug=False):
"""Runs SavedModel and fetch all outputs.
Runs the input dictionary through the MetaGraphDef within a SavedModel
@@ -287,6 +288,8 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
the same name exists.
worker: If provided, the session will be run on the worker. Valid worker
specification is a bns or gRPC path.
+ init_tpu: If true, the TPU system will be initialized after the session
+ is created.
tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
intermediate Tensor values and runtime GraphDefs while running the
SavedModel.
@@ -328,6 +331,12 @@ def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
]
with session.Session(worker, graph=ops_lib.Graph()) as sess:
+ if init_tpu:
+ print('Initializing TPU System ...')
+ # This is needed for freshly started worker, or if the job
+ # restarts after a preemption.
+ sess.run(tf.contrib.tpu.initialize_system())
+
loader.load(sess, tag_set.split(','), saved_model_dir)
if tf_debug:
@@ -632,7 +641,7 @@ def run(args):
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
tensor_key_feed_dict, args.outdir,
args.overwrite, worker=args.worker,
- tf_debug=args.tf_debug)
+ init_tpu=args.init_tpu, tf_debug=args.tf_debug)
def scan(args):
@@ -775,6 +784,12 @@ def create_parser():
default=None,
help='if specified, a Session will be run on the worker. '
'Valid worker specification is a bns or gRPC path.')
+ parser_run.add_argument(
+ '--init_tpu',
+ action='store_true',
+ default=None,
+ help='if specified, tpu.initialize_system will be called on the Session. '
+ 'This option should be only used if the worker is a TPU job.')
parser_run.set_defaults(func=run)
# scan command
diff --git a/tensorflow/python/training/distribute.py b/tensorflow/python/training/distribute.py
index 419a9ec12b..a92a1bdee7 100644
--- a/tensorflow/python/training/distribute.py
+++ b/tensorflow/python/training/distribute.py
@@ -26,7 +26,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import resource_variable_ops
-from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.losses import losses_impl
from tensorflow.python.platform import tf_logging
@@ -807,15 +806,22 @@ class DistributionStrategy(object):
var: Variable, possibly mirrored to multiple devices, to operate on.
fn: Function to call. Should take the variable as the first argument.
*args: Additional positional arguments to pass to `fn()`.
- **kwargs: Keyword arguments to pass to `fn()`.
+ **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is
+ specified, the return value will be unwrapped.
Returns:
- Merged return value of `fn` across all towers.
+ By default, the merged return value of `fn` across all towers. The merged
+ result has dependencies to make sure that if it is evaluated at all, the
+ side effects (updates) will happen on every tower. If instead
+ "grouped=False" is specified, this function will return a nest of lists
+ where each list has an element per tower, and the caller is responsible
+ for ensuring all elements are executed.
"""
_require_cross_tower_context(self)
- return self._update(var, fn, *args, **kwargs)
+ options = {"grouped": kwargs.pop("grouped", True)}
+ return self._update(var, options, fn, *args, **kwargs)
- def _update(self, var, fn, *args, **kwargs):
+ def _update(self, var, options, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
def update_non_slot(self, colocate_with, fn, *args, **kwargs):
@@ -825,15 +831,18 @@ class DistributionStrategy(object):
colocate_with: The return value of `non_slot_devices()`.
fn: Function to execute.
*args: Positional arguments to pass to `fn()`.
- **kwargs: Keyword arguments to pass to `fn()`.
+ **kwargs: Keyword arguments to pass to `fn()`. If "grouped=False" is
+ specified, the return value will be unwrapped and the caller is
+ responsible for ensuring all elements are executed.
Returns:
Return value of `fn`, possibly merged across devices.
"""
_require_cross_tower_context(self)
- return self._update_non_slot(colocate_with, fn, *args, **kwargs)
+ options = {"grouped": kwargs.pop("grouped", True)}
+ return self._update_non_slot(colocate_with, options, fn, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
raise NotImplementedError("must be implemented in descendants")
def unwrap(self, value):
@@ -1134,17 +1143,22 @@ class _DefaultDistributionStrategy(DistributionStrategy):
del aggregation, destinations
return value
- def _update(self, var, fn, *args, **kwargs):
- # TODO(josh11b): Figure out what we should be passing to UpdateContext()
- # once that value is used for something.
- with ops.colocate_with(var), UpdateContext(var):
- return fn(var, *args, **kwargs)
+ def _update(self, var, options, fn, *args, **kwargs):
+ # The implementations of _update() and _update_non_slot() are identical
+ # except _update() passes `var` as the first argument to `fn()`.
+ return self._update_non_slot(var, options, fn, var, *args, **kwargs)
- def _update_non_slot(self, colocate_with, fn, *args, **kwargs):
+ def _update_non_slot(self, colocate_with, options, fn, *args, **kwargs):
+ should_group = options.pop("grouped")
+ assert not options # Validate that we are processing all of the options.
# TODO(josh11b): Figure out what we should be passing to UpdateContext()
# once that value is used for something.
with ops.colocate_with(colocate_with), UpdateContext(colocate_with):
- return fn(*args, **kwargs)
+ result = fn(*args, **kwargs)
+ if should_group:
+ return result
+ else:
+ return nest.map_structure(self._unwrap, result)
def read_var(self, tower_local_var):
return array_ops.identity(tower_local_var)
@@ -1193,13 +1207,10 @@ class _DefaultDistributionStrategy(DistributionStrategy):
def increment_var(v, amount=1):
"""`v += amount`, distributed-aware version."""
def update(vu):
- if isinstance(vu, resource_variable_ops.ResourceVariable):
- return vu.assign_add(amount, read_value=False)
- else:
- return state_ops.assign_add(vu, amount)
+ return vu.assign_add(amount, read_value=False)
def merge_fn(dist, vm):
- return dist.group(dist.update(vm, update))
+ return dist.update(vm, update)
tower_context = distribution_strategy_context.get_tower_context()
return tower_context.merge_call(merge_fn, v)
diff --git a/tensorflow/python/training/distribution_strategy_context.py b/tensorflow/python/training/distribution_strategy_context.py
index 998b5c35ce..ce580a406f 100644
--- a/tensorflow/python/training/distribution_strategy_context.py
+++ b/tensorflow/python/training/distribution_strategy_context.py
@@ -89,6 +89,7 @@ def get_tower_context():
"""Returns the current TowerContext or None if in a cross-tower context.
Note that execution:
+
1. starts in the default (single-tower) tower context (this function
will return the default TowerContext object);
2. switches to cross-tower context (in which case this will return
@@ -121,6 +122,7 @@ def get_cross_tower_context():
"""Returns the current DistributionStrategy if in a cross-tower context.
Note that execution:
+
1. starts in the default (single-tower) tower context;
2. switches to cross-tower context when entering a
`with DistributionStrategy.scope():` block;
diff --git a/tensorflow/python/training/input.py b/tensorflow/python/training/input.py
index 9d9db70890..eb131ac9f7 100644
--- a/tensorflow/python/training/input.py
+++ b/tensorflow/python/training/input.py
@@ -56,7 +56,8 @@ _restore_sparse = sparse_ops._take_many_sparse_from_tensors_map
# pylint: enable=protected-access
-@tf_export("train.match_filenames_once")
+@tf_export("io.match_filenames_once", "train.match_filenames_once")
+@deprecation.deprecated_endpoints("train.match_filenames_once")
def match_filenames_once(pattern, name=None):
"""Save the list of files matching pattern, so it is only computed once.
diff --git a/tensorflow/python/training/moving_averages.py b/tensorflow/python/training/moving_averages.py
index 177a7ddfa5..041266da3e 100644
--- a/tensorflow/python/training/moving_averages.py
+++ b/tensorflow/python/training/moving_averages.py
@@ -372,13 +372,13 @@ class ExponentialMovingAverage(object):
Args:
var_list: A list of Variable or Tensor objects. The variables
- and Tensors must be of types float16, float32, or float64.
+ and Tensors must be of types bfloat16, float16, float32, or float64.
Returns:
An Operation that updates the moving averages.
Raises:
- TypeError: If the arguments are not all float16, float32, or float64.
+ TypeError: If the arguments are not an allowed type.
ValueError: If the moving average of one of the variables is already
being computed.
"""
@@ -387,8 +387,9 @@ class ExponentialMovingAverage(object):
var_list = variables.trainable_variables()
zero_debias_true = set() # set of vars to set `zero_debias=True`
for var in var_list:
- if var.dtype.base_dtype not in [dtypes.float16, dtypes.float32,
- dtypes.float64]:
+ if var.dtype.base_dtype not in [
+ dtypes.bfloat16, dtypes.float16, dtypes.float32, dtypes.float64
+ ]:
raise TypeError("The variables must be half, float, or double: %s" %
var.name)
diff --git a/tensorflow/python/training/moving_averages_test.py b/tensorflow/python/training/moving_averages_test.py
index 93991d0e14..bb2fca66e3 100644
--- a/tensorflow/python/training/moving_averages_test.py
+++ b/tensorflow/python/training/moving_averages_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
@@ -110,6 +111,32 @@ class MovingAveragesTest(test.TestCase):
denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay)
self.assertAllClose(numerator_2 / denominator_2, wma_array)
+ def testWeightedMovingAverageBfloat16(self):
+ bfloat16 = pywrap_tensorflow.TF_bfloat16_type()
+ with self.cached_session() as sess:
+ decay = 0.5
+ weight = array_ops.placeholder(dtypes.bfloat16, [])
+ val = array_ops.placeholder(dtypes.bfloat16, [])
+
+ wma = moving_averages.weighted_moving_average(val, decay, weight)
+ variables.global_variables_initializer().run()
+
+ # Get the first weighted moving average.
+ val_1 = 3.0
+ weight_1 = 4.0
+ wma_array = sess.run(wma, feed_dict={val: val_1, weight: weight_1})
+ numerator_1 = val_1 * weight_1 * (1.0 - decay)
+ denominator_1 = weight_1 * (1.0 - decay)
+ self.assertAllClose(numerator_1 / denominator_1, wma_array)
+
+ # Get the second weighted moving average.
+ val_2 = 11.0
+ weight_2 = 22.0
+ wma_array = sess.run(wma, feed_dict={val: val_2, weight: weight_2})
+ numerator_2 = numerator_1 * decay + val_2 * weight_2 * (1.0 - decay)
+ denominator_2 = denominator_1 * decay + weight_2 * (1.0 - decay)
+ self.assertAllClose(bfloat16(numerator_2 / denominator_2), wma_array)
+
def _Repeat(value, dim):
if dim == 1:
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index f004f3944a..47034919e1 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -471,7 +471,10 @@ class Optimizer(
if var_list is None:
var_list = tape.watched_variables()
- grads = tape.gradient(loss_value, var_list, grad_loss)
+ # TODO(jhseu): Figure out why GradientTape's gradients don't require loss
+ # to be executed.
+ with ops.control_dependencies([loss_value]):
+ grads = tape.gradient(loss_value, var_list, grad_loss)
return list(zip(grads, var_list))
# Non-callable/Tensor loss case
@@ -689,7 +692,7 @@ class Optimizer(
update_ops = [
op
for grad, var in grads_and_vars
- for op in distribution.unwrap(distribution.update(var, update, grad))
+ for op in distribution.update(var, update, grad, grouped=False)
]
def finish(self, update_ops):
@@ -697,13 +700,13 @@ class Optimizer(
non_slot_devices = distribution.non_slot_devices(var_list)
finish_updates = distribution.update_non_slot(
- non_slot_devices, finish, self, update_ops)
+ non_slot_devices, finish, self, update_ops, grouped=False)
if global_step is None:
apply_updates = distribution.group(finish_updates, name=name)
else:
- with ops.control_dependencies(distribution.unwrap(finish_updates)):
- apply_updates = distribution.group(distribution.update(
- global_step, state_ops.assign_add, 1, name=name))
+ with ops.control_dependencies(finish_updates):
+ apply_updates = distribution.update(
+ global_step, state_ops.assign_add, 1, name=name)
if not context.executing_eagerly():
if isinstance(apply_updates, ops.Tensor):
diff --git a/tensorflow/python/training/session_manager.py b/tensorflow/python/training/session_manager.py
index a2e0645ba8..cd313c2ce0 100644
--- a/tensorflow/python/training/session_manager.py
+++ b/tensorflow/python/training/session_manager.py
@@ -25,6 +25,7 @@ from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpoint_management
+from tensorflow.python.training import distribution_strategy_context
from tensorflow.python.util.tf_export import tf_export
@@ -182,6 +183,12 @@ class SessionManager(object):
"""
self._target = master
sess = session.Session(self._target, graph=self._graph, config=config)
+ # TODO(jhseu): Delete once tpu.initialize_system() goes away.
+ initialize_ops = (
+ distribution_strategy_context.get_distribution_strategy().initialize()
+ )
+ if initialize_ops:
+ sess.run(initialize_ops)
if checkpoint_dir and checkpoint_filename_with_path:
raise ValueError("Can not provide both checkpoint_dir and "
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 758cba7487..d67dbde304 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -819,5 +819,5 @@ def flatten_with_joined_string_paths(structure, separator="/"):
return list(zip(flat_string_paths, flatten(structure)))
-_pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence)
-_pywrap_tensorflow.RegisterMappingClass(_collections.Mapping)
+_pywrap_tensorflow.RegisterType("Mapping", _collections.Mapping)
+_pywrap_tensorflow.RegisterType("Sequence", _collections.Sequence)
diff --git a/tensorflow/python/util/tf_inspect.py b/tensorflow/python/util/tf_inspect.py
index 967c872c2a..444e44eaf1 100644
--- a/tensorflow/python/util/tf_inspect.py
+++ b/tensorflow/python/util/tf_inspect.py
@@ -36,6 +36,55 @@ else:
'annotations'
])
+if hasattr(_inspect, 'getfullargspec'):
+ _getfullargspec = _inspect.getfullargspec # pylint: disable=invalid-name
+
+ def _getargspec(target):
+ """A python3 version of getargspec.
+
+ Calls `getfullargspec` and assigns args, varargs,
+ varkw, and defaults to a python 2/3 compatible `ArgSpec`.
+
+ The parameter name 'varkw' is changed to 'keywords' to fit the
+ `ArgSpec` struct.
+
+ Args:
+ target: the target object to inspect.
+
+ Returns:
+ An ArgSpec with args, varargs, keywords, and defaults parameters
+ from FullArgSpec.
+ """
+ fullargspecs = getfullargspec(target)
+ argspecs = ArgSpec(
+ args=fullargspecs.args,
+ varargs=fullargspecs.varargs,
+ keywords=fullargspecs.varkw,
+ defaults=fullargspecs.defaults)
+ return argspecs
+else:
+ _getargspec = _inspect.getargspec
+
+ def _getfullargspec(target):
+ """A python2 version of getfullargspec.
+
+ Args:
+ target: the target object to inspect.
+
+ Returns:
+ A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations.
+ """
+ argspecs = getargspec(target)
+ fullargspecs = FullArgSpec(
+ args=argspecs.args,
+ varargs=argspecs.varargs,
+ varkw=argspecs.keywords,
+ defaults=argspecs.defaults,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+ return fullargspecs
+
def currentframe():
"""TFDecorator-aware replacement for inspect.currentframe."""
@@ -43,16 +92,18 @@ def currentframe():
def getargspec(obj):
- """TFDecorator-aware replacement for inspect.getargspec.
+ """TFDecorator-aware replacement for `inspect.getargspec`.
+
+ Note: `getfullargspec` is recommended as the python 2/3 compatible
+ replacement for this function.
Args:
- obj: A function, partial function, or callable object, possibly
- decorated.
+ obj: A function, partial function, or callable object, possibly decorated.
Returns:
The `ArgSpec` that describes the signature of the outermost decorator that
- changes the callable's signature. If the callable is not decorated,
- `inspect.getargspec()` will be called directly on the object.
+ changes the callable's signature, or the `ArgSpec` that describes
+ the object if not decorated.
Raises:
ValueError: When callable's signature can not be expressed with
@@ -72,24 +123,24 @@ def getargspec(obj):
try:
# Python3 will handle most callables here (not partial).
- return _inspect.getargspec(target)
+ return _getargspec(target)
except TypeError:
pass
if isinstance(target, type):
try:
- return _inspect.getargspec(target.__init__)
+ return _getargspec(target.__init__)
except TypeError:
pass
try:
- return _inspect.getargspec(target.__new__)
+ return _getargspec(target.__new__)
except TypeError:
pass
# The `type(target)` ensures that if a class is received we don't return
# the signature of it's __call__ method.
- return _inspect.getargspec(type(target).__call__)
+ return _getargspec(type(target).__call__)
def _get_argspec_for_partial(obj):
@@ -172,30 +223,6 @@ def _get_argspec_for_partial(obj):
return ArgSpec(args, varargs, keywords, tuple(all_defaults[first_default:]))
-if hasattr(_inspect, 'getfullargspec'):
- _getfullargspec = _inspect.getfullargspec
-else:
-
- def _getfullargspec(target):
- """A python2 version of getfullargspec.
-
- Args:
- target: the target object to inspect.
- Returns:
- A FullArgSpec with empty kwonlyargs, kwonlydefaults and annotations.
- """
- argspecs = getargspec(target)
- fullargspecs = FullArgSpec(
- args=argspecs.args,
- varargs=argspecs.varargs,
- varkw=argspecs.keywords,
- defaults=argspecs.defaults,
- kwonlyargs=[],
- kwonlydefaults=None,
- annotations={})
- return fullargspecs
-
-
def getfullargspec(obj):
"""TFDecorator-aware replacement for `inspect.getfullargspec`.
diff --git a/tensorflow/python/util/tf_inspect_test.py b/tensorflow/python/util/tf_inspect_test.py
index d3b7e4b969..02d075cdff 100644
--- a/tensorflow/python/util/tf_inspect_test.py
+++ b/tensorflow/python/util/tf_inspect_test.py
@@ -122,18 +122,6 @@ class TfInspectTest(test.TestCase):
self.assertEqual(argspec, tf_inspect.getargspec(partial_func))
- def testGetFullArgsSpecForPartial(self):
-
- def func(a, b):
- del a, b
-
- partial_function = functools.partial(func, 1)
- argspec = tf_inspect.FullArgSpec(
- args=['b'], varargs=None, varkw=None, defaults=None,
- kwonlyargs=[], kwonlydefaults=None, annotations={})
-
- self.assertEqual(argspec, tf_inspect.getfullargspec(partial_function))
-
def testGetArgSpecOnPartialInvalidArgspec(self):
"""Tests getargspec on partial function that doesn't have valid argspec."""
@@ -303,6 +291,193 @@ class TfInspectTest(test.TestCase):
self.assertEqual(argspec, tf_inspect.getargspec(NewClass))
+ def testGetFullArgSpecOnDecoratorsThatDontProvideFullArgSpec(self):
+ argspec = tf_inspect.getfullargspec(test_decorated_function_with_defaults)
+ self.assertEqual(['a', 'b', 'c'], argspec.args)
+ self.assertEqual((2, 'Hello'), argspec.defaults)
+
+ def testGetFullArgSpecOnDecoratorThatChangesFullArgSpec(self):
+ argspec = tf_inspect.FullArgSpec(
+ args=['a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ decorator = tf_decorator.TFDecorator('', test_undecorated_function, '',
+ argspec)
+ self.assertEqual(argspec, tf_inspect.getfullargspec(decorator))
+
+ def testGetFullArgSpecIgnoresDecoratorsThatDontProvideFullArgSpec(self):
+ argspec = tf_inspect.FullArgSpec(
+ args=['a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function,
+ '', argspec)
+ outer_decorator = tf_decorator.TFDecorator('', inner_decorator)
+ self.assertEqual(argspec, tf_inspect.getfullargspec(outer_decorator))
+
+ def testGetFullArgSpecReturnsOutermostDecoratorThatChangesFullArgSpec(self):
+ outer_argspec = tf_inspect.FullArgSpec(
+ args=['a'],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+ inner_argspec = tf_inspect.FullArgSpec(
+ args=['b'],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ inner_decorator = tf_decorator.TFDecorator('', test_undecorated_function,
+ '', inner_argspec)
+ outer_decorator = tf_decorator.TFDecorator('', inner_decorator, '',
+ outer_argspec)
+ self.assertEqual(outer_argspec, tf_inspect.getfullargspec(outer_decorator))
+
+ def testGetFullArgsSpecForPartial(self):
+
+ def func(a, b):
+ del a, b
+
+ partial_function = functools.partial(func, 1)
+ argspec = tf_inspect.FullArgSpec(
+ args=['b'],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(partial_function))
+
+ def testGetFullArgSpecOnPartialNoArgumentsLeft(self):
+ """Tests getfullargspec on partial function that prunes all arguments."""
+
+ def func(m, n):
+ return 2 * m + n
+
+ partial_func = functools.partial(func, 7, 10)
+ argspec = tf_inspect.FullArgSpec(
+ args=[],
+ varargs=None,
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
+
+ def testGetFullArgSpecOnPartialWithVarargs(self):
+ """Tests getfullargspec on partial function with variable arguments."""
+
+ def func(m, *arg):
+ return m + len(arg)
+
+ partial_func = functools.partial(func, 7, 8)
+ argspec = tf_inspect.FullArgSpec(
+ args=[],
+ varargs='arg',
+ varkw=None,
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
+
+ def testGetFullArgSpecOnPartialWithVarkwargs(self):
+ """Tests getfullargspec.
+
+ Tests on partial function with variable keyword arguments.
+ """
+
+ def func(m, n, **kwarg):
+ return m * n + len(kwarg)
+
+ partial_func = functools.partial(func, 7)
+ argspec = tf_inspect.FullArgSpec(
+ args=['n'],
+ varargs=None,
+ varkw='kwarg',
+ defaults=None,
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(partial_func))
+
+ def testGetFullArgSpecOnCallableObject(self):
+
+ class Callable(object):
+
+ def __call__(self, a, b=1, c='hello'):
+ pass
+
+ argspec = tf_inspect.FullArgSpec(
+ args=['self', 'a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ test_obj = Callable()
+ self.assertEqual(argspec, tf_inspect.getfullargspec(test_obj))
+
+ def testGetFullArgSpecOnInitClass(self):
+
+ class InitClass(object):
+
+ def __init__(self, a, b=1, c='hello'):
+ pass
+
+ argspec = tf_inspect.FullArgSpec(
+ args=['self', 'a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(InitClass))
+
+ def testGetFullArgSpecOnNewClass(self):
+
+ class NewClass(object):
+
+ def __new__(cls, a, b=1, c='hello'):
+ pass
+
+ argspec = tf_inspect.FullArgSpec(
+ args=['cls', 'a', 'b', 'c'],
+ varargs=None,
+ varkw=None,
+ defaults=(1, 'hello'),
+ kwonlyargs=[],
+ kwonlydefaults=None,
+ annotations={})
+
+ self.assertEqual(argspec, tf_inspect.getfullargspec(NewClass))
+
def testGetDoc(self):
self.assertEqual('Test Decorated Function With Defaults Docstring.',
tf_inspect.getdoc(test_decorated_function_with_defaults))
diff --git a/tensorflow/python/util/util.cc b/tensorflow/python/util/util.cc
index 38b8491c66..7b3e618e84 100644
--- a/tensorflow/python/util/util.cc
+++ b/tensorflow/python/util/util.cc
@@ -29,14 +29,51 @@ limitations under the License.
namespace tensorflow {
namespace swig {
-namespace {
+std::unordered_map<string, PyObject*>* PythonTypesMap() {
+ static auto* m = new std::unordered_map<string, PyObject*>();
+ return m;
+}
+
+PyObject* GetRegisteredType(const string& key) {
+ auto* m = PythonTypesMap();
+ auto it = m->find(key);
+ if (it == m->end()) return nullptr;
+ return it->second;
+}
+
+PyObject* RegisterType(PyObject* type_name, PyObject* type) {
+ if (!PyType_Check(type)) {
+ PyErr_SetString(PyExc_TypeError,
+ tensorflow::strings::StrCat("Expecting a type, got ",
+ Py_TYPE(type)->tp_name)
+ .c_str());
+ return nullptr;
+ }
-// Type object for collections.Sequence. This is set by RegisterSequenceClass.
-PyObject* CollectionsSequenceType = nullptr;
-// Type object for collections.Mapping, set by RegisterMappingClass.
-PyObject* CollectionsMappingType = nullptr;
-PyTypeObject* SparseTensorValueType = nullptr;
+ string key;
+ if (PyBytes_Check(type_name)) {
+ key = PyBytes_AsString(type_name);
+ }
+#if PY_MAJOR_VERSION >= 3
+ if (PyUnicode_Check(type_name)) {
+ key = PyUnicode_AsUTF8(type_name);
+ }
+#endif
+ if (PythonTypesMap()->find(key) != PythonTypesMap()->end()) {
+ PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
+ "Type already registered for ", key)
+ .c_str());
+ return nullptr;
+ }
+
+ Py_INCREF(type);
+ PythonTypesMap()->emplace(key, type);
+
+ Py_RETURN_NONE;
+}
+
+namespace {
const int kMaxItemsInCache = 1024;
bool WarnedThatSetIsNotSequence = false;
@@ -177,46 +214,82 @@ class CachedTypeCheck {
// Returns -1 if an error occurred.
int IsMappingHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
- return PyObject_IsInstance(to_check, CollectionsMappingType);
+ PyObject* collections_mapping_type = GetRegisteredType("Mapping");
+ if (TF_PREDICT_FALSE(collections_mapping_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Mapping type has not been set. "
+ "Please register the type with the identifier "
+ "\"Mapping\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ return PyObject_IsInstance(to_check, collections_mapping_type);
});
if (PyDict_Check(o)) return true;
- if (TF_PREDICT_FALSE(CollectionsMappingType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Mapping type has not been set. "
- "Please call RegisterMappingClass before using this module")
- .c_str());
- return -1;
- }
return check_cache->CachedLookup(o);
}
// Returns 1 if `o` is an instance of attrs-decorated class.
// Returns 0 otherwise.
int IsAttrsHelper(PyObject* o) {
- Safe_PyObjectPtr cls(PyObject_GetAttrString(o, "__class__"));
- if (cls) {
- return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
- } else {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
+ if (cls) {
+ return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
+ }
+
// PyObject_GetAttrString returns null on error
PyErr_Clear();
return 0;
- }
+ });
+ return check_cache->CachedLookup(o);
}
-// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
+// Returns 1 if `o` is an object of type IndexedSlices.
// Returns 0 otherwise.
// Returns -1 if an error occurred.
-int IsSequenceHelper(PyObject* o) {
+int IsIndexedSlicesHelper(PyObject* o) {
static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
- int is_instance = PyObject_IsInstance(to_check, CollectionsSequenceType);
-
- // Don't cache a failed is_instance check.
- if (is_instance == -1) return -1;
+ PyObject* indexed_slices_type = GetRegisteredType("IndexedSlices");
+ if (TF_PREDICT_FALSE(indexed_slices_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "IndexedSlices type has not been set. "
+ "Please register the type with the identifier "
+ "\"IndexedSlices\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ return PyObject_IsInstance(to_check, indexed_slices_type);
+ });
+ return check_cache->CachedLookup(o);
+}
- return static_cast<int>(is_instance != 0 && !IsString(to_check));
+// Returns 1 if `o` is a Tensor.
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsTensorHelper(PyObject* o) {
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ PyObject* tensor_type = GetRegisteredType("Tensor");
+ if (TF_PREDICT_FALSE(tensor_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "Tensor type has not been set. "
+ "Please register the type with the identifier "
+ "\"Tensor\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ return PyObject_IsInstance(to_check, tensor_type);
});
+ return check_cache->CachedLookup(o);
+}
+
+// Returns 1 if `o` is considered a sequence for the purposes of Flatten().
+// Returns 0 otherwise.
+// Returns -1 if an error occurred.
+int IsSequenceHelper(PyObject* o) {
// We treat dicts and other mappings as special cases of sequences.
if (IsMappingHelper(o)) return true;
if (IsAttrsHelper(o)) return true;
@@ -226,15 +299,24 @@ int IsSequenceHelper(PyObject* o) {
"so consider avoiding using them.";
WarnedThatSetIsNotSequence = true;
}
- if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Sequence type has not been set. "
- "Please call RegisterSequenceClass before using this module")
- .c_str());
- return -1;
- }
+ static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
+ PyObject* collections_sequence_type = GetRegisteredType("Sequence");
+ if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Sequence type has not been set. "
+ "Please register the type with the identifier "
+ "\"Sequence\" using RegisterType.")
+ .c_str());
+ return -1;
+ }
+ int is_instance = PyObject_IsInstance(to_check, collections_sequence_type);
+
+ // Don't cache a failed is_instance check.
+ if (is_instance == -1) return -1;
+
+ return static_cast<int>(is_instance != 0 && !IsString(to_check));
+ });
return check_cache->CachedLookup(o);
}
@@ -401,11 +483,13 @@ class AttrsValueIterator : public ValueIterator {
};
bool IsSparseTensorValueType(PyObject* o) {
- if (TF_PREDICT_FALSE(SparseTensorValueType == nullptr)) {
+ PyObject* sparse_tensor_value_type = GetRegisteredType("SparseTensorValue");
+ if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
return false;
}
- return PyObject_TypeCheck(o, SparseTensorValueType) == 1;
+ return PyObject_TypeCheck(
+ o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
}
int IsSequenceForDataHelper(PyObject* o) {
@@ -647,49 +731,11 @@ bool AssertSameStructureHelper(
} // namespace
-void RegisterSequenceClass(PyObject* sequence_class) {
- if (!PyType_Check(sequence_class)) {
- PyErr_SetString(
- PyExc_TypeError,
- tensorflow::strings::StrCat(
- "Expecting a class definition for `collections.Sequence`. Got ",
- Py_TYPE(sequence_class)->tp_name)
- .c_str());
- return;
- }
- CollectionsSequenceType = sequence_class;
-}
-
-void RegisterMappingClass(PyObject* mapping_class) {
- if (!PyType_Check(mapping_class)) {
- PyErr_SetString(
- PyExc_TypeError,
- tensorflow::strings::StrCat(
- "Expecting a class definition for `collections.Mapping`. Got ",
- Py_TYPE(mapping_class)->tp_name)
- .c_str());
- return;
- }
- CollectionsMappingType = mapping_class;
-}
-
-void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class) {
- if (!PyType_Check(sparse_tensor_value_class)) {
- PyErr_SetString(
- PyExc_TypeError,
- tensorflow::strings::StrCat(
- "Expecting a class definition for `SparseTensorValue`. Got ",
- Py_TYPE(sparse_tensor_value_class)->tp_name)
- .c_str());
- return;
- }
- SparseTensorValueType =
- reinterpret_cast<PyTypeObject*>(sparse_tensor_value_class);
-}
-
bool IsSequence(PyObject* o) { return IsSequenceHelper(o) == 1; }
bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
+bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
+bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
PyObject* Flatten(PyObject* nested) {
PyObject* list = PyList_New(0);
@@ -737,13 +783,15 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) {
}
}
- if (TF_PREDICT_FALSE(CollectionsSequenceType == nullptr)) {
- PyErr_SetString(
- PyExc_RuntimeError,
- tensorflow::strings::StrCat(
- "collections.Sequence type has not been set. "
- "Please call RegisterSequenceClass before using this module")
- .c_str());
+ PyObject* collections_sequence_type = GetRegisteredType("Sequence");
+
+ if (TF_PREDICT_FALSE(collections_sequence_type == nullptr)) {
+ PyErr_SetString(PyExc_RuntimeError,
+ tensorflow::strings::StrCat(
+ "collections.Sequence type has not been set. "
+ "Please register the type with the identifier "
+ "\"Sequence\" using RegisterType.")
+ .c_str());
return nullptr;
}
@@ -755,7 +803,8 @@ PyObject* IsNamedtuple(PyObject* o, bool strict) {
}
Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
- int is_instance = PyObject_IsInstance(fields.get(), CollectionsSequenceType);
+ int is_instance =
+ PyObject_IsInstance(fields.get(), collections_sequence_type);
if (is_instance == 0) {
Py_RETURN_FALSE;
} else if (is_instance == -1) {
diff --git a/tensorflow/python/util/util.h b/tensorflow/python/util/util.h
index 01f85ea1dc..f37cd527d8 100644
--- a/tensorflow/python/util/util.h
+++ b/tensorflow/python/util/util.h
@@ -65,6 +65,24 @@ bool IsMapping(PyObject* o);
// True if the object is an instance of an attr.s decorated class.
bool IsAttrs(PyObject* o);
+// Returns a true if its input is an ops.Tensor.
+//
+// Args:
+// seq: the input to be checked.
+//
+// Returns:
+// True if the object is a tensor.
+bool IsTensor(PyObject* o);
+
+// Returns a true if its input is an ops.IndexesSlices.
+//
+// Args:
+// seq: the input to be checked.
+//
+// Returns:
+// True if the object is an ops.IndexedSlices.
+bool IsIndexedSlices(PyObject* o);
+
// Implements the same interface as tensorflow.util.nest._same_namedtuples
// Returns Py_True iff the two namedtuples have the same name and fields.
// Raises RuntimeError if `o1` or `o2` don't look like namedtuples (don't have
@@ -130,18 +148,6 @@ PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types);
// TypeError: The nest is or contains a dict with non-sortable keys.
PyObject* Flatten(PyObject* nested);
-// RegisterSequenceClass is used to pass PyTypeObject for collections.Sequence
-// (which is defined in python) into the C++ world.
-// Alternative approach could be to import the collections modules and retrieve
-// the type from the module. This approach also requires some trigger from
-// Python so that we know that Python interpreter had been initialzied.
-void RegisterSequenceClass(PyObject* sequence_class);
-// Like RegisterSequenceClass, but for collections.Mapping.
-void RegisterMappingClass(PyObject* mapping_class);
-// Similar to the above functions, except for the
-// sparse_tensor.SparseTensorValue class.
-void RegisterSparseTensorValueClass(PyObject* sparse_tensor_value_class);
-
// The tensorflow.python.data package has its own nest utility that follows very
// slightly different semantics for its functions than the tensorflow.python
// nest utility. Returns a true if its input is a collections.Sequence (except
@@ -167,6 +173,10 @@ PyObject* FlattenForData(PyObject* nested);
PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
bool check_types);
+// RegisterType is used to pass PyTypeObject (which is defined in python) for an
+// arbitrary identifier `type_name` into C++.
+PyObject* RegisterType(PyObject* type_name, PyObject* type);
+
} // namespace swig
} // namespace tensorflow
diff --git a/tensorflow/python/util/util.i b/tensorflow/python/util/util.i
index 32a6e684fa..3c0ec87fa4 100644
--- a/tensorflow/python/util/util.i
+++ b/tensorflow/python/util/util.i
@@ -28,14 +28,8 @@ limitations under the License.
// for functions in this module because they use python methods that need GIL.
// TODO(iga): Find a way not to leak such definitions across files.
-%unignore tensorflow::swig::RegisterSequenceClass;
-%noexception tensorflow::swig::RegisterSequenceClass;
-
-%unignore tensorflow::swig::RegisterMappingClass;
-%noexception tensorflow::swig::RegisterMappingClass;
-
-%unignore tensorflow::swig::RegisterSparseTensorValueClass;
-%noexception tensorflow::swig::RegisterSparseTensorValueClass;
+%unignore tensorflow::swig::RegisterType;
+%noexception tensorflow::swig::RegisterType;
%feature("docstring") tensorflow::swig::IsSequence
"""Returns a true if its input is a collections.Sequence (except strings).
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
index e30f50ea2a..5cceb8983c 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.cc
@@ -470,30 +470,59 @@ void CUDAExecutor::VlogOccupancyInfo(const KernelBase &kernel,
const DeviceDescription &device_description =
kernel.parent()->GetDeviceDescription();
- uint64 blocks_per_sm = CalculateOccupancy(
- device_description, regs_per_thread, smem_per_block, thread_dims);
+ const CUDAKernel *cuda_kernel = AsCUDAKernel(&kernel);
+ CUfunction cufunc = cuda_kernel->AsCUDAFunctionValue();
+
+ int blocks_per_sm = CalculateOccupancy(device_description, regs_per_thread,
+ smem_per_block, thread_dims, cufunc);
VLOG(2) << "Resident blocks per SM is " << blocks_per_sm;
- // To increase occupancy, there must be a sufficient number of blocks
- // available to spread across the sm's at this new improved occupancy level.
- int multiprocessor_count = device_description.core_count();
- int block_count = block_dims.x * block_dims.y * block_dims.z;
- int available_blocks_per_sm =
- port::MathUtil::CeilOfRatio(block_count, multiprocessor_count);
- if (available_blocks_per_sm <= static_cast<int64>(blocks_per_sm)) {
- VLOG(2) << "Occupancy is limited by number of blocks available per sm.";
- return;
+ int suggested_threads =
+ CompareOccupancy(&blocks_per_sm, device_description, regs_per_thread,
+ smem_per_block, thread_dims, cufunc);
+ if (suggested_threads != 0) {
+ VLOG(2) << "The cuda occupancy calculator recommends using "
+ << suggested_threads
+ << " threads per block to achieve an occupancy of " << blocks_per_sm
+ << " blocks per SM.";
}
+}
- uint64 improved_regs_per_thread = CalculateRegisterLimitForTargetOccupancy(
- device_description, smem_per_block, thread_dims, blocks_per_sm + 1);
- if (improved_regs_per_thread != 0) {
- VLOG(2) << "Reducing register usage from " << regs_per_thread
- << " to " << improved_regs_per_thread
- << " could increase resident blocks per SM by one.";
+// Compute and return maximum blocks per core (occupancy) based on the
+// device description, some kernel characteristics and the number of threads per
+// block. If unable to compute occupancy, zero is returned.
+int CUDAExecutor::CalculateOccupancy(
+ const DeviceDescription &device_description, uint64 registers_per_thread,
+ uint64 shared_memory_per_block, const ThreadDim &thread_dims,
+ CUfunction func) {
+ int suggested_blocks = 0;
+ int suggested_threads = 0;
+ CUresult err = cuOccupancyMaxPotentialBlockSize(
+ &suggested_blocks, &suggested_threads, func, nullptr,
+ shared_memory_per_block, 0);
+ CHECK_EQ(err, CUDA_SUCCESS);
+ return suggested_blocks;
+}
+
+// Compute and return the suggested thread count to achieve ideal occupancy.
+// If the provided thread dimensions match this number, zero is returned.
+int CUDAExecutor::CompareOccupancy(int *initial_blocks,
+ const DeviceDescription &device_description,
+ uint64 registers_per_thread,
+ uint64 shared_memory_per_block,
+ const ThreadDim &thread_dims,
+ CUfunction func) {
+ int suggested_blocks = 0;
+ int suggested_threads = 0;
+ CUresult err = cuOccupancyMaxPotentialBlockSize(
+ &suggested_blocks, &suggested_threads, func, nullptr,
+ shared_memory_per_block, 0);
+ CHECK_EQ(err, CUDA_SUCCESS);
+ if (suggested_blocks > *initial_blocks) {
+ *initial_blocks = suggested_blocks;
+ return suggested_threads;
} else {
- VLOG(2) << "Resident blocks per SM cannot be increased by reducing "
- "register usage.";
+ return 0;
}
}
@@ -980,144 +1009,6 @@ static int TryToReadNumaNode(const string &pci_bus_id, int device_ordinal) {
#endif
}
-// Set of compute capability specific device parameters that cannot be
-// queried from the driver API. These values instead are baked into a
-// lookup table indexed by compute capability version.
-struct UnqueryableDeviceParams {
- int cc_major;
- int cc_minor;
- uint64 blocks_per_core_limit;
- uint64 registers_per_core_limit;
- uint64 registers_per_thread_limit;
- uint64 warp_alloc_granularity;
- uint64 register_alloc_granularity;
- uint64 shared_memory_alloc_granularity;
-};
-
-// http://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#compute-capabilities
-// https://developer.download.nvidia.com/compute/cuda/CUDA_Occupancy_calculator.xls
-static const UnqueryableDeviceParams kAllUnqueryableDeviceParams[] = {
- {
- 2, 0, // compute capability (2.0)
- 8, // blocks_per_core_limit
- 32 * 1024, // registers_per_core_limit
- 63, // registers_per_thread_limit
- 2, // warp_alloc_granularity
- 64, // register_alloc_granularity
- 128, // shared_memory_alloc_granularity
- },
- {
- 2, 1, // compute capability (2.1)
- 8, // blocks_per_core_limit
- 32 * 1024, // registers_per_core_limit
- 63, // registers_per_thread_limit
- 2, // warp_alloc_granularity
- 64, // register_alloc_granularity
- 128, // shared_memory_alloc_granularity
- },
- {
- 3, 0, // compute capability (3.0)
- 16, // blocks_per_core_limit
- 64 * 1024, // registers_per_core_limit
- 63, // registers_per_thread_limit
- 4, // warp_alloc_granularity
- 256, // register_alloc_granularity
- 256, // shared_memory_alloc_granularity
- },
- {
- 3, 2, // compute capability (3.2)
- 16, // blocks_per_core_limit
- 64 * 1024, // registers_per_core_limit
- 255, // registers_per_thread_limit
- 4, // warp_alloc_granularity
- 256, // register_alloc_granularity
- 256, // shared_memory_alloc_granularity
- },
- {
- 3, 5, // compute capability (3.5)
- 16, // blocks_per_core_limit
- 64 * 1024, // registers_per_core_limit
- 255, // registers_per_thread_limit
- 4, // warp_alloc_granularity
- 256, // register_alloc_granularity
- 256, // shared_memory_alloc_granularity
- },
- {
- 3, 7, // compute capability (3.7)
- 16, // blocks_per_core_limit
- 128 * 1024, // registers_per_core_limit
- 255, // registers_per_thread_limit
- 4, // warp_alloc_granularity
- 256, // register_alloc_granularity
- 256, // shared_memory_alloc_granularity
- },
- {
- 5, 0, // compute capability (5.0)
- 32, // blocks_per_core_limit
- 64 * 1024, // registers_per_core_limit
- 255, // registers_per_thread_limit
- 4, // warp_alloc_granularity
- 256, // register_alloc_granularity
- 256, // shared_memory_alloc_granularity
- },
- {
- 5, 2, // compute capability (5.2)
- 32, // blocks_per_core_limit
- 64 * 1024, // registers_per_core_limit
- 255, // registers_per_thread_limit
- 4, // warp_alloc_granularity
- 256, // register_alloc_granularity
- 256, // shared_memory_alloc_granularity
- },
- {
- 5, 3, // compute capability (5.3)
- 32, // blocks_per_core_limit
- 64 * 1024, // registers_per_core_limit
- 255, // registers_per_thread_limit
- 4, // warp_alloc_granularity
- 256, // register_alloc_granularity
- 256, // shared_memory_alloc_granularity
- },
- {
- 6, 0, // compute capability (6.0)
- 32, // blocks_per_core_limit
- 64 * 1024, // registers_per_core_limit
- 255, // registers_per_thread_limit
- 2, // warp_alloc_granularity
- 256, // register_alloc_granularity
- 256, // shared_memory_alloc_granularity
- },
- {
- 6, 1, // compute capability (6.1)
- 32, // blocks_per_core_limit
- 64 * 1024, // registers_per_core_limit
- 255, // registers_per_thread_limit
- 4, // warp_alloc_granularity
- 256, // register_alloc_granularity
- 256, // shared_memory_alloc_granularity
- },
- {
- 6, 2, // compute capability (6.2)
- 32, // blocks_per_core_limit
- 64 * 1024, // registers_per_core_limit
- 255, // registers_per_thread_limit
- 4, // warp_alloc_granularity
- 256, // register_alloc_granularity
- 256, // shared_memory_alloc_granularity
- },
- // TODO(jlebar): Confirm the alloc granularity values for sm_70. These are
- // not published in the spreadsheet linked above. Currently we guess that
- // they're the same as sm_60.
- {
- 7, 0, // compute capability (7.0)
- 32, // blocks_per_core_limit
- 64 * 1024, // registers_per_core_limit
- 255, // registers_per_thread_limit
- 2, // warp_alloc_granularity
- 256, // register_alloc_granularity
- 256, // shared_memory_alloc_granularity
- },
-};
DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const {
internal::DeviceDescriptionBuilder builder;
@@ -1193,19 +1084,6 @@ DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const {
builder.set_name(device_name);
}
- for (size_t i = 0; i < TF_ARRAYSIZE(kAllUnqueryableDeviceParams); i++) {
- const auto &params = kAllUnqueryableDeviceParams[i];
- if (params.cc_major == cc_major_ && params.cc_minor == cc_minor_) {
- builder.set_blocks_per_core_limit(params.blocks_per_core_limit);
- builder.set_registers_per_core_limit(params.registers_per_core_limit);
- builder.set_registers_per_thread_limit(params.registers_per_thread_limit);
- builder.set_warp_alloc_granularity(params.warp_alloc_granularity);
- builder.set_register_alloc_granularity(params.register_alloc_granularity);
- builder.set_shared_memory_alloc_granularity(
- params.shared_memory_alloc_granularity);
- }
- }
-
builder.set_platform_version(
port::StrCat("Compute Capability ", cc_major_, ".", cc_minor_));
@@ -1227,6 +1105,10 @@ DeviceDescription *CUDAExecutor::PopulateDeviceDescription() const {
CUDADriver::GetMaxRegistersPerBlock(device_).ValueOrDie());
builder.set_threads_per_warp(
CUDADriver::GetThreadsPerWarp(device_).ValueOrDie());
+ builder.set_registers_per_core_limit(
+ CUDADriver::GetDeviceAttribute(
+ CU_DEVICE_ATTRIBUTE_MAX_REGISTERS_PER_MULTIPROCESSOR, device_)
+ .ValueOrDie());
auto built = builder.Build();
return built.release();
diff --git a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
index 8a954d5461..53b2a29ae7 100644
--- a/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
+++ b/tensorflow/stream_executor/cuda/cuda_gpu_executor.h
@@ -70,6 +70,17 @@ class CUDAExecutor : public internal::StreamExecutorInterface {
const BlockDim &block_dims, const KernelBase &k,
const KernelArgsArrayBase &args) override;
+ int CalculateOccupancy(const DeviceDescription &device_description,
+ uint64 registers_per_thread,
+ uint64 shared_memory_per_block,
+ const ThreadDim &thread_dims, CUfunction func);
+
+ int CompareOccupancy(int *initial_blocks,
+ const DeviceDescription &device_description,
+ uint64 registers_per_thread,
+ uint64 shared_memory_per_block,
+ const ThreadDim &thread_dims, CUfunction func);
+
void *Allocate(uint64 size) override;
void *AllocateSubBuffer(DeviceMemoryBase *mem, uint64 offset_bytes,
diff --git a/tensorflow/stream_executor/device_description.cc b/tensorflow/stream_executor/device_description.cc
index 8ca0677f8a..726c4adf74 100644
--- a/tensorflow/stream_executor/device_description.cc
+++ b/tensorflow/stream_executor/device_description.cc
@@ -37,16 +37,11 @@ DeviceDescription::DeviceDescription()
kUninitializedUint64),
block_dim_limit_(kUninitializedUint64, kUninitializedUint64,
kUninitializedUint64),
- blocks_per_core_limit_(kUninitializedUint64),
threads_per_core_limit_(kUninitializedUint64),
threads_per_block_limit_(kUninitializedUint64),
threads_per_warp_(kUninitializedUint64),
registers_per_core_limit_(kUninitializedUint64),
registers_per_block_limit_(kUninitializedUint64),
- registers_per_thread_limit_(kUninitializedUint64),
- warp_alloc_granularity_(1),
- register_alloc_granularity_(1),
- shared_memory_alloc_granularity_(1),
device_address_bits_(kUninitializedUint64),
device_memory_size_(kUninitializedUint64),
memory_bandwidth_(kUninitializedUint64),
@@ -162,75 +157,4 @@ static uint64 RoundDown(uint64 value, uint64 n) {
return port::MathUtil::FloorOfRatio(value, n) * n;
}
-uint64 CalculateOccupancy(const DeviceDescription &device_description,
- uint64 registers_per_thread,
- uint64 shared_memory_per_block,
- const ThreadDim &thread_dims) {
- // Don't try to compute occupancy if necessary values are not initialized.
- uint64 required_fields[] = { device_description.registers_per_thread_limit(),
- device_description.threads_per_warp(),
- device_description.warp_alloc_granularity(),
- device_description.register_alloc_granularity(),
- device_description.registers_per_block_limit(),
- device_description.shared_memory_per_core(),
- device_description.blocks_per_core_limit() };
- for (auto value : required_fields) {
- if (value == kUninitializedUint64) {
- return 0;
- }
- }
-
- if (registers_per_thread > device_description.registers_per_thread_limit()) {
- return 0;
- }
-
- uint64 warps_per_block =
- port::MathUtil::CeilOfRatio(thread_dims.x * thread_dims.y * thread_dims.z,
- device_description.threads_per_warp());
-
- // Warp resources are allocated at a particular granularity. This value is
- // the effective number of warps for resource allocation purposes.
- uint64 alloc_warps_per_block =
- RoundUp(warps_per_block, device_description.warp_alloc_granularity());
-
- uint64 alloc_regs_per_warp =
- RoundUp(device_description.threads_per_warp() * registers_per_thread,
- device_description.register_alloc_granularity());
- uint64 regs_per_block = alloc_warps_per_block * alloc_regs_per_warp;
- uint64 reg_limit =
- device_description.registers_per_block_limit() / regs_per_block;
-
- uint64 alloc_smem_per_block = RoundUp(
- shared_memory_per_block,
- device_description.shared_memory_alloc_granularity());
- uint64 smem_limit = alloc_smem_per_block > 0 ?
- device_description.shared_memory_per_core() / alloc_smem_per_block :
- device_description.blocks_per_core_limit();
-
- uint64 thread_limit = device_description.threads_per_core_limit()
- / (warps_per_block * device_description.threads_per_warp());
-
- return std::min({ device_description.blocks_per_core_limit(),
- reg_limit, smem_limit, thread_limit });
-}
-
-uint64 CalculateRegisterLimitForTargetOccupancy(
- const DeviceDescription &device_description, uint64 shared_memory_per_block,
- const ThreadDim &thread_dims, uint64 target_blocks_per_core) {
- // Linear search from maximum number of registers down until the target
- // blocks per SM is found.
- // TODO(meheff): Compute this using a closed form solution.
- int reg_step = device_description.register_alloc_granularity() /
- device_description.threads_per_warp();
- for (int r = device_description.registers_per_thread_limit(); r > 0;
- r = RoundDown(r - 1, reg_step)) {
- uint64 occupancy = CalculateOccupancy(
- device_description, r, shared_memory_per_block, thread_dims);
- if (occupancy >= target_blocks_per_core) {
- return r;
- }
- }
- return 0;
-}
-
} // namespace stream_executor
diff --git a/tensorflow/stream_executor/device_description.h b/tensorflow/stream_executor/device_description.h
index a4580d6462..8ddf18629d 100644
--- a/tensorflow/stream_executor/device_description.h
+++ b/tensorflow/stream_executor/device_description.h
@@ -78,10 +78,6 @@ class DeviceDescription {
// legitimate kernel launch request.
const BlockDim &block_dim_limit() const { return block_dim_limit_; }
- // Returns the limit on the number of simultaneously resident blocks
- // on a multiprocessor.
- uint64 blocks_per_core_limit() const { return blocks_per_core_limit_; }
-
// Returns the limit on the total number of threads that can be launched in a
// single block; i.e. the limit on x * y * z dimensions of a ThreadDim.
// This limit affects what constitutes a legitimate kernel launch request.
@@ -109,27 +105,6 @@ class DeviceDescription {
return registers_per_block_limit_;
}
- // Returns the limit on the total number of registers that can be
- // allocated to a thread.
- const uint64 &registers_per_thread_limit() const {
- return registers_per_thread_limit_;
- }
-
- // Returns the granularity at which warps are allocated resources.
- const uint64 &warp_alloc_granularity() const {
- return warp_alloc_granularity_;
- }
-
- // Returns the granularity at which registers are allocated to warps.
- const uint64 &register_alloc_granularity() const {
- return register_alloc_granularity_;
- }
-
- // Returns the granularity at which shared memory is allocated to warps.
- const uint64 &shared_memory_alloc_granularity() const {
- return shared_memory_alloc_granularity_;
- }
-
// Returns the number of address bits available to kernel code running on the
// platform. This affects things like the maximum allocation size and perhaps
// types used in kernel code such as size_t.
@@ -199,19 +174,12 @@ class DeviceDescription {
ThreadDim thread_dim_limit_;
BlockDim block_dim_limit_;
- uint64 blocks_per_core_limit_;
-
uint64 threads_per_core_limit_;
uint64 threads_per_block_limit_;
uint64 threads_per_warp_;
uint64 registers_per_core_limit_;
uint64 registers_per_block_limit_;
- uint64 registers_per_thread_limit_;
-
- uint64 warp_alloc_granularity_;
- uint64 register_alloc_granularity_;
- uint64 shared_memory_alloc_granularity_;
uint64 device_address_bits_;
uint64 device_memory_size_;
@@ -269,10 +237,6 @@ class DeviceDescriptionBuilder {
device_description_->block_dim_limit_ = value;
}
- void set_blocks_per_core_limit(uint64 value) {
- device_description_->blocks_per_core_limit_ = value;
- }
-
void set_threads_per_core_limit(uint64 value) {
device_description_->threads_per_core_limit_ = value;
}
@@ -289,19 +253,6 @@ class DeviceDescriptionBuilder {
void set_registers_per_block_limit(uint64 value) {
device_description_->registers_per_block_limit_ = value;
}
- void set_registers_per_thread_limit(uint64 value) {
- device_description_->registers_per_thread_limit_ = value;
- }
-
- void set_warp_alloc_granularity(uint64 value) {
- device_description_->warp_alloc_granularity_ = value;
- }
- void set_register_alloc_granularity(uint64 value) {
- device_description_->register_alloc_granularity_ = value;
- }
- void set_shared_memory_alloc_granularity(uint64 value) {
- device_description_->shared_memory_alloc_granularity_ = value;
- }
void set_device_address_bits(uint64 value) {
device_description_->device_address_bits_ = value;
@@ -370,21 +321,6 @@ void CalculateDimensionality(const DeviceDescription &device_description,
uint64 element_count, uint64 *threads_per_block,
uint64 *block_count);
-// Compute and return maximum blocks per core (occupancy) based on the
-// device description, some kernel characteristics and the number of threads per
-// block. If unable to compute occupancy, zero is returned.
-uint64 CalculateOccupancy(const DeviceDescription &device_description,
- uint64 registers_per_thread,
- uint64 shared_memory_per_block,
- const ThreadDim &thread_dims);
-
-// Compute and return the maximum number of registers per thread which
-// achieves the target occupancy. If the target is not possible then
-// zero is returned.
-uint64 CalculateRegisterLimitForTargetOccupancy(
- const DeviceDescription &device_description, uint64 shared_memory_per_block,
- const ThreadDim &thread_dims, uint64 target_blocks_per_core);
-
} // namespace stream_executor
#endif // TENSORFLOW_STREAM_EXECUTOR_DEVICE_DESCRIPTION_H_
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt
deleted file mode 100644
index eb41deee13..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.-config-proto.-experimental.pbtxt
+++ /dev/null
@@ -1,24 +0,0 @@
-path: "tensorflow.ConfigProto.Experimental"
-tf_proto {
- descriptor {
- name: "Experimental"
- field {
- name: "collective_group_leader"
- number: 1
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "executor_type"
- number: 3
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt b/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
deleted file mode 100644
index e565b903d2..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.-config-proto.pbtxt
+++ /dev/null
@@ -1,148 +0,0 @@
-path: "tensorflow.ConfigProto"
-tf_proto {
- descriptor {
- name: "ConfigProto"
- field {
- name: "device_count"
- number: 1
- label: LABEL_REPEATED
- type: TYPE_MESSAGE
- type_name: ".tensorflow.ConfigProto.DeviceCountEntry"
- }
- field {
- name: "intra_op_parallelism_threads"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_INT32
- }
- field {
- name: "inter_op_parallelism_threads"
- number: 5
- label: LABEL_OPTIONAL
- type: TYPE_INT32
- }
- field {
- name: "use_per_session_threads"
- number: 9
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "session_inter_op_thread_pool"
- number: 12
- label: LABEL_REPEATED
- type: TYPE_MESSAGE
- type_name: ".tensorflow.ThreadPoolOptionProto"
- }
- field {
- name: "placement_period"
- number: 3
- label: LABEL_OPTIONAL
- type: TYPE_INT32
- }
- field {
- name: "device_filters"
- number: 4
- label: LABEL_REPEATED
- type: TYPE_STRING
- }
- field {
- name: "gpu_options"
- number: 6
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.GPUOptions"
- }
- field {
- name: "allow_soft_placement"
- number: 7
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "log_device_placement"
- number: 8
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "graph_options"
- number: 10
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.GraphOptions"
- }
- field {
- name: "operation_timeout_in_ms"
- number: 11
- label: LABEL_OPTIONAL
- type: TYPE_INT64
- }
- field {
- name: "rpc_options"
- number: 13
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.RPCOptions"
- }
- field {
- name: "cluster_def"
- number: 14
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.ClusterDef"
- }
- field {
- name: "isolate_session_state"
- number: 15
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "experimental"
- number: 16
- label: LABEL_OPTIONAL
- type: TYPE_MESSAGE
- type_name: ".tensorflow.ConfigProto.Experimental"
- }
- nested_type {
- name: "DeviceCountEntry"
- field {
- name: "key"
- number: 1
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- field {
- name: "value"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_INT32
- }
- options {
- map_entry: true
- }
- }
- nested_type {
- name: "Experimental"
- field {
- name: "collective_group_leader"
- number: 1
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- field {
- name: "client_handles_error_formatting"
- number: 2
- label: LABEL_OPTIONAL
- type: TYPE_BOOL
- }
- field {
- name: "executor_type"
- number: 3
- label: LABEL_OPTIONAL
- type: TYPE_STRING
- }
- }
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt b/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt
deleted file mode 100644
index 4f0147a523..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.data.-iterator.pbtxt
+++ /dev/null
@@ -1,46 +0,0 @@
-path: "tensorflow.data.Iterator"
-tf_class {
- is_instance: "<class \'tensorflow.python.data.ops.iterator_ops.Iterator\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "initializer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_classes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shapes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_types"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'iterator_resource\', \'initializer\', \'output_types\', \'output_shapes\', \'output_classes\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "from_string_handle"
- argspec: "args=[\'string_handle\', \'output_types\', \'output_shapes\', \'output_classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "from_structure"
- argspec: "args=[\'output_types\', \'output_shapes\', \'shared_name\', \'output_classes\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "get_next"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "make_initializer"
- argspec: "args=[\'self\', \'dataset\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "string_handle"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
deleted file mode 100644
index c23b04b4ef..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-classifier.pbtxt
+++ /dev/null
@@ -1,58 +0,0 @@
-path: "tensorflow.estimator.BoostedTreesClassifier"
-tf_class {
- is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesClassifier\'>"
- is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "config"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_dir"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_fn"
- mtype: "<type \'property\'>"
- }
- member {
- name: "params"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'n_classes\', \'weight_column\', \'label_vocabulary\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
- }
- member_method {
- name: "eval_dir"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
- }
- member_method {
- name: "get_variable_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_variable_value"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "latest_checkpoint"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
- }
- member_method {
- name: "train"
- argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
deleted file mode 100644
index 6878d28fff..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-boosted-trees-regressor.pbtxt
+++ /dev/null
@@ -1,58 +0,0 @@
-path: "tensorflow.estimator.BoostedTreesRegressor"
-tf_class {
- is_instance: "<class \'tensorflow.python.estimator.canned.boosted_trees.BoostedTreesRegressor\'>"
- is_instance: "<class \'tensorflow.python.estimator.estimator.Estimator\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "config"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_dir"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_fn"
- mtype: "<type \'property\'>"
- }
- member {
- name: "params"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'feature_columns\', \'n_batches_per_layer\', \'model_dir\', \'label_dimension\', \'weight_column\', \'n_trees\', \'max_depth\', \'learning_rate\', \'l1_regularization\', \'l2_regularization\', \'tree_complexity\', \'min_node_weight\', \'config\', \'center_bias\', \'pruning_mode\'], varargs=None, keywords=None, defaults=[\'None\', \'<object object instance>\', \'None\', \'100\', \'6\', \'0.1\', \'0.0\', \'0.0\', \'0.0\', \'0.0\', \'None\', \'False\', \'none\'], "
- }
- member_method {
- name: "eval_dir"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'input_fn\', \'steps\', \'hooks\', \'checkpoint_path\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "export_savedmodel"
- argspec: "args=[\'self\', \'export_dir_base\', \'serving_input_receiver_fn\', \'assets_extra\', \'as_text\', \'checkpoint_path\', \'strip_default_attrs\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\', \'False\'], "
- }
- member_method {
- name: "get_variable_names"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_variable_value"
- argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "latest_checkpoint"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'input_fn\', \'predict_keys\', \'hooks\', \'checkpoint_path\', \'yield_single_examples\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\'], "
- }
- member_method {
- name: "train"
- argspec: "args=[\'self\', \'input_fn\', \'hooks\', \'steps\', \'max_steps\', \'saving_listeners\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt b/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
deleted file mode 100644
index bf1f94b6ae..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.estimator.-run-config.pbtxt
+++ /dev/null
@@ -1,105 +0,0 @@
-path: "tensorflow.estimator.RunConfig"
-tf_class {
- is_instance: "<class \'tensorflow.python.estimator.run_config.RunConfig\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "cluster_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "device_fn"
- mtype: "<type \'property\'>"
- }
- member {
- name: "eval_distribute"
- mtype: "<type \'property\'>"
- }
- member {
- name: "evaluation_master"
- mtype: "<type \'property\'>"
- }
- member {
- name: "global_id_in_cluster"
- mtype: "<type \'property\'>"
- }
- member {
- name: "is_chief"
- mtype: "<type \'property\'>"
- }
- member {
- name: "keep_checkpoint_every_n_hours"
- mtype: "<type \'property\'>"
- }
- member {
- name: "keep_checkpoint_max"
- mtype: "<type \'property\'>"
- }
- member {
- name: "log_step_count_steps"
- mtype: "<type \'property\'>"
- }
- member {
- name: "master"
- mtype: "<type \'property\'>"
- }
- member {
- name: "model_dir"
- mtype: "<type \'property\'>"
- }
- member {
- name: "num_ps_replicas"
- mtype: "<type \'property\'>"
- }
- member {
- name: "num_worker_replicas"
- mtype: "<type \'property\'>"
- }
- member {
- name: "protocol"
- mtype: "<type \'property\'>"
- }
- member {
- name: "save_checkpoints_secs"
- mtype: "<type \'property\'>"
- }
- member {
- name: "save_checkpoints_steps"
- mtype: "<type \'property\'>"
- }
- member {
- name: "save_summary_steps"
- mtype: "<type \'property\'>"
- }
- member {
- name: "service"
- mtype: "<type \'property\'>"
- }
- member {
- name: "session_config"
- mtype: "<type \'property\'>"
- }
- member {
- name: "task_id"
- mtype: "<type \'property\'>"
- }
- member {
- name: "task_type"
- mtype: "<type \'property\'>"
- }
- member {
- name: "tf_random_seed"
- mtype: "<type \'property\'>"
- }
- member {
- name: "train_distribute"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'model_dir\', \'tf_random_seed\', \'save_summary_steps\', \'save_checkpoints_steps\', \'save_checkpoints_secs\', \'session_config\', \'keep_checkpoint_max\', \'keep_checkpoint_every_n_hours\', \'log_step_count_steps\', \'train_distribute\', \'device_fn\', \'protocol\', \'eval_distribute\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'100\', \'<object object instance>\', \'<object object instance>\', \'None\', \'5\', \'10000\', \'100\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "replace"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/tensorflow.image.pbtxt
deleted file mode 100644
index 5c46dc5ee7..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.image.pbtxt
+++ /dev/null
@@ -1,251 +0,0 @@
-path: "tensorflow.image"
-tf_module {
- member {
- name: "ResizeMethod"
- mtype: "<type \'type\'>"
- }
- member_method {
- name: "adjust_brightness"
- argspec: "args=[\'image\', \'delta\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "adjust_contrast"
- argspec: "args=[\'images\', \'contrast_factor\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "adjust_gamma"
- argspec: "args=[\'image\', \'gamma\', \'gain\'], varargs=None, keywords=None, defaults=[\'1\', \'1\'], "
- }
- member_method {
- name: "adjust_hue"
- argspec: "args=[\'image\', \'delta\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "adjust_jpeg_quality"
- argspec: "args=[\'image\', \'jpeg_quality\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "adjust_saturation"
- argspec: "args=[\'image\', \'saturation_factor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "central_crop"
- argspec: "args=[\'image\', \'central_fraction\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "convert_image_dtype"
- argspec: "args=[\'image\', \'dtype\', \'saturate\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "crop_and_resize"
- argspec: "args=[\'image\', \'boxes\', \'box_ind\', \'crop_size\', \'method\', \'extrapolation_value\', \'name\'], varargs=None, keywords=None, defaults=[\'bilinear\', \'0\', \'None\'], "
- }
- member_method {
- name: "crop_to_bounding_box"
- argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "decode_and_crop_jpeg"
- argspec: "args=[\'contents\', \'crop_window\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], "
- }
- member_method {
- name: "decode_bmp"
- argspec: "args=[\'contents\', \'channels\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
- }
- member_method {
- name: "decode_gif"
- argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "decode_image"
- argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'uint8\'>\", \'None\'], "
- }
- member_method {
- name: "decode_jpeg"
- argspec: "args=[\'contents\', \'channels\', \'ratio\', \'fancy_upscaling\', \'try_recover_truncated\', \'acceptable_fraction\', \'dct_method\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \'True\', \'False\', \'1\', \'\', \'None\'], "
- }
- member_method {
- name: "decode_png"
- argspec: "args=[\'contents\', \'channels\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \"<dtype: \'uint8\'>\", \'None\'], "
- }
- member_method {
- name: "draw_bounding_boxes"
- argspec: "args=[\'images\', \'boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "encode_jpeg"
- argspec: "args=[\'image\', \'format\', \'quality\', \'progressive\', \'optimize_size\', \'chroma_downsampling\', \'density_unit\', \'x_density\', \'y_density\', \'xmp_metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'95\', \'False\', \'False\', \'True\', \'in\', \'300\', \'300\', \'\', \'None\'], "
- }
- member_method {
- name: "encode_png"
- argspec: "args=[\'image\', \'compression\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'None\'], "
- }
- member_method {
- name: "extract_glimpse"
- argspec: "args=[\'input\', \'size\', \'offsets\', \'centered\', \'normalized\', \'uniform_noise\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'True\', \'True\', \'None\'], "
- }
- member_method {
- name: "extract_image_patches"
- argspec: "args=[\'images\', \'ksizes\', \'strides\', \'rates\', \'padding\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "extract_jpeg_shape"
- argspec: "args=[\'contents\', \'output_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'int32\'>\", \'None\'], "
- }
- member_method {
- name: "flip_left_right"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "flip_up_down"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "grayscale_to_rgb"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "hsv_to_rgb"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "image_gradients"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "is_jpeg"
- argspec: "args=[\'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "non_max_suppression"
- argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
- }
- member_method {
- name: "non_max_suppression_overlaps"
- argspec: "args=[\'overlaps\', \'scores\', \'max_output_size\', \'overlap_threshold\', \'score_threshold\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'None\'], "
- }
- member_method {
- name: "non_max_suppression_padded"
- argspec: "args=[\'boxes\', \'scores\', \'max_output_size\', \'iou_threshold\', \'score_threshold\', \'pad_to_max_output_size\', \'name\'], varargs=None, keywords=None, defaults=[\'0.5\', \'-inf\', \'False\', \'None\'], "
- }
- member_method {
- name: "pad_to_bounding_box"
- argspec: "args=[\'image\', \'offset_height\', \'offset_width\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "per_image_standardization"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "psnr"
- argspec: "args=[\'a\', \'b\', \'max_val\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_brightness"
- argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_contrast"
- argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_flip_left_right"
- argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_flip_up_down"
- argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_hue"
- argspec: "args=[\'image\', \'max_delta\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_jpeg_quality"
- argspec: "args=[\'image\', \'min_jpeg_quality\', \'max_jpeg_quality\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "random_saturation"
- argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "resize_area"
- argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "resize_bicubic"
- argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "resize_bilinear"
- argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "resize_image_with_crop_or_pad"
- argspec: "args=[\'image\', \'target_height\', \'target_width\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "resize_image_with_pad"
- argspec: "args=[\'image\', \'target_height\', \'target_width\', \'method\'], varargs=None, keywords=None, defaults=[\'0\'], "
- }
- member_method {
- name: "resize_images"
- argspec: "args=[\'images\', \'size\', \'method\', \'align_corners\', \'preserve_aspect_ratio\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\'], "
- }
- member_method {
- name: "resize_nearest_neighbor"
- argspec: "args=[\'images\', \'size\', \'align_corners\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
- }
- member_method {
- name: "rgb_to_grayscale"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "rgb_to_hsv"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "rgb_to_yiq"
- argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "rgb_to_yuv"
- argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "rot90"
- argspec: "args=[\'image\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
- }
- member_method {
- name: "sample_distorted_bounding_box"
- argspec: "args=[\'image_size\', \'bounding_boxes\', \'seed\', \'seed2\', \'min_object_covered\', \'aspect_ratio_range\', \'area_range\', \'max_attempts\', \'use_image_if_no_bounding_boxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'0.1\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "sobel_edges"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "ssim"
- argspec: "args=[\'img1\', \'img2\', \'max_val\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "ssim_multiscale"
- argspec: "args=[\'img1\', \'img2\', \'max_val\', \'power_factors\'], varargs=None, keywords=None, defaults=[\'(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)\'], "
- }
- member_method {
- name: "total_variation"
- argspec: "args=[\'images\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "transpose_image"
- argspec: "args=[\'image\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "yiq_to_rgb"
- argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "yuv_to_rgb"
- argspec: "args=[\'images\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
deleted file mode 100644
index e579fe6a1a..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
+++ /dev/null
@@ -1,268 +0,0 @@
-path: "tensorflow.keras.Model"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "layers"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "stateful"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "uses_learning_phase"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
- }
- member_method {
- name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
- }
- member_method {
- name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_layer"
- argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load_weights"
- argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
- }
- member_method {
- name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "predict_on_batch"
- argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "reset_states"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "save"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
- }
- member_method {
- name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "summary"
- argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "test_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "to_json"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "to_yaml"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "train_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
deleted file mode 100644
index 6f05cdd093..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
+++ /dev/null
@@ -1,289 +0,0 @@
-path: "tensorflow.keras.Sequential"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "layers"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "stateful"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "uses_learning_phase"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'layers\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "add"
- argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
- }
- member_method {
- name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
- }
- member_method {
- name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_layer"
- argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load_weights"
- argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "pop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
- }
- member_method {
- name: "predict_classes"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], "
- }
- member_method {
- name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "predict_on_batch"
- argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict_proba"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], "
- }
- member_method {
- name: "reset_states"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "save"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
- }
- member_method {
- name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "summary"
- argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "symbolic_set_inputs"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "test_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "to_json"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "to_yaml"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "train_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt
deleted file mode 100644
index 2e9de9ebb2..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.activations.pbtxt
+++ /dev/null
@@ -1,55 +0,0 @@
-path: "tensorflow.keras.activations"
-tf_module {
- member_method {
- name: "deserialize"
- argspec: "args=[\'name\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "elu"
- argspec: "args=[\'x\', \'alpha\'], varargs=None, keywords=None, defaults=[\'1.0\'], "
- }
- member_method {
- name: "get"
- argspec: "args=[\'identifier\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "hard_sigmoid"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "linear"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "relu"
- argspec: "args=[\'x\', \'alpha\', \'max_value\', \'threshold\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\', \'0\'], "
- }
- member_method {
- name: "selu"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "serialize"
- argspec: "args=[\'activation\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "sigmoid"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "softmax"
- argspec: "args=[\'x\', \'axis\'], varargs=None, keywords=None, defaults=[\'-1\'], "
- }
- member_method {
- name: "softplus"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "softsign"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "tanh"
- argspec: "args=[\'x\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
deleted file mode 100644
index 56914e1746..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
+++ /dev/null
@@ -1,268 +0,0 @@
-path: "tensorflow.keras.models.Model"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "layers"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "stateful"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "uses_learning_phase"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
- }
- member_method {
- name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
- }
- member_method {
- name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_layer"
- argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load_weights"
- argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
- }
- member_method {
- name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "predict_on_batch"
- argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "reset_states"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "save"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
- }
- member_method {
- name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "summary"
- argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "test_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "to_json"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "to_yaml"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "train_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
deleted file mode 100644
index 4c1c54001d..0000000000
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
+++ /dev/null
@@ -1,289 +0,0 @@
-path: "tensorflow.keras.models.Sequential"
-tf_class {
- is_instance: "<class \'tensorflow.python.keras.engine.sequential.Sequential\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.training.Model\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.network.Network\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_spec"
- mtype: "<type \'property\'>"
- }
- member {
- name: "layers"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "stateful"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "uses_learning_phase"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'layers\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "add"
- argspec: "args=[\'self\', \'layer\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'True\', \'None\'], "
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'partitioner\', \'use_resource\', \'synchronization\', \'aggregation\', \'getter\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'training\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "compile"
- argspec: "args=[\'self\', \'optimizer\', \'loss\', \'metrics\', \'loss_weights\', \'sample_weight_mode\', \'weighted_metrics\', \'target_tensors\', \'distribute\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "evaluate"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'verbose\', \'sample_weight\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'1\', \'None\', \'None\'], "
- }
- member_method {
- name: "evaluate_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "fit"
- argspec: "args=[\'self\', \'x\', \'y\', \'batch_size\', \'epochs\', \'verbose\', \'callbacks\', \'validation_split\', \'validation_data\', \'shuffle\', \'class_weight\', \'sample_weight\', \'initial_epoch\', \'steps_per_epoch\', \'validation_steps\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'1\', \'1\', \'None\', \'0.0\', \'None\', \'True\', \'None\', \'None\', \'0\', \'None\', \'None\'], "
- }
- member_method {
- name: "fit_generator"
- argspec: "args=[\'self\', \'generator\', \'steps_per_epoch\', \'epochs\', \'verbose\', \'callbacks\', \'validation_data\', \'validation_steps\', \'class_weight\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'shuffle\', \'initial_epoch\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'1\', \'None\', \'None\', \'None\', \'None\', \'10\', \'1\', \'False\', \'True\', \'0\'], "
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\', \'custom_objects\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_layer"
- argspec: "args=[\'self\', \'name\', \'index\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "load_weights"
- argspec: "args=[\'self\', \'filepath\', \'by_name\'], varargs=None, keywords=None, defaults=[\'False\'], "
- }
- member_method {
- name: "pop"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\', \'steps\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
- }
- member_method {
- name: "predict_classes"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], "
- }
- member_method {
- name: "predict_generator"
- argspec: "args=[\'self\', \'generator\', \'steps\', \'max_queue_size\', \'workers\', \'use_multiprocessing\', \'verbose\'], varargs=None, keywords=None, defaults=[\'None\', \'10\', \'1\', \'False\', \'0\'], "
- }
- member_method {
- name: "predict_on_batch"
- argspec: "args=[\'self\', \'x\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "predict_proba"
- argspec: "args=[\'self\', \'x\', \'batch_size\', \'verbose\'], varargs=None, keywords=None, defaults=[\'32\', \'0\'], "
- }
- member_method {
- name: "reset_states"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "save"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'include_optimizer\'], varargs=None, keywords=None, defaults=[\'True\', \'True\'], "
- }
- member_method {
- name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "summary"
- argspec: "args=[\'self\', \'line_length\', \'positions\', \'print_fn\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "symbolic_set_inputs"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "test_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
- }
- member_method {
- name: "to_json"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "to_yaml"
- argspec: "args=[\'self\'], varargs=None, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "train_on_batch"
- argspec: "args=[\'self\', \'x\', \'y\', \'sample_weight\', \'class_weight\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
-}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt
index 537e73aa89..47b5b56faf 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.-experimental.pbtxt
@@ -8,5 +8,11 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_INT64
}
+ field {
+ name: "use_run_handler_pool"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
}
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt
index cec04a2bf0..c0c2e7b9f8 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.-run-options.pbtxt
@@ -55,6 +55,12 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_INT64
}
+ field {
+ name: "use_run_handler_pool"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
}
enum_type {
name: "TraceLevel"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
index 825afb622f..8b7f63e43e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-dataset.pbtxt
@@ -79,6 +79,10 @@ tf_class {
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "padded_batch"
argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
@@ -119,6 +123,10 @@ tf_class {
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
}
member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
index cdad5f6360..a7bfa82c65 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -80,6 +80,10 @@ tf_class {
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "padded_batch"
argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
@@ -120,6 +124,10 @@ tf_class {
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
}
member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt
new file mode 100644
index 0000000000..d15dccc173
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-options.pbtxt
@@ -0,0 +1,57 @@
+path: "tensorflow.data.Options"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Options\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "experimental_autotune"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_filter_fusion"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_hoist_random_uniform"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_latency_all_edges"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_map_and_batch_fusion"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_map_and_filter_fusion"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_map_fusion"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_map_parallelization"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_map_vectorization"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_noop_elimination"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_shuffle_and_repeat_fusion"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "merge"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
index df41bff1b5..7b7a9ebaf0 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -80,6 +80,10 @@ tf_class {
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "padded_batch"
argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
@@ -120,6 +124,10 @@ tf_class {
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
}
member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
index 028bcc2ce9..2817f900e1 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.-text-line-dataset.pbtxt
@@ -80,6 +80,10 @@ tf_class {
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "padded_batch"
argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
@@ -120,6 +124,10 @@ tf_class {
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
}
member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt
new file mode 100644
index 0000000000..03c16cda8b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt
@@ -0,0 +1,30 @@
+path: "tensorflow.data.experimental.CheckpointInputPipelineHook"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.iterator_ops.CheckpointInputPipelineHook\'>"
+ is_instance: "<class \'tensorflow.python.training.session_run_hook.SessionRunHook\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'estimator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "after_create_session"
+ argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "after_run"
+ argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "before_run"
+ argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "begin"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "end"
+ argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt
new file mode 100644
index 0000000000..3eeaa1b185
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.data.experimental.CsvDataset.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt
new file mode 100644
index 0000000000..2520e28a3c
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-csv-dataset.pbtxt
@@ -0,0 +1,135 @@
+path: "tensorflow.data.experimental.CsvDataset"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.CsvDataset\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "output_classes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shapes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_types"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'filenames\', \'record_defaults\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \',\', \'True\', \'\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "batch"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "cache"
+ argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
+ }
+ member_method {
+ name: "concatenate"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "filter"
+ argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flat_map"
+ argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_generator"
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_sparse_tensor_slices"
+ argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensor_slices"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensors"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "interleave"
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+ }
+ member_method {
+ name: "list_files"
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "make_initializable_iterator"
+ argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "make_one_shot_iterator"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map"
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "padded_batch"
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "prefetch"
+ argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "range"
+ argspec: "args=[], varargs=args, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "repeat"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "skip"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "take"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zip"
+ argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optional.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optional.pbtxt
new file mode 100644
index 0000000000..b4c9459098
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-optional.pbtxt
@@ -0,0 +1,28 @@
+path: "tensorflow.data.experimental.Optional"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.ops.optional_ops.Optional\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "value_structure"
+ mtype: "<class \'abc.abstractproperty\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "from_value"
+ argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_value"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "has_value"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "none_from_structure"
+ argspec: "args=[\'value_structure\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt
new file mode 100644
index 0000000000..2991b12f64
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.data.experimental.RandomDataset.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt
new file mode 100644
index 0000000000..1dd53b1eab
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-random-dataset.pbtxt
@@ -0,0 +1,135 @@
+path: "tensorflow.data.experimental.RandomDataset"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.random_ops.RandomDataset\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "output_classes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shapes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_types"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "batch"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "cache"
+ argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
+ }
+ member_method {
+ name: "concatenate"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "filter"
+ argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flat_map"
+ argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_generator"
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_sparse_tensor_slices"
+ argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensor_slices"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensors"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "interleave"
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+ }
+ member_method {
+ name: "list_files"
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "make_initializable_iterator"
+ argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "make_one_shot_iterator"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map"
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "padded_batch"
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "prefetch"
+ argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "range"
+ argspec: "args=[], varargs=args, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "repeat"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "skip"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "take"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zip"
+ argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-reducer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-reducer.pbtxt
new file mode 100644
index 0000000000..6b477a8a72
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-reducer.pbtxt
@@ -0,0 +1,21 @@
+path: "tensorflow.data.experimental.Reducer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.grouping.Reducer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "finalize_func"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "init_func"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "reduce_func"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'init_func\', \'reduce_func\', \'finalize_func\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt
new file mode 100644
index 0000000000..948e99ef86
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.data.experimental.SqlDataset.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt
new file mode 100644
index 0000000000..8fdd9dc52e
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-sql-dataset.pbtxt
@@ -0,0 +1,135 @@
+path: "tensorflow.data.experimental.SqlDataset"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.SqlDataset\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "output_classes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shapes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_types"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'driver_name\', \'data_source_name\', \'query\', \'output_types\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "batch"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "cache"
+ argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
+ }
+ member_method {
+ name: "concatenate"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "filter"
+ argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flat_map"
+ argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_generator"
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_sparse_tensor_slices"
+ argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensor_slices"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensors"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "interleave"
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+ }
+ member_method {
+ name: "list_files"
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "make_initializable_iterator"
+ argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "make_one_shot_iterator"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map"
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "padded_batch"
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "prefetch"
+ argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "range"
+ argspec: "args=[], varargs=args, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "repeat"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "skip"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "take"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zip"
+ argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-stats-aggregator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-stats-aggregator.pbtxt
new file mode 100644
index 0000000000..0bcc8cf3e8
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-stats-aggregator.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.data.experimental.StatsAggregator"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_ops.StatsAggregator\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_summary"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-t-f-record-writer.pbtxt
new file mode 100644
index 0000000000..6f9d18a701
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.-t-f-record-writer.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.data.experimental.TFRecordWriter"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.writers.TFRecordWriter\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'filename\', \'compression_type\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "write"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt
new file mode 100644
index 0000000000..b14585f8d7
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.experimental.pbtxt
@@ -0,0 +1,139 @@
+path: "tensorflow.data.experimental"
+tf_module {
+ member {
+ name: "CheckpointInputPipelineHook"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "CsvDataset"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "Optional"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "RandomDataset"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "Reducer"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "SqlDataset"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "StatsAggregator"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TFRecordWriter"
+ mtype: "<type \'type\'>"
+ }
+ member_method {
+ name: "Counter"
+ argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
+ }
+ member_method {
+ name: "bucket_by_sequence_length"
+ argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\'], "
+ }
+ member_method {
+ name: "choose_from_datasets"
+ argspec: "args=[\'datasets\', \'choice_dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "copy_to_device"
+ argspec: "args=[\'target_device\', \'source_device\'], varargs=None, keywords=None, defaults=[\'/cpu:0\'], "
+ }
+ member_method {
+ name: "dense_to_sparse_batch"
+ argspec: "args=[\'batch_size\', \'row_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "enumerate_dataset"
+ argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
+ }
+ member_method {
+ name: "get_next_as_optional"
+ argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_single_element"
+ argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "group_by_reducer"
+ argspec: "args=[\'key_func\', \'reducer\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "group_by_window"
+ argspec: "args=[\'key_func\', \'reduce_func\', \'window_size\', \'window_size_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "ignore_errors"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "latency_stats"
+ argspec: "args=[\'tag\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "make_batched_features_dataset"
+ argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\"<class \'tensorflow.python.data.ops.readers.TFRecordDataset\'>\", \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'-1\', \'1\', \'2\', \'False\', \'False\'], "
+ }
+ member_method {
+ name: "make_csv_dataset"
+ argspec: "args=[\'file_pattern\', \'batch_size\', \'column_names\', \'column_defaults\', \'label_name\', \'select_columns\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'header\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'num_parallel_reads\', \'sloppy\', \'num_rows_for_inference\', \'compression_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \',\', \'True\', \'\', \'True\', \'None\', \'True\', \'10000\', \'None\', \'-1\', \'1\', \'False\', \'100\', \'None\'], "
+ }
+ member_method {
+ name: "make_saveable_from_iterator"
+ argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map_and_batch"
+ argspec: "args=[\'map_func\', \'batch_size\', \'num_parallel_batches\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "parallel_interleave"
+ argspec: "args=[\'map_func\', \'cycle_length\', \'block_length\', \'sloppy\', \'buffer_output_elements\', \'prefetch_input_elements\'], varargs=None, keywords=None, defaults=[\'1\', \'False\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "parse_example_dataset"
+ argspec: "args=[\'features\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ }
+ member_method {
+ name: "prefetch_to_device"
+ argspec: "args=[\'device\', \'buffer_size\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "rejection_resample"
+ argspec: "args=[\'class_func\', \'target_dist\', \'initial_dist\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "sample_from_datasets"
+ argspec: "args=[\'datasets\', \'weights\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "scan"
+ argspec: "args=[\'initial_state\', \'scan_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_stats_aggregator"
+ argspec: "args=[\'stats_aggregator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle_and_repeat"
+ argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "unbatch"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "unique"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.data.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.data.pbtxt
index 56fb270a49..3023276a1d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.data.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.data.pbtxt
@@ -13,6 +13,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "Options"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "TFRecordDataset"
mtype: "<class \'abc.ABCMeta\'>"
}
@@ -20,4 +24,8 @@ tf_module {
name: "TextLineDataset"
mtype: "<class \'abc.ABCMeta\'>"
}
+ member {
+ name: "experimental"
+ mtype: "<type \'module\'>"
+ }
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.debugging.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.debugging.pbtxt
index d9efe97821..ab6287f8cd 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.debugging.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.debugging.pbtxt
@@ -1,6 +1,90 @@
path: "tensorflow.debugging"
tf_module {
member_method {
+ name: "Assert"
+ argspec: "args=[\'condition\', \'data\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_all_finite"
+ argspec: "args=[\'t\', \'msg\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "assert_equal"
+ argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_greater"
+ argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_greater_equal"
+ argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_integer"
+ argspec: "args=[\'x\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_less"
+ argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_less_equal"
+ argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_near"
+ argspec: "args=[\'x\', \'y\', \'rtol\', \'atol\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_negative"
+ argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_non_negative"
+ argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_non_positive"
+ argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_none_equal"
+ argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_positive"
+ argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_proper_iterable"
+ argspec: "args=[\'values\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "assert_rank"
+ argspec: "args=[\'x\', \'rank\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_rank_at_least"
+ argspec: "args=[\'x\', \'rank\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_rank_in"
+ argspec: "args=[\'x\', \'ranks\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_same_float_dtype"
+ argspec: "args=[\'tensors\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_scalar"
+ argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "assert_type"
+ argspec: "args=[\'tensor\', \'tf_type\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "check_numerics"
argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -16,4 +100,16 @@ tf_module {
name: "is_nan"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
+ member_method {
+ name: "is_non_decreasing"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "is_numeric_tensor"
+ argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_strictly_increasing"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.dtypes.-d-type.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.dtypes.-d-type.pbtxt
new file mode 100644
index 0000000000..423eca32a2
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.dtypes.-d-type.pbtxt
@@ -0,0 +1,77 @@
+path: "tensorflow.dtypes.DType"
+tf_class {
+ is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "as_datatype_enum"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "as_numpy_dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "base_dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_bool"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_complex"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_floating"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_integer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_numpy_compatible"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_quantized"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_unsigned"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "limits"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "max"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "min"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "real_dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "size"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_compatible_with"
+ argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.dtypes.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.dtypes.pbtxt
index 98e1feed00..ea23feca84 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.dtypes.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.dtypes.pbtxt
@@ -1,7 +1,27 @@
path: "tensorflow.dtypes"
tf_module {
+ member {
+ name: "DType"
+ mtype: "<type \'type\'>"
+ }
+ member_method {
+ name: "as_dtype"
+ argspec: "args=[\'type_value\'], varargs=None, keywords=None, defaults=None"
+ }
member_method {
name: "as_string"
argspec: "args=[\'input\', \'precision\', \'scientific\', \'shortest\', \'width\', \'fill\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'False\', \'False\', \'-1\', \'\', \'None\'], "
}
+ member_method {
+ name: "cast"
+ argspec: "args=[\'x\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "complex"
+ argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "saturate_cast"
+ argspec: "args=[\'value\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.graph_util.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.graph_util.pbtxt
index eeabf845dc..162ee76ee7 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.graph_util.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.graph_util.pbtxt
@@ -9,6 +9,10 @@ tf_module {
argspec: "args=[\'graph_def\', \'dest_nodes\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "import_graph_def"
+ argspec: "args=[\'graph_def\', \'input_map\', \'return_elements\', \'name\', \'op_dict\', \'producer_op_list\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "must_run_on_cpu"
argspec: "args=[\'node\', \'pin_variables_on_cpu\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt
index 5c46dc5ee7..0a231f1b65 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.image.pbtxt
@@ -149,6 +149,10 @@ tf_module {
argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "random_crop"
+ argspec: "args=[\'value\', \'size\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "random_flip_left_right"
argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt
index d499c67d89..19ca62122e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.initializers.pbtxt
@@ -73,6 +73,10 @@ tf_module {
argspec: "args=[], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "tables_initializer"
+ argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], "
+ }
+ member_method {
name: "variables"
argspec: "args=[\'var_list\', \'name\'], varargs=None, keywords=None, defaults=[\'init\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.-fixed-len-feature.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.-fixed-len-feature.pbtxt
new file mode 100644
index 0000000000..cd0e51c8c7
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.-fixed-len-feature.pbtxt
@@ -0,0 +1,27 @@
+path: "tensorflow.io.FixedLenFeature"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenFeature\'>"
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenFeature\'>"
+ is_instance: "<type \'tuple\'>"
+ member {
+ name: "default_value"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "count"
+ }
+ member_method {
+ name: "index"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.-fixed-len-sequence-feature.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.-fixed-len-sequence-feature.pbtxt
new file mode 100644
index 0000000000..8a38f25fdf
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.-fixed-len-sequence-feature.pbtxt
@@ -0,0 +1,31 @@
+path: "tensorflow.io.FixedLenSequenceFeature"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenSequenceFeature\'>"
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenSequenceFeature\'>"
+ is_instance: "<type \'tuple\'>"
+ member {
+ name: "allow_missing"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "default_value"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "count"
+ }
+ member_method {
+ name: "index"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.-padding-f-i-f-o-queue.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.-padding-f-i-f-o-queue.pbtxt
new file mode 100644
index 0000000000..85306fdcac
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.-padding-f-i-f-o-queue.pbtxt
@@ -0,0 +1,66 @@
+path: "tensorflow.io.PaddingFIFOQueue"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.PaddingFIFOQueue\'>"
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.QueueBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "dtypes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "names"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "queue_ref"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shapes"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'capacity\', \'dtypes\', \'shapes\', \'names\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'padding_fifo_queue\'], "
+ }
+ member_method {
+ name: "close"
+ argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "dequeue"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_many"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_up_to"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue_many"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "from_list"
+ argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_closed"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "size"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.-priority-queue.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.-priority-queue.pbtxt
new file mode 100644
index 0000000000..02d8037b34
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.-priority-queue.pbtxt
@@ -0,0 +1,66 @@
+path: "tensorflow.io.PriorityQueue"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.PriorityQueue\'>"
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.QueueBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "dtypes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "names"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "queue_ref"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shapes"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'capacity\', \'types\', \'shapes\', \'names\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'priority_queue\'], "
+ }
+ member_method {
+ name: "close"
+ argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "dequeue"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_many"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_up_to"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue_many"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "from_list"
+ argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_closed"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "size"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.-queue-base.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.-queue-base.pbtxt
new file mode 100644
index 0000000000..a30481a0ea
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.-queue-base.pbtxt
@@ -0,0 +1,65 @@
+path: "tensorflow.io.QueueBase"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.QueueBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "dtypes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "names"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "queue_ref"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shapes"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'dtypes\', \'shapes\', \'names\', \'queue_ref\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "close"
+ argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "dequeue"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_many"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_up_to"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue_many"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "from_list"
+ argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_closed"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "size"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.-random-shuffle-queue.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.-random-shuffle-queue.pbtxt
new file mode 100644
index 0000000000..82cbf9884f
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.-random-shuffle-queue.pbtxt
@@ -0,0 +1,66 @@
+path: "tensorflow.io.RandomShuffleQueue"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.RandomShuffleQueue\'>"
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.QueueBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "dtypes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "names"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "queue_ref"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shapes"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'capacity\', \'min_after_dequeue\', \'dtypes\', \'shapes\', \'names\', \'seed\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'random_shuffle_queue\'], "
+ }
+ member_method {
+ name: "close"
+ argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "dequeue"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_many"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_up_to"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue_many"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "from_list"
+ argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_closed"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "size"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.-sparse-feature.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.-sparse-feature.pbtxt
new file mode 100644
index 0000000000..216947b4ed
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.-sparse-feature.pbtxt
@@ -0,0 +1,35 @@
+path: "tensorflow.io.SparseFeature"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.SparseFeature\'>"
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.SparseFeature\'>"
+ is_instance: "<type \'tuple\'>"
+ member {
+ name: "already_sorted"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "index_key"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "value_key"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "count"
+ }
+ member_method {
+ name: "index"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-compression-type.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-compression-type.pbtxt
new file mode 100644
index 0000000000..b598f73d7e
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-compression-type.pbtxt
@@ -0,0 +1,20 @@
+path: "tensorflow.io.TFRecordCompressionType"
+tf_class {
+ is_instance: "<class \'tensorflow.python.lib.io.tf_record.TFRecordCompressionType\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "GZIP"
+ mtype: "<type \'int\'>"
+ }
+ member {
+ name: "NONE"
+ mtype: "<type \'int\'>"
+ }
+ member {
+ name: "ZLIB"
+ mtype: "<type \'int\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-options.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-options.pbtxt
new file mode 100644
index 0000000000..bfbf37ccf4
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-options.pbtxt
@@ -0,0 +1,17 @@
+path: "tensorflow.io.TFRecordOptions"
+tf_class {
+ is_instance: "<class \'tensorflow.python.lib.io.tf_record.TFRecordOptions\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "compression_type_map"
+ mtype: "<type \'dict\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'compression_type\', \'flush_mode\', \'input_buffer_size\', \'output_buffer_size\', \'window_bits\', \'compression_level\', \'compression_method\', \'mem_level\', \'compression_strategy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "get_compression_type_string"
+ argspec: "args=[\'cls\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-writer.pbtxt
new file mode 100644
index 0000000000..6fd443f6d7
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.-t-f-record-writer.pbtxt
@@ -0,0 +1,21 @@
+path: "tensorflow.io.TFRecordWriter"
+tf_class {
+ is_instance: "<class \'tensorflow.python.lib.io.tf_record.TFRecordWriter\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "close"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flush"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "write"
+ argspec: "args=[\'self\', \'record\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.-var-len-feature.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.-var-len-feature.pbtxt
new file mode 100644
index 0000000000..fd835dbfbb
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.-var-len-feature.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.io.VarLenFeature"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.VarLenFeature\'>"
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.VarLenFeature\'>"
+ is_instance: "<type \'tuple\'>"
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "count"
+ }
+ member_method {
+ name: "index"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt
index 8938cf217b..dccf136788 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.io.pbtxt
@@ -1,5 +1,49 @@
path: "tensorflow.io"
tf_module {
+ member {
+ name: "FixedLenFeature"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "FixedLenSequenceFeature"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "PaddingFIFOQueue"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "PriorityQueue"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "QueueBase"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "RandomShuffleQueue"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "SparseFeature"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TFRecordCompressionType"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TFRecordOptions"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TFRecordWriter"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "VarLenFeature"
+ mtype: "<type \'type\'>"
+ }
member_method {
name: "decode_base64"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@@ -9,6 +53,10 @@ tf_module {
argspec: "args=[\'bytes\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
+ name: "decode_csv"
+ argspec: "args=[\'records\', \'record_defaults\', \'field_delim\', \'use_quote_delim\', \'name\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\',\', \'True\', \'None\', \'\', \'None\'], "
+ }
+ member_method {
name: "decode_json_example"
argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -17,18 +65,38 @@ tf_module {
argspec: "args=[\'bytes\', \'out_type\', \'little_endian\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
+ name: "deserialize_many_sparse"
+ argspec: "args=[\'serialized_sparse\', \'dtype\', \'rank\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "encode_base64"
argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
+ name: "match_filenames_once"
+ argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "matching_files"
argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "parse_example"
+ argspec: "args=[\'serialized\', \'features\', \'name\', \'example_names\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "parse_sequence_example"
argspec: "args=[\'serialized\', \'context_features\', \'sequence_features\', \'example_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "parse_single_example"
+ argspec: "args=[\'serialized\', \'features\', \'name\', \'example_names\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "parse_single_sequence_example"
+ argspec: "args=[\'serialized\', \'context_features\', \'sequence_features\', \'example_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "parse_tensor"
argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -37,7 +105,23 @@ tf_module {
argspec: "args=[\'filename\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "serialize_many_sparse"
+ argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'string\'>\"], "
+ }
+ member_method {
+ name: "serialize_sparse"
+ argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'string\'>\"], "
+ }
+ member_method {
+ name: "tf_record_iterator"
+ argspec: "args=[\'path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "write_file"
argspec: "args=[\'filename\', \'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
+ member_method {
+ name: "write_graph"
+ argspec: "args=[\'graph_or_graph_def\', \'logdir\', \'name\', \'as_text\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt
index 126ce8db6a..a71a59e269 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.backend.pbtxt
@@ -398,7 +398,7 @@ tf_module {
}
member_method {
name: "rnn"
- argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\'], "
+ argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "round"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt
index 2b6e8af11d..68b6678d48 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -86,7 +86,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\'], "
+ argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'time_major\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt
index df74c32e1f..0c24e9c7dd 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.layers.pbtxt
@@ -122,7 +122,7 @@ tf_module {
}
member_method {
name: "flatten"
- argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'inputs\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'channels_last\'], "
}
member_method {
name: "max_pooling1d"
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
index d979116887..6ac95d96da 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.linalg.pbtxt
@@ -109,10 +109,18 @@ tf_module {
argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
+ name: "global_norm"
+ argspec: "args=[\'t_list\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "inv"
argspec: "args=[\'input\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
+ name: "l2_normalize"
+ argspec: "args=[\'x\', \'axis\', \'epsilon\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'1e-12\', \'None\', \'None\'], "
+ }
+ member_method {
name: "logdet"
argspec: "args=[\'matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -125,6 +133,10 @@ tf_module {
argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], "
}
member_method {
+ name: "matmul"
+ argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "
+ }
+ member_method {
name: "norm"
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
index 72856466ec..459b9e3684 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.math.pbtxt
@@ -1,6 +1,14 @@
path: "tensorflow.math"
tf_module {
member_method {
+ name: "abs"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "accumulate_n"
+ argspec: "args=[\'inputs\', \'shape\', \'tensor_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "acos"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -13,6 +21,22 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "add_n"
+ argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "angle"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "argmax"
+ argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\', \'output_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
+ }
+ member_method {
+ name: "argmin"
+ argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\', \'output_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
+ }
+ member_method {
name: "asin"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -53,10 +77,18 @@ tf_module {
argspec: "args=[\'a\', \'b\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "bincount"
+ argspec: "args=[\'arr\', \'weights\', \'minlength\', \'maxlength\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int32\'>\"], "
+ }
+ member_method {
name: "ceil"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "conj"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "cos"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -65,14 +97,34 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "count_nonzero"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'dtype\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"<dtype: \'int64\'>\", \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "cumprod"
+ argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "cumsum"
+ argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
+ }
+ member_method {
name: "digamma"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "divide"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "equal"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "erf"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "erfc"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -89,6 +141,10 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "floordiv"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "greater"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -105,10 +161,26 @@ tf_module {
argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "imag"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "in_top_k"
+ argspec: "args=[\'predictions\', \'targets\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "invert_permutation"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "l2_normalize"
+ argspec: "args=[\'x\', \'axis\', \'epsilon\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'1e-12\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "lbeta"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "less"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -129,6 +201,14 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "log_sigmoid"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "log_softmax"
+ argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "logical_and"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -141,6 +221,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "logical_xor"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'LogicalXor\'], "
+ }
+ member_method {
name: "maximum"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -149,6 +233,14 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "multiply"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "negative"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "not_equal"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -161,18 +253,66 @@ tf_module {
argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "pow"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "real"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "reciprocal"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "reduce_all"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_any"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_logsumexp"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_max"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_mean"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_min"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_prod"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_sum"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "rint"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "round"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "rsqrt"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "scalar_mul"
+ argspec: "args=[\'scalar\', \'x\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "segment_max"
argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -193,6 +333,14 @@ tf_module {
argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "sigmoid"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "sign"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "sin"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -201,6 +349,10 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "softmax"
+ argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "softplus"
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -209,18 +361,46 @@ tf_module {
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "sqrt"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "square"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "squared_difference"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "subtract"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "tan"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "tanh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "top_k"
+ argspec: "args=[\'input\', \'k\', \'sorted\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "truediv"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "unsorted_segment_max"
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "unsorted_segment_mean"
+ argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "unsorted_segment_min"
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -229,6 +409,10 @@ tf_module {
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "unsorted_segment_sqrt_n"
+ argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "unsorted_segment_sum"
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -241,6 +425,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "zero_fraction"
+ argspec: "args=[\'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "zeta"
argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt
index d9e5b0d0fc..9b28ce5746 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.nn.pbtxt
@@ -101,6 +101,10 @@ tf_module {
argspec: "args=[\'labels\', \'inputs\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'True\'], "
}
member_method {
+ name: "depth_to_space"
+ argspec: "args=[\'input\', \'block_size\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'NHWC\'], "
+ }
+ member_method {
name: "depthwise_conv2d"
argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'rate\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
@@ -305,6 +309,14 @@ tf_module {
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "space_to_batch"
+ argspec: "args=[\'input\', \'paddings\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "space_to_depth"
+ argspec: "args=[\'input\', \'block_size\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'NHWC\'], "
+ }
+ member_method {
name: "sparse_softmax_cross_entropy_with_logits"
argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
index 509ceff9df..a268529c1f 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt
@@ -497,6 +497,10 @@ tf_module {
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
+ name: "random"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "random_normal_initializer"
mtype: "<type \'type\'>"
}
@@ -1745,6 +1749,10 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "roll"
+ argspec: "args=[\'input\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "round"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt
index 6d865efed0..77c92aeb0d 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.quantization.pbtxt
@@ -29,6 +29,10 @@ tf_module {
argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], "
}
member_method {
+ name: "quantize"
+ argspec: "args=[\'input\', \'min_range\', \'max_range\', \'T\', \'mode\', \'round_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'HALF_AWAY_FROM_ZERO\', \'None\'], "
+ }
+ member_method {
name: "quantized_concat"
argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.random.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.random.pbtxt
new file mode 100644
index 0000000000..a568dd4cd8
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.random.pbtxt
@@ -0,0 +1,47 @@
+path: "tensorflow.random"
+tf_module {
+ member_method {
+ name: "gamma"
+ argspec: "args=[\'shape\', \'alpha\', \'beta\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\', \'None\'], "
+ }
+ member_method {
+ name: "get_seed"
+ argspec: "args=[\'op_seed\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "log_uniform_candidate_sampler"
+ argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "multinomial"
+ argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'name\', \'output_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "normal"
+ argspec: "args=[\'shape\', \'mean\', \'stddev\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\', \'None\'], "
+ }
+ member_method {
+ name: "poisson"
+ argspec: "args=[\'lam\', \'shape\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\', \'None\'], "
+ }
+ member_method {
+ name: "set_random_seed"
+ argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'value\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "truncated_normal"
+ argspec: "args=[\'shape\', \'mean\', \'stddev\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\', \'None\'], "
+ }
+ member_method {
+ name: "uniform"
+ argspec: "args=[\'shape\', \'minval\', \'maxval\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \"<dtype: \'float32\'>\", \'None\', \'None\'], "
+ }
+ member_method {
+ name: "uniform_candidate_sampler"
+ argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-builder.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-builder.pbtxt
new file mode 100644
index 0000000000..67457de070
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.-builder.pbtxt
@@ -0,0 +1,21 @@
+path: "tensorflow.saved_model.Builder"
+tf_class {
+ is_instance: "<class \'tensorflow.python.saved_model.builder_impl.SavedModelBuilder\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'export_dir\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "add_meta_graph"
+ argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "add_meta_graph_and_variables"
+ argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "save"
+ argspec: "args=[\'self\', \'as_text\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt
index e1a0385092..3f4965fc69 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.saved_model.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.saved_model"
tf_module {
member {
+ name: "Builder"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "builder"
mtype: "<type \'module\'>"
}
@@ -33,6 +37,46 @@ tf_module {
mtype: "<type \'module\'>"
}
member_method {
+ name: "build_signature_def"
+ argspec: "args=[\'inputs\', \'outputs\', \'method_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "build_tensor_info"
+ argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "classification_signature_def"
+ argspec: "args=[\'examples\', \'classes\', \'scores\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_tensor_from_tensor_info"
+ argspec: "args=[\'tensor_info\', \'graph\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "is_valid_signature"
+ argspec: "args=[\'signature_def\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "load"
+ argspec: "args=[\'sess\', \'tags\', \'export_dir\', \'import_scope\'], varargs=None, keywords=saver_kwargs, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "main_op_with_restore"
+ argspec: "args=[\'restore_op_name\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "maybe_saved_model_directory"
+ argspec: "args=[\'export_dir\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "predict_signature_def"
+ argspec: "args=[\'inputs\', \'outputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "regression_signature_def"
+ argspec: "args=[\'examples\', \'predictions\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "simple_save"
argspec: "args=[\'session\', \'export_dir\', \'inputs\', \'outputs\', \'legacy_init_op\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-conditional-accumulator.pbtxt
new file mode 100644
index 0000000000..cd97716c9d
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-conditional-accumulator.pbtxt
@@ -0,0 +1,46 @@
+path: "tensorflow.sparse.SparseConditionalAccumulator"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.SparseConditionalAccumulator\'>"
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.ConditionalAccumulatorBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "accumulator_ref"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
+ }
+ member_method {
+ name: "apply_grad"
+ argspec: "args=[\'self\', \'grad_indices\', \'grad_values\', \'grad_shape\', \'local_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ }
+ member_method {
+ name: "apply_indexed_slices_grad"
+ argspec: "args=[\'self\', \'grad\', \'local_step\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
+ }
+ member_method {
+ name: "num_accumulated"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "set_global_step"
+ argspec: "args=[\'self\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "take_grad"
+ argspec: "args=[\'self\', \'num_required\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "take_indexed_slices_grad"
+ argspec: "args=[\'self\', \'num_required\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt
new file mode 100644
index 0000000000..02e59a63e1
--- /dev/null
+++ b/tensorflow/tools/api/golden/v1/tensorflow.sparse.-sparse-tensor.pbtxt
@@ -0,0 +1,54 @@
+path: "tensorflow.sparse.SparseTensor"
+tf_class {
+ is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>"
+ is_instance: "<class \'tensorflow.python.framework.ops._TensorLike\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "dense_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "indices"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "op"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "values"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "consumers"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "eval"
+ argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_value"
+ argspec: "args=[\'cls\', \'sparse_tensor_value\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_shape"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt
index ba9e651b34..32bd8d5f8e 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.sparse.pbtxt
@@ -1,5 +1,21 @@
path: "tensorflow.sparse"
tf_module {
+ member {
+ name: "SparseConditionalAccumulator"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "SparseTensor"
+ mtype: "<type \'type\'>"
+ }
+ member_method {
+ name: "add"
+ argspec: "args=[\'a\', \'b\', \'thresh\'], varargs=None, keywords=None, defaults=[\'0\'], "
+ }
+ member_method {
+ name: "concat"
+ argspec: "args=[\'axis\', \'sp_inputs\', \'name\', \'expand_nonconcat_dim\', \'concat_dim\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
member_method {
name: "cross"
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@@ -16,4 +32,100 @@ tf_module {
name: "eye"
argspec: "args=[\'num_rows\', \'num_columns\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
+ member_method {
+ name: "fill_empty_rows"
+ argspec: "args=[\'sp_input\', \'default_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "mask"
+ argspec: "args=[\'a\', \'mask_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "matmul"
+ argspec: "args=[\'sp_a\', \'b\', \'adjoint_a\', \'adjoint_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "maximum"
+ argspec: "args=[\'sp_a\', \'sp_b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "merge"
+ argspec: "args=[\'sp_ids\', \'sp_values\', \'vocab_size\', \'name\', \'already_sorted\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "minimum"
+ argspec: "args=[\'sp_a\', \'sp_b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "placeholder"
+ argspec: "args=[\'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_max"
+ argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_max_sparse"
+ argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_sum"
+ argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_sum_sparse"
+ argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reorder"
+ argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "reset_shape"
+ argspec: "args=[\'sp_input\', \'new_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "reshape"
+ argspec: "args=[\'sp_input\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "retain"
+ argspec: "args=[\'sp_input\', \'to_retain\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "segment_mean"
+ argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "segment_sqrt_n"
+ argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "segment_sum"
+ argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "slice"
+ argspec: "args=[\'sp_input\', \'start\', \'size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "softmax"
+ argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "split"
+ argspec: "args=[\'keyword_required\', \'sp_input\', \'num_split\', \'axis\', \'name\', \'split_dim\'], varargs=None, keywords=None, defaults=[\'KeywordRequired()\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "to_dense"
+ argspec: "args=[\'sp_input\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "to_indicator"
+ argspec: "args=[\'sp_input\', \'vocab_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "transpose"
+ argspec: "args=[\'sp_input\', \'perm\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
index 312e94b41d..ebdaf57231 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.strings.pbtxt
@@ -13,6 +13,10 @@ tf_module {
argspec: "args=[\'input\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
+ name: "reduce_join"
+ argspec: "args=[\'inputs\', \'axis\', \'keep_dims\', \'separator\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'\', \'None\', \'None\'], "
+ }
+ member_method {
name: "regex_full_match"
argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt
index 9f35395284..45c81fdd3b 100644
--- a/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/v1/tensorflow.train.pbtxt
@@ -273,6 +273,10 @@ tf_module {
argspec: "args=[\'checkpoint_prefix\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "confusion_matrix"
+ argspec: "args=[\'labels\', \'predictions\', \'num_classes\', \'dtype\', \'name\', \'weights\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'int32\'>\", \'None\', \'None\'], "
+ }
+ member_method {
name: "cosine_decay"
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt
index 537e73aa89..47b5b56faf 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.-experimental.pbtxt
@@ -8,5 +8,11 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_INT64
}
+ field {
+ name: "use_run_handler_pool"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
}
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt
index cec04a2bf0..c0c2e7b9f8 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.-run-options.pbtxt
@@ -55,6 +55,12 @@ tf_proto {
label: LABEL_OPTIONAL
type: TYPE_INT64
}
+ field {
+ name: "use_run_handler_pool"
+ number: 2
+ label: LABEL_OPTIONAL
+ type: TYPE_BOOL
+ }
}
enum_type {
name: "TraceLevel"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
index 825afb622f..8b7f63e43e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-dataset.pbtxt
@@ -79,6 +79,10 @@ tf_class {
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "padded_batch"
argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
@@ -119,6 +123,10 @@ tf_class {
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
}
member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
index cdad5f6360..a7bfa82c65 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-fixed-length-record-dataset.pbtxt
@@ -80,6 +80,10 @@ tf_class {
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "padded_batch"
argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
@@ -120,6 +124,10 @@ tf_class {
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
}
member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt
new file mode 100644
index 0000000000..d15dccc173
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-options.pbtxt
@@ -0,0 +1,57 @@
+path: "tensorflow.data.Options"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Options\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "experimental_autotune"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_filter_fusion"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_hoist_random_uniform"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_latency_all_edges"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_map_and_batch_fusion"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_map_and_filter_fusion"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_map_fusion"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_map_parallelization"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_map_vectorization"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_noop_elimination"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "experimental_shuffle_and_repeat_fusion"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "merge"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
index df41bff1b5..7b7a9ebaf0 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-t-f-record-dataset.pbtxt
@@ -80,6 +80,10 @@ tf_class {
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "padded_batch"
argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
@@ -120,6 +124,10 @@ tf_class {
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
}
member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
index 028bcc2ce9..2817f900e1 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.-text-line-dataset.pbtxt
@@ -80,6 +80,10 @@ tf_class {
argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "padded_batch"
argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
}
@@ -120,6 +124,10 @@ tf_class {
argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
}
member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "zip"
argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt
new file mode 100644
index 0000000000..03c16cda8b
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-checkpoint-input-pipeline-hook.pbtxt
@@ -0,0 +1,30 @@
+path: "tensorflow.data.experimental.CheckpointInputPipelineHook"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.iterator_ops.CheckpointInputPipelineHook\'>"
+ is_instance: "<class \'tensorflow.python.training.session_run_hook.SessionRunHook\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'estimator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "after_create_session"
+ argspec: "args=[\'self\', \'session\', \'coord\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "after_run"
+ argspec: "args=[\'self\', \'run_context\', \'run_values\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "before_run"
+ argspec: "args=[\'self\', \'run_context\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "begin"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "end"
+ argspec: "args=[\'self\', \'session\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt
new file mode 100644
index 0000000000..3eeaa1b185
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.data.experimental.CsvDataset.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
new file mode 100644
index 0000000000..2520e28a3c
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-csv-dataset.pbtxt
@@ -0,0 +1,135 @@
+path: "tensorflow.data.experimental.CsvDataset"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.CsvDataset\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "output_classes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shapes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_types"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'filenames\', \'record_defaults\', \'compression_type\', \'buffer_size\', \'header\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \',\', \'True\', \'\', \'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "batch"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "cache"
+ argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
+ }
+ member_method {
+ name: "concatenate"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "filter"
+ argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flat_map"
+ argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_generator"
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_sparse_tensor_slices"
+ argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensor_slices"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensors"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "interleave"
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+ }
+ member_method {
+ name: "list_files"
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "make_initializable_iterator"
+ argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "make_one_shot_iterator"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map"
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "padded_batch"
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "prefetch"
+ argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "range"
+ argspec: "args=[], varargs=args, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "repeat"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "skip"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "take"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zip"
+ argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optional.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optional.pbtxt
new file mode 100644
index 0000000000..b4c9459098
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-optional.pbtxt
@@ -0,0 +1,28 @@
+path: "tensorflow.data.experimental.Optional"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.ops.optional_ops.Optional\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "value_structure"
+ mtype: "<class \'abc.abstractproperty\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "from_value"
+ argspec: "args=[\'value\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_value"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "has_value"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "none_from_structure"
+ argspec: "args=[\'value_structure\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt
new file mode 100644
index 0000000000..2991b12f64
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.data.experimental.RandomDataset.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
new file mode 100644
index 0000000000..1dd53b1eab
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-random-dataset.pbtxt
@@ -0,0 +1,135 @@
+path: "tensorflow.data.experimental.RandomDataset"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.random_ops.RandomDataset\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "output_classes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shapes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_types"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "batch"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "cache"
+ argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
+ }
+ member_method {
+ name: "concatenate"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "filter"
+ argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flat_map"
+ argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_generator"
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_sparse_tensor_slices"
+ argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensor_slices"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensors"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "interleave"
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+ }
+ member_method {
+ name: "list_files"
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "make_initializable_iterator"
+ argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "make_one_shot_iterator"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map"
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "padded_batch"
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "prefetch"
+ argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "range"
+ argspec: "args=[], varargs=args, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "repeat"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "skip"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "take"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zip"
+ argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-reducer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-reducer.pbtxt
new file mode 100644
index 0000000000..6b477a8a72
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-reducer.pbtxt
@@ -0,0 +1,21 @@
+path: "tensorflow.data.experimental.Reducer"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.grouping.Reducer\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "finalize_func"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "init_func"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "reduce_func"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'init_func\', \'reduce_func\', \'finalize_func\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt
new file mode 100644
index 0000000000..948e99ef86
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.data.experimental.SqlDataset.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
new file mode 100644
index 0000000000..8fdd9dc52e
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-sql-dataset.pbtxt
@@ -0,0 +1,135 @@
+path: "tensorflow.data.experimental.SqlDataset"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.readers.SqlDataset\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.DatasetSource\'>"
+ is_instance: "<class \'tensorflow.python.data.ops.dataset_ops.Dataset\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "output_classes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_shapes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "output_types"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'driver_name\', \'data_source_name\', \'query\', \'output_types\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "apply"
+ argspec: "args=[\'self\', \'transformation_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "batch"
+ argspec: "args=[\'self\', \'batch_size\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+ member_method {
+ name: "cache"
+ argspec: "args=[\'self\', \'filename\'], varargs=None, keywords=None, defaults=[\'\'], "
+ }
+ member_method {
+ name: "concatenate"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "filter"
+ argspec: "args=[\'self\', \'predicate\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flat_map"
+ argspec: "args=[\'self\', \'map_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_generator"
+ argspec: "args=[\'generator\', \'output_types\', \'output_shapes\', \'args\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_sparse_tensor_slices"
+ argspec: "args=[\'sparse_tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensor_slices"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "from_tensors"
+ argspec: "args=[\'tensors\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "interleave"
+ argspec: "args=[\'self\', \'map_func\', \'cycle_length\', \'block_length\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\', \'None\'], "
+ }
+ member_method {
+ name: "list_files"
+ argspec: "args=[\'file_pattern\', \'shuffle\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "make_initializable_iterator"
+ argspec: "args=[\'self\', \'shared_name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "make_one_shot_iterator"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map"
+ argspec: "args=[\'self\', \'map_func\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "options"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "padded_batch"
+ argspec: "args=[\'self\', \'batch_size\', \'padded_shapes\', \'padding_values\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "prefetch"
+ argspec: "args=[\'self\', \'buffer_size\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "range"
+ argspec: "args=[], varargs=args, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "reduce"
+ argspec: "args=[\'self\', \'initial_state\', \'reduce_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "repeat"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "shard"
+ argspec: "args=[\'self\', \'num_shards\', \'index\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'self\', \'buffer_size\', \'seed\', \'reshuffle_each_iteration\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "skip"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "take"
+ argspec: "args=[\'self\', \'count\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "window"
+ argspec: "args=[\'self\', \'size\', \'shift\', \'stride\', \'drop_remainder\'], varargs=None, keywords=None, defaults=[\'None\', \'1\', \'False\'], "
+ }
+ member_method {
+ name: "with_options"
+ argspec: "args=[\'self\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "zip"
+ argspec: "args=[\'datasets\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-stats-aggregator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-stats-aggregator.pbtxt
new file mode 100644
index 0000000000..0bcc8cf3e8
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-stats-aggregator.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.data.experimental.StatsAggregator"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.stats_ops.StatsAggregator\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_summary"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-t-f-record-writer.pbtxt
new file mode 100644
index 0000000000..6f9d18a701
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.-t-f-record-writer.pbtxt
@@ -0,0 +1,13 @@
+path: "tensorflow.data.experimental.TFRecordWriter"
+tf_class {
+ is_instance: "<class \'tensorflow.python.data.experimental.ops.writers.TFRecordWriter\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'filename\', \'compression_type\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "write"
+ argspec: "args=[\'self\', \'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt
new file mode 100644
index 0000000000..b14585f8d7
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.experimental.pbtxt
@@ -0,0 +1,139 @@
+path: "tensorflow.data.experimental"
+tf_module {
+ member {
+ name: "CheckpointInputPipelineHook"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "CsvDataset"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "Optional"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "RandomDataset"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "Reducer"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "SqlDataset"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "StatsAggregator"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TFRecordWriter"
+ mtype: "<type \'type\'>"
+ }
+ member_method {
+ name: "Counter"
+ argspec: "args=[\'start\', \'step\', \'dtype\'], varargs=None, keywords=None, defaults=[\'0\', \'1\', \"<dtype: \'int64\'>\"], "
+ }
+ member_method {
+ name: "bucket_by_sequence_length"
+ argspec: "args=[\'element_length_func\', \'bucket_boundaries\', \'bucket_batch_sizes\', \'padded_shapes\', \'padding_values\', \'pad_to_bucket_boundary\', \'no_padding\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'False\', \'False\'], "
+ }
+ member_method {
+ name: "choose_from_datasets"
+ argspec: "args=[\'datasets\', \'choice_dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "copy_to_device"
+ argspec: "args=[\'target_device\', \'source_device\'], varargs=None, keywords=None, defaults=[\'/cpu:0\'], "
+ }
+ member_method {
+ name: "dense_to_sparse_batch"
+ argspec: "args=[\'batch_size\', \'row_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "enumerate_dataset"
+ argspec: "args=[\'start\'], varargs=None, keywords=None, defaults=[\'0\'], "
+ }
+ member_method {
+ name: "get_next_as_optional"
+ argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_single_element"
+ argspec: "args=[\'dataset\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "group_by_reducer"
+ argspec: "args=[\'key_func\', \'reducer\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "group_by_window"
+ argspec: "args=[\'key_func\', \'reduce_func\', \'window_size\', \'window_size_func\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "ignore_errors"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "latency_stats"
+ argspec: "args=[\'tag\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "make_batched_features_dataset"
+ argspec: "args=[\'file_pattern\', \'batch_size\', \'features\', \'reader\', \'label_key\', \'reader_args\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'reader_num_threads\', \'parser_num_threads\', \'sloppy_ordering\', \'drop_final_batch\'], varargs=None, keywords=None, defaults=[\"<class \'tensorflow.python.data.ops.readers.TFRecordDataset\'>\", \'None\', \'None\', \'None\', \'True\', \'10000\', \'None\', \'-1\', \'1\', \'2\', \'False\', \'False\'], "
+ }
+ member_method {
+ name: "make_csv_dataset"
+ argspec: "args=[\'file_pattern\', \'batch_size\', \'column_names\', \'column_defaults\', \'label_name\', \'select_columns\', \'field_delim\', \'use_quote_delim\', \'na_value\', \'header\', \'num_epochs\', \'shuffle\', \'shuffle_buffer_size\', \'shuffle_seed\', \'prefetch_buffer_size\', \'num_parallel_reads\', \'sloppy\', \'num_rows_for_inference\', \'compression_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \',\', \'True\', \'\', \'True\', \'None\', \'True\', \'10000\', \'None\', \'-1\', \'1\', \'False\', \'100\', \'None\'], "
+ }
+ member_method {
+ name: "make_saveable_from_iterator"
+ argspec: "args=[\'iterator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "map_and_batch"
+ argspec: "args=[\'map_func\', \'batch_size\', \'num_parallel_batches\', \'drop_remainder\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "parallel_interleave"
+ argspec: "args=[\'map_func\', \'cycle_length\', \'block_length\', \'sloppy\', \'buffer_output_elements\', \'prefetch_input_elements\'], varargs=None, keywords=None, defaults=[\'1\', \'False\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "parse_example_dataset"
+ argspec: "args=[\'features\', \'num_parallel_calls\'], varargs=None, keywords=None, defaults=[\'1\'], "
+ }
+ member_method {
+ name: "prefetch_to_device"
+ argspec: "args=[\'device\', \'buffer_size\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "rejection_resample"
+ argspec: "args=[\'class_func\', \'target_dist\', \'initial_dist\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "sample_from_datasets"
+ argspec: "args=[\'datasets\', \'weights\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "scan"
+ argspec: "args=[\'initial_state\', \'scan_func\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "set_stats_aggregator"
+ argspec: "args=[\'stats_aggregator\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle_and_repeat"
+ argspec: "args=[\'buffer_size\', \'count\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "unbatch"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "unique"
+ argspec: "args=[], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt
index 56fb270a49..3023276a1d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.data.pbtxt
@@ -13,6 +13,10 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
+ name: "Options"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "TFRecordDataset"
mtype: "<class \'abc.ABCMeta\'>"
}
@@ -20,4 +24,8 @@ tf_module {
name: "TextLineDataset"
mtype: "<class \'abc.ABCMeta\'>"
}
+ member {
+ name: "experimental"
+ mtype: "<type \'module\'>"
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.debugging.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.debugging.pbtxt
index d9efe97821..ab6287f8cd 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.debugging.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.debugging.pbtxt
@@ -1,6 +1,90 @@
path: "tensorflow.debugging"
tf_module {
member_method {
+ name: "Assert"
+ argspec: "args=[\'condition\', \'data\', \'summarize\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_all_finite"
+ argspec: "args=[\'t\', \'msg\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "assert_equal"
+ argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_greater"
+ argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_greater_equal"
+ argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_integer"
+ argspec: "args=[\'x\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_less"
+ argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_less_equal"
+ argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_near"
+ argspec: "args=[\'x\', \'y\', \'rtol\', \'atol\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_negative"
+ argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_non_negative"
+ argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_non_positive"
+ argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_none_equal"
+ argspec: "args=[\'x\', \'y\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_positive"
+ argspec: "args=[\'x\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_proper_iterable"
+ argspec: "args=[\'values\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "assert_rank"
+ argspec: "args=[\'x\', \'rank\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_rank_at_least"
+ argspec: "args=[\'x\', \'rank\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_rank_in"
+ argspec: "args=[\'x\', \'ranks\', \'data\', \'summarize\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_same_float_dtype"
+ argspec: "args=[\'tensors\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "assert_scalar"
+ argspec: "args=[\'tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "assert_type"
+ argspec: "args=[\'tensor\', \'tf_type\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "check_numerics"
argspec: "args=[\'tensor\', \'message\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -16,4 +100,16 @@ tf_module {
name: "is_nan"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
+ member_method {
+ name: "is_non_decreasing"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "is_numeric_tensor"
+ argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_strictly_increasing"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.dtypes.-d-type.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.dtypes.-d-type.pbtxt
new file mode 100644
index 0000000000..423eca32a2
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.dtypes.-d-type.pbtxt
@@ -0,0 +1,77 @@
+path: "tensorflow.dtypes.DType"
+tf_class {
+ is_instance: "<class \'tensorflow.python.framework.dtypes.DType\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "as_datatype_enum"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "as_numpy_dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "base_dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_bool"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_complex"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_floating"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_integer"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_numpy_compatible"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_quantized"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_unsigned"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "limits"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "max"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "min"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "real_dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "size"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'type_enum\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_compatible_with"
+ argspec: "args=[\'self\', \'other\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.dtypes.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.dtypes.pbtxt
index 98e1feed00..ea23feca84 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.dtypes.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.dtypes.pbtxt
@@ -1,7 +1,27 @@
path: "tensorflow.dtypes"
tf_module {
+ member {
+ name: "DType"
+ mtype: "<type \'type\'>"
+ }
+ member_method {
+ name: "as_dtype"
+ argspec: "args=[\'type_value\'], varargs=None, keywords=None, defaults=None"
+ }
member_method {
name: "as_string"
argspec: "args=[\'input\', \'precision\', \'scientific\', \'shortest\', \'width\', \'fill\', \'name\'], varargs=None, keywords=None, defaults=[\'-1\', \'False\', \'False\', \'-1\', \'\', \'None\'], "
}
+ member_method {
+ name: "cast"
+ argspec: "args=[\'x\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "complex"
+ argspec: "args=[\'real\', \'imag\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "saturate_cast"
+ argspec: "args=[\'value\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.graph_util.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.graph_util.pbtxt
index eeabf845dc..162ee76ee7 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.graph_util.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.graph_util.pbtxt
@@ -9,6 +9,10 @@ tf_module {
argspec: "args=[\'graph_def\', \'dest_nodes\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "import_graph_def"
+ argspec: "args=[\'graph_def\', \'input_map\', \'return_elements\', \'name\', \'op_dict\', \'producer_op_list\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "must_run_on_cpu"
argspec: "args=[\'node\', \'pin_variables_on_cpu\'], varargs=None, keywords=None, defaults=[\'False\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
index 5c46dc5ee7..0a231f1b65 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.image.pbtxt
@@ -149,6 +149,10 @@ tf_module {
argspec: "args=[\'image\', \'lower\', \'upper\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "random_crop"
+ argspec: "args=[\'value\', \'size\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "random_flip_left_right"
argspec: "args=[\'image\', \'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
index e3c63fe737..d49181714f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.initializers.pbtxt
@@ -64,4 +64,8 @@ tf_module {
name: "lecun_uniform"
argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
+ member_method {
+ name: "tables_initializer"
+ argspec: "args=[\'name\'], varargs=None, keywords=None, defaults=[\'init_all_tables\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.-fixed-len-feature.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.-fixed-len-feature.pbtxt
new file mode 100644
index 0000000000..cd0e51c8c7
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.-fixed-len-feature.pbtxt
@@ -0,0 +1,27 @@
+path: "tensorflow.io.FixedLenFeature"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenFeature\'>"
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenFeature\'>"
+ is_instance: "<type \'tuple\'>"
+ member {
+ name: "default_value"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "count"
+ }
+ member_method {
+ name: "index"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.-fixed-len-sequence-feature.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.-fixed-len-sequence-feature.pbtxt
new file mode 100644
index 0000000000..8a38f25fdf
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.-fixed-len-sequence-feature.pbtxt
@@ -0,0 +1,31 @@
+path: "tensorflow.io.FixedLenSequenceFeature"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenSequenceFeature\'>"
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.FixedLenSequenceFeature\'>"
+ is_instance: "<type \'tuple\'>"
+ member {
+ name: "allow_missing"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "default_value"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "count"
+ }
+ member_method {
+ name: "index"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.-padding-f-i-f-o-queue.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.-padding-f-i-f-o-queue.pbtxt
new file mode 100644
index 0000000000..85306fdcac
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.-padding-f-i-f-o-queue.pbtxt
@@ -0,0 +1,66 @@
+path: "tensorflow.io.PaddingFIFOQueue"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.PaddingFIFOQueue\'>"
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.QueueBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "dtypes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "names"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "queue_ref"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shapes"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'capacity\', \'dtypes\', \'shapes\', \'names\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'padding_fifo_queue\'], "
+ }
+ member_method {
+ name: "close"
+ argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "dequeue"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_many"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_up_to"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue_many"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "from_list"
+ argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_closed"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "size"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.-priority-queue.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.-priority-queue.pbtxt
new file mode 100644
index 0000000000..02d8037b34
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.-priority-queue.pbtxt
@@ -0,0 +1,66 @@
+path: "tensorflow.io.PriorityQueue"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.PriorityQueue\'>"
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.QueueBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "dtypes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "names"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "queue_ref"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shapes"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'capacity\', \'types\', \'shapes\', \'names\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'priority_queue\'], "
+ }
+ member_method {
+ name: "close"
+ argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "dequeue"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_many"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_up_to"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue_many"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "from_list"
+ argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_closed"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "size"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.-queue-base.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.-queue-base.pbtxt
new file mode 100644
index 0000000000..a30481a0ea
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.-queue-base.pbtxt
@@ -0,0 +1,65 @@
+path: "tensorflow.io.QueueBase"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.QueueBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "dtypes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "names"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "queue_ref"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shapes"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'dtypes\', \'shapes\', \'names\', \'queue_ref\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "close"
+ argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "dequeue"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_many"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_up_to"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue_many"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "from_list"
+ argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_closed"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "size"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.-random-shuffle-queue.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.-random-shuffle-queue.pbtxt
new file mode 100644
index 0000000000..82cbf9884f
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.-random-shuffle-queue.pbtxt
@@ -0,0 +1,66 @@
+path: "tensorflow.io.RandomShuffleQueue"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.RandomShuffleQueue\'>"
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.QueueBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "dtypes"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "names"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "queue_ref"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shapes"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'capacity\', \'min_after_dequeue\', \'dtypes\', \'shapes\', \'names\', \'seed\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'random_shuffle_queue\'], "
+ }
+ member_method {
+ name: "close"
+ argspec: "args=[\'self\', \'cancel_pending_enqueues\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
+ }
+ member_method {
+ name: "dequeue"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_many"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "dequeue_up_to"
+ argspec: "args=[\'self\', \'n\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "enqueue_many"
+ argspec: "args=[\'self\', \'vals\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "from_list"
+ argspec: "args=[\'index\', \'queues\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "is_closed"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "size"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.-sparse-feature.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.-sparse-feature.pbtxt
new file mode 100644
index 0000000000..216947b4ed
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.-sparse-feature.pbtxt
@@ -0,0 +1,35 @@
+path: "tensorflow.io.SparseFeature"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.SparseFeature\'>"
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.SparseFeature\'>"
+ is_instance: "<type \'tuple\'>"
+ member {
+ name: "already_sorted"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "index_key"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "size"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "value_key"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "count"
+ }
+ member_method {
+ name: "index"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-compression-type.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-compression-type.pbtxt
new file mode 100644
index 0000000000..b598f73d7e
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-compression-type.pbtxt
@@ -0,0 +1,20 @@
+path: "tensorflow.io.TFRecordCompressionType"
+tf_class {
+ is_instance: "<class \'tensorflow.python.lib.io.tf_record.TFRecordCompressionType\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "GZIP"
+ mtype: "<type \'int\'>"
+ }
+ member {
+ name: "NONE"
+ mtype: "<type \'int\'>"
+ }
+ member {
+ name: "ZLIB"
+ mtype: "<type \'int\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-options.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-options.pbtxt
new file mode 100644
index 0000000000..bfbf37ccf4
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-options.pbtxt
@@ -0,0 +1,17 @@
+path: "tensorflow.io.TFRecordOptions"
+tf_class {
+ is_instance: "<class \'tensorflow.python.lib.io.tf_record.TFRecordOptions\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "compression_type_map"
+ mtype: "<type \'dict\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'compression_type\', \'flush_mode\', \'input_buffer_size\', \'output_buffer_size\', \'window_bits\', \'compression_level\', \'compression_method\', \'mem_level\', \'compression_strategy\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "get_compression_type_string"
+ argspec: "args=[\'cls\', \'options\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-writer.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-writer.pbtxt
new file mode 100644
index 0000000000..6fd443f6d7
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.-t-f-record-writer.pbtxt
@@ -0,0 +1,21 @@
+path: "tensorflow.io.TFRecordWriter"
+tf_class {
+ is_instance: "<class \'tensorflow.python.lib.io.tf_record.TFRecordWriter\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "close"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "flush"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "write"
+ argspec: "args=[\'self\', \'record\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.-var-len-feature.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.-var-len-feature.pbtxt
new file mode 100644
index 0000000000..fd835dbfbb
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.-var-len-feature.pbtxt
@@ -0,0 +1,19 @@
+path: "tensorflow.io.VarLenFeature"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.VarLenFeature\'>"
+ is_instance: "<class \'tensorflow.python.ops.parsing_ops.VarLenFeature\'>"
+ is_instance: "<type \'tuple\'>"
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "count"
+ }
+ member_method {
+ name: "index"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt
index 8938cf217b..dccf136788 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.io.pbtxt
@@ -1,5 +1,49 @@
path: "tensorflow.io"
tf_module {
+ member {
+ name: "FixedLenFeature"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "FixedLenSequenceFeature"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "PaddingFIFOQueue"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "PriorityQueue"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "QueueBase"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "RandomShuffleQueue"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "SparseFeature"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TFRecordCompressionType"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TFRecordOptions"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "TFRecordWriter"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "VarLenFeature"
+ mtype: "<type \'type\'>"
+ }
member_method {
name: "decode_base64"
argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@@ -9,6 +53,10 @@ tf_module {
argspec: "args=[\'bytes\', \'compression_type\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'None\'], "
}
member_method {
+ name: "decode_csv"
+ argspec: "args=[\'records\', \'record_defaults\', \'field_delim\', \'use_quote_delim\', \'name\', \'na_value\', \'select_cols\'], varargs=None, keywords=None, defaults=[\',\', \'True\', \'None\', \'\', \'None\'], "
+ }
+ member_method {
name: "decode_json_example"
argspec: "args=[\'json_examples\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -17,18 +65,38 @@ tf_module {
argspec: "args=[\'bytes\', \'out_type\', \'little_endian\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
+ name: "deserialize_many_sparse"
+ argspec: "args=[\'serialized_sparse\', \'dtype\', \'rank\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "encode_base64"
argspec: "args=[\'input\', \'pad\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
+ name: "match_filenames_once"
+ argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "matching_files"
argspec: "args=[\'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "parse_example"
+ argspec: "args=[\'serialized\', \'features\', \'name\', \'example_names\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
name: "parse_sequence_example"
argspec: "args=[\'serialized\', \'context_features\', \'sequence_features\', \'example_names\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
member_method {
+ name: "parse_single_example"
+ argspec: "args=[\'serialized\', \'features\', \'name\', \'example_names\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "parse_single_sequence_example"
+ argspec: "args=[\'serialized\', \'context_features\', \'sequence_features\', \'example_name\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "parse_tensor"
argspec: "args=[\'serialized\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -37,7 +105,23 @@ tf_module {
argspec: "args=[\'filename\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "serialize_many_sparse"
+ argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'string\'>\"], "
+ }
+ member_method {
+ name: "serialize_sparse"
+ argspec: "args=[\'sp_input\', \'name\', \'out_type\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'string\'>\"], "
+ }
+ member_method {
+ name: "tf_record_iterator"
+ argspec: "args=[\'path\', \'options\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "write_file"
argspec: "args=[\'filename\', \'contents\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
+ member_method {
+ name: "write_graph"
+ argspec: "args=[\'graph_or_graph_def\', \'logdir\', \'name\', \'as_text\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
index 126ce8db6a..a71a59e269 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.backend.pbtxt
@@ -398,7 +398,7 @@ tf_module {
}
member_method {
name: "rnn"
- argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\'], "
+ argspec: "args=[\'step_function\', \'inputs\', \'initial_states\', \'go_backwards\', \'mask\', \'constants\', \'unroll\', \'input_length\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'None\', \'None\', \'False\', \'None\', \'False\'], "
}
member_method {
name: "round"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt
index 2b6e8af11d..68b6678d48 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.keras.layers.-r-n-n.pbtxt
@@ -86,7 +86,7 @@ tf_class {
}
member_method {
name: "__init__"
- argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\'], "
+ argspec: "args=[\'self\', \'cell\', \'return_sequences\', \'return_state\', \'go_backwards\', \'stateful\', \'unroll\', \'time_major\'], varargs=None, keywords=kwargs, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\'], "
}
member_method {
name: "add_loss"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt
index df74c32e1f..0c24e9c7dd 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.layers.pbtxt
@@ -122,7 +122,7 @@ tf_module {
}
member_method {
name: "flatten"
- argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ argspec: "args=[\'inputs\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'channels_last\'], "
}
member_method {
name: "max_pooling1d"
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
index d979116887..6ac95d96da 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.linalg.pbtxt
@@ -109,10 +109,18 @@ tf_module {
argspec: "args=[\'num_rows\', \'num_columns\', \'batch_shape\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
member_method {
+ name: "global_norm"
+ argspec: "args=[\'t_list\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "inv"
argspec: "args=[\'input\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'None\'], "
}
member_method {
+ name: "l2_normalize"
+ argspec: "args=[\'x\', \'axis\', \'epsilon\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'1e-12\', \'None\', \'None\'], "
+ }
+ member_method {
name: "logdet"
argspec: "args=[\'matrix\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -125,6 +133,10 @@ tf_module {
argspec: "args=[\'matrix\', \'rhs\', \'l2_regularizer\', \'fast\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'True\', \'None\'], "
}
member_method {
+ name: "matmul"
+ argspec: "args=[\'a\', \'b\', \'transpose_a\', \'transpose_b\', \'adjoint_a\', \'adjoint_b\', \'a_is_sparse\', \'b_is_sparse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'False\', \'False\', \'False\', \'False\', \'None\'], "
+ }
+ member_method {
name: "norm"
argspec: "args=[\'tensor\', \'ord\', \'axis\', \'keepdims\', \'name\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'euclidean\', \'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
index 72856466ec..459b9e3684 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.math.pbtxt
@@ -1,6 +1,14 @@
path: "tensorflow.math"
tf_module {
member_method {
+ name: "abs"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "accumulate_n"
+ argspec: "args=[\'inputs\', \'shape\', \'tensor_dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "acos"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -13,6 +21,22 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "add_n"
+ argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "angle"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "argmax"
+ argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\', \'output_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
+ }
+ member_method {
+ name: "argmin"
+ argspec: "args=[\'input\', \'axis\', \'name\', \'dimension\', \'output_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int64\'>\"], "
+ }
+ member_method {
name: "asin"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -53,10 +77,18 @@ tf_module {
argspec: "args=[\'a\', \'b\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "bincount"
+ argspec: "args=[\'arr\', \'weights\', \'minlength\', \'maxlength\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \"<dtype: \'int32\'>\"], "
+ }
+ member_method {
name: "ceil"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "conj"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "cos"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -65,14 +97,34 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "count_nonzero"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'dtype\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \"<dtype: \'int64\'>\", \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "cumprod"
+ argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "cumsum"
+ argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'False\', \'False\', \'None\'], "
+ }
+ member_method {
name: "digamma"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "divide"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "equal"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "erf"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "erfc"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -89,6 +141,10 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "floordiv"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "greater"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -105,10 +161,26 @@ tf_module {
argspec: "args=[\'a\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "imag"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "in_top_k"
+ argspec: "args=[\'predictions\', \'targets\', \'k\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "invert_permutation"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "l2_normalize"
+ argspec: "args=[\'x\', \'axis\', \'epsilon\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'1e-12\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "lbeta"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "less"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -129,6 +201,14 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "log_sigmoid"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "log_softmax"
+ argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "logical_and"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -141,6 +221,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "logical_xor"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'LogicalXor\'], "
+ }
+ member_method {
name: "maximum"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -149,6 +233,14 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "multiply"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "negative"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "not_equal"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -161,18 +253,66 @@ tf_module {
argspec: "args=[\'coeffs\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "pow"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "real"
+ argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "reciprocal"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "reduce_all"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_any"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_logsumexp"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_max"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_mean"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_min"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_prod"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_sum"
+ argspec: "args=[\'input_tensor\', \'axis\', \'keepdims\', \'name\', \'reduction_indices\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "rint"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "round"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "rsqrt"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "scalar_mul"
+ argspec: "args=[\'scalar\', \'x\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "segment_max"
argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -193,6 +333,14 @@ tf_module {
argspec: "args=[\'data\', \'segment_ids\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "sigmoid"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "sign"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "sin"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -201,6 +349,10 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "softmax"
+ argspec: "args=[\'logits\', \'axis\', \'name\', \'dim\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
name: "softplus"
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -209,18 +361,46 @@ tf_module {
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "sqrt"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "square"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "squared_difference"
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "subtract"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "tan"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "tanh"
+ argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "top_k"
+ argspec: "args=[\'input\', \'k\', \'sorted\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "truediv"
+ argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "unsorted_segment_max"
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "unsorted_segment_mean"
+ argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "unsorted_segment_min"
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -229,6 +409,10 @@ tf_module {
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "unsorted_segment_sqrt_n"
+ argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "unsorted_segment_sum"
argspec: "args=[\'data\', \'segment_ids\', \'num_segments\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
@@ -241,6 +425,10 @@ tf_module {
argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "zero_fraction"
+ argspec: "args=[\'value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
name: "zeta"
argspec: "args=[\'x\', \'q\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt
index d9e5b0d0fc..9b28ce5746 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.pbtxt
@@ -101,6 +101,10 @@ tf_module {
argspec: "args=[\'labels\', \'inputs\', \'sequence_length\', \'preprocess_collapse_repeated\', \'ctc_merge_repeated\', \'ignore_longer_outputs_than_inputs\', \'time_major\'], varargs=None, keywords=None, defaults=[\'False\', \'True\', \'False\', \'True\'], "
}
member_method {
+ name: "depth_to_space"
+ argspec: "args=[\'input\', \'block_size\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'NHWC\'], "
+ }
+ member_method {
name: "depthwise_conv2d"
argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'rate\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
@@ -305,6 +309,14 @@ tf_module {
argspec: "args=[\'features\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "space_to_batch"
+ argspec: "args=[\'input\', \'paddings\', \'block_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "space_to_depth"
+ argspec: "args=[\'input\', \'block_size\', \'name\', \'data_format\'], varargs=None, keywords=None, defaults=[\'None\', \'NHWC\'], "
+ }
+ member_method {
name: "sparse_softmax_cross_entropy_with_logits"
argspec: "args=[\'_sentinel\', \'labels\', \'logits\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
deleted file mode 100644
index a4483fefa2..0000000000
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.-basic-r-n-n-cell.pbtxt
+++ /dev/null
@@ -1,202 +0,0 @@
-path: "tensorflow.nn.rnn_cell.BasicRNNCell"
-tf_class {
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.BasicRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.LayerRNNCell\'>"
- is_instance: "<class \'tensorflow.python.ops.rnn_cell_impl.RNNCell\'>"
- is_instance: "<class \'tensorflow.python.layers.base.Layer\'>"
- is_instance: "<class \'tensorflow.python.keras.engine.base_layer.Layer\'>"
- is_instance: "<class \'tensorflow.python.training.checkpointable.base.CheckpointableBase\'>"
- is_instance: "<type \'object\'>"
- member {
- name: "activity_regularizer"
- mtype: "<type \'property\'>"
- }
- member {
- name: "dtype"
- mtype: "<type \'property\'>"
- }
- member {
- name: "graph"
- mtype: "<type \'property\'>"
- }
- member {
- name: "inbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "input_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "losses"
- mtype: "<type \'property\'>"
- }
- member {
- name: "name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "non_trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "outbound_nodes"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_mask"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_shape"
- mtype: "<type \'property\'>"
- }
- member {
- name: "output_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "scope_name"
- mtype: "<type \'property\'>"
- }
- member {
- name: "state_size"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "trainable_weights"
- mtype: "<type \'property\'>"
- }
- member {
- name: "updates"
- mtype: "<type \'property\'>"
- }
- member {
- name: "variables"
- mtype: "<type \'property\'>"
- }
- member {
- name: "weights"
- mtype: "<type \'property\'>"
- }
- member_method {
- name: "__init__"
- argspec: "args=[\'self\', \'num_units\', \'activation\', \'reuse\', \'name\', \'dtype\'], varargs=None, keywords=kwargs, defaults=[\'None\', \'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "add_loss"
- argspec: "args=[\'self\', \'losses\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_update"
- argspec: "args=[\'self\', \'updates\', \'inputs\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "add_variable"
- argspec: "args=[\'self\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "add_weight"
- argspec: "args=[\'self\', \'name\', \'shape\', \'dtype\', \'initializer\', \'regularizer\', \'trainable\', \'constraint\', \'use_resource\', \'synchronization\', \'aggregation\', \'partitioner\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\', \'None\', \'None\', \'VariableSynchronization.AUTO\', \'VariableAggregation.NONE\', \'None\'], "
- }
- member_method {
- name: "apply"
- argspec: "args=[\'self\', \'inputs\'], varargs=args, keywords=kwargs, defaults=None"
- }
- member_method {
- name: "build"
- argspec: "args=[\'instance\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "call"
- argspec: "args=[\'self\', \'inputs\', \'state\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "compute_mask"
- argspec: "args=[\'self\', \'inputs\', \'mask\'], varargs=None, keywords=None, defaults=[\'None\'], "
- }
- member_method {
- name: "compute_output_shape"
- argspec: "args=[\'self\', \'input_shape\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "count_params"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "from_config"
- argspec: "args=[\'cls\', \'config\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_config"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_initial_state"
- argspec: "args=[\'self\', \'inputs\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
- }
- member_method {
- name: "get_input_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_input_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_losses_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_mask_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_output_shape_at"
- argspec: "args=[\'self\', \'node_index\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_updates_for"
- argspec: "args=[\'self\', \'inputs\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "get_weights"
- argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "set_weights"
- argspec: "args=[\'self\', \'weights\'], varargs=None, keywords=None, defaults=None"
- }
- member_method {
- name: "zero_state"
- argspec: "args=[\'self\', \'batch_size\', \'dtype\'], varargs=None, keywords=None, defaults=None"
- }
-}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt
index 64697e8a02..24767e250f 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.nn.rnn_cell.pbtxt
@@ -5,10 +5,6 @@ tf_module {
mtype: "<type \'type\'>"
}
member {
- name: "BasicRNNCell"
- mtype: "<type \'type\'>"
- }
- member {
name: "DeviceWrapper"
mtype: "<type \'type\'>"
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
index d2dc8bc85f..5b3ea75bce 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt
@@ -457,6 +457,10 @@ tf_module {
mtype: "<class \'tensorflow.python.framework.dtypes.DType\'>"
}
member {
+ name: "random"
+ mtype: "<type \'module\'>"
+ }
+ member {
name: "random_normal_initializer"
mtype: "<type \'type\'>"
}
@@ -1609,6 +1613,10 @@ tf_module {
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
member_method {
+ name: "roll"
+ argspec: "args=[\'input\', \'shift\', \'axis\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "round"
argspec: "args=[\'x\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt
index 6d865efed0..77c92aeb0d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.quantization.pbtxt
@@ -29,6 +29,10 @@ tf_module {
argspec: "args=[\'gradients\', \'inputs\', \'min\', \'max\', \'num_bits\', \'narrow_range\', \'name\'], varargs=None, keywords=None, defaults=[\'8\', \'False\', \'None\'], "
}
member_method {
+ name: "quantize"
+ argspec: "args=[\'input\', \'min_range\', \'max_range\', \'T\', \'mode\', \'round_mode\', \'name\'], varargs=None, keywords=None, defaults=[\'MIN_COMBINED\', \'HALF_AWAY_FROM_ZERO\', \'None\'], "
+ }
+ member_method {
name: "quantized_concat"
argspec: "args=[\'concat_dim\', \'values\', \'input_mins\', \'input_maxes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt
new file mode 100644
index 0000000000..a568dd4cd8
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.random.pbtxt
@@ -0,0 +1,47 @@
+path: "tensorflow.random"
+tf_module {
+ member_method {
+ name: "gamma"
+ argspec: "args=[\'shape\', \'alpha\', \'beta\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\', \'None\'], "
+ }
+ member_method {
+ name: "get_seed"
+ argspec: "args=[\'op_seed\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "log_uniform_candidate_sampler"
+ argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "multinomial"
+ argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'name\', \'output_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "normal"
+ argspec: "args=[\'shape\', \'mean\', \'stddev\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\', \'None\'], "
+ }
+ member_method {
+ name: "poisson"
+ argspec: "args=[\'lam\', \'shape\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'float32\'>\", \'None\', \'None\'], "
+ }
+ member_method {
+ name: "set_random_seed"
+ argspec: "args=[\'seed\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "shuffle"
+ argspec: "args=[\'value\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "truncated_normal"
+ argspec: "args=[\'shape\', \'mean\', \'stddev\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'1.0\', \"<dtype: \'float32\'>\", \'None\', \'None\'], "
+ }
+ member_method {
+ name: "uniform"
+ argspec: "args=[\'shape\', \'minval\', \'maxval\', \'dtype\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\', \"<dtype: \'float32\'>\", \'None\', \'None\'], "
+ }
+ member_method {
+ name: "uniform_candidate_sampler"
+ argspec: "args=[\'true_classes\', \'num_true\', \'num_sampled\', \'unique\', \'range_max\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-builder.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-builder.pbtxt
new file mode 100644
index 0000000000..67457de070
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.-builder.pbtxt
@@ -0,0 +1,21 @@
+path: "tensorflow.saved_model.Builder"
+tf_class {
+ is_instance: "<class \'tensorflow.python.saved_model.builder_impl.SavedModelBuilder\'>"
+ is_instance: "<type \'object\'>"
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'export_dir\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "add_meta_graph"
+ argspec: "args=[\'self\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "add_meta_graph_and_variables"
+ argspec: "args=[\'self\', \'sess\', \'tags\', \'signature_def_map\', \'assets_collection\', \'legacy_init_op\', \'clear_devices\', \'main_op\', \'strip_default_attrs\', \'saver\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'False\', \'None\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "save"
+ argspec: "args=[\'self\', \'as_text\'], varargs=None, keywords=None, defaults=[\'False\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt
index e1a0385092..3f4965fc69 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.saved_model.pbtxt
@@ -1,6 +1,10 @@
path: "tensorflow.saved_model"
tf_module {
member {
+ name: "Builder"
+ mtype: "<type \'type\'>"
+ }
+ member {
name: "builder"
mtype: "<type \'module\'>"
}
@@ -33,6 +37,46 @@ tf_module {
mtype: "<type \'module\'>"
}
member_method {
+ name: "build_signature_def"
+ argspec: "args=[\'inputs\', \'outputs\', \'method_name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "build_tensor_info"
+ argspec: "args=[\'tensor\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "classification_signature_def"
+ argspec: "args=[\'examples\', \'classes\', \'scores\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_tensor_from_tensor_info"
+ argspec: "args=[\'tensor_info\', \'graph\', \'import_scope\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "is_valid_signature"
+ argspec: "args=[\'signature_def\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "load"
+ argspec: "args=[\'sess\', \'tags\', \'export_dir\', \'import_scope\'], varargs=None, keywords=saver_kwargs, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "main_op_with_restore"
+ argspec: "args=[\'restore_op_name\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "maybe_saved_model_directory"
+ argspec: "args=[\'export_dir\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "predict_signature_def"
+ argspec: "args=[\'inputs\', \'outputs\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "regression_signature_def"
+ argspec: "args=[\'examples\', \'predictions\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
name: "simple_save"
argspec: "args=[\'session\', \'export_dir\', \'inputs\', \'outputs\', \'legacy_init_op\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-conditional-accumulator.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-conditional-accumulator.pbtxt
new file mode 100644
index 0000000000..cd97716c9d
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-conditional-accumulator.pbtxt
@@ -0,0 +1,46 @@
+path: "tensorflow.sparse.SparseConditionalAccumulator"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.SparseConditionalAccumulator\'>"
+ is_instance: "<class \'tensorflow.python.ops.data_flow_ops.ConditionalAccumulatorBase\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "accumulator_ref"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'dtype\', \'shape\', \'shared_name\', \'name\', \'reduction_type\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'sparse_conditional_accumulator\', \'MEAN\'], "
+ }
+ member_method {
+ name: "apply_grad"
+ argspec: "args=[\'self\', \'grad_indices\', \'grad_values\', \'grad_shape\', \'local_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'0\', \'None\'], "
+ }
+ member_method {
+ name: "apply_indexed_slices_grad"
+ argspec: "args=[\'self\', \'grad\', \'local_step\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'None\'], "
+ }
+ member_method {
+ name: "num_accumulated"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "set_global_step"
+ argspec: "args=[\'self\', \'new_global_step\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "take_grad"
+ argspec: "args=[\'self\', \'num_required\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "take_indexed_slices_grad"
+ argspec: "args=[\'self\', \'num_required\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt
new file mode 100644
index 0000000000..02e59a63e1
--- /dev/null
+++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.-sparse-tensor.pbtxt
@@ -0,0 +1,54 @@
+path: "tensorflow.sparse.SparseTensor"
+tf_class {
+ is_instance: "<class \'tensorflow.python.framework.sparse_tensor.SparseTensor\'>"
+ is_instance: "<class \'tensorflow.python.framework.ops._TensorLike\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "dense_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "indices"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "op"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "values"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'indices\', \'values\', \'dense_shape\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "consumers"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "eval"
+ argspec: "args=[\'self\', \'feed_dict\', \'session\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "from_value"
+ argspec: "args=[\'cls\', \'sparse_tensor_value\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "get_shape"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt
index ba9e651b34..32bd8d5f8e 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.sparse.pbtxt
@@ -1,5 +1,21 @@
path: "tensorflow.sparse"
tf_module {
+ member {
+ name: "SparseConditionalAccumulator"
+ mtype: "<type \'type\'>"
+ }
+ member {
+ name: "SparseTensor"
+ mtype: "<type \'type\'>"
+ }
+ member_method {
+ name: "add"
+ argspec: "args=[\'a\', \'b\', \'thresh\'], varargs=None, keywords=None, defaults=[\'0\'], "
+ }
+ member_method {
+ name: "concat"
+ argspec: "args=[\'axis\', \'sp_inputs\', \'name\', \'expand_nonconcat_dim\', \'concat_dim\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'None\'], "
+ }
member_method {
name: "cross"
argspec: "args=[\'inputs\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
@@ -16,4 +32,100 @@ tf_module {
name: "eye"
argspec: "args=[\'num_rows\', \'num_columns\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'float32\'>\", \'None\'], "
}
+ member_method {
+ name: "fill_empty_rows"
+ argspec: "args=[\'sp_input\', \'default_value\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "mask"
+ argspec: "args=[\'a\', \'mask_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "matmul"
+ argspec: "args=[\'sp_a\', \'b\', \'adjoint_a\', \'adjoint_b\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], "
+ }
+ member_method {
+ name: "maximum"
+ argspec: "args=[\'sp_a\', \'sp_b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "merge"
+ argspec: "args=[\'sp_ids\', \'sp_values\', \'vocab_size\', \'name\', \'already_sorted\'], varargs=None, keywords=None, defaults=[\'None\', \'False\'], "
+ }
+ member_method {
+ name: "minimum"
+ argspec: "args=[\'sp_a\', \'sp_b\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "placeholder"
+ argspec: "args=[\'dtype\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_max"
+ argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_max_sparse"
+ argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_sum"
+ argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reduce_sum_sparse"
+ argspec: "args=[\'sp_input\', \'axis\', \'keepdims\', \'reduction_axes\', \'keep_dims\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "reorder"
+ argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "reset_shape"
+ argspec: "args=[\'sp_input\', \'new_shape\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "reshape"
+ argspec: "args=[\'sp_input\', \'shape\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "retain"
+ argspec: "args=[\'sp_input\', \'to_retain\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "segment_mean"
+ argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "segment_sqrt_n"
+ argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "segment_sum"
+ argspec: "args=[\'data\', \'indices\', \'segment_ids\', \'name\', \'num_segments\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
+ member_method {
+ name: "slice"
+ argspec: "args=[\'sp_input\', \'start\', \'size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "softmax"
+ argspec: "args=[\'sp_input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "split"
+ argspec: "args=[\'keyword_required\', \'sp_input\', \'num_split\', \'axis\', \'name\', \'split_dim\'], varargs=None, keywords=None, defaults=[\'KeywordRequired()\', \'None\', \'None\', \'None\', \'None\', \'None\'], "
+ }
+ member_method {
+ name: "to_dense"
+ argspec: "args=[\'sp_input\', \'default_value\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'True\', \'None\'], "
+ }
+ member_method {
+ name: "to_indicator"
+ argspec: "args=[\'sp_input\', \'vocab_size\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
+ }
+ member_method {
+ name: "transpose"
+ argspec: "args=[\'sp_input\', \'perm\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ }
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
index 312e94b41d..ebdaf57231 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.strings.pbtxt
@@ -13,6 +13,10 @@ tf_module {
argspec: "args=[\'input\', \'name\', \'unit\'], varargs=None, keywords=None, defaults=[\'None\', \'BYTE\'], "
}
member_method {
+ name: "reduce_join"
+ argspec: "args=[\'inputs\', \'axis\', \'keep_dims\', \'separator\', \'name\', \'reduction_indices\'], varargs=None, keywords=None, defaults=[\'None\', \'False\', \'\', \'None\', \'None\'], "
+ }
+ member_method {
name: "regex_full_match"
argspec: "args=[\'input\', \'pattern\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], "
}
diff --git a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
index cb6da5088b..7e980fe44d 100644
--- a/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
+++ b/tensorflow/tools/api/golden/v2/tensorflow.train.pbtxt
@@ -253,6 +253,10 @@ tf_module {
argspec: "args=[\'checkpoint_prefix\'], varargs=None, keywords=None, defaults=None"
}
member_method {
+ name: "confusion_matrix"
+ argspec: "args=[\'labels\', \'predictions\', \'num_classes\', \'dtype\', \'name\', \'weights\'], varargs=None, keywords=None, defaults=[\'None\', \"<dtype: \'int32\'>\", \'None\', \'None\'], "
+ }
+ member_method {
name: "cosine_decay"
argspec: "args=[\'learning_rate\', \'global_step\', \'decay_steps\', \'alpha\', \'name\'], varargs=None, keywords=None, defaults=[\'0.0\', \'None\'], "
}
diff --git a/tensorflow/tools/ci_build/Dockerfile.cmake b/tensorflow/tools/ci_build/Dockerfile.cmake
index b7450c83de..ef0024fdb4 100644
--- a/tensorflow/tools/ci_build/Dockerfile.cmake
+++ b/tensorflow/tools/ci_build/Dockerfile.cmake
@@ -28,8 +28,8 @@ RUN pip install --upgrade astor
RUN pip install --upgrade gast
RUN pip install --upgrade numpy
RUN pip install --upgrade termcolor
-RUN pip install keras_applications==1.0.5
-RUN pip install keras_preprocessing==1.0.3
+RUN pip install --upgrade keras_applications
+RUN pip install --upgrade keras_preprocessing
# Install golang
RUN apt-get install -t xenial-backports -y golang-1.9
diff --git a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04 b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
index a30858db82..dd8d705331 100644
--- a/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
+++ b/tensorflow/tools/ci_build/Dockerfile.rbe.cuda9.0-cudnn7-ubuntu14.04
@@ -26,7 +26,7 @@ ENV NVIDIA_VISIBLE_DEVICES all
ENV NVIDIA_DRIVER_CAPABILITIES compute,utility
ENV NVIDIA_REQUIRE_CUDA "cuda>=9.0"
ENV NCCL_VERSION 2.2.13
-ENV CUDNN_VERSION 7.2.1.38
+ENV CUDNN_VERSION 7.1.4.18
# TODO(b/110903506): /usr/loca/cuda/lib64/stubs should not be needed in
# LD_LIBRARY_PATH. The stubs/libcuda.so is not meant to used at runtime. The
diff --git a/tensorflow/tools/ci_build/builds/android.sh b/tensorflow/tools/ci_build/builds/android.sh
index 7c3e308229..ec5ec9993a 100755
--- a/tensorflow/tools/ci_build/builds/android.sh
+++ b/tensorflow/tools/ci_build/builds/android.sh
@@ -38,6 +38,7 @@ TARGETS+=" //tensorflow/core/common_runtime/eager:execute"
bazel --bazelrc=/dev/null build \
--compilation_mode=opt --cxxopt=-std=c++11 --fat_apk_cpu=x86_64 \
--spawn_strategy=sandboxed --genrule_strategy=sandboxed \
+ --define=grpc_no_ares=true \
${TARGETS}
echo "========== Makefile Build Test =========="
diff --git a/tensorflow/tools/ci_build/builds/run_pip_tests.sh b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
index 17198a6560..7d5cf3f843 100755
--- a/tensorflow/tools/ci_build/builds/run_pip_tests.sh
+++ b/tensorflow/tools/ci_build/builds/run_pip_tests.sh
@@ -111,7 +111,6 @@ bazel clean
# virtualenv.
export TF_NEED_GCP=0
export TF_NEED_HDFS=0
-export TF_ENABLE_XLA=0
# Obtain the path to Python binary
if [[ ${IS_VIRTUALENV} == "1" ]]; then
diff --git a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
index cd7206baf8..9c6390070c 100755
--- a/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
+++ b/tensorflow/tools/ci_build/gpu_build/parallel_gpu_execute.sh
@@ -29,7 +29,7 @@ TF_TESTS_PER_GPU=${TF_TESTS_PER_GPU:-8}
# p100 has minimum 12G memory. Therefore, we should limit each test to 1.5G.
# To leave some room in case we want to run more tests in parallel in the
# future and to use a rounder number, we set it to 1G.
-export TF_PER_DEVICE_MEMORY_LIMIT_MB=1024
+export TF_PER_DEVICE_MEMORY_LIMIT_MB=${TF_PER_DEVICE_MEMORY_LIMIT_MB:-1024}
# *******************************************************************
# This section of the script is needed to
diff --git a/tensorflow/tools/ci_build/install/install_pip_packages.sh b/tensorflow/tools/ci_build/install/install_pip_packages.sh
index 4ced96f90b..b90f3f3b97 100755
--- a/tensorflow/tools/ci_build/install/install_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_pip_packages.sh
@@ -115,10 +115,10 @@ pip2 install --upgrade setuptools==39.1.0
pip3 install --upgrade setuptools==39.1.0
# Keras
-pip2 install keras_applications==1.0.5 --no-deps
-pip3 install keras_applications==1.0.5 --no-deps
-pip2 install keras_preprocessing==1.0.3 --no-deps
-pip3 install keras_preprocessing==1.0.3 --no-deps
+pip2 install keras_applications==1.0.6 --no-deps
+pip3 install keras_applications==1.0.6 --no-deps
+pip2 install keras_preprocessing==1.0.5 --no-deps
+pip3 install keras_preprocessing==1.0.5 --no-deps
pip2 install --upgrade h5py==2.8.0
pip3 install --upgrade h5py==2.8.0
diff --git a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
index 37e6b51f66..61d4fe3fe8 100755
--- a/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh
@@ -85,8 +85,8 @@ pip3.5 install --upgrade termcolor
pip3.5 install --upgrade setuptools==39.1.0
# Keras
-pip3.5 install keras_applications==1.0.5
-pip3.5 install keras_preprocessing==1.0.3
+pip3.5 install keras_applications==1.0.6
+pip3.5 install keras_preprocessing==1.0.5
pip3.5 install --upgrade h5py==2.8.0
# Install last working version of setuptools.
diff --git a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
index 7520ff74cb..8949af8a88 100755
--- a/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
+++ b/tensorflow/tools/ci_build/install/install_python3.6_pip_packages.sh
@@ -102,7 +102,7 @@ pip3 install --upgrade setuptools==39.1.0
pip3 install --upgrade h5py==2.8.0
# Keras
-pip3 install keras_applications==1.0.5
-pip3 install keras_preprocessing==1.0.3
+pip3 install keras_applications==1.0.6
+pip3 install keras_preprocessing==1.0.5
# LINT.ThenChange(//tensorflow/tools/ci_build/install/install_python3.5_pip_packages.sh)
diff --git a/tensorflow/tools/docker/Dockerfile b/tensorflow/tools/docker/Dockerfile
index b5a6c05193..205128ad58 100644
--- a/tensorflow/tools/docker/Dockerfile
+++ b/tensorflow/tools/docker/Dockerfile
@@ -29,8 +29,8 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.5 \
- keras_preprocessing==1.0.3 \
+ keras_applications \
+ keras_preprocessing \
matplotlib \
numpy \
pandas \
diff --git a/tensorflow/tools/docker/Dockerfile.devel b/tensorflow/tools/docker/Dockerfile.devel
index c741e8ad0c..6f8e91fccf 100644
--- a/tensorflow/tools/docker/Dockerfile.devel
+++ b/tensorflow/tools/docker/Dockerfile.devel
@@ -33,8 +33,8 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.5 \
- keras_preprocessing==1.0.3 \
+ keras_applications \
+ keras_preprocessing \
matplotlib \
mock \
numpy \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-gpu b/tensorflow/tools/docker/Dockerfile.devel-gpu
index f544725af4..69a117fda6 100644
--- a/tensorflow/tools/docker/Dockerfile.devel-gpu
+++ b/tensorflow/tools/docker/Dockerfile.devel-gpu
@@ -55,8 +55,8 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.5 \
- keras_preprocessing==1.0.3 \
+ keras_applications \
+ keras_preprocessing \
matplotlib \
mock \
numpy \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl b/tensorflow/tools/docker/Dockerfile.devel-mkl
index db7c701289..e433e9ebb2 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl
@@ -52,8 +52,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.5 \
- keras_preprocessing==1.0.3 \
+ keras_applications \
+ keras_preprocessing \
matplotlib \
mock \
numpy \
diff --git a/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
index 987b582d10..48f2400569 100755
--- a/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
+++ b/tensorflow/tools/docker/Dockerfile.devel-mkl-horovod
@@ -45,8 +45,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.5 \
- keras_preprocessing==1.0.3 \
+ keras_applications \
+ keras_preprocessing \
matplotlib \
mock \
numpy \
diff --git a/tensorflow/tools/docker/Dockerfile.gpu b/tensorflow/tools/docker/Dockerfile.gpu
index 781bf9e851..7dc92a888b 100644
--- a/tensorflow/tools/docker/Dockerfile.gpu
+++ b/tensorflow/tools/docker/Dockerfile.gpu
@@ -42,8 +42,8 @@ RUN pip --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.5 \
- keras_preprocessing==1.0.3 \
+ keras_applications \
+ keras_preprocessing \
matplotlib \
numpy \
pandas \
diff --git a/tensorflow/tools/docker/Dockerfile.mkl b/tensorflow/tools/docker/Dockerfile.mkl
index 641c9e3b16..ac41cffe4b 100755
--- a/tensorflow/tools/docker/Dockerfile.mkl
+++ b/tensorflow/tools/docker/Dockerfile.mkl
@@ -38,8 +38,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.5 \
- keras_preprocessing==1.0.3 \
+ keras_applications \
+ keras_preprocessing \
matplotlib \
numpy \
pandas \
diff --git a/tensorflow/tools/docker/Dockerfile.mkl-horovod b/tensorflow/tools/docker/Dockerfile.mkl-horovod
index 2b11679f54..4daf4fefff 100755
--- a/tensorflow/tools/docker/Dockerfile.mkl-horovod
+++ b/tensorflow/tools/docker/Dockerfile.mkl-horovod
@@ -38,8 +38,8 @@ RUN ${PIP} --no-cache-dir install \
h5py \
ipykernel \
jupyter \
- keras_applications==1.0.5 \
- keras_preprocessing==1.0.3 \
+ keras_applications \
+ keras_preprocessing \
matplotlib \
numpy \
pandas \
diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD
index 2a858b4fd6..1a53f24177 100644
--- a/tensorflow/tools/docs/BUILD
+++ b/tensorflow/tools/docs/BUILD
@@ -127,7 +127,6 @@ py_test(
name = "build_docs_test",
size = "small",
srcs = ["build_docs_test.py"],
- data = ["//tensorflow/docs_src"],
srcs_version = "PY2AND3",
tags = [
# No reason to run sanitizers or fastbuild for this test.
diff --git a/tensorflow/tools/docs/build_docs_test.py b/tensorflow/tools/docs/build_docs_test.py
index 0cbf8b478f..4d3bedda2d 100644
--- a/tensorflow/tools/docs/build_docs_test.py
+++ b/tensorflow/tools/docs/build_docs_test.py
@@ -30,9 +30,11 @@ from tensorflow.tools.docs import generate_lib
class Flags(object):
resource_root = resource_loader.get_root_dir_with_all_resources()
- src_dir = os.path.join(resource_root, 'tensorflow/docs_src')
+ src_dir = os.path.join(googletest.GetTempDir(), 'input')
+ os.mkdir(src_dir)
base_dir = os.path.join(resource_root, 'tensorflow/')
- output_dir = googletest.GetTempDir()
+ output_dir = os.path.join(googletest.GetTempDir(), 'output')
+ os.mkdir(output_dir)
class BuildDocsTest(googletest.TestCase):
diff --git a/tensorflow/tools/lib_package/BUILD b/tensorflow/tools/lib_package/BUILD
index b9f4902639..85514b8629 100644
--- a/tensorflow/tools/lib_package/BUILD
+++ b/tensorflow/tools/lib_package/BUILD
@@ -137,14 +137,6 @@ genrule(
"@snappy//:COPYING",
"@zlib_archive//:zlib.h",
] + select({
- "//tensorflow:with_jemalloc_linux_x86_64": [
- "@jemalloc//:COPYING",
- ],
- "//tensorflow:with_jemalloc_linux_ppc64le": [
- "@jemalloc//:COPYING",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow/core/kernels:xsmm": [
"@libxsmm_archive//:LICENSE.md",
],
@@ -202,14 +194,6 @@ genrule(
"@snappy//:COPYING",
"@zlib_archive//:zlib.h",
] + select({
- "//tensorflow:with_jemalloc_linux_x86_64": [
- "@jemalloc//:COPYING",
- ],
- "//tensorflow:with_jemalloc_linux_ppc64le": [
- "@jemalloc//:COPYING",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow/core/kernels:xsmm": [
"@libxsmm_archive//:LICENSE.md",
],
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index f1de22300b..164b3d8303 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -64,10 +64,6 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
"//tensorflow/contrib/compiler:xla",
"//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
- "//tensorflow/contrib/data/python/kernel_tests/serialization:dataset_serialization_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:stats_dataset_test_base",
- "//tensorflow/contrib/data/python/kernel_tests:test_utils",
- "//tensorflow/contrib/data/python/ops:contrib_op_loader",
"//tensorflow/contrib/eager/python/examples:examples_pip",
"//tensorflow/contrib/eager/python:evaluator",
"//tensorflow/contrib/gan:gan",
@@ -108,6 +104,8 @@ COMMON_PIP_DEPS = [
"//tensorflow/python:meta_graph_testdata",
"//tensorflow/python:spectral_ops_test_util",
"//tensorflow/python:util_example_parser_configuration",
+ "//tensorflow/python/data/experimental/kernel_tests/serialization:dataset_serialization_test_base",
+ "//tensorflow/python/data/experimental/kernel_tests:stats_dataset_test_base",
"//tensorflow/python/data/kernel_tests:test_base",
"//tensorflow/python/debug:debug_pip",
"//tensorflow/python/eager:eager_pip",
@@ -169,14 +167,6 @@ filegroup(
"@zlib_archive//:zlib.h",
"@org_python_pypi_backports_weakref//:LICENSE",
] + select({
- "//tensorflow:with_jemalloc_linux_x86_64": [
- "@jemalloc//:COPYING",
- ],
- "//tensorflow:with_jemalloc_linux_ppc64le": [
- "@jemalloc//:COPYING",
- ],
- "//conditions:default": [],
- }) + select({
"//tensorflow/core/kernels:xsmm": [
"@libxsmm_archive//:LICENSE.md",
],
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index b95e1f5c87..d864a7a039 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -51,12 +51,11 @@ REQUIRED_PACKAGES = [
'absl-py >= 0.1.6',
'astor >= 0.6.0',
'gast >= 0.2.0',
- 'keras_applications >= 1.0.5',
- 'keras_preprocessing >= 1.0.3',
+ 'keras_applications >= 1.0.6',
+ 'keras_preprocessing >= 1.0.5',
'numpy >= 1.13.3',
'six >= 1.10.0',
- 'protobuf >= 3.6.0',
- 'setuptools <= 39.1.0',
+ 'protobuf >= 3.6.1',
'tensorboard >= 1.11.0, < 1.12.0',
'termcolor >= 1.1.0',
]
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 70bade060e..bcc89ef729 100755
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -110,11 +110,11 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
tf_http_archive(
name = "com_google_absl",
build_file = clean_dep("//third_party:com_google_absl.BUILD"),
- sha256 = "278a1af58b633be886fe81bf7061dca6b5fea99566850d1319fffdaa1a061792",
- strip_prefix = "abseil-cpp-e291c279e458761e77a69b09b129d3d1e81f1e80",
+ sha256 = "7dd09690ae7ca4551de3111d4a86b75b23ec17445f273d3c42bdcdc1c7b02e4e",
+ strip_prefix = "abseil-cpp-48cd2c3f351ff188bc85684b84a91b6e6d17d896",
urls = [
- "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/e291c279e458761e77a69b09b129d3d1e81f1e80.tar.gz",
- "https://github.com/abseil/abseil-cpp/archive/e291c279e458761e77a69b09b129d3d1e81f1e80.tar.gz",
+ "https://mirror.bazel.build/github.com/abseil/abseil-cpp/archive/48cd2c3f351ff188bc85684b84a91b6e6d17d896.tar.gz",
+ "https://github.com/abseil/abseil-cpp/archive/48cd2c3f351ff188bc85684b84a91b6e6d17d896.tar.gz",
],
)
@@ -642,18 +642,6 @@ def tf_workspace(path_prefix = "", tf_repo_name = ""):
testonly_ = True,
)
- tf_http_archive(
- name = "jemalloc",
- build_file = clean_dep("//third_party:jemalloc.BUILD"),
- sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8",
- strip_prefix = "jemalloc-4.4.0",
- system_build_file = clean_dep("//third_party/systemlibs:jemalloc.BUILD"),
- urls = [
- "https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz",
- "https://github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz",
- ],
- )
-
java_import_external(
name = "com_google_testing_compile",
jar_sha256 = "edc180fdcd9f740240da1a7a45673f46f59c5578d8cd3fbc912161f74b5aebb8",
diff --git a/third_party/gpus/crosstool/BUILD.tpl b/third_party/gpus/crosstool/BUILD.tpl
index f638756d23..c8812fab33 100644
--- a/third_party/gpus/crosstool/BUILD.tpl
+++ b/third_party/gpus/crosstool/BUILD.tpl
@@ -2,6 +2,20 @@ licenses(["restricted"])
package(default_visibility = ["//visibility:public"])
+toolchain(
+ name = "toolchain-linux-x86_64",
+ exec_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ target_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ toolchain = ":cc-compiler-local",
+ toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
+)
+
cc_toolchain_suite(
name = "toolchain",
toolchains = {
diff --git a/third_party/gpus/cuda_configure.bzl b/third_party/gpus/cuda_configure.bzl
index f5fdd3a75e..69f4599c16 100644
--- a/third_party/gpus/cuda_configure.bzl
+++ b/third_party/gpus/cuda_configure.bzl
@@ -1107,8 +1107,8 @@ def symlink_genrule_for_dir(
# $(@D) will include the full path to the file.
dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
- # On Windows, symlink is not supported, so we just copy all the files.
- cmd = "cp -f" if _is_windows(repository_ctx) else "ln -s"
+ # Copy the headers to create a sandboxable setup.
+ cmd = "cp -f"
command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
outs.append(' "' + dest_dir + dest_files[i] + '",')
genrule = _genrule(
@@ -1334,27 +1334,14 @@ def _create_local_cuda_repository(repository_ctx):
cuda_defines["%{host_compiler_path}"] = "clang/bin/crosstool_wrapper_driver_is_not_gcc"
cuda_defines["%{host_compiler_warnings}"] = ""
- # TODO(klimek): We currently need to inject "/" as builtin directory path
- # to disable bazel's dependency checks.
- # The problem is that:
- # - the python rules symlink the python headers into the bazel root
- # - the rules use 'includes' in the BUILD file to redirect includes of the
- # python headers through those paths
- # - bazel currently uses -isystem for include paths specified via 'includes'
- # - gcc follows symlinks when resolving files via -isystem paths, and puts
- # the resolved paths into the .d file, which makes the dependency check
- # fail for bazel
- # There are multiple possible ways to solve this:
- # 1. make bazel not use -isystem for paths specified via 'includes'
- # 2. cp the headers instead of symlinking them
- #
- # Once this is fixed, the right builtin directory path is:
- # (host_compiler_includes +
- # "\n cxx_builtin_include_directory: \"%s\"" % cuda_include_path)
- # The cuda directory needs to be passed, as there is currently no rule
- # providing the cuda headers in the same way the python headers are
- # provided.
- cuda_defines["%{host_compiler_includes}"] = "\n cxx_builtin_include_directory: \"/\""
+ # nvcc has the system include paths built in and will automatically
+ # search them; we cannot work around that, so we add the relevant cuda
+ # system paths to the allowed compiler specific include paths.
+ cuda_defines["%{host_compiler_includes}"] = (
+ host_compiler_includes + "\n" +
+ _cuda_include_path(repository_ctx, cuda_config) +
+ "\n cxx_builtin_include_directory: \"%s\"" % cupti_header_dir +
+ "\n cxx_builtin_include_directory: \"%s\"" % cudnn_header_dir)
nvcc_path = str(repository_ctx.path("%s/bin/nvcc%s" %
(
cuda_config.cuda_toolkit_path,
diff --git a/third_party/jemalloc.BUILD b/third_party/jemalloc.BUILD
deleted file mode 100644
index 1b0829b8fe..0000000000
--- a/third_party/jemalloc.BUILD
+++ /dev/null
@@ -1,356 +0,0 @@
-# Description:
-# jemalloc - a general-purpose scalable concurrent malloc implementation
-
-licenses(["notice"]) # BSD
-
-exports_files(["COPYING"])
-
-load("@org_tensorflow//third_party:common.bzl", "template_rule")
-
-cc_library(
- name = "jemalloc_headers",
- hdrs = ["include/jemalloc/jemalloc.h"],
- includes = ["include"],
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "jemalloc_impl",
- srcs = [
- "src/arena.c",
- "src/atomic.c",
- "src/base.c",
- "src/bitmap.c",
- "src/chunk.c",
- "src/chunk_dss.c",
- "src/chunk_mmap.c",
- "src/ckh.c",
- "src/ctl.c",
- "src/extent.c",
- "src/hash.c",
- "src/huge.c",
- "src/jemalloc.c",
- "src/mb.c",
- "src/mutex.c",
- "src/nstime.c",
- "src/pages.c",
- "src/prng.c",
- "src/prof.c",
- "src/quarantine.c",
- "src/rtree.c",
- "src/spin.c",
- "src/stats.c",
- "src/tcache.c",
- "src/tsd.c",
- "src/util.c",
- "src/witness.c",
- ],
- hdrs = [
- "include/jemalloc/internal/arena.h",
- "include/jemalloc/internal/assert.h",
- "include/jemalloc/internal/atomic.h",
- "include/jemalloc/internal/base.h",
- "include/jemalloc/internal/bitmap.h",
- "include/jemalloc/internal/chunk.h",
- "include/jemalloc/internal/chunk_dss.h",
- "include/jemalloc/internal/chunk_mmap.h",
- "include/jemalloc/internal/ckh.h",
- "include/jemalloc/internal/ctl.h",
- "include/jemalloc/internal/extent.h",
- "include/jemalloc/internal/hash.h",
- "include/jemalloc/internal/huge.h",
- "include/jemalloc/internal/jemalloc_internal.h",
- "include/jemalloc/internal/jemalloc_internal_decls.h",
- "include/jemalloc/internal/jemalloc_internal_defs.h",
- "include/jemalloc/internal/jemalloc_internal_macros.h",
- "include/jemalloc/internal/mb.h",
- "include/jemalloc/internal/mutex.h",
- "include/jemalloc/internal/nstime.h",
- "include/jemalloc/internal/pages.h",
- "include/jemalloc/internal/ph.h",
- "include/jemalloc/internal/private_namespace.h",
- "include/jemalloc/internal/prng.h",
- "include/jemalloc/internal/prof.h",
- "include/jemalloc/internal/ql.h",
- "include/jemalloc/internal/qr.h",
- "include/jemalloc/internal/quarantine.h",
- "include/jemalloc/internal/rb.h",
- "include/jemalloc/internal/rtree.h",
- "include/jemalloc/internal/size_classes.h",
- "include/jemalloc/internal/smoothstep.h",
- "include/jemalloc/internal/spin.h",
- "include/jemalloc/internal/stats.h",
- "include/jemalloc/internal/tcache.h",
- "include/jemalloc/internal/ticker.h",
- "include/jemalloc/internal/tsd.h",
- "include/jemalloc/internal/util.h",
- "include/jemalloc/internal/valgrind.h",
- "include/jemalloc/internal/witness.h",
- ],
- # Same flags that jemalloc uses to build.
- copts = [
- "-O3",
- "-funroll-loops",
- "-D_GNU_SOURCE",
- "-D_REENTRANT",
- ],
- includes = ["include"],
- # pthread_atfork() is called for PPC.
- linkopts = select({
- "@org_tensorflow//tensorflow:linux_ppc64le": [
- "-lpthread",
- ],
- "@org_tensorflow//tensorflow:linux_x86_64": [
- "-lpthread",
- ],
- "//conditions:default": [
- ],
- }),
- visibility = ["//visibility:public"],
- deps = [":jemalloc_headers"],
-)
-
-sh_binary(
- name = "jemalloc_sh",
- srcs = ["include/jemalloc/jemalloc.sh"],
-)
-
-genrule(
- name = "jemalloc_h",
- srcs = [
- ":jemalloc_defs_h",
- ":jemalloc_macros_h",
- ":jemalloc_mangle_h",
- ":jemalloc_protos_h",
- ":jemalloc_rename_h",
- ":jemalloc_typedefs_h",
- ],
- outs = ["include/jemalloc/jemalloc.h"],
- cmd = "$(location :jemalloc_sh) $$(dirname $(location :jemalloc_defs_h))/../../ >$@",
- tools = [":jemalloc_sh"],
-)
-
-# Add to this list if you want to export more symbols from jemalloc.
-genrule(
- name = "public_symbols_txt",
- outs = ["include/jemalloc/internal/public_symbols.txt"],
- cmd = "\n".join([
- "cat <<'EOF' > $@",
- "free:jemalloc_free",
- "malloc:jemalloc_malloc",
- "posix_memalign:jemalloc_posix_memalign",
- "realloc:jemalloc_realloc",
- "EOF",
- ]),
-)
-
-sh_binary(
- name = "jemalloc_mangle_sh",
- srcs = ["include/jemalloc/jemalloc_mangle.sh"],
-)
-
-genrule(
- name = "jemalloc_mangle_h",
- srcs = [":public_symbols_txt"],
- outs = ["include/jemalloc/jemalloc_mangle.h"],
- cmd = "$(location :jemalloc_mangle_sh) $(location :public_symbols_txt) je_ >$@",
- tools = [":jemalloc_mangle_sh"],
-)
-
-sh_binary(
- name = "jemalloc_rename_sh",
- srcs = ["include/jemalloc/jemalloc_rename.sh"],
-)
-
-genrule(
- name = "jemalloc_rename_h",
- srcs = [":public_symbols_txt"],
- outs = ["include/jemalloc/jemalloc_rename.h"],
- cmd = "$(location :jemalloc_rename_sh) $(location :public_symbols_txt) >$@",
- tools = [":jemalloc_rename_sh"],
-)
-
-sh_binary(
- name = "private_namespace_sh",
- srcs = ["include/jemalloc/internal/private_namespace.sh"],
-)
-
-genrule(
- name = "private_namespace_h",
- srcs = ["include/jemalloc/internal/private_symbols.txt"],
- outs = ["include/jemalloc/internal/private_namespace.h"],
- cmd = "$(location :private_namespace_sh) $(location include/jemalloc/internal/private_symbols.txt) >$@",
- tools = [":private_namespace_sh"],
-)
-
-sh_binary(
- name = "public_namespace_sh",
- srcs = ["include/jemalloc/internal/public_namespace.sh"],
-)
-
-genrule(
- name = "public_namespace_h",
- srcs = [":public_symbols_txt"],
- outs = ["include/jemalloc/internal/public_namespace.h"],
- cmd = "$(location :public_namespace_sh) $(location :public_symbols_txt) >$@",
- tools = [":public_namespace_sh"],
-)
-
-sh_binary(
- name = "size_classes_sh",
- srcs = ["include/jemalloc/internal/size_classes.sh"],
-)
-
-# Size classes for Linux x86_64 and ppc64le. Update if adding builds for other
-# architectures. See size_classes.sh for details on the arguments.
-# For default case, kept the arguments same as that of x86_64 for now.
-genrule(
- name = "size_classes_h",
- outs = ["include/jemalloc/internal/size_classes.h"],
- cmd = select({
- "@org_tensorflow//tensorflow:linux_ppc64le": "$(location :size_classes_sh) \"3 4\" 3 16 2 >$@",
- "@org_tensorflow//tensorflow:linux_x86_64": "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@",
- "//conditions:default": "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@",
- }),
- tools = [":size_classes_sh"],
-)
-
-template_rule(
- name = "jemalloc_internal_h",
- src = "include/jemalloc/internal/jemalloc_internal.h.in",
- out = "include/jemalloc/internal/jemalloc_internal.h",
- substitutions = {
- "@private_namespace@": "je_",
- "@install_suffix@": "",
- },
-)
-
-template_rule(
- name = "jemalloc_internal_defs_h",
- src = "include/jemalloc/internal/jemalloc_internal_defs.h.in",
- out = "include/jemalloc/internal/jemalloc_internal_defs.h",
- substitutions = {
- "#undef JEMALLOC_PREFIX": "#define JEMALLOC_PREFIX \"jemalloc_\"",
- "#undef JEMALLOC_CPREFIX": "#define JEMALLOC_CPREFIX \"JEMALLOC_\"",
- "#undef JEMALLOC_PRIVATE_NAMESPACE": "#define JEMALLOC_PRIVATE_NAMESPACE je_",
- "#undef CPU_SPINWAIT": "\n".join([
- "#if defined(__powerpc64__) || defined(__powerpc__)",
- "#define CPU_SPINWAIT __asm__ volatile(\"or 27,27,27\")",
- "#else",
- "#define CPU_SPINWAIT __asm__ volatile(\"pause\")",
- "#endif",
- ]),
- "#undef JEMALLOC_HAVE_BUILTIN_CLZ": "#define JEMALLOC_HAVE_BUILTIN_CLZ",
- "#undef JEMALLOC_USE_SYSCALL": "#define JEMALLOC_USE_SYSCALL",
- "#undef JEMALLOC_HAVE_SECURE_GETENV": "#define JEMALLOC_HAVE_SECURE_GETENV",
- "#undef JEMALLOC_HAVE_PTHREAD_ATFORK": "#define JEMALLOC_HAVE_PTHREAD_ATFORK",
- "#undef JEMALLOC_HAVE_CLOCK_MONOTONIC_COARSE": "#define JEMALLOC_HAVE_CLOCK_MONOTONIC_COARSE 1",
- # Newline required because of substitution conflicts.
- "#undef JEMALLOC_HAVE_CLOCK_MONOTONIC\n": "#define JEMALLOC_HAVE_CLOCK_MONOTONIC 1\n",
- "#undef JEMALLOC_THREADED_INIT": "#define JEMALLOC_THREADED_INIT",
- "#undef JEMALLOC_TLS_MODEL": "#define JEMALLOC_TLS_MODEL __attribute__((tls_model(\"initial-exec\")))",
- "#undef JEMALLOC_CC_SILENCE": "#define JEMALLOC_CC_SILENCE",
- "#undef JEMALLOC_STATS": "#define JEMALLOC_STATS",
- "#undef JEMALLOC_TCACHE": "#define JEMALLOC_TCACHE",
- "#undef JEMALLOC_DSS": "#define JEMALLOC_DSS",
- "#undef JEMALLOC_FILL": "#define JEMALLOC_FILL",
- "#undef LG_TINY_MIN": "#define LG_TINY_MIN 3",
- "#undef LG_PAGE": "\n".join([
- "#if defined(__powerpc64__) || defined(__powerpc__)",
- "#define LG_PAGE 16",
- "#else",
- "#define LG_PAGE 12",
- "#endif",
- ]),
- "#undef JEMALLOC_MAPS_COALESCE": "#define JEMALLOC_MAPS_COALESCE",
- "#undef JEMALLOC_TLS": "#define JEMALLOC_TLS",
- "#undef JEMALLOC_INTERNAL_UNREACHABLE": "#define JEMALLOC_INTERNAL_UNREACHABLE __builtin_unreachable",
- "#undef JEMALLOC_INTERNAL_FFSLL": "#define JEMALLOC_INTERNAL_FFSLL __builtin_ffsll",
- # Newline required because of substitution conflicts.
- "#undef JEMALLOC_INTERNAL_FFSL\n": "#define JEMALLOC_INTERNAL_FFSL __builtin_ffsl\n",
- "#undef JEMALLOC_INTERNAL_FFS\n": "#define JEMALLOC_INTERNAL_FFS __builtin_ffs\n",
- "#undef JEMALLOC_CACHE_OBLIVIOUS": "#define JEMALLOC_CACHE_OBLIVIOUS",
- "#undef JEMALLOC_PROC_SYS_VM_OVERCOMMIT_MEMORY": "#define JEMALLOC_PROC_SYS_VM_OVERCOMMIT_MEMORY",
- "#undef JEMALLOC_HAVE_MADVISE": "#define JEMALLOC_HAVE_MADVISE",
- "#undef JEMALLOC_PURGE_MADVISE_DONTNEED": "#define JEMALLOC_PURGE_MADVISE_DONTNEED",
- "#undef JEMALLOC_THP": "#define JEMALLOC_THP",
- "#undef JEMALLOC_HAS_ALLOCA_H": "#define JEMALLOC_HAS_ALLOCA_H 1",
- # Newline required because of substitution conflicts.
- "#undef LG_SIZEOF_INT\n": "#define LG_SIZEOF_INT 2\n",
- "#undef LG_SIZEOF_LONG\n": "#define LG_SIZEOF_LONG 3\n",
- "#undef LG_SIZEOF_LONG_LONG": "#define LG_SIZEOF_LONG_LONG 3",
- "#undef LG_SIZEOF_INTMAX_T": "#define LG_SIZEOF_INTMAX_T 3",
- "#undef JEMALLOC_GLIBC_MALLOC_HOOK": "#define JEMALLOC_GLIBC_MALLOC_HOOK",
- "#undef JEMALLOC_GLIBC_MEMALIGN_HOOK": "#define JEMALLOC_GLIBC_MEMALIGN_HOOK",
- "#undef JEMALLOC_HAVE_PTHREAD_MUTEX_ADAPTIVE_NP": "#define JEMALLOC_HAVE_PTHREAD_MUTEX_ADAPTIVE_NP",
- "#undef JEMALLOC_CONFIG_MALLOC_CONF": "#define JEMALLOC_CONFIG_MALLOC_CONF \"\"",
- },
-)
-
-template_rule(
- name = "jemalloc_defs_h",
- src = "include/jemalloc/jemalloc_defs.h.in",
- out = "include/jemalloc/jemalloc_defs.h",
- substitutions = {
- "#undef JEMALLOC_HAVE_ATTR": "#define JEMALLOC_HAVE_ATTR",
- "#undef JEMALLOC_HAVE_ATTR_ALLOC_SIZE": "#define JEMALLOC_HAVE_ATTR_ALLOC_SIZE",
- "#undef JEMALLOC_HAVE_ATTR_FORMAT_GNU_PRINTF": "#define JEMALLOC_HAVE_ATTR_FORMAT_GNU_PRINTF",
- "#undef JEMALLOC_HAVE_ATTR_FORMAT_PRINTF": "#define JEMALLOC_HAVE_ATTR_FORMAT_PRINTF",
- "#undef JEMALLOC_OVERRIDE_MEMALIGN": "#define JEMALLOC_OVERRIDE_MEMALIGN",
- "#undef JEMALLOC_OVERRIDE_VALLOC": "#define JEMALLOC_OVERRIDE_VALLOC",
- "#undef JEMALLOC_USABLE_SIZE_CONST": "#define JEMALLOC_USABLE_SIZE_CONST",
- "#undef JEMALLOC_USE_CXX_THROW": "#define JEMALLOC_USE_CXX_THROW",
- "#undef LG_SIZEOF_PTR": "#define LG_SIZEOF_PTR 3",
- },
-)
-
-template_rule(
- name = "jemalloc_macros_h",
- src = "include/jemalloc/jemalloc_macros.h.in",
- out = "include/jemalloc/jemalloc_macros.h",
- substitutions = {
- "@jemalloc_version@": "0.0.0",
- "@jemalloc_version_major@": "0",
- "@jemalloc_version_minor@": "0",
- "@jemalloc_version_bugfix@": "0",
- "@jemalloc_version_nrev@": "0",
- "@jemalloc_version_gid@": "0000000000000000000000000000000000000000",
- },
-)
-
-template_rule(
- name = "jemalloc_protos_h",
- src = "include/jemalloc/jemalloc_protos.h.in",
- out = "include/jemalloc/jemalloc_protos.h",
- substitutions = {
- "@aligned_alloc": "aligned_alloc",
- "@calloc": "calloc",
- "@cbopaque": "cbopaque",
- "@dallocx": "dallocx",
- "@free": "free",
- "@je": "je",
- "@mallctl": "mallctl",
- "@mallctlnametomib": "mallctlnametomib",
- "@mallctlbymib": "mallctlbymib",
- "@malloc_stats_print": "malloc_stats_print",
- "@malloc_usable_size": "malloc_usable_size",
- "@malloc": "malloc",
- "@mallocx": "mallocx",
- "@memalign": "memalign",
- "@nallocx": "nallocx",
- "@posix_memalign": "posix_memalign",
- "@rallocx": "rallocx",
- "@realloc": "realloc",
- "@sallocx": "sallocx",
- "@sdallocx": "sdallocx",
- "@valloc": "valloc",
- "@xallocx": "xallocx",
- },
-)
-
-template_rule(
- name = "jemalloc_typedefs_h",
- src = "include/jemalloc/jemalloc_typedefs.h.in",
- out = "include/jemalloc/jemalloc_typedefs.h",
- substitutions = {},
-)
diff --git a/third_party/nccl/nccl_configure.bzl b/third_party/nccl/nccl_configure.bzl
index ce9447096e..d78fe8f3aa 100644
--- a/third_party/nccl/nccl_configure.bzl
+++ b/third_party/nccl/nccl_configure.bzl
@@ -5,6 +5,7 @@
* `TF_NCCL_VERSION`: The NCCL version.
* `NCCL_INSTALL_PATH`: The installation path of the NCCL library.
+ * `NCCL_HDR_PATH`: The installation path of the NCCL header files.
"""
load(
@@ -15,6 +16,7 @@ load(
)
_NCCL_INSTALL_PATH = "NCCL_INSTALL_PATH"
+_NCCL_HDR_PATH = "NCCL_HDR_PATH"
_TF_NCCL_VERSION = "TF_NCCL_VERSION"
_TF_NCCL_CONFIG_REPO = "TF_NCCL_CONFIG_REPO"
@@ -68,7 +70,7 @@ def _find_nccl_header(repository_ctx, nccl_install_path):
return header_path
-def _check_nccl_version(repository_ctx, nccl_install_path, nccl_version):
+def _check_nccl_version(repository_ctx, nccl_install_path, nccl_hdr_path, nccl_version):
"""Checks whether the header file matches the specified version of NCCL.
Args:
@@ -79,7 +81,9 @@ def _check_nccl_version(repository_ctx, nccl_install_path, nccl_version):
Returns:
A string containing the library version of NCCL.
"""
- header_path = _find_nccl_header(repository_ctx, nccl_install_path)
+ header_path = repository_ctx.path("%s/nccl.h" % nccl_hdr_path)
+ if not header_path.exists:
+ header_path = _find_nccl_header(repository_ctx, nccl_install_path)
header_dir = str(header_path.realpath.dirname)
major_version = find_cuda_define(repository_ctx, header_dir, "nccl.h",
_DEFINE_NCCL_MAJOR)
@@ -138,10 +142,12 @@ def _nccl_configure_impl(repository_ctx):
else:
# Create target for locally installed NCCL.
nccl_install_path = repository_ctx.os.environ[_NCCL_INSTALL_PATH].strip()
- _check_nccl_version(repository_ctx, nccl_install_path, nccl_version)
+ nccl_hdr_path = repository_ctx.os.environ[_NCCL_HDR_PATH].strip()
+ _check_nccl_version(repository_ctx, nccl_install_path, nccl_hdr_path, nccl_version)
repository_ctx.template("BUILD", _NCCL_LOCAL_BUILD_TEMPLATE, {
"%{version}": nccl_version,
"%{install_path}": nccl_install_path,
+ "%{hdr_path}": nccl_hdr_path,
})
@@ -149,6 +155,7 @@ nccl_configure = repository_rule(
implementation=_nccl_configure_impl,
environ=[
_NCCL_INSTALL_PATH,
+ _NCCL_HDR_PATH,
_TF_NCCL_VERSION,
],
)
diff --git a/third_party/nccl/system.BUILD.tpl b/third_party/nccl/system.BUILD.tpl
index 7ca835dedf..a07f54955f 100644
--- a/third_party/nccl/system.BUILD.tpl
+++ b/third_party/nccl/system.BUILD.tpl
@@ -20,7 +20,7 @@ genrule(
"libnccl.so.%{version}",
"nccl.h",
],
- cmd = """cp "%{install_path}/include/nccl.h" "$(@D)/nccl.h" &&
- cp "%{install_path}/lib/libnccl.so.%{version}" "$(@D)/libnccl.so.%{version}" """,
+ cmd = """cp "%{hdr_path}/nccl.h" "$(@D)/nccl.h" &&
+ cp "%{install_path}/libnccl.so.%{version}" "$(@D)/libnccl.so.%{version}" """,
)
diff --git a/third_party/py/python_configure.bzl b/third_party/py/python_configure.bzl
index 3c7e5c8469..53264630a1 100644
--- a/third_party/py/python_configure.bzl
+++ b/third_party/py/python_configure.bzl
@@ -130,8 +130,8 @@ def _symlink_genrule_for_dir(repository_ctx, src_dir, dest_dir, genrule_name,
# If we have only one file to link we do not want to use the dest_dir, as
# $(@D) will include the full path to the file.
dest = '$(@D)/' + dest_dir + dest_files[i] if len(dest_files) != 1 else '$(@D)/' + dest_files[i]
- # On Windows, symlink is not supported, so we just copy all the files.
- cmd = 'cp -f' if _is_windows(repository_ctx) else 'ln -s'
+ # Copy the headers to create a sandboxable setup.
+ cmd = 'cp -f'
command.append(cmd + ' "%s" "%s"' % (src_files[i] , dest))
outs.append(' "' + dest_dir + dest_files[i] + '",')
genrule = _genrule(src_dir, genrule_name, " && ".join(command),
diff --git a/third_party/systemlibs/jemalloc.BUILD b/third_party/systemlibs/jemalloc.BUILD
deleted file mode 100644
index 6a48d582ba..0000000000
--- a/third_party/systemlibs/jemalloc.BUILD
+++ /dev/null
@@ -1,30 +0,0 @@
-licenses(["notice"]) # BSD
-
-filegroup(
- name = "COPYING",
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "jemalloc_headers",
- defines = [
- "jemalloc_posix_memalign=posix_memalign",
- "jemalloc_malloc=malloc",
- "jemalloc_realloc=realloc",
- "jemalloc_free=free",
- ],
- visibility = ["//visibility:public"],
-)
-
-cc_library(
- name = "jemalloc_impl",
- linkopts = ["-ljemalloc"],
- defines = [
- "jemalloc_posix_memalign=posix_memalign",
- "jemalloc_malloc=malloc",
- "jemalloc_realloc=realloc",
- "jemalloc_free=free",
- ],
- visibility = ["//visibility:public"],
- deps = [":jemalloc_headers"],
-)
diff --git a/third_party/systemlibs/syslibs_configure.bzl b/third_party/systemlibs/syslibs_configure.bzl
index 8b0ab39eaf..b03d3380d7 100644
--- a/third_party/systemlibs/syslibs_configure.bzl
+++ b/third_party/systemlibs/syslibs_configure.bzl
@@ -23,7 +23,6 @@ VALID_LIBS = [
"gast_archive",
"gif_archive",
"grpc",
- "jemalloc",
"jpeg",
"jsoncpp_git",
"lmdb",
diff --git a/third_party/toolchains/BUILD b/third_party/toolchains/BUILD
index 7256a7d96e..bcbc4dda11 100644
--- a/third_party/toolchains/BUILD
+++ b/third_party/toolchains/BUILD
@@ -26,12 +26,10 @@ platform(
constraint_values = [
"@bazel_tools//platforms:x86_64",
"@bazel_tools//platforms:linux",
- "@bazel_tools//tools/cpp:clang",
- "@bazel_toolchains//constraints:xenial",
],
remote_execution_properties = """
properties: {
name: "container-image"
- value:"docker://gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04@sha256:06b585f42eed3b2030e9566b8f88f48d7472fa0f47e59765bc115376c8801bdf"
+ value:"docker://gcr.io/asci-toolchain/nosla-cuda9.0-cudnn7-ubuntu14.04@sha256:e5099ff15650986e268a43ee99e2d2b7ffe2459b8b6935385078d1d3b2ed4d02"
}""",
)
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
index 2d3e41127d..05abcb56d8 100755
--- a/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
+++ b/third_party/toolchains/preconfig/ubuntu14.04/cuda9.0-cudnn7/cuda/BUILD
@@ -1253,7 +1253,7 @@ genrule(
"cuda/lib/libcupti.so.9.0",
],
cmd = """
-if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0.480" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0.176" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0.176" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0.176" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7.2.1" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda-9.0/extras/CUPTI/lib64/libcupti.so.9.0.176" "$(@D)/cuda/lib/libcupti.so.9.0"
+if [ -d "$(@D)/extras" ]; then rm $(@D)/extras -drf; fi && if [ -d "$(@D)/include" ]; then rm $(@D)/include -drf; fi && if [ -d "$(@D)/lib" ]; then rm $(@D)/lib -drf; fi && if [ -d "$(@D)/nvvm" ]; then rm $(@D)/nvvm -drf; fi && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/stubs/libcuda.so" "$(@D)/cuda/lib/libcuda.so" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart.so.9.0.176" "$(@D)/cuda/lib/libcudart.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcudart_static.a" "$(@D)/cuda/lib/libcudart_static.a" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcublas.so.9.0.480" "$(@D)/cuda/lib/libcublas.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcusolver.so.9.0.176" "$(@D)/cuda/lib/libcusolver.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcurand.so.9.0.176" "$(@D)/cuda/lib/libcurand.so.9.0" && cp "/usr/local/cuda-9.0/targets/x86_64-linux/lib/libcufft.so.9.0.176" "$(@D)/cuda/lib/libcufft.so.9.0" && cp "/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.4" "$(@D)/cuda/lib/libcudnn.so.7" && cp "/usr/local/cuda-9.0/extras/CUPTI/lib64/libcupti.so.9.0.176" "$(@D)/cuda/lib/libcupti.so.9.0"
""",
)
diff --git a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
index a56b4513fb..6442e7628a 100755
--- a/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
+++ b/third_party/toolchains/preconfig/ubuntu14.04/gcc-nvcc/BUILD
@@ -2,6 +2,20 @@ licenses(["restricted"])
package(default_visibility = ["//visibility:public"])
+toolchain(
+ name = "toolchain-linux-x86_64",
+ exec_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ target_compatible_with = [
+ "@bazel_tools//platforms:linux",
+ "@bazel_tools//platforms:x86_64",
+ ],
+ toolchain = ":cc-compiler-local",
+ toolchain_type = "@bazel_tools//tools/cpp:toolchain_type",
+)
+
cc_toolchain_suite(
name = "toolchain",
toolchains = {