aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Patrick Nguyen <drpng@google.com>2018-04-25 20:30:00 -0700
committerGravatar Patrick Nguyen <drpng@google.com>2018-04-25 20:30:00 -0700
commitf8d7553d0f34b621d6123e153e69b0babb09d22c (patch)
tree62355f3d9d9f7a4f81f31910779ab399e4850ded
parentadf045607cc4126366ebb84ee2109f88c6ab25fc (diff)
parent43a7072882196c7ac2d9429050a3140b1ecb52db (diff)
Merge commit for internal changes.
-rw-r--r--tensorflow/c/eager/BUILD2
-rw-r--r--tensorflow/c/eager/c_api.cc57
-rw-r--r--tensorflow/c/eager/c_api.h14
-rw-r--r--tensorflow/c/python_api.cc28
-rw-r--r--tensorflow/c/python_api.h12
-rw-r--r--tensorflow/compiler/aot/compile.cc5
-rw-r--r--tensorflow/compiler/aot/test.cc1
-rw-r--r--tensorflow/compiler/jit/BUILD38
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc2
-rw-r--r--tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc2
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.cc14
-rw-r--r--tensorflow/compiler/jit/kernels/xla_launch_op.h2
-rw-r--r--tensorflow/compiler/jit/xla_compile_on_demand_op.cc2
-rw-r--r--tensorflow/compiler/jit/xla_device.cc43
-rw-r--r--tensorflow/compiler/jit/xla_device.h21
-rw-r--r--tensorflow/compiler/jit/xla_device_context.cc2
-rw-r--r--tensorflow/compiler/jit/xla_device_context.h15
-rw-r--r--tensorflow/compiler/jit/xla_gpu_device.cc9
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.cc26
-rw-r--r--tensorflow/compiler/jit/xla_launch_util.h13
-rw-r--r--tensorflow/compiler/jit/xla_tensor.cc9
-rw-r--r--tensorflow/compiler/jit/xla_tensor.h3
-rw-r--r--tensorflow/compiler/tests/BUILD31
-rw-r--r--tensorflow/compiler/tests/eager_test.py137
-rw-r--r--tensorflow/compiler/tests/placeholder_test.py48
-rw-r--r--tensorflow/compiler/tests/ternary_ops_test.py61
-rw-r--r--tensorflow/compiler/tf2xla/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc2
-rw-r--r--tensorflow/compiler/tf2xla/kernels/identity_op.cc1
-rw-r--r--tensorflow/compiler/tf2xla/lib/triangular_solve.cc82
-rw-r--r--tensorflow/compiler/xla/BUILD2
-rw-r--r--tensorflow/compiler/xla/client/client.cc5
-rw-r--r--tensorflow/compiler/xla/client/client.h3
-rw-r--r--tensorflow/compiler/xla/client/local_client.cc12
-rw-r--r--tensorflow/compiler/xla/client/xla_client/BUILD2
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_computation.cc11
-rw-r--r--tensorflow/compiler/xla/client/xla_client/xla_computation.h4
-rw-r--r--tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc2
-rw-r--r--tensorflow/compiler/xla/ptr_util.h22
-rw-r--r--tensorflow/compiler/xla/reference_util.cc6
-rw-r--r--tensorflow/compiler/xla/service/BUILD4
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.cc45
-rw-r--r--tensorflow/compiler/xla/service/allocation_tracker.h32
-rw-r--r--tensorflow/compiler/xla/service/backend.cc1
-rw-r--r--tensorflow/compiler/xla/service/computation_layout.cc7
-rw-r--r--tensorflow/compiler/xla/service/computation_layout.h5
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD50
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc258
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.cc14
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_executable.h8
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.cc7
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_options.h1
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc192
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h80
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc21
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h2
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc528
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h137
-rw-r--r--tensorflow/compiler/xla/service/elemental_ir_emitter.cc18
-rw-r--r--tensorflow/compiler/xla/service/executable.cc14
-rw-r--r--tensorflow/compiler/xla/service/executable.h12
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.cc8
-rw-r--r--tensorflow/compiler/xla/service/gpu/gpu_executable.h4
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto2
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.cc4
-rw-r--r--tensorflow/compiler/xla/service/hlo_cse.h11
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.cc63
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers.h69
-rw-r--r--tensorflow/compiler/xla/service/hlo_matchers_test.cc58
-rw-r--r--tensorflow/compiler/xla/service/hlo_runner.cc14
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc39
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding_test.cc15
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.cc8
-rw-r--r--tensorflow/compiler/xla/service/interpreter/executable.h4
-rw-r--r--tensorflow/compiler/xla/service/interpreter/platform.cc4
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.cc328
-rw-r--r--tensorflow/compiler/xla/service/layout_assignment.h65
-rw-r--r--tensorflow/compiler/xla/service/service.cc19
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.cc4
-rw-r--r--tensorflow/compiler/xla/service/shaped_buffer.h6
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.cc15
-rw-r--r--tensorflow/compiler/xla/service/transfer_manager.h5
-rw-r--r--tensorflow/compiler/xla/service/tuple_simplifier.cc25
-rw-r--r--tensorflow/compiler/xla/shape_layout.h3
-rw-r--r--tensorflow/compiler/xla/shape_util.h1
-rw-r--r--tensorflow/compiler/xla/statusor.h26
-rw-r--r--tensorflow/compiler/xla/tests/BUILD99
-rw-r--r--tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc27
-rw-r--r--tensorflow/compiler/xla/tests/axpy_simple_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc13
-rw-r--r--tensorflow/compiler/xla/tests/binop_scaling_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/broadcast_simple_test.cc82
-rw-r--r--tensorflow/compiler/xla/tests/build_defs.bzl24
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.cc24
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h9
-rw-r--r--tensorflow/compiler/xla/tests/client_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc38
-rw-r--r--tensorflow/compiler/xla/tests/convolution_variants_test.cc167
-rw-r--r--tensorflow/compiler/xla/tests/dynamic_ops_test.cc23
-rw-r--r--tensorflow/compiler/xla/tests/execution_profile_test.cc3
-rw-r--r--tensorflow/compiler/xla/tests/fusion_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/gather_operation_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/local_client_allocation_test.cc7
-rw-r--r--tensorflow/compiler/xla/tests/local_client_execute_test.cc74
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.cc10
-rw-r--r--tensorflow/compiler/xla/tests/local_client_test_base.h10
-rw-r--r--tensorflow/compiler/xla/tests/map_test.cc44
-rw-r--r--tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc18
-rw-r--r--tensorflow/compiler/xla/tests/multidimensional_slice_test.cc6
-rw-r--r--tensorflow/compiler/xla/tests/params_test.cc85
-rw-r--r--tensorflow/compiler/xla/tests/pred_test.cc49
-rw-r--r--tensorflow/compiler/xla/tests/prng_test.cc37
-rw-r--r--tensorflow/compiler/xla/tests/query_inferred_shape_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/reduce_test.cc104
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/replay_test.cc38
-rw-r--r--tensorflow/compiler/xla/tests/reshape_motion_test.cc5
-rw-r--r--tensorflow/compiler/xla/tests/reverse_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/test_macros.h8
-rw-r--r--tensorflow/compiler/xla/tests/test_utils_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/tuple_test.cc13
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc8
-rw-r--r--tensorflow/compiler/xla/tests/vector_ops_simple_test.cc80
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc4
-rw-r--r--tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc8
-rw-r--r--tensorflow/compiler/xla/tools/parser/hlo_parser.cc16
-rw-r--r--tensorflow/compiler/xla/window_util.cc5
-rw-r--r--tensorflow/contrib/BUILD3
-rw-r--r--tensorflow/contrib/__init__.py1
-rw-r--r--tensorflow/contrib/all_reduce/python/all_reduce.py28
-rw-r--r--tensorflow/contrib/autograph/README.md119
-rw-r--r--tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py32
-rw-r--r--tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py6
-rw-r--r--tensorflow/contrib/cmake/python_modules.txt7
-rw-r--r--tensorflow/contrib/cmake/tf_core_kernels.cmake1
-rw-r--r--tensorflow/contrib/coder/BUILD72
-rw-r--r--tensorflow/contrib/coder/__init__.py3
-rw-r--r--tensorflow/contrib/coder/python/layers/entropybottleneck.py697
-rw-r--r--tensorflow/contrib/coder/python/layers/entropybottleneck_test.py315
-rw-r--r--tensorflow/contrib/constrained_optimization/BUILD91
-rw-r--r--tensorflow/contrib/constrained_optimization/README.md345
-rw-r--r--tensorflow/contrib/constrained_optimization/__init__.py41
-rw-r--r--tensorflow/contrib/constrained_optimization/python/candidates.py319
-rw-r--r--tensorflow/contrib/constrained_optimization/python/candidates_test.py95
-rw-r--r--tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py123
-rw-r--r--tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py208
-rw-r--r--tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py375
-rw-r--r--tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py136
-rw-r--r--tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py595
-rw-r--r--tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py212
-rw-r--r--tensorflow/contrib/constrained_optimization/python/test_util.py58
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py32
-rw-r--r--tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py35
-rw-r--r--tensorflow/contrib/distribute/python/estimator_integration_test.py3
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/BUILD11
-rw-r--r--tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py43
-rw-r--r--tensorflow/contrib/estimator/BUILD1
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py7
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py53
-rw-r--r--tensorflow/contrib/factorization/BUILD1
-rw-r--r--tensorflow/contrib/factorization/kernels/clustering_ops.cc1
-rw-r--r--tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc2
-rw-r--r--tensorflow/contrib/framework/BUILD3
-rw-r--r--tensorflow/contrib/framework/__init__.py18
-rw-r--r--tensorflow/contrib/framework/kernels/zero_initializer_op.cc71
-rw-r--r--tensorflow/contrib/framework/ops/variable_ops.cc29
-rw-r--r--tensorflow/contrib/framework/python/ops/critical_section_test.py21
-rw-r--r--tensorflow/contrib/framework/python/ops/variables.py8
-rw-r--r--tensorflow/contrib/framework/python/ops/variables_test.py26
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc2
-rw-r--r--tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py30
-rw-r--r--tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc2
-rw-r--r--tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py15
-rw-r--r--tensorflow/contrib/kfac/python/ops/BUILD3
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py109
-rw-r--r--tensorflow/contrib/learn/python/learn/estimators/estimator_test.py4
-rw-r--r--tensorflow/contrib/linalg/__init__.py4
-rw-r--r--tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py4
-rw-r--r--tensorflow/contrib/lite/context.h6
-rw-r--r--tensorflow/contrib/lite/interpreter.cc13
-rw-r--r--tensorflow/contrib/lite/interpreter.h12
-rw-r--r--tensorflow/contrib/lite/interpreter_test.cc8
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java16
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java48
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.pngbin3136 -> 3696 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.pngbin1915 -> 1847 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.pngbin4294 -> 5666 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.pngbin7279 -> 10264 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/logo.pngbin0 -> 23476 bytes
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml59
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml88
-rw-r--r--tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml67
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h503
-rw-r--r--tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h6
-rw-r--r--tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h158
-rw-r--r--tensorflow/contrib/lite/kernels/test_util.h4
-rw-r--r--tensorflow/contrib/lite/model.cc6
-rw-r--r--tensorflow/contrib/lite/model.h1
-rw-r--r--tensorflow/contrib/lite/optional_debug_tools.cc2
-rw-r--r--tensorflow/contrib/lite/profiling/profile_buffer.h10
-rw-r--r--tensorflow/contrib/lite/profiling/profile_buffer_test.cc4
-rw-r--r--tensorflow/contrib/lite/profiling/profiler_test.cc2
-rw-r--r--tensorflow/contrib/lite/python/BUILD45
-rw-r--r--tensorflow/contrib/lite/python/convert.py187
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model.py387
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model_test.py172
-rw-r--r--tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py106
-rw-r--r--tensorflow/contrib/lite/python/convert_test.py (renamed from tensorflow/contrib/lite/python/lite_test.py)41
-rw-r--r--tensorflow/contrib/lite/python/lite.py204
-rw-r--r--tensorflow/contrib/lite/python/lite_constants.py53
-rw-r--r--tensorflow/contrib/lite/schema/schema.fbs19
-rw-r--r--tensorflow/contrib/lite/string_util.cc45
-rw-r--r--tensorflow/contrib/lite/string_util.h8
-rw-r--r--tensorflow/contrib/lite/toco/dump_graphviz.cc13
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc27
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc58
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc7
-rw-r--r--tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc5
-rw-r--r--tensorflow/contrib/lite/toco/import_tensorflow.cc31
-rw-r--r--tensorflow/contrib/lite/toco/tflite/BUILD1
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types.cc33
-rw-r--r--tensorflow/contrib/lite/toco/tflite/types_test.cc7
-rw-r--r--tensorflow/contrib/lite/toco/tooling_util.cc15
-rw-r--r--tensorflow/contrib/lookup/lookup_ops_test.py6
-rw-r--r--tensorflow/contrib/meta_graph_transform/meta_graph_transform.py5
-rw-r--r--tensorflow/contrib/metrics/BUILD2
-rw-r--r--tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc2
-rw-r--r--tensorflow/contrib/mpi_collectives/mpi_ops.cc1236
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.cc56
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.h36
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager_test.cc8
-rw-r--r--tensorflow/contrib/opt/BUILD20
-rw-r--r--tensorflow/contrib/opt/python/training/reg_adagrad_optimizer.py107
-rw-r--r--tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py343
-rw-r--r--tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py6
-rw-r--r--tensorflow/contrib/optimizer_v2/optimizer_v2.py15
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms.py24
-rw-r--r--tensorflow/contrib/quantize/python/fold_batch_norms_test.py79
-rw-r--r--tensorflow/contrib/quantize/python/quantize.py68
-rw-r--r--tensorflow/contrib/quantize/python/quantize_graph_test.py14
-rw-r--r--tensorflow/contrib/quantize/python/quantize_test.py57
-rw-r--r--tensorflow/contrib/rnn/kernels/blas_gemm.cc11
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py1
-rw-r--r--tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py60
-rw-r--r--tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc1
-rw-r--r--tensorflow/contrib/timeseries/examples/known_anomaly.py75
-rw-r--r--tensorflow/contrib/timeseries/examples/known_anomaly_test.py18
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model.py173
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py8
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators.py11
-rw-r--r--tensorflow/contrib/timeseries/python/timeseries/estimators_test.py48
-rw-r--r--tensorflow/contrib/tpu/BUILD1
-rw-r--r--tensorflow/contrib/tpu/python/tpu/keras_support.py391
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu.py38
-rw-r--r--tensorflow/core/BUILD125
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt26
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt24
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV2.pbtxt49
-rw-r--r--tensorflow/core/api_def/base_api/api_def_CudnnRNNV2.pbtxt40
-rw-r--r--tensorflow/core/api_def/base_api/api_def_PartitionedCall.pbtxt23
-rw-r--r--tensorflow/core/api_def/python_api/api_def_PartitionedCall.pbtxt1
-rw-r--r--tensorflow/core/common_runtime/direct_session.cc2
-rw-r--r--tensorflow/core/common_runtime/function.cc52
-rw-r--r--tensorflow/core/common_runtime/function_test.cc27
-rw-r--r--tensorflow/core/common_runtime/function_threadpool_test.cc14
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h8
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h2
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc18
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h4
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc22
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.cc71
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_device.h14
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc22
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr.h30
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc19
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_id_utils.h12
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_init.cc13
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util.cc20
-rw-r--r--tensorflow/core/common_runtime/gpu/gpu_util.h9
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator.h4
-rw-r--r--tensorflow/core/common_runtime/gpu/pool_allocator_test.cc32
-rw-r--r--tensorflow/core/common_runtime/gpu/process_state.cc6
-rw-r--r--tensorflow/core/common_runtime/gpu_device_context.h35
-rw-r--r--tensorflow/core/common_runtime/kernel_benchmark_testlib.cc1
-rw-r--r--tensorflow/core/common_runtime/local_device.cc1
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.cc21
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime.h3
-rw-r--r--tensorflow/core/common_runtime/process_function_library_runtime_test.cc10
-rw-r--r--tensorflow/core/common_runtime/process_util.cc1
-rw-r--r--tensorflow/core/distributed_runtime/master_session.cc14
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc135
-rw-r--r--tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h18
-rw-r--r--tensorflow/core/framework/bfloat16.h1
-rw-r--r--tensorflow/core/framework/collective.h8
-rw-r--r--tensorflow/core/framework/graph_to_functiondef.cc (renamed from tensorflow/compiler/jit/graph_to_functiondef.cc)4
-rw-r--r--tensorflow/core/framework/graph_to_functiondef.h (renamed from tensorflow/compiler/jit/graph_to_functiondef.h)9
-rw-r--r--tensorflow/core/framework/graph_to_functiondef_test.cc (renamed from tensorflow/compiler/jit/graph_to_functiondef_test.cc)2
-rw-r--r--tensorflow/core/framework/remote_fused_graph_execute_info.proto8
-rw-r--r--tensorflow/core/framework/resource_var.h58
-rw-r--r--tensorflow/core/grappler/clusters/utils.cc1
-rw-r--r--tensorflow/core/grappler/costs/BUILD2
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc576
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h28
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc6
-rw-r--r--tensorflow/core/grappler/costs/utils.cc2
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc16
-rw-r--r--tensorflow/core/grappler/devices.cc13
-rw-r--r--tensorflow/core/grappler/graph_view.cc49
-rw-r--r--tensorflow/core/grappler/graph_view.h36
-rw-r--r--tensorflow/core/grappler/op_types.cc113
-rw-r--r--tensorflow/core/grappler/op_types.h2
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD5
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc190
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.h10
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc102
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc1
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc81
-rw-r--r--tensorflow/core/grappler/optimizers/function_optimizer.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/loop_optimizer.cc70
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.cc97
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer.h1
-rw-r--r--tensorflow/core/grappler/optimizers/meta_optimizer_test.cc172
-rw-r--r--tensorflow/core/kernels/BUILD14
-rw-r--r--tensorflow/core/kernels/avgpooling_op.cc24
-rw-r--r--tensorflow/core/kernels/batch_matmul_op_impl.h44
-rw-r--r--tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h1
-rw-r--r--tensorflow/core/kernels/batching_util/shared_batch_scheduler.h1
-rw-r--r--tensorflow/core/kernels/bias_op.cc4
-rw-r--r--tensorflow/core/kernels/cast_op.h2
-rw-r--r--tensorflow/core/kernels/check_numerics_op.cc6
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc32
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc28
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc62
-rw-r--r--tensorflow/core/kernels/conv_ops.cc24
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc26
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h24
-rw-r--r--tensorflow/core/kernels/crop_and_resize_op.cc8
-rw-r--r--tensorflow/core/kernels/cuda_device_array.h2
-rw-r--r--tensorflow/core/kernels/cuda_solvers.cc6
-rw-r--r--tensorflow/core/kernels/cuda_solvers.h2
-rw-r--r--tensorflow/core/kernels/cudnn_pooling_gpu.cc42
-rw-r--r--tensorflow/core/kernels/cudnn_pooling_gpu.h4
-rw-r--r--tensorflow/core/kernels/cudnn_rnn_ops.cc501
-rw-r--r--tensorflow/core/kernels/decode_raw_op.cc2
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc3
-rw-r--r--tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc4
-rw-r--r--tensorflow/core/kernels/fft_ops.cc33
-rw-r--r--tensorflow/core/kernels/fused_batch_norm_op.cc22
-rw-r--r--tensorflow/core/kernels/fuzzing/BUILD2
-rw-r--r--tensorflow/core/kernels/fuzzing/decode_wav_fuzz.cc30
-rw-r--r--tensorflow/core/kernels/gpu_utils.h8
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc1
-rw-r--r--tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc1
-rw-r--r--tensorflow/core/kernels/i_remote_fused_graph_executor.h4
-rw-r--r--tensorflow/core/kernels/initializable_lookup_table.h2
-rw-r--r--tensorflow/core/kernels/lookup_table_op.h24
-rw-r--r--tensorflow/core/kernels/lrn_op.cc12
-rw-r--r--tensorflow/core/kernels/matmul_op.cc51
-rw-r--r--tensorflow/core/kernels/matrix_triangular_solve_op.cc31
-rw-r--r--tensorflow/core/kernels/maxpooling_op.cc20
-rw-r--r--tensorflow/core/kernels/mkl_input_conversion_op.cc1
-rw-r--r--tensorflow/core/kernels/mkl_tfconv_op.h1
-rw-r--r--tensorflow/core/kernels/partitioned_function_ops.cc279
-rw-r--r--tensorflow/core/kernels/pooling_ops_3d.cc23
-rw-r--r--tensorflow/core/kernels/pooling_ops_common.cc46
-rw-r--r--tensorflow/core/kernels/pooling_ops_common_gpu.h4
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils.cc46
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils.h28
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc1
-rw-r--r--tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc1
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.cc4
-rw-r--r--tensorflow/core/kernels/segment_reduction_ops.h8
-rw-r--r--tensorflow/core/kernels/sparse_matmul_op.h1
-rw-r--r--tensorflow/core/kernels/string_split_op.cc2
-rw-r--r--tensorflow/core/kernels/summary_interface.h5
-rw-r--r--tensorflow/core/kernels/summary_kernels.cc1
-rw-r--r--tensorflow/core/kernels/training_ops.cc23
-rw-r--r--tensorflow/core/kernels/training_ops.h2
-rw-r--r--tensorflow/core/kernels/training_ops_gpu.cu.cc6
-rw-r--r--tensorflow/core/kernels/variable_ops.h34
-rw-r--r--tensorflow/core/kernels/where_op.cc5
-rw-r--r--tensorflow/core/lib/bfloat16/bfloat16.h3
-rw-r--r--tensorflow/core/lib/core/coding.cc2
-rw-r--r--tensorflow/core/lib/core/raw_coding.h2
-rw-r--r--tensorflow/core/lib/gtl/inlined_vector.h2
-rw-r--r--tensorflow/core/lib/png/png_io.cc2
-rw-r--r--tensorflow/core/lib/wav/wav_io.cc8
-rw-r--r--tensorflow/core/ops/compat/ops_history.v1.pbtxt758
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops.cc79
-rw-r--r--tensorflow/core/ops/cudnn_rnn_ops_test.cc35
-rw-r--r--tensorflow/core/ops/functional_ops.cc9
-rw-r--r--tensorflow/core/ops/ops.pbtxt510
-rw-r--r--tensorflow/core/ops/training_ops.cc4
-rw-r--r--tensorflow/core/platform/byte_order.h37
-rw-r--r--tensorflow/core/platform/cloud/expiring_lru_cache.h18
-rw-r--r--tensorflow/core/platform/cloud/expiring_lru_cache_test.cc17
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc19
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.h3
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system_test.cc50
-rw-r--r--tensorflow/core/platform/cpu_feature_guard.cc1
-rw-r--r--tensorflow/core/platform/cpu_info.h7
-rw-r--r--tensorflow/core/platform/default/build_config.bzl49
-rw-r--r--tensorflow/core/platform/default/from_stream_executor_status.h35
-rw-r--r--tensorflow/core/platform/default/gpu/cupti_wrapper.cc42
-rw-r--r--tensorflow/core/platform/denormal.cc3
-rw-r--r--tensorflow/core/platform/stream_executor.h2
-rw-r--r--tensorflow/core/platform/stream_executor_no_cuda.h2
-rw-r--r--tensorflow/core/platform/types.h4
-rw-r--r--tensorflow/core/platform/windows/cpu_info.h9
-rw-r--r--tensorflow/core/protobuf/eager_service.proto158
-rw-r--r--tensorflow/core/util/rpc/call_container.h165
-rw-r--r--tensorflow/core/util/rpc/rpc_factory.h5
-rw-r--r--tensorflow/core/util/stream_executor_util.h16
-rw-r--r--tensorflow/core/util/use_cudnn.cc46
-rw-r--r--tensorflow/core/util/use_cudnn.h13
-rw-r--r--tensorflow/docs_src/get_started/feature_columns.md2
-rw-r--r--tensorflow/docs_src/install/install_java.md39
-rw-r--r--tensorflow/docs_src/install/install_sources.md9
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md60
-rw-r--r--tensorflow/go/op/wrappers.go2738
-rw-r--r--tensorflow/java/maven/libtensorflow/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni/pom.xml2
-rw-r--r--tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml2
-rw-r--r--tensorflow/java/maven/pom.xml2
-rw-r--r--tensorflow/java/maven/proto/pom.xml2
-rw-r--r--tensorflow/java/maven/tensorflow/pom.xml2
-rw-r--r--tensorflow/python/BUILD1
-rw-r--r--tensorflow/python/__init__.py167
-rw-r--r--tensorflow/python/client/tf_session.i2
-rw-r--r--tensorflow/python/data/__init__.py3
-rw-r--r--tensorflow/python/data/util/nest.py14
-rw-r--r--tensorflow/python/eager/function.py23
-rw-r--r--tensorflow/python/estimator/estimator.py283
-rw-r--r--tensorflow/python/estimator/estimator_lib.py41
-rw-r--r--tensorflow/python/estimator/estimator_test.py5
-rw-r--r--tensorflow/python/estimator/export/export_lib.py13
-rw-r--r--tensorflow/python/estimator/inputs/inputs.py8
-rw-r--r--tensorflow/python/feature_column/feature_column_lib.py21
-rw-r--r--tensorflow/python/framework/errors.py46
-rw-r--r--tensorflow/python/framework/function.py67
-rw-r--r--tensorflow/python/framework/function_test.py60
-rw-r--r--tensorflow/python/framework/graph_util.py11
-rw-r--r--tensorflow/python/framework/importer.py9
-rw-r--r--tensorflow/python/framework/meta_graph.py5
-rw-r--r--tensorflow/python/framework/ops.py6
-rw-r--r--tensorflow/python/framework/test_util.py209
-rw-r--r--tensorflow/python/framework/test_util_test.py193
-rwxr-xr-xtensorflow/python/keras/BUILD1
-rw-r--r--tensorflow/python/keras/_impl/keras/applications/mobilenet.py1
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/base_layer.py142
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/network.py251
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/saving_test.py214
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/sequential.py32
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/topology_test.py8
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training.py3
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_eager.py2
-rw-r--r--tensorflow/python/keras/_impl/keras/engine/training_eager_test.py22
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/advanced_activations.py14
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional.py4
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py6
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/embeddings.py6
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/local.py10
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/merge.py16
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/noise.py8
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/recurrent.py26
-rw-r--r--tensorflow/python/keras/_impl/keras/layers/wrappers.py18
-rw-r--r--tensorflow/python/keras/_impl/keras/metrics_test.py129
-rw-r--r--tensorflow/python/keras/_impl/keras/model_subclassing_test.py29
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/generic_utils.py30
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/tf_utils.py80
-rw-r--r--tensorflow/python/kernel_tests/BUILD8
-rw-r--r--tensorflow/python/kernel_tests/control_flow_ops_py_test.py9
-rw-r--r--tensorflow/python/kernel_tests/functional_ops_test.py146
-rw-r--r--tensorflow/python/kernel_tests/linalg/BUILD20
-rw-r--r--tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py700
-rw-r--r--tensorflow/python/layers/layers.py6
-rw-r--r--tensorflow/python/lib/io/python_io.py5
-rw-r--r--tensorflow/python/ops/array_ops.py6
-rw-r--r--tensorflow/python/ops/bitwise_ops.py3
-rw-r--r--tensorflow/python/ops/control_flow_ops.py16
-rw-r--r--tensorflow/python/ops/cudnn_rnn_grad.py28
-rw-r--r--tensorflow/python/ops/distributions/bijector.py5
-rw-r--r--tensorflow/python/ops/distributions/distributions.py26
-rw-r--r--tensorflow/python/ops/distributions/transformed_distribution.py2
-rw-r--r--tensorflow/python/ops/embedding_ops.py4
-rw-r--r--tensorflow/python/ops/functional_ops.py105
-rw-r--r--tensorflow/python/ops/gradients.py10
-rw-r--r--tensorflow/python/ops/image_ops.py10
-rw-r--r--tensorflow/python/ops/linalg/linalg.py1
-rw-r--r--tensorflow/python/ops/linalg/linear_operator_circulant.py1074
-rw-r--r--tensorflow/python/ops/lookup_ops.py20
-rw-r--r--tensorflow/python/ops/losses/losses.py9
-rw-r--r--tensorflow/python/ops/manip_ops.py5
-rw-r--r--tensorflow/python/ops/math_ops.py4
-rw-r--r--tensorflow/python/ops/metrics.py5
-rw-r--r--tensorflow/python/ops/nn.py20
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py9
-rw-r--r--tensorflow/python/ops/rnn_cell.py5
-rw-r--r--tensorflow/python/ops/sdca_ops.py5
-rw-r--r--tensorflow/python/ops/sets.py5
-rw-r--r--tensorflow/python/ops/spectral_ops.py3
-rw-r--r--tensorflow/python/ops/standard_ops.py210
-rw-r--r--tensorflow/python/ops/summary_ops_v2.py42
-rw-r--r--tensorflow/python/platform/app.py9
-rw-r--r--tensorflow/python/platform/gfile.py22
-rw-r--r--tensorflow/python/platform/resource_loader.py5
-rw-r--r--tensorflow/python/platform/sysconfig.py4
-rw-r--r--tensorflow/python/platform/test.py11
-rw-r--r--tensorflow/python/platform/tf_logging.py28
-rw-r--r--tensorflow/python/profiler/profiler.py9
-rw-r--r--tensorflow/python/pywrap_tfe.i2
-rw-r--r--tensorflow/python/saved_model/builder.py7
-rw-r--r--tensorflow/python/saved_model/constants.py15
-rw-r--r--tensorflow/python/saved_model/loader.py8
-rw-r--r--tensorflow/python/saved_model/main_op.py7
-rw-r--r--tensorflow/python/saved_model/saved_model.py15
-rw-r--r--tensorflow/python/saved_model/signature_constants.py17
-rw-r--r--tensorflow/python/saved_model/tag_constants.py9
-rw-r--r--tensorflow/python/saved_model/utils.py4
-rw-r--r--tensorflow/python/summary/summary.py8
-rw-r--r--tensorflow/python/tools/BUILD2
-rw-r--r--tensorflow/python/training/basic_session_run_hooks.py36
-rw-r--r--tensorflow/python/training/basic_session_run_hooks_test.py38
-rw-r--r--tensorflow/python/training/checkpointable.py14
-rw-r--r--tensorflow/python/training/checkpointable_utils.py110
-rw-r--r--tensorflow/python/training/checkpointable_utils_test.py96
-rw-r--r--tensorflow/python/training/optimizer.py15
-rw-r--r--tensorflow/python/training/queue_runner.py10
-rw-r--r--tensorflow/python/training/training.py46
-rw-r--r--tensorflow/python/util/compat.py11
-rw-r--r--tensorflow/python/util/nest.py19
-rw-r--r--tensorflow/stream_executor/BUILD2
-rw-r--r--tensorflow/stream_executor/blas.h81
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.cc81
-rw-r--r--tensorflow/stream_executor/cuda/cuda_blas.h14
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.cc78
-rw-r--r--tensorflow/stream_executor/cuda/cuda_dnn.h2
-rw-r--r--tensorflow/stream_executor/cuda/cuda_platform.cc4
-rw-r--r--tensorflow/stream_executor/dnn.cc5
-rw-r--r--tensorflow/stream_executor/dnn.h3
-rw-r--r--tensorflow/stream_executor/host/host_platform.cc4
-rw-r--r--tensorflow/stream_executor/host_or_device_scalar.h56
-rw-r--r--tensorflow/stream_executor/lib/ptr_util.h42
-rw-r--r--tensorflow/stream_executor/multi_platform_manager.h2
-rw-r--r--tensorflow/stream_executor/stream.cc114
-rw-r--r--tensorflow/stream_executor/stream.h62
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.cc7
-rw-r--r--tensorflow/stream_executor/stream_executor_pimpl.h2
-rw-r--r--tensorflow/tensorflow.bzl33
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt2
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant.pbtxt155
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt155
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt14
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt155
-rw-r--r--tensorflow/tools/api/golden/tensorflow.linalg.pbtxt12
-rwxr-xr-xtensorflow/tools/ci_build/ci_sanity.sh1
-rw-r--r--tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh7
-rw-r--r--tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh30
-rwxr-xr-xtensorflow/tools/git/gen_git_source.py13
-rw-r--r--tensorflow/tools/pip_package/BUILD1
-rw-r--r--tensorflow/tools/proto_text/BUILD7
-rw-r--r--tensorflow/tools/proto_text/gen_proto_text_functions.cc6
571 files changed, 21075 insertions, 8669 deletions
diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD
index fae922ea3b..1432119162 100644
--- a/tensorflow/c/eager/BUILD
+++ b/tensorflow/c/eager/BUILD
@@ -40,8 +40,6 @@ tf_cuda_library(
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
- # TODO(b/74620627): move this here
- "//tensorflow/python:cpp_shape_inference_proto_cc",
],
}) + select({
"//tensorflow:with_xla_support": [
diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc
index 975bde7c7f..3bf071f3ab 100644
--- a/tensorflow/c/eager/c_api.cc
+++ b/tensorflow/c/eager/c_api.cc
@@ -48,7 +48,6 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/public/version.h"
-#include "tensorflow/python/framework/cpp_shape_inference.pb.h"
using tensorflow::int64;
using tensorflow::string;
@@ -503,62 +502,6 @@ void TFE_ContextExportRunMetadata(TFE_Context* ctx, TF_Buffer* buf,
ctx->context.RunMetadataProto()->Clear();
}
-void TFE_GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
- TF_Buffer* output_proto,
- TF_Status* status) {
- tensorflow::Node* node = &output.oper->node;
- tensorflow::CppShapeInferenceResult::HandleData handle_data;
- handle_data.set_is_set(true);
- {
- tensorflow::mutex_lock l(graph->mu);
- tensorflow::shape_inference::InferenceContext* ic =
- graph->refiner.GetContext(node);
- CHECK(ic != nullptr);
- CHECK_LT(output.index, ic->num_outputs());
- const auto* shapes_and_types =
- ic->output_handle_shapes_and_types(output.index);
- if (shapes_and_types == nullptr) {
- output_proto->data = nullptr;
- output_proto->length = 0;
- output_proto->data_deallocator = nullptr;
- return;
- }
-
- for (const auto& p : *shapes_and_types) {
- auto* out_shape_and_type = handle_data.add_shape_and_type();
- ic->ShapeHandleToProto(p.shape, out_shape_and_type->mutable_shape());
- out_shape_and_type->set_dtype(p.dtype);
- }
- }
- status->status = MessageToBuffer(handle_data, output_proto);
-}
-
-void TFE_SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
- const void* proto, size_t proto_len,
- TF_Status* status) {
- tensorflow::CppShapeInferenceResult::HandleData handle_data;
- if (!handle_data.ParseFromArray(proto, proto_len)) {
- status->status = tensorflow::errors::InvalidArgument(
- "Couldn't deserialize HandleData proto");
- return;
- }
- DCHECK(handle_data.is_set());
-
- tensorflow::mutex_lock l(graph->mu);
- tensorflow::shape_inference::InferenceContext* ic =
- graph->refiner.GetContext(&output.oper->node);
-
- std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
- for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
- tensorflow::shape_inference::ShapeHandle shape;
- status->status =
- ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
- if (status->status.ok()) return;
- shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
- }
- ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
-}
-
namespace {
TFE_Op* GetFunc(TFE_Context* ctx, const tensorflow::NameAttrList& func,
TF_Status* status) {
diff --git a/tensorflow/c/eager/c_api.h b/tensorflow/c/eager/c_api.h
index ba77f3cd07..c06ce84a8c 100644
--- a/tensorflow/c/eager/c_api.h
+++ b/tensorflow/c/eager/c_api.h
@@ -329,20 +329,6 @@ TF_CAPI_EXPORT extern void TFE_ContextExportRunMetadata(TFE_Context* ctx,
TF_Buffer* buf,
TF_Status* status);
-// Returns the serialized CppShapeInferenceResult::HandleData proto for
-// `output` if its a resource tensor, or otherwise returns an empty buffer.
-TF_CAPI_EXPORT extern void TFE_GetResourceHandleShapeAndType(
- TF_Graph* graph, TF_Output output, TF_Buffer* output_proto,
- TF_Status* status);
-
-// Sets `output` based on `proto`, which should be a serialized
-// CppShapeInferenceResult::HandleData proto.
-TF_CAPI_EXPORT extern void TFE_SetResourceHandleShapeAndType(TF_Graph* graph,
- TF_Output output,
- const void* proto,
- size_t proto_len,
- TF_Status* status);
-
#ifdef __cplusplus
} /* end extern "C" */
#endif
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index 93155998b8..e18fdf6c57 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -110,7 +110,7 @@ void ExtendSession(TF_Session* session, TF_Status* status) {
session->extend_before_run = false;
}
-std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
+std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
Node* node = &output.oper->node;
CppShapeInferenceResult::HandleData handle_data;
handle_data.set_is_set(true);
@@ -135,4 +135,30 @@ std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output) {
return result;
}
+void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
+ const void* proto, size_t proto_len,
+ TF_Status* status) {
+ tensorflow::CppShapeInferenceResult::HandleData handle_data;
+ if (!handle_data.ParseFromArray(proto, proto_len)) {
+ status->status = tensorflow::errors::InvalidArgument(
+ "Couldn't deserialize HandleData proto");
+ return;
+ }
+ DCHECK(handle_data.is_set());
+
+ tensorflow::mutex_lock l(graph->mu);
+ tensorflow::shape_inference::InferenceContext* ic =
+ graph->refiner.GetContext(&output.oper->node);
+
+ std::vector<tensorflow::shape_inference::ShapeAndType> shapes_and_types;
+ for (const auto& shape_and_type_proto : handle_data.shape_and_type()) {
+ tensorflow::shape_inference::ShapeHandle shape;
+ status->status =
+ ic->MakeShapeFromShapeProto(shape_and_type_proto.shape(), &shape);
+ if (status->status.ok()) return;
+ shapes_and_types.emplace_back(shape, shape_and_type_proto.dtype());
+ }
+ ic->set_output_handle_shapes_and_types(output.index, shapes_and_types);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/c/python_api.h b/tensorflow/c/python_api.h
index 2d4c8cd9ed..4bcb5bde62 100644
--- a/tensorflow/c/python_api.h
+++ b/tensorflow/c/python_api.h
@@ -55,9 +55,15 @@ void ExtendSession(TF_Session* session, TF_Status* status);
// Returns the serialized CppShapeInferenceResult::HandleData proto for
// `output` if its a resource tensor, or otherwise returns the empty string.
-// TODO(b/74620627): remove when _USE_C_SHAPES is removed
-std::string ResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
-
+std::string GetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output);
+
+// Sets `output` based on `proto`, which should be a serialized
+// CppShapeInferenceResult::HandleData proto.
+// NOTE(skyewm): `proto` is passed a void*/size_t pair instead of a std::string
+// because I couldn't get SWIG to work otherwise.
+void SetResourceHandleShapeAndType(TF_Graph* graph, TF_Output output,
+ const void* proto, size_t proto_len,
+ TF_Status* status);
} // namespace tensorflow
#endif // TENSORFLOW_C_PYTHON_API_H_
diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc
index 7c83387881..e17a7c4bf6 100644
--- a/tensorflow/compiler/aot/compile.cc
+++ b/tensorflow/compiler/aot/compile.cc
@@ -88,9 +88,8 @@ Status CompileGraph(const GraphDef& graph_def, const tf2xla::Config& config,
// Converts the graph into an XLA computation, and compiles the
// computation.
// TODO(toddw): Should we let the user pick the XLA cpu vs. gpu client?
- namespace gpu = perftools::gputools;
- gpu::Platform* cpu_platform =
- gpu::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
+ se::Platform* cpu_platform =
+ se::MultiPlatformManager::PlatformWithName("Host").ValueOrDie();
xla::CompileOnlyClient* client =
xla::ClientLibrary::GetOrCreateCompileOnlyClient(cpu_platform)
.ValueOrDie();
diff --git a/tensorflow/compiler/aot/test.cc b/tensorflow/compiler/aot/test.cc
index 47ef5f82cb..6b098049cb 100644
--- a/tensorflow/compiler/aot/test.cc
+++ b/tensorflow/compiler/aot/test.cc
@@ -35,6 +35,7 @@ limitations under the License.
// clang-format on
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD
index 50fa95c4f3..af2965bba5 100644
--- a/tensorflow/compiler/jit/BUILD
+++ b/tensorflow/compiler/jit/BUILD
@@ -180,6 +180,7 @@ cc_library(
"//tensorflow/core/kernels:no_op",
"//tensorflow/core/kernels:sendrecv_ops",
"//tensorflow/core/kernels:variable_ops",
+ "@com_google_absl//absl/memory",
],
)
@@ -257,19 +258,6 @@ cc_library(
)
cc_library(
- name = "graph_to_functiondef",
- srcs = ["graph_to_functiondef.cc"],
- hdrs = ["graph_to_functiondef.h"],
- visibility = [":friends"],
- deps = [
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:framework",
- "//tensorflow/core:lib",
- "//tensorflow/core:protos_all_cc",
- ],
-)
-
-cc_library(
name = "create_xla_launch_op",
srcs = [
"create_xla_launch_op.cc",
@@ -299,7 +287,6 @@ cc_library(
],
deps = [
":common",
- ":graph_to_functiondef",
":shape_inference_helpers",
":union_find",
"//tensorflow/compiler/jit/graphcycles",
@@ -347,28 +334,6 @@ tf_cc_test(
)
tf_cc_test(
- name = "graph_to_functiondef_test",
- size = "small",
- srcs = [
- "graph_to_functiondef_test.cc",
- ],
- deps = [
- ":graph_to_functiondef",
- "//tensorflow/cc:cc_ops",
- "//tensorflow/cc:cc_ops_internal",
- "//tensorflow/cc:function_ops",
- "//tensorflow/cc:ops",
- "//tensorflow/compiler/tf2xla:xla_compiler",
- "//tensorflow/compiler/tf2xla/kernels:xla_ops",
- "//tensorflow/core:core_cpu",
- "//tensorflow/core:framework_internal",
- "//tensorflow/core:test",
- "//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
- ],
-)
-
-tf_cc_test(
name = "compilation_passes_test",
size = "small",
srcs = [
@@ -378,7 +343,6 @@ tf_cc_test(
deps = [
":common",
":compilation_passes",
- ":graph_to_functiondef",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
"//tensorflow/cc:function_ops",
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
index 7507e193b5..f06debaf31 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include <unordered_map>
#include <vector>
-#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/compiler/jit/graphcycles/graphcycles.h"
#include "tensorflow/compiler/jit/legacy_flags/encapsulate_subgraphs_pass_flags.h"
#include "tensorflow/compiler/jit/mark_for_compilation_pass.h"
@@ -35,6 +34,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph_def_util.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/graph/algorithm.h"
diff --git a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
index 3502d1bb45..5ec24d39a2 100644
--- a/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
+++ b/tensorflow/compiler/jit/encapsulate_subgraphs_pass_test.cc
@@ -20,8 +20,8 @@ limitations under the License.
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/core/framework/function_testlib.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/core/status_test_util.h"
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.cc b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
index f48941fce3..049d170fa4 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.cc
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.cc
@@ -37,8 +37,6 @@ limitations under the License.
#include "tensorflow/core/platform/stream_executor_no_cuda.h"
#include "tensorflow/core/util/stream_executor_util.h"
-namespace gpu = perftools::gputools;
-
namespace tensorflow {
XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
@@ -51,9 +49,9 @@ XlaLocalLaunchOp::XlaLocalLaunchOp(OpKernelConstruction* ctx)
num_constant_args_ = constant_types.size();
OP_REQUIRES_OK(ctx, ctx->GetAttr("Nresources", &num_resource_args_));
if (device_type_ == DeviceType(DEVICE_CPU)) {
- platform_id_ = gpu::host::kHostPlatformId;
+ platform_id_ = se::host::kHostPlatformId;
} else if (device_type_ == DeviceType(DEVICE_GPU)) {
- platform_id_ = gpu::cuda::kCudaPlatformId;
+ platform_id_ = se::cuda::kCudaPlatformId;
} else {
platform_id_ = nullptr;
}
@@ -69,9 +67,9 @@ Status XlaLocalLaunchOp::BuildCompilationCache(OpKernelContext* ctx,
return Status::OK();
}
- auto platform = gpu::MultiPlatformManager::PlatformWithId(platform_id_);
+ auto platform = se::MultiPlatformManager::PlatformWithId(platform_id_);
if (!platform.ok()) {
- return StreamExecutorUtil::ConvertStatus(platform.status());
+ return platform.status();
}
xla::LocalClientOptions client_options;
client_options.set_platform(platform.ValueOrDie());
@@ -100,7 +98,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
ResourceMgr* rm = ctx->resource_manager();
OP_REQUIRES(ctx, rm, errors::Internal("No resource manager."));
- gpu::Stream* stream =
+ se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
XlaCompilationCache* cache;
@@ -153,7 +151,7 @@ void XlaLocalLaunchOp::Compute(OpKernelContext* ctx) {
options.device_type = &cache->device_type();
options.flib_def = ctx->function_library()->GetFunctionLibraryDefinition();
options.graph_def_version = ctx->function_library()->graph_def_version();
- options.allow_cpu_custom_calls = (platform_id_ == gpu::host::kHostPlatformId);
+ options.allow_cpu_custom_calls = (platform_id_ == se::host::kHostPlatformId);
options.device_allocator = xla_allocator;
// TODO(b/77671268): We don't set variable_representation_shape_fn here. This
// is restricted to Variables, but we need something like this to apply to
diff --git a/tensorflow/compiler/jit/kernels/xla_launch_op.h b/tensorflow/compiler/jit/kernels/xla_launch_op.h
index c6cc0986af..8f8e646f0f 100644
--- a/tensorflow/compiler/jit/kernels/xla_launch_op.h
+++ b/tensorflow/compiler/jit/kernels/xla_launch_op.h
@@ -53,7 +53,7 @@ class XlaLocalLaunchOp : public OpKernel {
// Number of resource variable arguments.
int num_resource_args_;
- perftools::gputools::Platform::Id platform_id_;
+ se::Platform::Id platform_id_;
TF_DISALLOW_COPY_AND_ASSIGN(XlaLocalLaunchOp);
};
diff --git a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
index 6c2782e28e..60458f6f33 100644
--- a/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
+++ b/tensorflow/compiler/jit/xla_compile_on_demand_op.cc
@@ -58,7 +58,7 @@ Status XlaCompileOnDemandOp::Run(OpKernelContext* ctx,
launch_context.PopulateInputs(ctx, result, variables);
- perftools::gputools::Stream* stream =
+ se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
TF_RET_CHECK(stream);
diff --git a/tensorflow/compiler/jit/xla_device.cc b/tensorflow/compiler/jit/xla_device.cc
index 12f471735f..c814b7eb02 100644
--- a/tensorflow/compiler/jit/xla_device.cc
+++ b/tensorflow/compiler/jit/xla_device.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <stdlib.h>
#include <unordered_set>
+#include "absl/memory/memory.h"
#include "tensorflow/compiler/jit/defs.h"
#include "tensorflow/compiler/jit/xla_compile_on_demand_op.h"
#include "tensorflow/compiler/jit/xla_device_context.h"
@@ -50,8 +51,6 @@ limitations under the License.
#include "tensorflow/core/util/device_name_utils.h"
#include "tensorflow/core/util/stream_executor_util.h"
-namespace se = ::perftools::gputools;
-
namespace tensorflow {
// Caches a XlaDeviceAllocator per <backend, device ordinal> pair. A
@@ -121,7 +120,7 @@ XlaDeviceAllocator* XlaDeviceAllocatorState::GetOrCreateXlaDeviceAllocator(
auto platform = se::MultiPlatformManager::PlatformWithName(platform_name);
if (!platform.ok()) {
- return StreamExecutorUtil::ConvertStatus(platform.status());
+ return platform.status();
}
const DeviceAttributes attrs = Device::BuildDeviceAttributes(
@@ -181,9 +180,15 @@ XlaDevice::XlaDevice(const SessionOptions& options,
jit_device_name_(jit_device_name),
xla_allocator_(nullptr),
platform_(platform),
- transfer_as_literal_(transfer_as_literal) {}
+ transfer_as_literal_(transfer_as_literal) {
+ VLOG(1) << "Created XLA device " << jit_device_name;
+}
-XlaDevice::~XlaDevice() {}
+XlaDevice::~XlaDevice() {
+ if (gpu_device_info_ != nullptr) {
+ gpu_device_info_->default_context->Unref();
+ }
+}
xla::LocalClient* XlaDevice::client() const {
// We lazily create the client because the platform commits to the
@@ -191,9 +196,8 @@ xla::LocalClient* XlaDevice::client() const {
// don't want to do it until we get a chance to hook the platform up
// to a simulator.
- // For now GetOrCreateLocalClient always returns success when passed
- // a non-null platform. If that changes we may have to plumb in some
- // way to pass Status back.
+ // TODO(b/78468222): This can fail, at least when the backend is GPU and
+ // there is no GPU on the host.
return xla::ClientLibrary::GetOrCreateLocalClient(platform_).ValueOrDie();
}
@@ -218,14 +222,31 @@ xla::StatusOr<se::Stream*> XlaDevice::GetStream() {
return stream_.get();
}
+Status XlaDevice::CreateAndSetGpuDeviceInfo() {
+ if (gpu_device_info_ == nullptr) {
+ TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
+ // Call GetAllocator for the side-effect of ensuring the allocator
+ // is created.
+ GetAllocator({});
+ // XlaDevice owns both gpu_device_info_ and
+ // gpu_device_info_->default_context.
+ gpu_device_info_ = absl::make_unique<GpuDeviceInfo>();
+ gpu_device_info_->stream = stream;
+ gpu_device_info_->default_context =
+ new XlaDeviceContext(stream, client(), transfer_as_literal_);
+ set_tensorflow_gpu_device_info(gpu_device_info_.get());
+ }
+
+ return Status::OK();
+}
+
Status XlaDevice::FillContextMap(const Graph* graph,
DeviceContextMap* device_context_map) {
VLOG(1) << "XlaDevice::FillContextMap";
device_context_map->resize(graph->num_node_ids());
TF_ASSIGN_OR_RETURN(se::Stream * stream, GetStream());
- // Call GetAllocator for the side-effect of ensuring the allocator and
- // XlaTensorInfoManager is created.
- (void)GetAllocator({});
+ // Call GetAllocator for the side-effect of ensuring the allocator is created.
+ GetAllocator({});
auto ctx = new XlaDeviceContext(stream, client(), transfer_as_literal_);
for (Node* n : graph->nodes()) {
VLOG(2) << n->id() << " : " << n->type_string() << " : " << n->name();
diff --git a/tensorflow/compiler/jit/xla_device.h b/tensorflow/compiler/jit/xla_device.h
index 4fe7dd8c9f..3ae87308cc 100644
--- a/tensorflow/compiler/jit/xla_device.h
+++ b/tensorflow/compiler/jit/xla_device.h
@@ -49,20 +49,20 @@ class XlaDevice : public LocalDevice {
// retrieved e.g., when lazily creating the XlaCompilationCache device.
class Metadata {
public:
- Metadata(int device_ordinal, perftools::gputools::Platform* platform,
+ Metadata(int device_ordinal, se::Platform* platform,
const DeviceType& device_type);
// The index of the device on this host.
int device_ordinal() const;
- perftools::gputools::Platform* platform() const;
+ se::Platform* platform() const;
xla::LocalClient* client() const;
const DeviceType& jit_device_type() const;
private:
const int device_ordinal_;
const DeviceType device_type_;
- perftools::gputools::Platform* platform_; // Not owned.
+ se::Platform* platform_; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(Metadata);
};
@@ -85,8 +85,7 @@ class XlaDevice : public LocalDevice {
XlaDevice(const SessionOptions& options, const DeviceAttributes& attrs,
int device_ordinal, const DeviceType& jit_device_name,
- ::perftools::gputools::Platform* platform,
- bool transfer_as_literal);
+ se::Platform* platform, bool transfer_as_literal);
~XlaDevice() override;
Allocator* GetAllocator(AllocatorAttributes attr) override;
@@ -103,7 +102,11 @@ class XlaDevice : public LocalDevice {
Tensor* tensor) override;
xla::LocalClient* client() const;
- xla::StatusOr<::perftools::gputools::Stream*> GetStream();
+ xla::StatusOr<se::Stream*> GetStream();
+
+ // If not already set, create and set GpuDeviceInfo.
+ // Not thread-safe
+ Status CreateAndSetGpuDeviceInfo();
private:
// The metadata of this XlaDevice.
@@ -114,7 +117,7 @@ class XlaDevice : public LocalDevice {
DeviceType jit_device_name_;
// Memory allocator associated with this device.
Allocator* xla_allocator_; // Not owned.
- ::perftools::gputools::Platform* platform_; // Not owned.
+ se::Platform* platform_; // Not owned.
// Stream associated with this device. Operations enqueued on this
// stream are executed on the device. Operations include data
// copying back and forth between CPU and the device, and
@@ -123,6 +126,10 @@ class XlaDevice : public LocalDevice {
// Must we use XLA's transfer manager for correct host<->device transfers? if
// false, we can use ThenMemcpy() instead.
bool transfer_as_literal_;
+
+ // If set, holds default device context (that we must Unref)
+ // and its stream.
+ std::unique_ptr<GpuDeviceInfo> gpu_device_info_;
};
// Builds OpKernel registrations on 'device' for the JIT operators
diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc
index 43eb164012..bf8c1886a0 100644
--- a/tensorflow/compiler/jit/xla_device_context.cc
+++ b/tensorflow/compiler/jit/xla_device_context.cc
@@ -23,8 +23,6 @@ limitations under the License.
#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/platform/mem.h"
-namespace se = ::perftools::gputools;
-
namespace tensorflow {
// The allocator used for Tensors assigned to the XLA device.
diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h
index ad914a1c23..d7f5f1d208 100644
--- a/tensorflow/compiler/jit/xla_device_context.h
+++ b/tensorflow/compiler/jit/xla_device_context.h
@@ -45,8 +45,7 @@ class XlaDeviceAllocator : public Allocator {
// Helper class for managing data transfers between host and XLA devices.
class XlaTransferManager {
public:
- explicit XlaTransferManager(perftools::gputools::Stream* stream,
- xla::LocalClient* client,
+ explicit XlaTransferManager(se::Stream* stream, xla::LocalClient* client,
bool transfer_as_literal);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
@@ -54,7 +53,7 @@ class XlaTransferManager {
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done);
- perftools::gputools::Stream* stream() const { return stream_; }
+ se::Stream* stream() const { return stream_; }
private:
Status TransferLiteralToDevice(const Tensor& host_tensor,
@@ -64,7 +63,7 @@ class XlaTransferManager {
// Stream obtained from a Device, used to transfer tensors between
// CPU and device.
- perftools::gputools::Stream* stream_;
+ se::Stream* stream_;
// For the underlying memory allocator and XLA's TransferManager.
xla::LocalClient* client_;
// Transfer manager, for marshalling data to and from the device.
@@ -78,8 +77,8 @@ class XlaTransferManager {
// wraps the methods in XlaTransferManager.
class XlaDeviceContext : public DeviceContext {
public:
- explicit XlaDeviceContext(perftools::gputools::Stream* stream,
- xla::LocalClient* client, bool transfer_as_literal);
+ explicit XlaDeviceContext(se::Stream* stream, xla::LocalClient* client,
+ bool transfer_as_literal);
void CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device,
Tensor* device_tensor,
@@ -87,9 +86,7 @@ class XlaDeviceContext : public DeviceContext {
void CopyDeviceTensorToCPU(const Tensor* device_tensor,
StringPiece tensor_name, Device* device,
Tensor* cpu_tensor, StatusCallback done) override;
- perftools::gputools::Stream* stream() const override {
- return manager_.stream();
- }
+ se::Stream* stream() const override { return manager_.stream(); }
private:
XlaTransferManager manager_;
diff --git a/tensorflow/compiler/jit/xla_gpu_device.cc b/tensorflow/compiler/jit/xla_gpu_device.cc
index ac60423d95..a8afbf9dcd 100644
--- a/tensorflow/compiler/jit/xla_gpu_device.cc
+++ b/tensorflow/compiler/jit/xla_gpu_device.cc
@@ -54,6 +54,15 @@ Status XlaGpuDeviceFactory::CreateDevices(const SessionOptions& options,
VLOG(1) << "Failed to create XLA_GPU device: " << status;
return Status::OK();
}
+
+ // TODO(b/78468222): Uncomment after fixing this bug
+ // status = device->CreateAndSetGpuDeviceInfo();
+ // if (!status.ok()) {
+ // errors::AppendToMessage(&status, "while setting up ", DEVICE_GPU_XLA_JIT,
+ // " device");
+ // return status;
+ // }
+
devices->push_back(device.release());
return Status::OK();
}
diff --git a/tensorflow/compiler/jit/xla_launch_util.cc b/tensorflow/compiler/jit/xla_launch_util.cc
index 3520501c1a..2a7f04271d 100644
--- a/tensorflow/compiler/jit/xla_launch_util.cc
+++ b/tensorflow/compiler/jit/xla_launch_util.cc
@@ -32,13 +32,12 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/util/stream_executor_util.h"
+namespace tensorflow {
namespace {
-namespace gpu = perftools::gputools;
using xla::ScopedShapedBuffer;
using xla::ShapedBuffer;
} // anonymous namespace
-namespace tensorflow {
std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
int num_variables) {
std::map<int, OptionalTensor> snapshot;
@@ -57,24 +56,23 @@ std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
return snapshot;
}
-XlaAllocator::XlaAllocator(const gpu::Platform* platform, Allocator* wrapped)
+XlaAllocator::XlaAllocator(const se::Platform* platform, Allocator* wrapped)
: xla::DeviceMemoryAllocator(platform), wrapped_(wrapped) {}
XlaAllocator::~XlaAllocator() {}
-xla::StatusOr<gpu::DeviceMemoryBase> XlaAllocator::Allocate(
+xla::StatusOr<se::DeviceMemoryBase> XlaAllocator::Allocate(
int device_ordinal, uint64 size, bool retry_on_failure) {
void* data = wrapped_->AllocateRaw(Allocator::kAllocatorAlignment, size);
if (data == nullptr) {
return errors::ResourceExhausted("Out of memory while trying to allocate ",
size, " bytes.");
} else {
- return gpu::DeviceMemoryBase(data, size);
+ return se::DeviceMemoryBase(data, size);
}
}
-Status XlaAllocator::Deallocate(int device_ordinal,
- gpu::DeviceMemoryBase* mem) {
+Status XlaAllocator::Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) {
wrapped_->DeallocateRaw(mem->opaque());
return Status::OK();
}
@@ -102,7 +100,7 @@ ScopedShapedBuffer ExtractSubShapedBuffer(
/*target_base_index=*/{});
for (auto& index_to_buffer : shape_tree) {
if (!index_to_buffer.first.empty() && index_to_buffer.first[0] == index) {
- index_to_buffer.second = gpu::DeviceMemoryBase(nullptr, 0);
+ index_to_buffer.second = se::DeviceMemoryBase(nullptr, 0);
}
}
return ScopedShapedBuffer(std::move(sub_shaped_buffer), allocator);
@@ -149,7 +147,7 @@ void XlaComputationLaunchContext::PopulateInputs(
<< xla::ShapeUtil::HumanStringWithLayout(on_device_shape)
<< " not the same as on-host shape "
<< xla::ShapeUtil::HumanStringWithLayout(shape);
- gpu::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
+ se::DeviceMemoryBase dmem = XlaTensor::DeviceMemoryFromTensor(*t);
arg_buffers_[i] = xla::MakeUnique<ShapedBuffer>(
/*on_host_shape=*/shape, /*on_device_shape=*/shape,
client_->platform(), client_->default_device_ordinal());
@@ -162,7 +160,7 @@ void XlaComputationLaunchContext::PopulateInputs(
void XlaComputationLaunchContext::PopulateOutputs(
OpKernelContext* ctx, const XlaCompiler::CompilationResult* kernel,
ScopedShapedBuffer output) {
- gpu::Stream* stream =
+ se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
// Computation output should always be a tuple.
@@ -227,7 +225,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
const TensorShape& shape = kernel->outputs[i].shape;
VLOG(2) << "Retval " << i << " shape " << shape.DebugString();
- gpu::DeviceMemoryBase buffer = output.buffer({output_num});
+ se::DeviceMemoryBase buffer = output.buffer({output_num});
if (allocate_xla_tensors_) {
Tensor* output_tensor;
OP_REQUIRES_OK(ctx, ctx->allocate_output(i, shape, &output_tensor));
@@ -238,7 +236,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
ctx->expected_output_dtype(i), shape, buffer, allocator);
- output.set_buffer(gpu::DeviceMemoryBase(nullptr, 0), {output_num});
+ output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num});
ctx->set_output(i, output_tensor);
}
++output_num;
@@ -258,7 +256,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
write.input_index >= 0 && write.input_index < ctx->num_inputs(),
errors::Internal("Invalid input index for variable write."));
- gpu::DeviceMemoryBase buffer = output.buffer({output_num});
+ se::DeviceMemoryBase buffer = output.buffer({output_num});
Var* variable = nullptr;
// TODO(b/35625933): tensorflow::Var should contain a PersistentTensor,
@@ -288,7 +286,7 @@ void XlaComputationLaunchContext::PopulateOutputs(
} else {
Tensor output_tensor = XlaTensorBuffer::MakeTensor(
write.type, write.shape, buffer, allocator);
- output.set_buffer(gpu::DeviceMemoryBase(nullptr, 0), {output_num});
+ output.set_buffer(se::DeviceMemoryBase(nullptr, 0), {output_num});
*variable->tensor() = output_tensor;
}
++output_num;
diff --git a/tensorflow/compiler/jit/xla_launch_util.h b/tensorflow/compiler/jit/xla_launch_util.h
index 26dcaa8a51..8a6ff3b0c7 100644
--- a/tensorflow/compiler/jit/xla_launch_util.h
+++ b/tensorflow/compiler/jit/xla_launch_util.h
@@ -46,13 +46,11 @@ std::map<int, OptionalTensor> SnapshotResourceVariables(OpKernelContext* ctx,
// see comment on `AllowsAsynchronousDeallocation()`.
class XlaAllocator : public xla::DeviceMemoryAllocator {
public:
- XlaAllocator(const perftools::gputools::Platform* platform,
- Allocator* wrapped);
+ XlaAllocator(const se::Platform* platform, Allocator* wrapped);
~XlaAllocator() override;
- xla::StatusOr<perftools::gputools::DeviceMemoryBase> Allocate(
- int device_ordinal, uint64 size, bool retry_on_failure) override;
- Status Deallocate(int device_ordinal,
- perftools::gputools::DeviceMemoryBase* mem) override;
+ xla::StatusOr<se::DeviceMemoryBase> Allocate(int device_ordinal, uint64 size,
+ bool retry_on_failure) override;
+ Status Deallocate(int device_ordinal, se::DeviceMemoryBase* mem) override;
// The Tensorflow BFC allocator used on GPU allows host-side deallocation
// before GPU execution takes place. Tensorflow uses the ordering of the main
@@ -126,8 +124,7 @@ class XlaTensorBuffer : public TensorBuffer {
}
static Tensor MakeTensor(DataType dtype, const TensorShape& shape,
- perftools::gputools::DeviceMemoryBase buffer,
- Allocator* allocator) {
+ se::DeviceMemoryBase buffer, Allocator* allocator) {
size_t expected_size = shape.num_elements() * DataTypeSize(dtype);
auto* tensor_buffer = new XlaTensorBuffer(buffer.opaque(), expected_size,
buffer.size(), allocator);
diff --git a/tensorflow/compiler/jit/xla_tensor.cc b/tensorflow/compiler/jit/xla_tensor.cc
index 84b2835c40..ce6456880b 100644
--- a/tensorflow/compiler/jit/xla_tensor.cc
+++ b/tensorflow/compiler/jit/xla_tensor.cc
@@ -31,16 +31,15 @@ namespace tensorflow {
return FromTensor(const_cast<Tensor*>(tensor));
}
-/*static*/ perftools::gputools::DeviceMemoryBase
-XlaTensor::DeviceMemoryFromTensor(const Tensor& tensor) {
+/*static*/ se::DeviceMemoryBase XlaTensor::DeviceMemoryFromTensor(
+ const Tensor& tensor) {
const XlaTensor* xla_tensor = FromTensor(&tensor);
if (xla_tensor) {
CHECK(xla_tensor->has_shaped_buffer());
return xla_tensor->shaped_buffer().root_buffer();
} else {
- return perftools::gputools::DeviceMemoryBase(
- const_cast<char*>(tensor.tensor_data().data()),
- tensor.tensor_data().size());
+ return se::DeviceMemoryBase(const_cast<char*>(tensor.tensor_data().data()),
+ tensor.tensor_data().size());
}
}
diff --git a/tensorflow/compiler/jit/xla_tensor.h b/tensorflow/compiler/jit/xla_tensor.h
index 2334fd272b..922a918973 100644
--- a/tensorflow/compiler/jit/xla_tensor.h
+++ b/tensorflow/compiler/jit/xla_tensor.h
@@ -43,8 +43,7 @@ class XlaTensor {
// which case the returned value is shaped_buffer()->root_buffer(), or a
// normal Tensor in which case the returned value is
// {tensor.tensor_data().data(), tensor.tensor_data().size}.
- static perftools::gputools::DeviceMemoryBase DeviceMemoryFromTensor(
- const Tensor& tensor);
+ static se::DeviceMemoryBase DeviceMemoryFromTensor(const Tensor& tensor);
// Assign the internal ShapedBuffer to new memory for the given dtype and
// shape. If a ShapedBuffer exists already (has_shaped_buffer() == true), it
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 46b86c53aa..0c72093256 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -309,6 +309,25 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "eager_test",
+ size = "small",
+ srcs = ["eager_test.py"],
+ disabled_backends = [
+ # TODO(b/78199195) Support XLA CPU devices in eager runtime
+ "cpu",
+ "cpu_ondemand",
+ # TODO(b/78468222) Enable GPU backend
+ "gpu",
+ ],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "fft_test",
size = "medium",
srcs = ["fft_test.py"],
@@ -904,3 +923,15 @@ tf_xla_py_test(
"//tensorflow/python:platform_test",
],
)
+
+tf_xla_py_test(
+ name = "placeholder_test",
+ size = "small",
+ srcs = ["placeholder_test.py"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:platform_test",
+ ],
+)
diff --git a/tensorflow/compiler/tests/eager_test.py b/tensorflow/compiler/tests/eager_test.py
new file mode 100644
index 0000000000..bdd0185dfe
--- /dev/null
+++ b/tensorflow/compiler/tests/eager_test.py
@@ -0,0 +1,137 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Test cases for eager execution using XLA."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.eager import backprop
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.platform import googletest
+
+
+class EagerTest(XLATestCase):
+
+ def testBasic(self):
+ with self.test_scope():
+ three = constant_op.constant(3)
+ five = constant_op.constant(5)
+ product = three * five
+ self.assertAllEqual(15, product)
+
+ def testExecuteListOutputLen0(self):
+ with self.test_scope():
+ empty = constant_op.constant([], dtype=dtypes.int32)
+ result = array_ops.unstack(empty, 0)
+ self.assertTrue(isinstance(result, list))
+ self.assertEqual(0, len(result))
+
+ def testExecuteListOutputLen1(self):
+ with self.test_scope():
+ split_dim = constant_op.constant(1)
+ value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
+ result = array_ops.split(value, 1, axis=split_dim)
+ self.assertTrue(isinstance(result, list))
+ self.assertEqual(1, len(result))
+ self.assertAllEqual([[0, 1, 2], [3, 4, 5]], result[0])
+
+ def testExecuteListOutputLen3(self):
+ with self.test_scope():
+ split_dim = constant_op.constant(1)
+ value = constant_op.constant([[0, 1, 2], [3, 4, 5]])
+ result = array_ops.split(value, 3, axis=split_dim)
+ self.assertTrue(isinstance(result, list))
+ self.assertEqual(3, len(result))
+ self.assertAllEqual([[0], [3]], result[0])
+ self.assertAllEqual([[1], [4]], result[1])
+ self.assertAllEqual([[2], [5]], result[2])
+
+ def testBasicGraph(self):
+ # Run some ops eagerly
+ with self.test_scope():
+ three = constant_op.constant(3)
+ five = constant_op.constant(5)
+ product = three * five
+ self.assertAllEqual(15, product)
+
+ # Run some ops graphly
+ with context.graph_mode(), self.test_session() as sess:
+ with self.test_scope():
+ three = constant_op.constant(3)
+ five = constant_op.constant(5)
+ product = three * five
+ self.assertAllEqual(15, sess.run(product))
+
+ def testDegenerateSlices(self):
+ with self.test_scope():
+ npt = np.arange(1, 19, dtype=np.float32).reshape(3, 2, 3)
+ t = constant_op.constant(npt)
+ # degenerate by offering a forward interval with a negative stride
+ self.assertAllEqual(npt[0:-1:-1, :, :], t[0:-1:-1, :, :])
+ # degenerate with a reverse interval with a positive stride
+ self.assertAllEqual(npt[-1:0, :, :], t[-1:0, :, :])
+ # empty interval in every dimension
+ self.assertAllEqual(npt[-1:0, 2:2, 2:3:-1], t[-1:0, 2:2, 2:3:-1])
+
+ def testIdentity(self):
+ with self.test_scope():
+ self.assertAllEqual(2, array_ops.identity(2))
+
+ def testIdentityOnVariable(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable(True)
+ i = array_ops.identity(v)
+ self.assertAllEqual(True, i.numpy())
+
+ def testAssignAddVariable(self):
+ with self.test_scope():
+ v = resource_variable_ops.ResourceVariable(1.0)
+ v.assign_add(2.0)
+ self.assertEqual(3.0, v.numpy())
+
+ def testGradient(self):
+ def f(x):
+ return x
+
+ with self.test_scope():
+ grad_fn = backprop.gradients_function(f)
+ self.assertAllEqual(2., grad_fn(1., dy=2.)[0])
+
+ def testVariableGradient(self):
+ with self.test_scope():
+ v0 = resource_variable_ops.ResourceVariable(1.0)
+
+ def f():
+ x = v0 * v0
+ return x
+
+ grads = backprop.implicit_grad(f)()
+ self.assertEqual(2., grads[0][0].numpy())
+
+
+if __name__ == "__main__":
+ ops.enable_eager_execution(
+ config=config_pb2.ConfigProto(log_device_placement=True))
+ googletest.main()
diff --git a/tensorflow/compiler/tests/placeholder_test.py b/tensorflow/compiler/tests/placeholder_test.py
new file mode 100644
index 0000000000..5e6d1313bd
--- /dev/null
+++ b/tensorflow/compiler/tests/placeholder_test.py
@@ -0,0 +1,48 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for xla handling of placeholder_with_default."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import googletest
+
+
+class PlaceholderTest(XLATestCase):
+
+ def test_placeholder_with_default_default(self):
+ with self.test_session() as sess, self.test_scope():
+ v = resource_variable_ops.ResourceVariable(4.0)
+ ph = array_ops.placeholder_with_default(v, shape=[])
+ out = ph * 2
+ sess.run(variables.variables_initializer([v]))
+ self.assertEqual(8.0, sess.run(out))
+
+ def test_placeholder_with_default_fed(self):
+ with self.test_session() as sess, self.test_scope():
+ v = resource_variable_ops.ResourceVariable(4.0)
+ ph = array_ops.placeholder_with_default(v, shape=[])
+ out = ph * 2
+ sess.run(variables.variables_initializer([v]))
+ self.assertEqual(2.0, sess.run(out, {ph: 1.0}))
+
+
+if __name__ == '__main__':
+ googletest.main()
diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py
index 75a2cf07c5..ef047005b6 100644
--- a/tensorflow/compiler/tests/ternary_ops_test.py
+++ b/tensorflow/compiler/tests/ternary_ops_test.py
@@ -69,40 +69,41 @@ class TernaryOpsTest(XLATestCase):
expected=np.array([1, 3, 5], dtype=np.int32))
def testSelect(self):
- self._testTernary(
- array_ops.where,
- np.array(0, dtype=np.bool),
- np.array(2, dtype=np.float32),
- np.array(7, dtype=np.float32),
- expected=np.array(7, dtype=np.float32))
+ for dtype in self.numeric_types:
+ self._testTernary(
+ array_ops.where,
+ np.array(0, dtype=np.bool),
+ np.array(2, dtype=dtype),
+ np.array(7, dtype=dtype),
+ expected=np.array(7, dtype=dtype))
- self._testTernary(
- array_ops.where,
- np.array(1, dtype=np.bool),
- np.array([1, 2, 3, 4], dtype=np.float32),
- np.array([5, 6, 7, 8], dtype=np.float32),
- expected=np.array([1, 2, 3, 4], dtype=np.float32))
+ self._testTernary(
+ array_ops.where,
+ np.array(1, dtype=np.bool),
+ np.array([1, 2, 3, 4], dtype=dtype),
+ np.array([5, 6, 7, 8], dtype=dtype),
+ expected=np.array([1, 2, 3, 4], dtype=dtype))
- self._testTernary(
- array_ops.where,
- np.array(0, dtype=np.bool),
- np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32),
- np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32),
- expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32))
+ self._testTernary(
+ array_ops.where,
+ np.array(0, dtype=np.bool),
+ np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
+ np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
+ expected=np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype))
- self._testTernary(
- array_ops.where,
- np.array([0, 1, 1, 0], dtype=np.bool),
- np.array([1, 2, 3, 4], dtype=np.float32),
- np.array([5, 6, 7, 8], dtype=np.float32),
- expected=np.array([5, 2, 3, 8], dtype=np.float32))
+ self._testTernary(
+ array_ops.where,
+ np.array([0, 1, 1, 0], dtype=np.bool),
+ np.array([1, 2, 3, 4], dtype=dtype),
+ np.array([5, 6, 7, 8], dtype=dtype),
+ expected=np.array([5, 2, 3, 8], dtype=dtype))
- self._testTernary(
- array_ops.where,
- np.array([0, 1, 0], dtype=np.bool),
- np.array([[1, 2], [3, 4], [5, 6]], dtype=np.float32),
- np.array([[7, 8], [9, 10], [11, 12]], dtype=np.float32),
- expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=np.float32))
+ self._testTernary(
+ array_ops.where,
+ np.array([0, 1, 0], dtype=np.bool),
+ np.array([[1, 2], [3, 4], [5, 6]], dtype=dtype),
+ np.array([[7, 8], [9, 10], [11, 12]], dtype=dtype),
+ expected=np.array([[7, 8], [3, 4], [11, 12]], dtype=dtype))
def testSlice(self):
for dtype in self.numeric_types:
diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD
index ba5c3a1484..942504e6bd 100644
--- a/tensorflow/compiler/tf2xla/BUILD
+++ b/tensorflow/compiler/tf2xla/BUILD
@@ -412,7 +412,6 @@ cc_library(
hdrs = ["functionalize_control_flow.h"],
deps = [
":tf2xla_util",
- "//tensorflow/compiler/jit:graph_to_functiondef",
"//tensorflow/compiler/jit:union_find",
"//tensorflow/compiler/tf2xla:dump_graph",
"//tensorflow/compiler/tf2xla/ops:xla_ops",
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 23629d85ae..8d1f268490 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -21,13 +21,13 @@ limitations under the License.
#include <unordered_set>
#include <vector>
-#include "tensorflow/compiler/jit/graph_to_functiondef.h"
#include "tensorflow/compiler/jit/union_find.h"
#include "tensorflow/compiler/tf2xla/dump_graph.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/graph/algorithm.h"
#include "tensorflow/core/graph/control_flow.h"
diff --git a/tensorflow/compiler/tf2xla/kernels/identity_op.cc b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
index 39af662b63..e72200bfbc 100644
--- a/tensorflow/compiler/tf2xla/kernels/identity_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/identity_op.cc
@@ -38,6 +38,7 @@ class IdentityOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("Identity").CompilationOnly(), IdentityOp);
REGISTER_XLA_OP(Name("IdentityN").CompilationOnly(), IdentityOp);
+REGISTER_XLA_OP(Name("PlaceholderWithDefault"), IdentityOp);
REGISTER_XLA_OP(Name("PreventGradient"), IdentityOp);
REGISTER_XLA_OP(Name("StopGradient"), IdentityOp);
REGISTER_XLA_OP(Name("Snapshot"), IdentityOp);
diff --git a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
index 7f72a6073d..9bf5821b54 100644
--- a/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
+++ b/tensorflow/compiler/tf2xla/lib/triangular_solve.cc
@@ -83,15 +83,6 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
block_size);
}
- // Returns [b1, b2, ... , bn, indices[0], indices[1]].
- auto prepend_batch_dims = [&](std::array<int64, 2> indices) {
- std::vector<int64> output(ndims);
- std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin());
- std::copy(indices.begin(), indices.end(),
- output.begin() + batch_dimensions.size());
- return output;
- };
-
// Applies a complex conjugation operation if `a` is complex and `conjugate_a`
// is true, otherwise returns its argument.
auto maybe_conj = [&](xla::ComputationBuilder* builder,
@@ -108,11 +99,12 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
std::unique_ptr<xla::ComputationBuilder> sub = builder->CreateSubBuilder(
tensorflow::strings::StrCat("trsm_base_", k));
- auto a_param =
- sub->Parameter(0,
- xla::ShapeUtil::MakeShape(b_shape->element_type(),
- prepend_batch_dims({k, k})),
- "a");
+ auto a_param = sub->Parameter(
+ 0,
+ xla::ShapeUtil::MakeShape(
+ b_shape->element_type(),
+ PrependMajorDims(sub.get(), batch_dimensions, {k, k})),
+ "a");
std::array<int64, 2> b_lastd;
if (left_side) {
@@ -120,11 +112,12 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolve(
} else {
b_lastd = {m, k};
}
- auto b_param =
- sub->Parameter(1,
- xla::ShapeUtil::MakeShape(b_shape->element_type(),
- prepend_batch_dims(b_lastd)),
- "b");
+ auto b_param = sub->Parameter(
+ 1,
+ xla::ShapeUtil::MakeShape(
+ b_shape->element_type(),
+ PrependMajorDims(sub.get(), batch_dimensions, b_lastd)),
+ "b");
// We use a left-looking subroutine on the block diagonal in some common
// cases, while falling back to a recursive call in unsupported cases. The
@@ -380,14 +373,6 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
batch_dimensions.push_back(a_size);
}
- auto prepend_batch_dims = [&](std::array<int64, 2> indices) {
- std::vector<int64> output(ndims);
- std::copy(batch_dimensions.begin(), batch_dimensions.end(), output.begin());
- std::copy(indices.begin(), indices.end(),
- output.begin() + batch_dimensions.size());
- return output;
- };
-
auto maybe_conj = [&](xla::ComputationBuilder* builder,
xla::ComputationDataHandle x) {
auto perform_conj = a_shape->element_type() == xla::C64 && conjugate_a;
@@ -479,30 +464,6 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
auto body_b = bodyb->GetTupleElement(input_tuple, 3);
auto zero = bodyb->ConstantR0<int32>(0);
- // Set up some helper functions.
- auto prepend_zeros = [&](std::array<xla::ComputationDataHandle, 2> starts) {
- auto zero = bodyb->Reshape(bodyb->ConstantR0<int32>(0), {1});
- std::vector<xla::ComputationDataHandle> padded_starts(ndims, zero);
- padded_starts[ndims - 2] = bodyb->Reshape(starts[0], {1});
- padded_starts[ndims - 1] = bodyb->Reshape(starts[1], {1});
- return bodyb->ConcatInDim(padded_starts, 0);
- };
-
- auto dynamic_slice = [&](xla::ComputationDataHandle x,
- std::array<xla::ComputationDataHandle, 2> starts,
- std::array<int64, 2> sizes) {
- auto padded_starts = prepend_zeros(starts);
- auto padded_sizes = prepend_batch_dims(sizes);
- return bodyb->DynamicSlice(x, padded_starts, padded_sizes);
- };
-
- auto update = [&](xla::ComputationDataHandle x,
- xla::ComputationDataHandle update,
- std::array<xla::ComputationDataHandle, 2> starts) {
- auto padded_starts = prepend_zeros(starts);
- return bodyb->DynamicUpdateSlice(x, update, padded_starts);
- };
-
// We'd like to implement this:
// if transpose_a:
// a_row = T(a[..., i+1:, i:i+1])
@@ -516,22 +477,29 @@ xla::StatusOr<xla::ComputationDataHandle> TriangularSolveLeftLooking(
// all zeros and use that as zero-padding (doing unnecessary FLOPs).
xla::ComputationDataHandle a_row;
if (transpose_a) {
- a_row = dynamic_slice(body_a, {zero, i}, {m, 1});
+ TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a,
+ {zero, i}, {m, 1}));
} else {
- a_row = dynamic_slice(body_a, {i, zero}, {1, m});
+ TF_ASSIGN_OR_RETURN(a_row, DynamicSliceInMinorDims(bodyb.get(), body_a,
+ {i, zero}, {1, m}));
}
TF_ASSIGN_OR_RETURN(auto b_update, BatchDot(bodyb.get(), a_row, body_out,
/*transpose_x=*/transpose_a,
/*transpose_y=*/false,
/*conjugate_x=*/conjugate_a,
/*conjugate_y=*/false));
- auto result_row =
- bodyb->Sub(dynamic_slice(body_b, {i, zero}, {1, n}), b_update);
+ TF_ASSIGN_OR_RETURN(
+ auto result_row_slice,
+ DynamicSliceInMinorDims(bodyb.get(), body_b, {i, zero}, {1, n}));
+ auto result_row = bodyb->Sub(result_row_slice, b_update);
// body_out[..., i:i+1, :] = result_row / a[..., i:i+1, i:i+1]
- auto a_elt = dynamic_slice(body_a, {i, i}, {1, 1});
+ TF_ASSIGN_OR_RETURN(auto a_elt, DynamicSliceInMinorDims(bodyb.get(), body_a,
+ {i, i}, {1, 1}));
auto div_result = bodyb->Div(result_row, maybe_conj(bodyb.get(), a_elt));
- body_out = update(body_out, div_result, {i, zero});
+ TF_ASSIGN_OR_RETURN(body_out,
+ DynamicUpdateSliceInMinorDims(bodyb.get(), body_out,
+ div_result, {i, zero}));
// if transpose_a:
// return (i - 1, body_out, a, b)
diff --git a/tensorflow/compiler/xla/BUILD b/tensorflow/compiler/xla/BUILD
index 88f37433a5..1af9cb6d2a 100644
--- a/tensorflow/compiler/xla/BUILD
+++ b/tensorflow/compiler/xla/BUILD
@@ -605,8 +605,8 @@ cc_library(
":util",
":window_util",
":xla_data_proto",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_evaluator",
"//tensorflow/compiler/xla/service:shape_inference",
diff --git a/tensorflow/compiler/xla/client/client.cc b/tensorflow/compiler/xla/client/client.cc
index f0f94298a0..328e1b8fa8 100644
--- a/tensorflow/compiler/xla/client/client.cc
+++ b/tensorflow/compiler/xla/client/client.cc
@@ -235,6 +235,11 @@ StatusOr<Computation> Client::LoadSnapshot(const SessionModule& module) {
return Computation(stub_, response.computation());
}
+StatusOr<XlaComputation> Client::LoadSnapshot(const HloSnapshot& module) {
+ TF_RET_CHECK(module.has_hlo() && module.hlo().has_hlo_module());
+ return XlaComputation(module.hlo().hlo_module());
+}
+
StatusOr<std::unique_ptr<GlobalData>> Client::Execute(
const Computation& computation,
tensorflow::gtl::ArraySlice<GlobalData*> arguments,
diff --git a/tensorflow/compiler/xla/client/client.h b/tensorflow/compiler/xla/client/client.h
index 14c685d94e..a63ff4c56d 100644
--- a/tensorflow/compiler/xla/client/client.h
+++ b/tensorflow/compiler/xla/client/client.h
@@ -255,6 +255,9 @@ class Client {
StatusOr<Computation> LoadSnapshot(const SessionModule& module);
+ // TODO(b/74197823): This is a part of a NOT YET ready refactor.
+ StatusOr<XlaComputation> LoadSnapshot(const HloSnapshot& module);
+
ServiceInterface* stub() { return stub_; }
private:
diff --git a/tensorflow/compiler/xla/client/local_client.cc b/tensorflow/compiler/xla/client/local_client.cc
index d0e945b70f..1c12705903 100644
--- a/tensorflow/compiler/xla/client/local_client.cc
+++ b/tensorflow/compiler/xla/client/local_client.cc
@@ -166,12 +166,8 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
if (executable_->dumping()) {
return ExecuteAndDump(&service_options, arguments);
}
- TF_ASSIGN_OR_RETURN(
- ShapedBuffer result,
- executable_->ExecuteOnStreamWrapper(
- &service_options, run_options.execution_profile(), arguments));
-
- return ScopedShapedBuffer(std::move(result), run_options.allocator());
+ return executable_->ExecuteOnStreamWrapper(
+ &service_options, run_options.execution_profile(), arguments);
}
StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
@@ -181,12 +177,12 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::ExecuteAndDump(
backend_->platform()->Name());
TF_RETURN_IF_ERROR(RecordArguments(arguments, executable_->session_module()));
TF_ASSIGN_OR_RETURN(
- ShapedBuffer result,
+ ScopedShapedBuffer result,
executable_->ExecuteOnStream(run_options, arguments,
/*hlo_execution_profile=*/nullptr));
TF_RETURN_IF_ERROR(RecordResult(&result, executable_->session_module()));
TF_RETURN_IF_ERROR(executable_->DumpSessionModule());
- return ScopedShapedBuffer(std::move(result), run_options->allocator());
+ return std::move(result);
}
tensorflow::Status LocalExecutable::RecordArguments(
diff --git a/tensorflow/compiler/xla/client/xla_client/BUILD b/tensorflow/compiler/xla/client/xla_client/BUILD
index 31fa1241ee..0d6e207971 100644
--- a/tensorflow/compiler/xla/client/xla_client/BUILD
+++ b/tensorflow/compiler/xla/client/xla_client/BUILD
@@ -31,9 +31,9 @@ cc_library(
hdrs = ["xla_computation.h"],
deps = [
"//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/service:hlo_proto",
- "//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc
index a6752c6010..72e3935696 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_computation.cc
+++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.cc
@@ -17,7 +17,9 @@ limitations under the License.
#include <utility>
+#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
+#include "tensorflow/compiler/xla/util.h"
namespace xla {
@@ -26,4 +28,13 @@ StatusOr<ProgramShape> XlaComputation::GetProgramShape() const {
return proto_.program_shape();
}
+StatusOr<std::unique_ptr<HloSnapshot>> XlaComputation::Snapshot() const {
+ if (IsNull()) {
+ return InvalidArgument("Computation is invalid.");
+ }
+ auto session = MakeUnique<HloSnapshot>();
+ *session->mutable_hlo()->mutable_hlo_module() = proto_;
+ return std::move(session);
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/client/xla_client/xla_computation.h b/tensorflow/compiler/xla/client/xla_client/xla_computation.h
index 7ad212aa24..b70b57e9ff 100644
--- a/tensorflow/compiler/xla/client/xla_client/xla_computation.h
+++ b/tensorflow/compiler/xla/client/xla_client/xla_computation.h
@@ -48,6 +48,10 @@ class XlaComputation {
const HloModuleProto& proto() const { return proto_; }
+ // Requests that we snapshot the computation into a serializable protocol
+ // buffer form.
+ StatusOr<std::unique_ptr<HloSnapshot>> Snapshot() const;
+
// Returns true if this object is a null Computation.
bool IsNull() const { return unique_id_ == -1; }
diff --git a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
index 70ae95bf47..bc8405703b 100644
--- a/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
+++ b/tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
@@ -43,7 +43,7 @@ void SetDebugOptionsDefaults(DebugOptions* flags) {
#ifdef INTEL_MKL
flags->set_xla_cpu_use_mkl_dnn(true);
#endif // INTEL_MKL
- flags->set_xla_gpu_max_kernel_unroll_factor(1);
+ flags->set_xla_gpu_max_kernel_unroll_factor(4);
// Set cudnn batchnorm off by default; it does not provide a performance win
// on average.
flags->set_xla_gpu_use_cudnn_batchnorm(false);
diff --git a/tensorflow/compiler/xla/ptr_util.h b/tensorflow/compiler/xla/ptr_util.h
index c58c19db2c..bfcdfc62f9 100644
--- a/tensorflow/compiler/xla/ptr_util.h
+++ b/tensorflow/compiler/xla/ptr_util.h
@@ -28,26 +28,8 @@ limitations under the License.
#include "tensorflow/core/util/ptr_util.h"
namespace xla {
-
-template <typename T>
-std::unique_ptr<T> WrapUnique(T* ptr) {
- return tensorflow::WrapUnique<T>(ptr);
-}
-
-template <typename T, typename... Args>
-typename tensorflow::helper::MakeUniqueResult<T>::scalar MakeUnique(
- Args&&... args) {
- return tensorflow::MakeUnique<T, Args...>(std::forward<Args>(args)...);
-}
-
-// Overload for array of unknown bound.
-// The allocation of arrays needs to use the array form of new,
-// and cannot take element constructor arguments.
-template <typename T>
-typename tensorflow::helper::MakeUniqueResult<T>::array MakeUnique(size_t n) {
- return tensorflow::MakeUnique<T>(n);
-}
-
+using tensorflow::MakeUnique;
+using tensorflow::WrapUnique;
} // namespace xla
#endif // TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index ad3a28e119..df9dbc5830 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -18,7 +18,7 @@ limitations under the License.
#include <array>
#include <utility>
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/service/cpu/runtime_single_threaded_matmul.h"
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -90,7 +90,7 @@ std::unique_ptr<Array2D<T>> MatmulArray2DImpl(
Padding padding) {
return ConvArray3DGeneralDimensionsDilated(
lhs, rhs, kernel_stride, padding, 1, 1,
- ComputationBuilder::CreateDefaultConvDimensionNumbers(1));
+ XlaBuilder::CreateDefaultConvDimensionNumbers(1));
}
/*static*/ std::unique_ptr<Array3D<float>>
@@ -140,7 +140,7 @@ ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
std::pair<int64, int64> kernel_stride, Padding padding) {
return ConvArray4DGeneralDimensions(
lhs, rhs, kernel_stride, padding,
- ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
}
/* static */ std::unique_ptr<Array4D<float>>
diff --git a/tensorflow/compiler/xla/service/BUILD b/tensorflow/compiler/xla/service/BUILD
index afb344e5ae..d55da3686c 100644
--- a/tensorflow/compiler/xla/service/BUILD
+++ b/tensorflow/compiler/xla/service/BUILD
@@ -359,6 +359,7 @@ cc_library(
":hlo",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
+ "//tensorflow/core:lib",
],
)
@@ -1954,12 +1955,10 @@ cc_library(
deps = [
":computation_layout",
":hlo",
- ":hlo_dce",
":hlo_graph_dumper",
":hlo_pass",
":logical_buffer",
":tuple_points_to_analysis",
- ":tuple_simplifier",
"//tensorflow/compiler/xla:shape_layout",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
@@ -2436,7 +2435,6 @@ cc_library(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:status_macros",
"//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/core:lib",
],
)
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.cc b/tensorflow/compiler/xla/service/allocation_tracker.cc
index 6bf65825cd..cf1231bcce 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.cc
+++ b/tensorflow/compiler/xla/service/allocation_tracker.cc
@@ -31,23 +31,35 @@ limitations under the License.
namespace xla {
StatusOr<GlobalDataHandle> AllocationTracker::Register(
- ShapedBuffer shaped_buffer, const string& tag) {
+ ScopedShapedBuffer shaped_buffer, const string& tag) {
tensorflow::mutex_lock lock(mutex_);
VLOG(2) << "Register";
- std::vector<ShapedBuffer> replicated_buffers;
+ std::vector<ScopedShapedBuffer> replicated_buffers;
replicated_buffers.emplace_back(std::move(shaped_buffer));
return RegisterInternal(std::move(replicated_buffers), tag);
}
StatusOr<GlobalDataHandle> AllocationTracker::RegisterReplicatedBuffers(
- std::vector<ShapedBuffer> replicated_buffers, const string& tag) {
+ std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag) {
tensorflow::mutex_lock lock(mutex_);
VLOG(2) << "RegisterReplicatedBuffers";
return RegisterInternal(std::move(replicated_buffers), tag);
}
+// ReleaseIfScopedShapedBuffer lets RegisterInternal<ShapedBufferTy>(b) call
+// b.release() if b is a ScopedShapedBuffer, or otherwise pass b through
+// unmodified.
+static ShapedBuffer ReleaseIfScopedShapedBuffer(ShapedBuffer b) { return b; }
+static ShapedBuffer ReleaseIfScopedShapedBuffer(ScopedShapedBuffer b) {
+ return b.release();
+}
+
+template <typename ShapedBufferTy>
StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
- std::vector<ShapedBuffer> replicated_buffers, const string& tag) {
+ std::vector<ShapedBufferTy> replicated_buffers, const string& tag) {
+ static_assert(std::is_same<ShapedBufferTy, ShapedBuffer>::value ||
+ std::is_same<ShapedBufferTy, ScopedShapedBuffer>::value,
+ "ShapedBufferTy must be ShapedBuffer or ScopedShapedBuffer.");
VLOG(2) << "RegisterInternal("
<< "tag: \"" << tag << "\" with " << replicated_buffers.size()
<< " shaped_buffers.";
@@ -65,17 +77,22 @@ StatusOr<GlobalDataHandle> AllocationTracker::RegisterInternal(
int64 handle = next_handle_++;
for (auto& shaped_buffer : replicated_buffers) {
std::vector<ShapeIndex> shape_indices;
- ShapeUtil::ForEachSubshape(shaped_buffer.on_device_shape(),
- [this, &shape_indices](const Shape& /*subshape*/,
- const ShapeIndex& index) {
- shape_indices.push_back(index);
- });
+ ShapeUtil::ForEachSubshape(
+ shaped_buffer.on_device_shape(),
+ [&](const Shape& /*subshape*/, const ShapeIndex& index) {
+ shape_indices.push_back(index);
+ });
+ // Add shaped_buffer's buffers to opaque_to_allocation_map_, which owns
+ // them.
for (const ShapeIndex& index : shape_indices) {
AddAllocationOrIncrementRefCount(shaped_buffer.buffer(index),
shaped_buffer.device_ordinal());
}
- handle_to_shaped_buffers_[handle].emplace_back(
- MakeUnique<ShapedBuffer>(std::move(shaped_buffer)));
+ // If ShapedBufferTy is ScopedShapedBuffer, release the ScopedShapedBuffer
+ // into a regular ShapedBuffer, which is stored in
+ // handle_to_shaped_buffers_.
+ handle_to_shaped_buffers_[handle].emplace_back(MakeUnique<ShapedBuffer>(
+ ReleaseIfScopedShapedBuffer(std::move(shaped_buffer))));
}
GlobalDataHandle result;
@@ -102,10 +119,6 @@ tensorflow::Status AllocationTracker::Unregister(const GlobalDataHandle& data) {
shaped_buffer->device_ordinal()));
}
}
- return Reset(data);
-}
-
-Status AllocationTracker::Reset(const GlobalDataHandle& data) {
// Keep a nullptr as a tombstone for unregistered handles. This enables
// better error messages. That is, "handle has been deallocated" versus
// "handle does not exist".
@@ -152,7 +165,7 @@ StatusOr<std::vector<GlobalDataHandle>> AllocationTracker::DeconstructTuple(
element_buffer.set_buffer(shaped_buffer->buffer(/*index=*/{i}),
/*index=*/{});
std::vector<ShapedBuffer> replicated_buffers;
- replicated_buffers.emplace_back(std::move(element_buffer));
+ replicated_buffers.push_back(std::move(element_buffer));
TF_ASSIGN_OR_RETURN(
GlobalDataHandle element_handle,
RegisterInternal(std::move(replicated_buffers), "deconstructed tuple"));
diff --git a/tensorflow/compiler/xla/service/allocation_tracker.h b/tensorflow/compiler/xla/service/allocation_tracker.h
index 2bfcd53712..1174fa641c 100644
--- a/tensorflow/compiler/xla/service/allocation_tracker.h
+++ b/tensorflow/compiler/xla/service/allocation_tracker.h
@@ -45,13 +45,13 @@ class AllocationTracker {
// Registers a shaped buffer of device memory, and returns a corresponding
// handle that can be used for talking to XLA clients. The given shaped buffer
// will be treated as the buffer corresponding to the only replica.
- StatusOr<GlobalDataHandle> Register(ShapedBuffer shaped_buffer,
+ StatusOr<GlobalDataHandle> Register(ScopedShapedBuffer shaped_buffer,
const string& tag);
// Registers a vector of shaped buffers of device memory, one per replica, and
// returns a corresponding handle that can be used for talking to XLA clients.
StatusOr<GlobalDataHandle> RegisterReplicatedBuffers(
- std::vector<ShapedBuffer> replicated_buffers, const string& tag);
+ std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag);
// Unregister the allocation for the given data handle.
Status Unregister(const GlobalDataHandle& data);
@@ -87,21 +87,21 @@ class AllocationTracker {
};
// Internal helper which resolves the given GlobalDataHandle to a
- // ShapedBuffer.
+ // list of ScopedShapedBuffers.
StatusOr<std::vector<const ShapedBuffer*>> ResolveInternal(
const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
// Internal helper which registers a vector of shaped buffers, one per
- // replica.
+ // replica. ShapedBufferTy is either ScopedShapedBuffer or ShapedBuffer. If
+ // it's ShapedBuffer, all of the given buffers must already be tracked by this
+ // object -- presumably this is a call from DeconstructTuple.
+ template <typename ShapedBufferTy>
StatusOr<GlobalDataHandle> RegisterInternal(
- std::vector<ShapedBuffer> replicated_buffers, const string& tag)
+ std::vector<ShapedBufferTy> replicated_buffers, const string& tag)
EXCLUSIVE_LOCKS_REQUIRED(mutex_);
- // Resets the shaped buffers corresponding to the given handle.
- Status Reset(const GlobalDataHandle& data) EXCLUSIVE_LOCKS_REQUIRED(mutex_);
-
// Adds the given device address to the allocation tracker, or if it already
- // exists, then increment it's reference count.
+ // exists, then increment its reference count.
void AddAllocationOrIncrementRefCount(se::DeviceMemoryBase device_memory,
int device_ordinal)
EXCLUSIVE_LOCKS_REQUIRED(mutex_);
@@ -133,7 +133,19 @@ class AllocationTracker {
// buffers for different replicas.
//
// The ShapedBuffers in this map's vectors need to be unique_ptrs, because our
- // public API returns pointers to them.
+ // public API returns pointers to them. We expect the concrete class to be
+ // ShapedBuffer and never ScopedShapedBuffer; deallocation of buffers is
+ // handled by opaque_to_allocation_map_.
+ //
+ // The elements of the vectors need to be unique_ptrs because we return
+ // pointers to them. (In theory we could use std::list or something instead,
+ // but we also want to be able to null out these elements.)
+ //
+ // The reason that the elements can't be unique_ptr<ScopedShapedBuffer>s is
+ // the existence of DeconstructTuple(). This function allows us to create a
+ // non-owning "view" into a tuple's sub-buffers. The sub-buffers are then
+ // free'd when both the view *and* the original tuple are Unregistered. This
+ // refcounting is managed in opaque_to_allocation_map_.
tensorflow::gtl::FlatMap<int64, std::vector<std::unique_ptr<ShapedBuffer>>>
handle_to_shaped_buffers_ GUARDED_BY(mutex_);
diff --git a/tensorflow/compiler/xla/service/backend.cc b/tensorflow/compiler/xla/service/backend.cc
index a582dbffd6..b1d616ec35 100644
--- a/tensorflow/compiler/xla/service/backend.cc
+++ b/tensorflow/compiler/xla/service/backend.cc
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/compiler/xla/service/computation_layout.cc b/tensorflow/compiler/xla/service/computation_layout.cc
index cb61f3da39..d2d4f14fce 100644
--- a/tensorflow/compiler/xla/service/computation_layout.cc
+++ b/tensorflow/compiler/xla/service/computation_layout.cc
@@ -23,15 +23,12 @@ limitations under the License.
namespace xla {
-ComputationLayout::ComputationLayout(const ProgramShape& program_shape,
- bool ignore_layouts)
+ComputationLayout::ComputationLayout(const ProgramShape& program_shape)
: result_layout_(program_shape.result()) {
for (auto& shape : program_shape.parameters()) {
parameter_layouts_.emplace_back(shape);
}
- if (ignore_layouts) {
- SetToDefaultLayout();
- }
+ SetToDefaultLayout();
}
void ComputationLayout::SetToDefaultLayout() {
diff --git a/tensorflow/compiler/xla/service/computation_layout.h b/tensorflow/compiler/xla/service/computation_layout.h
index 53c3a3f7b7..80e102411c 100644
--- a/tensorflow/compiler/xla/service/computation_layout.h
+++ b/tensorflow/compiler/xla/service/computation_layout.h
@@ -34,9 +34,8 @@ class ComputationLayout {
public:
// Constructs a ComputationLayout from a ProgramShape. The layouts of the
// parameters and results are set to the default layout. Layouts in the
- // ProgramShape are ignored if ignore_layouts is true.
- explicit ComputationLayout(const ProgramShape& program_shape,
- bool ignore_layouts = true);
+ // ProgramShape are ignored.
+ explicit ComputationLayout(const ProgramShape& program_shape);
// Returns the layout of a particular parameter.
const ShapeLayout& parameter_layout(int64 param_no) const {
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index 246b802861..04fda3b2df 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -89,12 +89,10 @@ cc_library(
":cpu_instruction_fusion",
":cpu_layout_assignment",
":cpu_options",
- ":cpu_parallelization_preparation",
":disassembler",
":dot_op_emitter",
":ir_emission_utils",
":ir_emitter",
- ":parallel_cpu_executable",
":parallel_task_assignment",
":simple_orc_jit",
"//tensorflow/compiler/xla:literal_util",
@@ -233,35 +231,6 @@ cc_library(
)
cc_library(
- name = "parallel_cpu_executable",
- srcs = ["parallel_cpu_executable.cc"],
- hdrs = [
- "parallel_cpu_executable.h",
- ],
- deps = [
- ":cpu_runtime",
- ":shape_partition",
- ":simple_orc_jit",
- "//tensorflow/compiler/xla:shape_util",
- "//tensorflow/compiler/xla:status_macros",
- "//tensorflow/compiler/xla:statusor",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/service:buffer_assignment",
- "//tensorflow/compiler/xla/service:device_memory_allocator",
- "//tensorflow/compiler/xla/service:executable",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/service:hlo_execution_profile",
- "//tensorflow/compiler/xla/service:logical_buffer",
- "//tensorflow/compiler/xla/service:shaped_buffer",
- "//tensorflow/core:lib",
- "//tensorflow/core:stream_executor_no_cuda",
- "@llvm//:orc_jit",
- ],
-)
-
-cc_library(
name = "ir_emitter",
srcs = [
"elemental_ir_emitter.cc",
@@ -662,25 +631,6 @@ cc_library(
)
cc_library(
- name = "cpu_parallelization_preparation",
- srcs = ["cpu_parallelization_preparation.cc"],
- hdrs = [
- "cpu_parallelization_preparation.h",
- ],
- deps = [
- ":ir_emission_utils",
- ":parallel_task_assignment",
- ":shape_partition",
- "//tensorflow/compiler/xla:types",
- "//tensorflow/compiler/xla:util",
- "//tensorflow/compiler/xla/service:hlo",
- "//tensorflow/compiler/xla/service:hlo_cost_analysis",
- "//tensorflow/compiler/xla/service:hlo_pass",
- "//tensorflow/core:lib",
- ],
-)
-
-cc_library(
name = "ir_emission_utils",
srcs = ["ir_emission_utils.cc"],
hdrs = ["ir_emission_utils.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index e8472fd36b..3c0c367df3 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -56,12 +56,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_layout_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_options.h"
-#include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h"
#include "tensorflow/compiler/xla/service/cpu/disassembler.h"
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
-#include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h"
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
@@ -308,10 +306,7 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
module->config().intra_op_parallelism_threads() > 0
? module->config().intra_op_parallelism_threads()
: tensorflow::port::NumSchedulableCPUs();
- if (options::CpuParallelBackendRequested(module->config())) {
- pipeline.AddPass<ParallelizationPreparation>(max_parallelism,
- ShapeSizeBytesFunction());
- } else if (!is_aot_compile) {
+ if (!is_aot_compile) {
// Run ParallelTaskAssigner to assign parallel tasks to HLOs in module.
// Note this is not run for AOT because it would bring in thread pool
// and thread synchronization dependencies which would likely increase
@@ -329,13 +324,6 @@ Status CpuCompiler::RunHloPasses(HloModule* module, bool is_aot_compile) {
pipeline.AddPass<HloDCE>();
pipeline.AddPass<FlattenCallGraph>();
pipeline.AddPass<CpuCopyInsertion>();
- if (options::CpuParallelBackendRequested(module->config())) {
- // Re-run the outlining, in case any copies were inserted into the entry
- // computation.
- pipeline.AddPass<ParallelizationPreparation>(max_parallelism,
- ShapeSizeBytesFunction());
- pipeline.AddPass<CpuCopyInsertion>();
- }
pipeline.AddPass<HloDCE>();
return pipeline.Run(module).status();
}
@@ -522,190 +510,80 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
const string xla_dump_optimized_hlo_proto_to =
module->config().debug_options().xla_dump_optimized_hlo_proto_to();
- if (options::CpuParallelBackendRequested(module->config())) {
- VLOG(1) << "Using parallel cpu backend";
-
- // Run buffer analysis on the HLO graph. This analysis figures out which
- // temporary buffers are required to run the computation.
- // DependencyHloOrdering is used for the parallel emitter because the order
- // of HLO instruction execution is not known ahead of time.
- // DependencyHloOrdering is the most conservative partial order and only
- // uses data dependencies for determining order.
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<BufferAssignment> assignment,
- BufferAssigner::Run(
- module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()),
- BufferSizeBytesFunction(), memory_alignment));
- // BufferAssignment::ToString() includes a header, so no need for us to
- // print one ourselves.
- XLA_VLOG_LINES(2, assignment->ToString());
-
- if (!xla_dump_optimized_hlo_proto_to.empty()) {
- HloProto proto = MakeHloProto(*module, *assignment);
- TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
- proto, xla_dump_optimized_hlo_proto_to, module->name()));
- }
-
- // If we are using the parallel CPU backend, we need to create map from
- // HloInstruction to the corresponding generated function name.
- std::map<HloComputation*, HloInstruction*> parallel_computations;
- std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>>
- aligned_constants;
- for (auto instruction : entry_computation->MakeInstructionPostOrder()) {
- // Parameters and constants don't get their own computation.
- if (instruction->opcode() == HloOpcode::kParameter) {
- continue;
- }
- if (instruction->opcode() == HloOpcode::kConstant) {
- // Copy the constant out of the ProtocolBuffer so that we can give it a
- // higher alignment.
- const void* data = instruction->literal().untyped_data();
- int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape());
- auto iter = aligned_constants.emplace(
- instruction, xla::MakeUnique<unsigned char[]>(size));
- CHECK_EQ(iter.second, true);
- unsigned char* aligned_data = iter.first->second.get();
- memcpy(aligned_data, data, size);
- continue;
- }
- // The parallel preparation should have ensured that the top-level
- // computation consists solely of Call instructions.
- TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall)
- << module->ToString();
- HloComputation* to_apply = instruction->to_apply();
- parallel_computations.emplace(to_apply, instruction);
- }
-
- IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
- std::move(instruction_to_profile_idx),
- std::move(computation_to_profile_idx),
- jit->target_machine(), jit->external_constant_pool());
-
- std::unique_ptr<HloInstructionMap<string>> function_names(
- new HloInstructionMap<string>());
- for (auto embedded_computation :
- entry_computation->MakeEmbeddedComputationsList()) {
- if (embedded_computation->IsFusionComputation()) {
- continue;
- }
- auto parallel_computation_iter =
- parallel_computations.find(embedded_computation);
- // All parallel computations are considered to be an entry computation for
- // IR generation purposes.
- bool computation_is_parallel =
- parallel_computation_iter != parallel_computations.end();
- TF_ASSIGN_OR_RETURN(
- llvm::Function * ir_function,
- ir_emitter.EmitComputation(
- embedded_computation, embedded_computation->name(),
- /*is_top_level_computation=*/computation_is_parallel,
- /*instruction_order=*/nullptr));
- // If this computation is parallel, remember it in the function name map.
- // This way we know what function to execute when we try to run code for
- // the Call instruction.
- if (computation_is_parallel) {
- HloInstruction* call_instruction = parallel_computation_iter->second;
- InsertOrDie(function_names.get(), call_instruction,
- llvm_ir::AsString(ir_function->getName()));
- }
- }
-
- string ir_module_string;
- if (embed_ir_in_executable) {
- ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
- }
- TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
-
- // JIT compile the LLVM IR module to in-memory machine code.
- jit->AddModule(std::move(llvm_module));
- cpu_executable.reset(new ParallelCpuExecutable(
- std::move(jit), std::move(assignment), std::move(module),
- std::move(function_names), std::move(aligned_constants),
- std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
-
- if (embed_ir_in_executable) {
- static_cast<CpuExecutable&>(*cpu_executable)
- .set_ir_module_string(ir_module_string);
- }
- } else {
- VLOG(1) << "Using sequential cpu backend";
-
- // Select an order for emitting the HLO instructions for each
- // computation. Using this sequence enables tighter buffer liveness analysis
- // and reduced memory usage (as compared to using DependencyHloOrdering).
- TF_ASSIGN_OR_RETURN(
- SequentialHloOrdering::HloModuleSequence module_sequence,
- CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction()));
-
- // Run buffer analysis on the HLO graph. This analysis figures out which
- // temporary buffers are required to run the computation.
- TF_ASSIGN_OR_RETURN(
- std::unique_ptr<BufferAssignment> assignment,
- BufferAssigner::Run(module.get(),
- xla::MakeUnique<SequentialHloOrdering>(
- module.get(), module_sequence),
- BufferSizeBytesFunction(), memory_alignment));
- // BufferAssignment::ToString() includes a header, so no need for us to
- // print one ourselves.
- XLA_VLOG_LINES(2, assignment->ToString());
-
- if (!xla_dump_optimized_hlo_proto_to.empty()) {
- HloProto proto = MakeHloProto(*module, *assignment);
- TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
- proto, xla_dump_optimized_hlo_proto_to, module->name()));
- }
-
- // Each computation is a single function. Emit all embedded computations
- // before the entry computation. The order of computations returned from
- // GetEmbeddedComputations guarantees that a called computation occurs
- // before a caller computation.
+ // Select an order for emitting the HLO instructions for each
+ // computation. Using this sequence enables tighter buffer liveness analysis
+ // and reduced memory usage (as compared to using DependencyHloOrdering).
+ TF_ASSIGN_OR_RETURN(
+ SequentialHloOrdering::HloModuleSequence module_sequence,
+ CreateMemoryMinimizingSequence(*module, BufferSizeBytesFunction()));
+
+ // Run buffer analysis on the HLO graph. This analysis figures out which
+ // temporary buffers are required to run the computation.
+ TF_ASSIGN_OR_RETURN(
+ std::unique_ptr<BufferAssignment> assignment,
+ BufferAssigner::Run(
+ module.get(),
+ xla::MakeUnique<SequentialHloOrdering>(module.get(), module_sequence),
+ BufferSizeBytesFunction(), memory_alignment));
+ // BufferAssignment::ToString() includes a header, so no need for us to
+ // print one ourselves.
+ XLA_VLOG_LINES(2, assignment->ToString());
+
+ if (!xla_dump_optimized_hlo_proto_to.empty()) {
+ HloProto proto = MakeHloProto(*module, *assignment);
+ TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
+ proto, xla_dump_optimized_hlo_proto_to, module->name()));
+ }
- IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
- std::move(instruction_to_profile_idx),
- std::move(computation_to_profile_idx),
- jit->target_machine(), jit->external_constant_pool());
+ // Each computation is a single function. Emit all embedded computations
+ // before the entry computation. The order of computations returned from
+ // GetEmbeddedComputations guarantees that a called computation occurs
+ // before a caller computation.
- for (auto embedded_computation :
- entry_computation->MakeEmbeddedComputationsList()) {
- if (embedded_computation->IsFusionComputation()) {
- continue;
- }
- TF_RETURN_IF_ERROR(
- ir_emitter
- .EmitComputation(embedded_computation,
- embedded_computation->name(),
- /*is_top_level_computation=*/false,
- &module_sequence.at(embedded_computation))
- .status());
- }
- string function_name_prefix = entry_computation->name().empty()
- ? "__compute"
- : entry_computation->name();
- TF_ASSIGN_OR_RETURN(
- llvm::Function * entry_function,
- ir_emitter.EmitComputation(entry_computation, function_name_prefix,
- /*is_top_level_computation=*/true,
- &module_sequence.at(entry_computation)));
+ IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
+ std::move(instruction_to_profile_idx),
+ std::move(computation_to_profile_idx),
+ jit->target_machine(), jit->external_constant_pool());
- string function_name = llvm_ir::AsString(entry_function->getName());
- string ir_module_string;
- if (embed_ir_in_executable) {
- ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
+ for (auto embedded_computation :
+ entry_computation->MakeEmbeddedComputationsList()) {
+ if (embedded_computation->IsFusionComputation()) {
+ continue;
}
- TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
+ TF_RETURN_IF_ERROR(
+ ir_emitter
+ .EmitComputation(embedded_computation, embedded_computation->name(),
+ /*is_top_level_computation=*/false,
+ &module_sequence.at(embedded_computation))
+ .status());
+ }
+ string function_name_prefix = entry_computation->name().empty()
+ ? "__compute"
+ : entry_computation->name();
+ TF_ASSIGN_OR_RETURN(
+ llvm::Function * entry_function,
+ ir_emitter.EmitComputation(entry_computation, function_name_prefix,
+ /*is_top_level_computation=*/true,
+ &module_sequence.at(entry_computation)));
+
+ string function_name = llvm_ir::AsString(entry_function->getName());
+ string ir_module_string;
+ if (embed_ir_in_executable) {
+ ir_module_string = llvm_ir::DumpModuleToString(*llvm_module);
+ }
+ TF_RETURN_IF_ERROR(VerifyLlvmModule(*llvm_module));
- XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module));
+ XLA_VLOG_LINES(2, "LLVM IR:\n" + llvm_ir::DumpModuleToString(*llvm_module));
- // JIT compile the LLVM IR module to in-memory machine code.
- jit->AddModule(std::move(llvm_module));
- cpu_executable.reset(new CpuExecutable(
- std::move(jit), std::move(assignment), std::move(module), function_name,
- std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
+ // JIT compile the LLVM IR module to in-memory machine code.
+ jit->AddModule(std::move(llvm_module));
+ cpu_executable.reset(new CpuExecutable(
+ std::move(jit), std::move(assignment), std::move(module), function_name,
+ std::move(hlo_profile_printer_data), std::move(hlo_profile_index_map)));
- if (embed_ir_in_executable) {
- static_cast<CpuExecutable&>(*cpu_executable)
- .set_ir_module_string(ir_module_string);
- }
+ if (embed_ir_in_executable) {
+ static_cast<CpuExecutable&>(*cpu_executable)
+ .set_ir_module_string(ir_module_string);
}
VLOG(1) << "Compilation finished";
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
index 97e550abe4..aabf4d5161 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.cc
@@ -243,14 +243,14 @@ static Status DeallocateTempBuffers(
return Status::OK();
}
-StatusOr<ShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
+StatusOr<ScopedShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> allocated_buffers,
std::vector<bool>* buffers_in_result) {
se::Stream* stream = run_options->stream();
- ShapedBuffer result_buffer(
+ ScopedShapedBuffer result_buffer(
/*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(),
- stream->parent()->platform(), stream->parent()->device_ordinal());
+ run_options->allocator(), stream->parent()->device_ordinal());
// Copy DeviceMemoryBase values which contain the array(s) of the result into
// the respective location in ShapedBuffer which is returned to the caller.
@@ -281,7 +281,7 @@ StatusOr<ShapedBuffer> CpuExecutable::CreateResultShapedBuffer(
return std::move(result_buffer);
}
-StatusOr<ShapedBuffer> CpuExecutable::ExecuteOnStream(
+StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) {
@@ -300,7 +300,7 @@ StatusOr<ShapedBuffer> CpuExecutable::ExecuteOnStream(
std::vector<bool> buffers_in_result(assignment_->Allocations().size(), false);
TF_ASSIGN_OR_RETURN(
- ShapedBuffer result_buffer,
+ ScopedShapedBuffer result_buffer,
CreateResultShapedBuffer(run_options, buffers, &buffers_in_result));
// Free all buffers not in the result.
@@ -310,7 +310,7 @@ StatusOr<ShapedBuffer> CpuExecutable::ExecuteOnStream(
return std::move(result_buffer);
}
-StatusOr<ShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
+StatusOr<ScopedShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
if (hlo_profiling_enabled()) {
@@ -330,7 +330,7 @@ StatusOr<ShapedBuffer> CpuExecutable::ExecuteAsyncOnStream(
std::vector<bool> buffers_in_result(assignment_->Allocations().size(), false);
TF_ASSIGN_OR_RETURN(
- ShapedBuffer result_buffer,
+ ScopedShapedBuffer result_buffer,
CreateResultShapedBuffer(run_options, buffers, &buffers_in_result));
LogLiveAddresses(buffers, buffers_in_result);
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_executable.h b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
index 06b6943cb5..68ad38cba8 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_executable.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_executable.h
@@ -55,12 +55,12 @@ class CpuExecutable : public Executable {
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
~CpuExecutable() override {}
- StatusOr<ShapedBuffer> ExecuteOnStream(
+ StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) override;
- StatusOr<ShapedBuffer> ExecuteAsyncOnStream(
+ StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
@@ -102,13 +102,13 @@ class CpuExecutable : public Executable {
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
HloExecutionProfile* hlo_execution_profile);
- // Creates a ShapedBuffer for holding the result of the computation. The
+ // Creates a ScopedShapedBuffer for holding the result of the computation. The
// addresses (DeviceMemoryBases) are set according to buffer assignment.
// 'buffers_in_result' should point to a vector of the same size as
// 'allocated_buffers'. An element in buffers_in_result is set to true if the
// corresponding buffer is live out of the computation (and thus contained in
// the returned ShapedBuffer).
- StatusOr<ShapedBuffer> CreateResultShapedBuffer(
+ StatusOr<ScopedShapedBuffer> CreateResultShapedBuffer(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> allocated_buffers,
std::vector<bool>* buffers_in_result);
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
index 09f028463a..f9c51f243c 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc
@@ -19,7 +19,6 @@ limitations under the License.
namespace {
-const char* const kXlaParallelCpuOption = "xla_cpu_parallel";
const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size";
const char* const kXlaDisableVectorizedReduce = "xla_disable_vectorized_reduce";
const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor";
@@ -30,12 +29,6 @@ namespace xla {
namespace cpu {
namespace options {
-bool CpuParallelBackendRequested(const HloModuleConfig& config) {
- const auto& extra_options_map =
- config.debug_options().xla_backend_extra_options();
- return extra_options_map.count(kXlaParallelCpuOption) > 0;
-}
-
bool OptimizeForSizeRequested(const HloModuleConfig& config) {
const auto& extra_options_map =
config.debug_options().xla_backend_extra_options();
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.h b/tensorflow/compiler/xla/service/cpu/cpu_options.h
index 6ba0fd2453..be62ff3cc1 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_options.h
+++ b/tensorflow/compiler/xla/service/cpu/cpu_options.h
@@ -24,7 +24,6 @@ namespace xla {
namespace cpu {
namespace options {
-bool CpuParallelBackendRequested(const HloModuleConfig& config);
bool OptimizeForSizeRequested(const HloModuleConfig& config);
bool VectorizedReduceDisabled(const HloModuleConfig& config);
tensorflow::gtl::optional<int64> LlvmIrGemvTilingFactor(
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
deleted file mode 100644
index 662ee60923..0000000000
--- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.cc
+++ /dev/null
@@ -1,192 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h"
-
-#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
-#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
-#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-
-namespace xla {
-namespace cpu {
-
-StatusOr<bool> ParallelizationPreparation::Run(HloModule* module) {
- XLA_VLOG_LINES(2, "ParallelizationPreparation ENTRY");
- XLA_VLOG_LINES(2, module->ToString());
-
- bool changed = false;
- TF_ASSIGN_OR_RETURN(changed, RunParallelTaskAssignment(module));
-
- HloComputation* entry_computation = module->entry_computation();
- std::unordered_set<HloInstruction*> outlined;
- std::vector<HloInstruction*> instructions_to_outline;
- for (HloInstruction* instruction :
- entry_computation->MakeInstructionPostOrder()) {
- // If the instruction has been outlined, it no longer exists and we must not
- // dereference it.
- if (outlined.count(instruction) > 0) {
- continue;
- }
-
- // Skip parameters and constants, there is nothing to parallelize.
- if (instruction->opcode() == HloOpcode::kParameter ||
- instruction->opcode() == HloOpcode::kConstant) {
- continue;
- }
-
- // Outline 'instruction' in isolation if it was assigned parallel tasks.
- if (OutlineParallelizableInstruction(instruction)) {
- outlined.insert(instruction);
- changed = true;
- continue;
- }
-
- instructions_to_outline.clear();
- HloInstruction* outline_candidate = instruction;
- instructions_to_outline.push_back(outline_candidate);
-
- // Outline sole users with the current instruction.
- while (CanOutlineWithUser(outline_candidate)) {
- HloInstruction* prior_candidate = outline_candidate;
- outline_candidate = *outline_candidate->users().begin();
- if (std::any_of(outline_candidate->operands().begin(),
- outline_candidate->operands().end(),
- [&](const HloInstruction* operand) {
- // Do not consider any candidates which have operands
- // other than the prior candidate, constants or
- // parameters. Otherwise, we'd increase the fan-in which
- // would reduce parallelism.
- return operand->opcode() != HloOpcode::kParameter &&
- operand->opcode() != HloOpcode::kConstant &&
- operand != prior_candidate;
- })) {
- break;
- }
- instructions_to_outline.push_back(outline_candidate);
- }
-
- outlined.insert(instructions_to_outline.begin(),
- instructions_to_outline.end());
-
- // Optimization to avoid replacing a single existing kCall with another
- // kCall that just calls the first one.
- if (instructions_to_outline.size() == 1 &&
- instructions_to_outline[0]->opcode() == HloOpcode::kCall) {
- continue;
- }
-
- module->OutlineExpressionFromComputation(
- instructions_to_outline,
- tensorflow::strings::StrCat("pp_", instruction->name()),
- entry_computation);
- changed = true;
- }
-
- XLA_VLOG_LINES(2, "ParallelizationPreparation EXIT");
- XLA_VLOG_LINES(2, module->ToString());
- return changed;
-}
-
-StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
- HloModule* module) {
- VLOG(1) << "RunParallelTaskAssignment max_parallelism_: " << max_parallelism_;
- bool changed = false;
- // Initialize ParallelTaskAssignment.
- ParallelTaskAssignment parallel_task_assignment(max_parallelism_, shape_size_,
- module);
- // Assign parallel tasks to HLOs in entry computation.
- HloComputation* computation = module->entry_computation();
- for (auto* instruction : computation->instructions()) {
- // Calculate target parallel task count in [1, max_parallelism_].
- const int64 target_parallel_task_count =
- parallel_task_assignment.GetTargetParallelTaskCount(instruction);
- if (target_parallel_task_count == 1) {
- continue;
- }
-
- // Assign feasible dimension partitions (based on actual dimension sizes).
- auto dim_partition_counts = ShapePartitionAssigner(instruction->shape())
- .Run(target_parallel_task_count);
- const int64 total_partition_count =
- ShapePartitionAssigner::GetTotalPartitionCount(dim_partition_counts);
- if (total_partition_count <= 1) {
- // Feasible partition calculation resulting in no partitioning, so skip.
- continue;
- }
- VLOG(2) << "Assigning parallel task count: " << total_partition_count
- << " to instruction: " << instruction->name();
- // Map 'instruction' to assigned dimension partitioning.
- instruction->set_outer_dimension_partitions(dim_partition_counts);
- }
-
- return changed;
-}
-
-bool ParallelizationPreparation::OutlineParallelizableInstruction(
- HloInstruction* instruction) {
- if (instruction->outer_dimension_partitions().empty()) {
- return false;
- }
- // Store dimension partition counts before outlining (which clones
- // 'instruction').
- std::vector<int64> dim_partition_counts =
- instruction->outer_dimension_partitions();
- // Outline 'instruction' in its own sub-computation.
- HloModule* module = instruction->parent()->parent();
- auto* call = module->OutlineExpressionFromComputation(
- {instruction}, tensorflow::strings::StrCat("pp_", instruction->name()),
- module->entry_computation());
- // Map previously assigned 'dim_partition_counts' to cloned root instruction.
- VLOG(1) << "Outlining parallelizable"
- << " caller: " << call->name()
- << " callee: " << call->to_apply()->root_instruction()->name();
- call->to_apply()->root_instruction()->set_outer_dimension_partitions(
- dim_partition_counts);
- return true;
-}
-
-bool ParallelizationPreparation::CanOutlineWithUser(
- HloInstruction* instruction) {
- if (instruction->users().size() != 1) {
- // Do not outline 'instruction' with multiple users.
- return false;
- }
- if (AssignedParallelTasks(instruction) ||
- AssignedParallelTasks(*instruction->users().begin())) {
- // Do not outline if 'instruction' (or user) were assigned parallel tasks.
- return false;
- }
- return true;
-}
-
-bool ParallelizationPreparation::AssignedParallelTasks(
- HloInstruction* instruction) {
- return !instruction->outer_dimension_partitions().empty() ||
- (instruction->opcode() == HloOpcode::kCall &&
- !instruction->to_apply()
- ->root_instruction()
- ->outer_dimension_partitions()
- .empty());
-}
-
-} // namespace cpu
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h b/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h
deleted file mode 100644
index 87be758ef5..0000000000
--- a/tensorflow/compiler/xla/service/cpu/cpu_parallelization_preparation.h
+++ /dev/null
@@ -1,80 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_
-
-#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
-#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/hlo_pass_interface.h"
-
-namespace xla {
-namespace cpu {
-
-// This pass prepares an HLO module for parallel execution by transforming
-// subgraphs of the top-level computation into embedded computations which can
-// be executed in parallel.
-// TODO(b/29630486): Currently, it is limited to turning all instructions (which
-// are not constants or parameters) in the entry computation into embedded
-// computations. However, it could make sense to coarsen the parallelization to
-// improve cache locality. Also, we will need to do something to intelligently
-// handle While constructs.
-class ParallelizationPreparation : public HloPassInterface {
- public:
- // 'max_parallelism': the maximum parallel task count per instruction.
- // 'shape_size': shape size function used by HloCostAnalysis during parallel
- // task assignment.
- ParallelizationPreparation(
- const int64 max_parallelism,
- const HloCostAnalysis::ShapeSizeFunction& shape_size)
- : max_parallelism_(max_parallelism), shape_size_(shape_size) {}
- ~ParallelizationPreparation() override {}
-
- tensorflow::StringPiece name() const override {
- return "cpu-parallel-prepare";
- }
-
- // Run parallel preparation on the given computation. Returns whether the
- // computation was changed.
- StatusOr<bool> Run(HloModule* module) override;
-
- private:
- // Assigns parallel task partitions to conformant instructions in 'module'.
- // Returns true on success or error status otherwise.
- StatusOr<bool> RunParallelTaskAssignment(HloModule* module);
-
- // Outlines 'instruction' from entry computation, if it had
- // been assigned parallel tasks in an earlier pass through the computation.
- // Returns true if 'instruction' was successfully outlined, false otherwise.
- bool OutlineParallelizableInstruction(HloInstruction* instruction);
-
- // Returns true if 'instruction' can be outlined into the same sub-computation
- // with its single user (parallelizable instructions are not outlined with
- // each other). Returns false otherwise.
- bool CanOutlineWithUser(HloInstruction* instruction);
-
- // Returns true if 'instruction' (or the root of the sub-computation that
- // 'instruction' calls) has had parallel tasks assigned in earlier pass.
- // Returns false otherwise.
- bool AssignedParallelTasks(HloInstruction* instruction);
-
- const int64 max_parallelism_;
- const HloCostAnalysis::ShapeSizeFunction shape_size_;
-};
-
-} // namespace cpu
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_CPU_PARALLELIZATION_PREPARATION_H_
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index f990ee2785..0b08ad8da3 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -93,8 +93,6 @@ IrEmitter::IrEmitter(
computation_to_profile_idx_(std::move(computation_to_profile_idx)),
alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
hlo_module_config_(hlo_module.config()),
- parallel_cpu_backend_(
- options::CpuParallelBackendRequested(hlo_module_config_)),
is_top_level_computation_(false),
target_machine_features_(target_machine),
external_constant_pool_(external_constant_pool) {
@@ -2163,8 +2161,7 @@ Status IrEmitter::HandleCall(HloInstruction* call) {
TF_RETURN_IF_ERROR(EmitTargetAddressForOp(call));
- if (!computation->root_instruction()->outer_dimension_partitions().empty() &&
- !parallel_cpu_backend_) {
+ if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
// ParallelTaskAssignment assigned partitions, emit call to
// ParallelForkJoin.
std::vector<llvm::Value*> call_args = GetArrayFunctionCallArguments(
@@ -2550,22 +2547,6 @@ Status IrEmitter::FinishVisit(HloInstruction* root) {
}
};
- // For the parallel cpu backend, we record the total for each embedded
- // computation callee with its caller kCall HLO.
- if (parallel_cpu_backend_ && is_top_level_computation_) {
- auto* computation = root->parent();
- auto* entry_computation = computation->parent()->entry_computation();
- if (computation != entry_computation) {
- for (HloInstruction* instruction : entry_computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kCall &&
- instruction->to_apply()->root_instruction() == root) {
- record_complete_computation(GetProfileCounterFor(*instruction));
- return Status::OK();
- }
- }
- }
- }
-
// For the entry computation this increment is cumulative of embedded
// computations since it includes cycles spent in computations invoked by
// While, Call etc.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 5094402514..0f2f3d1817 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -532,8 +532,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
const HloModuleConfig& hlo_module_config_;
- const bool parallel_cpu_backend_;
-
bool is_top_level_computation_;
TargetMachineFeatures target_machine_features_;
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
deleted file mode 100644
index a2bd4fa195..0000000000
--- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.cc
+++ /dev/null
@@ -1,528 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#include "tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h"
-
-#include <stdint.h>
-#include <algorithm>
-#include <deque>
-#include <iterator>
-#include <list>
-#include <unordered_set>
-#include <utility>
-#include <vector>
-
-#include "llvm/ExecutionEngine/Orc/IRCompileLayer.h"
-#include "tensorflow/compiler/xla/map_util.h"
-#include "tensorflow/compiler/xla/service/buffer_assignment.h"
-#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
-#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
-#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/hlo_opcode.h"
-#include "tensorflow/compiler/xla/service/logical_buffer.h"
-#include "tensorflow/compiler/xla/service/shaped_buffer.h"
-#include "tensorflow/compiler/xla/shape_util.h"
-#include "tensorflow/compiler/xla/status_macros.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/compiler/xla/util.h"
-#include "tensorflow/compiler/xla/xla_data.pb.h"
-#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/lib/strings/stringprintf.h"
-#include "tensorflow/core/platform/env.h"
-#include "tensorflow/core/platform/logging.h"
-#include "tensorflow/core/platform/mem.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/types.h"
-
-namespace xla {
-namespace cpu {
-
-ParallelCpuExecutable::ParallelCpuExecutable(
- std::unique_ptr<SimpleOrcJIT> jit,
- std::unique_ptr<const BufferAssignment> assignment,
- std::unique_ptr<const HloModule> hlo_module,
- std::unique_ptr<const HloInstructionMap<string>> function_names,
- std::unordered_map<const HloInstruction*, std::unique_ptr<unsigned char[]>>
- aligned_constants,
- std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
- std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map)
- : Executable(std::move(hlo_module), std::move(hlo_profile_printer_data),
- std::move(hlo_profile_index_map)),
- jit_(std::move(jit)),
- assignment_(std::move(assignment)),
- function_names_(std::move(function_names)),
- aligned_constants_(std::move(aligned_constants)) {}
-
-// Type of the computation function we expect in the JIT.
-using ComputeFunctionType = void (*)(void*, const void*, const void**, void**,
- int64*, int64*);
-
-// Given a pointer to an output buffer (following the CPU JIT calling
-// conventions), mark addresses that are "live". The initial pointer itself is
-// trivially live. If the shape of the buffer is a tuple, this analysis looks
-// into the tuple's elements and marks them live as well (since tuples keep
-// pointers to buffers) and also works recursively.
-// address is an in-memory buffer address that contains some runtime XLA object.
-// shape is its shape. marked_addresses is the set of live addresses to
-// populate.
-static void MarkLiveAddressesInOutput(
- const void* address, const Shape& shape,
- std::unordered_set<const void*>* marked_addresses) {
- marked_addresses->insert(address);
- const uintptr_t* address_buffer = static_cast<const uintptr_t*>(address);
- if (ShapeUtil::IsTuple(shape)) {
- for (int i = 0; i < ShapeUtil::TupleElementCount(shape); ++i) {
- const uintptr_t* element_address = address_buffer + i;
- const void* element = reinterpret_cast<const void*>(*element_address);
- MarkLiveAddressesInOutput(
- element, ShapeUtil::GetTupleElementShape(shape, i), marked_addresses);
- }
- }
-}
-
-namespace {
-
-// Executor manages the concurrent execution of 'functions' for instructions
-// in 'pending' on 'thread_pool' (storing resulting data in 'results').
-class Executor {
- public:
- Executor(const HloInstructionMap<ComputeFunctionType>& functions,
- const ServiceExecutableRunOptions* run_options,
- std::list<HloInstruction*>* pending,
- HloInstructionMap<const void*>* results, void** temps_array,
- int64* profile_counters_array, const BufferAssignment* assignment)
- : functions_(functions),
- run_options_(run_options),
- pending_(pending),
- results_(results),
- temps_array_(temps_array),
- profile_counters_array_(profile_counters_array),
- thread_pool_(CHECK_NOTNULL(run_options_->xla_intra_op_thread_pool())),
- assignment_(assignment) {}
-
- // Executes pending list of instructions on thread pool.
- // Returns OK status on success, error status otherwise.
- Status Run();
-
- private:
- // Schedules a parallel invocation of compute function for 'instruction' on
- // 'thread_pool_', storing result in 'result_buffer'.
- // If 'partition_buffers' is non-null, parallel task will be invoked on
- // per-dimension partition [start, limit) values stored in
- // 'partition_buffers'.
- void Schedule(HloInstruction* instruction, int64* partition_buffers,
- void* result_buffer);
-
- // Returns true if 'instruction' has been assigned parallel tasks (returns
- // false otherwise).
- bool HasParallelTasks(HloInstruction* instruction);
-
- // Returns in 'partition_buffers' the partition [size, limit) for each
- // dimension.
- int64* GetPartitionBuffers(
- const std::vector<std::pair<int64, int64>>& partition);
-
- // Returns array of result buffers for all operands in 'instruction'.
- const void** GetOperandBuffers(HloInstruction* instruction);
-
- // Arguments passed into Executor.
- const HloInstructionMap<ComputeFunctionType>& functions_;
- const ServiceExecutableRunOptions* run_options_;
- std::list<HloInstruction*>* pending_;
- HloInstructionMap<const void*>* results_;
- void** temps_array_;
- int64* profile_counters_array_;
- tensorflow::thread::ThreadPool* thread_pool_;
- const BufferAssignment* assignment_;
-
- // Members used to manage instruction execution.
- tensorflow::mutex completion_queue_lock_;
- tensorflow::condition_variable completion_queue_cv_;
- std::deque<HloInstruction*> completion_queue_;
- int64 instructions_in_flight_ = 0;
- std::unordered_map<const HloInstruction*, int64> tasks_in_flight_;
-};
-
-Status Executor::Run() {
- while (!pending_->empty() || instructions_in_flight_ > 0) {
- auto pending_it = pending_->begin();
- while (pending_it != pending_->end()) {
- HloInstruction* instruction = *pending_it;
- // Skip pending instructions whose operands aren't ready.
- if (std::any_of(instruction->operands().begin(),
- instruction->operands().end(),
- [&](HloInstruction* operand) {
- return !ContainsKey(*results_, operand);
- })) {
- ++pending_it;
- continue;
- }
-
- // Get 'result_buffer' reference to result buffer for 'instruction'.
- TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
- assignment_->GetUniqueTopLevelSlice(instruction));
- void* result_buffer =
- static_cast<char*>(temps_array_[result_slice.index()]) +
- result_slice.offset();
-
- if (HasParallelTasks(instruction)) {
- // 'instruction' has been assigned parallel task partitions.
- CHECK_EQ(HloOpcode::kCall, instruction->opcode());
- HloInstruction* root = instruction->to_apply()->root_instruction();
-
- // Create ShapePartitionIterator to iterate through all outer dimension
- // partitions of 'instruction'.
- ShapePartitionIterator partition_iterator(
- root->shape(), root->outer_dimension_partitions());
-
- const int64 partition_count =
- partition_iterator.GetTotalPartitionCount();
-
- // Record total parallel task count for 'instruction' before dispatch.
- {
- tensorflow::mutex_lock l(completion_queue_lock_);
- tasks_in_flight_.insert(std::make_pair(instruction, partition_count));
- VLOG(2) << "Schedule PARALLEL"
- << " instruction: " << instruction->name()
- << " instruction.callee: "
- << instruction->to_apply()->root_instruction()->name()
- << " partition_count: " << partition_count;
- }
-
- for (int64 i = 0; i < partition_count; ++i) {
- // Get partition [start, limit) for each dimension.
- auto partition_buffers =
- GetPartitionBuffers(partition_iterator.GetPartition(i));
- Schedule(instruction, partition_buffers, result_buffer);
- }
-
- } else {
- // Set tasks in-flight to '1' for sequential instruction execution.
- {
- tensorflow::mutex_lock l(completion_queue_lock_);
- tasks_in_flight_.insert(std::make_pair(instruction, 1));
- VLOG(2) << "Schedule SEQUENTIAL"
- << " instruction: " << instruction->name()
- << " instruction.callee: "
- << instruction->to_apply()->root_instruction()->name();
- }
- Schedule(instruction, nullptr, result_buffer);
- }
-
- ++instructions_in_flight_;
- pending_it = pending_->erase(pending_it);
- }
- // Wait for a completed HLO instruction to be present in the queue. We will
- // pop it out of the queue and make the result available to its users.
- HloInstruction* instruction;
- do {
- tensorflow::mutex_lock l(completion_queue_lock_);
- if (completion_queue_.empty()) {
- completion_queue_cv_.wait(l);
- }
- if (!completion_queue_.empty()) {
- instruction = completion_queue_.front();
- completion_queue_.pop_front();
- break;
- }
- } while (true);
- TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
- assignment_->GetUniqueTopLevelSlice(instruction));
- void* result_buffer =
- static_cast<char*>(temps_array_[result_slice.index()]) +
- result_slice.offset();
- InsertOrDie(results_, instruction, result_buffer);
- --instructions_in_flight_;
- }
- return Status::OK();
-}
-
-void Executor::Schedule(HloInstruction* instruction, int64* partition_buffers,
- void* result_buffer) {
- // The thread pool entry takes ownership of |operand_buffers|.
- auto operand_buffers = GetOperandBuffers(instruction);
-
- auto function = FindOrDie(functions_, instruction);
- const auto* exec_run_options = &run_options_->run_options();
- thread_pool_->Schedule([this, instruction, result_buffer, operand_buffers,
- partition_buffers, exec_run_options, function]() {
- function(result_buffer, exec_run_options, operand_buffers, temps_array_,
- partition_buffers, profile_counters_array_);
-
- delete[] operand_buffers;
- delete[] partition_buffers;
- // Push the completed HLO instruction on the queue, the main
- // thread will pop it off and potentially launch more work which
- // uses the result.
- // TODO(b/27458679) Consider alternative task scheduling and synchronization
- // schemes. For example, we could avoid the overhead associate with the
- // condvar here if the thread just dequed the next instruction to execute
- // on completion.
- {
- tensorflow::mutex_lock l(completion_queue_lock_);
- // Decrement in-flight task count for this completion.
- if (--FindOrDie(tasks_in_flight_, instruction) == 0) {
- completion_queue_.push_back(instruction);
- completion_queue_cv_.notify_all();
- tasks_in_flight_.erase(instruction);
- }
- }
- });
-}
-
-int64* Executor::GetPartitionBuffers(
- const std::vector<std::pair<int64, int64>>& partition) {
- // Return in 'partition_buffers' partition [size, limit) for each dimension.
- auto partition_buffers = new int64[partition.size() * 2];
- for (int i = 0; i < partition.size(); ++i) {
- partition_buffers[2 * i + 0] = partition[i].first;
- partition_buffers[2 * i + 1] = partition[i].first + partition[i].second;
- }
- return partition_buffers;
-}
-
-bool Executor::HasParallelTasks(HloInstruction* instruction) {
- return instruction->opcode() == HloOpcode::kCall &&
- !instruction->to_apply()
- ->root_instruction()
- ->outer_dimension_partitions()
- .empty();
-}
-
-const void** Executor::GetOperandBuffers(HloInstruction* instruction) {
- // We cannot use a move-only RAII type like std::unique_ptr because the
- // list of operands is allocated on the main thread and transferred to the
- // worker via the lambda passed to enqueue_function. In order for the
- // lambda to take ownership, we would need to use generalized lambda
- // capture which is a feature new to C++14.
- // TODO(b/27458679) Avoid dynamic allocations in Executor.
- auto operand_buffers = new const void*[instruction->operand_count()];
- std::transform(instruction->operands().begin(), instruction->operands().end(),
- operand_buffers, [this](HloInstruction* operand) {
- return FindOrDie(*results_, operand);
- });
- return operand_buffers;
-}
-
-} // namespace
-
-Status ParallelCpuExecutable::AllocateBuffers(
- DeviceMemoryAllocator* memory_allocator, int device_ordinal,
- std::vector<se::DeviceMemoryBase>* buffers) {
- CHECK_EQ(buffers->size(), assignment_->Allocations().size());
- VLOG(3) << "Allocating " << assignment_->Allocations().size()
- << " allocations for module " << module().name();
- for (BufferAllocation::Index i = 0; i < assignment_->Allocations().size();
- ++i) {
- auto& allocation = assignment_->GetAllocation(i);
-
- VLOG(3) << allocation.ToString();
-
- if (allocation.is_entry_computation_parameter()) {
- VLOG(3) << "allocation #" << i << " is a parameter";
- continue;
- }
-
- if (allocation.is_thread_local()) {
- VLOG(3) << "buffer #" << i << " is thread-local";
- continue;
- }
-
- int64 buffer_size = allocation.size();
- if (!(*buffers)[i].is_null()) {
- VLOG(3) << "buffer #" << i
- << " is in the preallocated result ShapedBuffer";
- } else {
- TF_ASSIGN_OR_RETURN((*buffers)[i], memory_allocator->Allocate(
- device_ordinal, buffer_size));
-
- VLOG(3) << "buffer #" << i << " allocated " << buffer_size << " bytes ["
- << (*buffers)[i].opaque() << "]";
- }
-
- // Since the output buffer and all the temporary buffers were written into
- // by the JITed code, msan has no way of knowing their memory was
- // initialized. Mark them initialized so that msan doesn't flag loads from
- // these buffers.
- TF_ANNOTATE_MEMORY_IS_INITIALIZED((*buffers)[i].opaque(), buffer_size);
- }
-
- TF_ASSIGN_OR_RETURN(const BufferAllocation::Slice result_slice,
- assignment_->GetUniqueTopLevelOutputSlice());
- VLOG(3) << "result index: " << result_slice.index();
-
- return Status::OK();
-}
-
-Status ParallelCpuExecutable::ExecuteComputeFunctions(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
- HloExecutionProfile* hlo_execution_profile) {
- // Allocate profiling counters for each hlo instruction that we would like to
- // profile.
- std::vector<int64>* profile_counters = nullptr;
- if (hlo_execution_profile) {
- profile_counters = hlo_execution_profile->mutable_profile_counters();
- }
-
- std::vector<void*> buffer_pointers;
- buffer_pointers.reserve(buffers.size());
- for (auto device_allocation : buffers) {
- buffer_pointers.push_back(device_allocation.opaque());
- }
-
- // Resolve functions for all the HLO instructions ahead of time.
- HloInstructionMap<ComputeFunctionType> functions;
- for (auto& entry : *function_names_) {
- tensorflow::mutex_lock lock(jit_mutex_);
- HloInstruction* instruction = entry.first;
- llvm::JITSymbol sym = jit_->FindCompiledSymbol(entry.second);
- TF_RET_CHECK(sym);
- InsertOrDie(
- &functions, instruction,
- reinterpret_cast<ComputeFunctionType>(cantFail(sym.getAddress())));
- }
-
- // Map containing pointers to result buffers for each instruction.
- HloInstructionMap<const void*> results;
-
- uint64 start_micros = tensorflow::Env::Default()->NowMicros();
-
- std::list<HloInstruction*> pending;
-
- // Call the function for each HLO instruction in topological order.
- const HloComputation& entry_computation = *module().entry_computation();
- for (auto* instruction : entry_computation.MakeInstructionPostOrder()) {
- // Parameters and constants have no functions associated with them. Instead
- // just copy the existing buffer into the map containing instruction
- // results..
- if (instruction->opcode() == HloOpcode::kParameter) {
- InsertOrDie(
- &results, instruction,
- arguments[instruction->parameter_number()]->root_buffer().opaque());
- } else if (instruction->opcode() == HloOpcode::kConstant) {
- unsigned char* aligned_data =
- FindOrDie(aligned_constants_, instruction).get();
- InsertOrDie(&results, instruction, aligned_data);
- } else {
- TF_RET_CHECK(instruction->opcode() == HloOpcode::kCall);
- pending.push_back(instruction);
- }
- }
-
- // TODO(b/27458679) Manage scheduling based on in-flight concurrency limits.
- // For example, if we expect a library conv/matmul call to run at max
- // concurrency, we should not dispatch runnable instructions until the
- // library call is finished (to avoid expensive cache invalidation).
- Executor executor(
- functions, run_options, &pending, &results, buffer_pointers.data(),
- profile_counters ? profile_counters->data() : nullptr, assignment_.get());
-
- TF_RETURN_IF_ERROR(executor.Run());
-
- uint64 end_micros = tensorflow::Env::Default()->NowMicros();
-
- {
- tensorflow::mutex_lock lock(mutex_);
- double nanoseconds = (end_micros - start_micros) * 1000.0;
- execution_profile_.set_compute_time_ns(std::max(nanoseconds, 1.0));
- }
-
- return Status::OK();
-}
-
-StatusOr<ShapedBuffer> ParallelCpuExecutable::ExecuteOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- HloExecutionProfile* hlo_execution_profile) {
- if (GetRootPointsToSet().IsAmbiguous()) {
- return Unimplemented("Points-to set of root instruction is ambiguous");
- }
-
- se::Stream* stream = run_options->stream();
- DeviceMemoryAllocator* memory_allocator = run_options->allocator();
- std::vector<se::DeviceMemoryBase> buffers(assignment_->Allocations().size());
-
- ShapedBuffer result_buffer(
- /*on_host_shape=*/result_shape(), /*on_device_shape=*/result_shape(),
- stream->parent()->platform(), stream->parent()->device_ordinal());
-
- TF_RETURN_IF_ERROR(AllocateBuffers(
- memory_allocator, stream->parent()->device_ordinal(), &buffers));
-
- TF_RETURN_IF_ERROR(ExecuteComputeFunctions(run_options, arguments, buffers,
- hlo_execution_profile));
-
- // Copy DeviceMemoryBase values which into the respective location in
- // ShapedBuffer which is returned to the caller.
- std::vector<bool> buffers_in_result(assignment_->Allocations().size(), false);
- TF_RETURN_IF_ERROR(result_buffer.buffers().ForEachMutableElementWithStatus(
- [&](const ShapeIndex& index, se::DeviceMemoryBase* device_memory) {
- const auto& sources = this->GetRootPointsToSet().element(index);
-
- // The points to set is unambiguous so the set should be a singleton.
- CHECK_EQ(1, sources.size());
- const LogicalBuffer* buffer_source = sources[0];
- HloInstruction* src = buffer_source->instruction();
-
- // The source for this result buffer can be a nested buffer such as a
- // tuple element. The source instruction should have a non-parameter
- // buffer assigned.
- TF_ASSIGN_OR_RETURN(
- const BufferAllocation::Slice slice,
- this->assignment_->GetUniqueSlice(src, buffer_source->index()));
- CHECK(!slice.allocation()->is_entry_computation_parameter());
-
- const BufferAllocation::Index buffer_index = slice.index();
- const se::DeviceMemoryBase& buffer = buffers[buffer_index];
- CHECK(!buffer.is_null() || buffer.size() == 0);
- *device_memory = buffer;
- buffers_in_result[buffer_index] = true;
- return Status::OK();
- }));
-
- // Free all buffers not in the result.
- for (size_t i = 0; i < buffers.size(); ++i) {
- se::DeviceMemoryBase alloc = buffers[i];
- if (!buffers_in_result[i] && !alloc.is_null()) {
- VLOG(3) << "CpuExecutable deallocating buffer #" << i << " ["
- << alloc.opaque() << "]";
- TF_RETURN_IF_ERROR(memory_allocator->Deallocate(
- stream->parent()->device_ordinal(), &alloc));
- }
- }
-
- return std::move(result_buffer);
-}
-
-StatusOr<ShapedBuffer> ParallelCpuExecutable::ExecuteAsyncOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
- // TODO(b/30671675): Implement asynchronous execution mode.
- return Unimplemented(
- "Asynchronous execution on stream is not yet supported on CPU.");
-}
-
-const PointsToSet& ParallelCpuExecutable::GetRootPointsToSet() const {
- return assignment_->points_to_analysis().GetPointsToSet(
- module().entry_computation()->root_instruction());
-}
-
-} // namespace cpu
-} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h b/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
deleted file mode 100644
index 5ce84fa996..0000000000
--- a/tensorflow/compiler/xla/service/cpu/parallel_cpu_executable.h
+++ /dev/null
@@ -1,137 +0,0 @@
-/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_
-#define TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_
-
-#include <stddef.h>
-#include <map>
-#include <memory>
-#include <string>
-#include <unordered_map>
-
-#include "tensorflow/compiler/xla/service/buffer_assignment.h"
-#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
-#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
-#include "tensorflow/compiler/xla/service/executable.h"
-#include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
-#include "tensorflow/compiler/xla/service/hlo_instruction.h"
-#include "tensorflow/compiler/xla/service/hlo_module.h"
-#include "tensorflow/compiler/xla/service/shaped_buffer.h"
-#include "tensorflow/compiler/xla/statusor.h"
-#include "tensorflow/compiler/xla/types.h"
-#include "tensorflow/core/lib/gtl/array_slice.h"
-#include "tensorflow/core/platform/macros.h"
-#include "tensorflow/core/platform/mutex.h"
-#include "tensorflow/core/platform/stream_executor_no_cuda.h"
-#include "tensorflow/core/platform/thread_annotations.h"
-
-namespace xla {
-namespace cpu {
-
-// CPU-targeting parallel implementation of the XLA Executable interface.
-//
-// Wraps a JIT-ed object that can be executed "on device". We JIT for the host
-// architecture, so JIT-ed code and host code share the same ABI.
-class ParallelCpuExecutable : public Executable {
- public:
- ParallelCpuExecutable(
- std::unique_ptr<SimpleOrcJIT> jit,
- std::unique_ptr<const BufferAssignment> assignment,
- std::unique_ptr<const HloModule> hlo_module,
- std::unique_ptr<const HloInstructionMap<string>> function_names,
- std::unordered_map<const HloInstruction*,
- std::unique_ptr<unsigned char[]>>
- aligned_constants,
- std::unique_ptr<HloProfilePrinterData> hlo_profile_printer_data,
- std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map);
- ~ParallelCpuExecutable() override {}
-
- StatusOr<ShapedBuffer> ExecuteOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- HloExecutionProfile* hlo_execution_profile) override;
-
- StatusOr<ShapedBuffer> ExecuteAsyncOnStream(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
-
- // This should be called after set_ir_module_string.
- const string& ir_module_string() const { return ir_module_string_; }
-
- void set_ir_module_string(const string& ir_module_string) {
- ir_module_string_ = ir_module_string;
- }
-
- static int64 ShapeSizeBytes(const Shape& shape) {
- // On the cpu, opaques are pointers.
- if (ShapeUtil::IsOpaque(shape)) {
- return sizeof(void*);
- }
- return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
- }
-
- private:
- // Allocate buffers required for execution and assign them to the elements of
- // "buffers". "buffers" should be sized to the number of buffers in buffer
- // assignment. Each vector element corresponds to a particular Index. If
- // a vector element already contains a non-null DeviceMemoryBase, then no
- // buffer is assigned for this element.
- Status AllocateBuffers(DeviceMemoryAllocator* memory_allocator,
- int device_ordinal,
- std::vector<se::DeviceMemoryBase>* buffers);
-
- // Calls the generated functions in 'function_names_', performing the
- // computation with the given arguments using the supplied buffers.
- Status ExecuteComputeFunctions(
- const ServiceExecutableRunOptions* run_options,
- tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
- tensorflow::gtl::ArraySlice<se::DeviceMemoryBase> buffers,
- HloExecutionProfile* hlo_execution_profile);
-
- // Returns the points-to set of the root instruction of the entry
- // computation. Uses points-to analysis from buffer assignment.
- const PointsToSet& GetRootPointsToSet() const;
-
- // The JIT containing compiled modules.
- tensorflow::mutex jit_mutex_;
- const std::unique_ptr<SimpleOrcJIT> jit_ GUARDED_BY(jit_mutex_);
-
- // Buffer assignment for the buffers we need to allocate.
- const std::unique_ptr<const BufferAssignment> assignment_;
-
- // The LLVM IR, in string format, of the unoptimized module generated for this
- // ParallelCpuExecutable. We save a string instead of an llvm::Module* because
- // leaving llvm::Module* in a singleton can cause the heap checker to emit
- // false positives.
- string ir_module_string_;
-
- // Map containing the JITted function names for each HLO instruction.
- const std::unique_ptr<const HloInstructionMap<string>> function_names_;
-
- // Map from HLO Constant instructions to a pointer to their literal data.
- // The data stored in the protocol buffer might be insufficiently aligned,
- // we create a sufficiently aligned copy and store it in this map.
- const std::unordered_map<const HloInstruction*,
- std::unique_ptr<unsigned char[]>>
- aligned_constants_;
-
- TF_DISALLOW_COPY_AND_ASSIGN(ParallelCpuExecutable);
-};
-
-} // namespace cpu
-} // namespace xla
-
-#endif // TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_CPU_EXECUTABLE_H_
diff --git a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
index 56e35e2604..38b5efa9fb 100644
--- a/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/elemental_ir_emitter.cc
@@ -52,6 +52,13 @@ using tensorflow::strings::StrCat;
namespace {
+int64 GlobalRandomValue() {
+ static auto* mu = new tensorflow::mutex();
+ static std::mt19937_64 rng{42};
+ tensorflow::mutex_lock l(*mu);
+ return rng();
+}
+
llvm::Value* EmitReducePrecisionFloat(llvm::Value* x, int64 exponent_bits,
int64 mantissa_bits,
llvm::IRBuilder<>* ir_builder) {
@@ -1175,7 +1182,7 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
llvm::Value* increment = ir_builder_->getInt(
llvm::APInt(128, {0x14057B7EF767814F, 0x5851F42D4C957F2D}));
- auto random_value = [hlo]() {
+ auto random_value_from_hlo = [hlo]() {
const HloModule* module =
hlo->IsFused() ? hlo->parent()->FusionInstruction()->parent()->parent()
: hlo->parent()->parent();
@@ -1197,10 +1204,15 @@ llvm_ir::ElementGenerator ElementalIrEmitter::MakeRngElementGenerator(
/*Ty=*/ir_builder_->getInt64Ty(),
/*isConstant=*/false,
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
- /*Initializer=*/ir_builder_->getInt64(random_value()),
+ /*Initializer=*/ir_builder_->getInt64(random_value_from_hlo()),
/*Name=*/"state_ptr0");
+
+ // When the module config seed is 0, the expected result of a prng is a random
+ // value. Instead of using the random_value_from_hlo, we need a global random
+ // value as the graph seed. This is because if we use random_value_from_hlo
+ // here, then for a newly built hlo graph, it always gives the same number.
uint64 graph_seed = hlo_module_config_.seed() != 0 ? hlo_module_config_.seed()
- : random_value();
+ : GlobalRandomValue();
llvm::GlobalVariable* state_ptr1 = new llvm::GlobalVariable(
/*M=*/*module_,
/*Ty=*/ir_builder_->getInt64Ty(),
diff --git a/tensorflow/compiler/xla/service/executable.cc b/tensorflow/compiler/xla/service/executable.cc
index 8218b5f7c8..021f09d310 100644
--- a/tensorflow/compiler/xla/service/executable.cc
+++ b/tensorflow/compiler/xla/service/executable.cc
@@ -29,12 +29,12 @@ using tensorflow::gtl::ArraySlice;
namespace xla {
-StatusOr<std::vector<ShapedBuffer>> Executable::ExecuteOnStreams(
+StatusOr<std::vector<ScopedShapedBuffer>> Executable::ExecuteOnStreams(
ArraySlice<const ServiceExecutableRunOptions> run_options,
ArraySlice<ArraySlice<const ShapedBuffer*>> arguments) {
TF_RET_CHECK(run_options.size() == arguments.size());
- std::vector<ShapedBuffer> return_values;
+ std::vector<ScopedShapedBuffer> return_values;
return_values.reserve(run_options.size());
if (run_options.size() == 1) {
@@ -60,7 +60,7 @@ StatusOr<std::vector<ShapedBuffer>> Executable::ExecuteOnStreams(
return std::move(return_values);
}
-StatusOr<ShapedBuffer> Executable::ExecuteOnStreamWrapper(
+StatusOr<ScopedShapedBuffer> Executable::ExecuteOnStreamWrapper(
const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile,
ArraySlice<const ShapedBuffer*> arguments) {
se::Stream* stream = run_options->stream();
@@ -80,7 +80,7 @@ StatusOr<ShapedBuffer> Executable::ExecuteOnStreamWrapper(
&hlo_profile_index_map())
: nullptr;
- StatusOr<ShapedBuffer> return_value =
+ StatusOr<ScopedShapedBuffer> return_value =
ExecuteOnStream(run_options, arguments, profile_ptr.get());
TF_RETURN_IF_ERROR(return_value.status());
@@ -163,9 +163,9 @@ Status Executable::DumpSessionModule() {
result);
}
-/* static */ Status Executable::DumpToDirectory(const string& directory_path,
- string filename,
- const HloSession& hlo_session) {
+/* static */ Status Executable::DumpToDirectory(
+ const string& directory_path, string filename,
+ const HloSnapshot& hlo_session) {
tensorflow::Env* env = tensorflow::Env::Default();
if (!env->IsDirectory(directory_path).ok()) {
// NB! CreateDir does not work reliably with multiple XLA threads -- two
diff --git a/tensorflow/compiler/xla/service/executable.h b/tensorflow/compiler/xla/service/executable.h
index bdbe119120..f7af1ca574 100644
--- a/tensorflow/compiler/xla/service/executable.h
+++ b/tensorflow/compiler/xla/service/executable.h
@@ -63,14 +63,14 @@ class Executable {
// enabled.
//
// Returns a shaped buffer containing the result of the computation.
- virtual StatusOr<ShapedBuffer> ExecuteOnStream(
+ virtual StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) = 0;
// Same as ExecuteOnStream(), but this call is non-blocking and returns as
// soon as all of the operations are enqueued for launch on the stream.
- virtual StatusOr<ShapedBuffer> ExecuteAsyncOnStream(
+ virtual StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) = 0;
@@ -78,7 +78,7 @@ class Executable {
// streams. arguments[i] contains the arguments to the execution on
// run_options[i]->stream() and the returned value is at index i of the
// returned vector.
- virtual StatusOr<std::vector<ShapedBuffer>> ExecuteOnStreams(
+ virtual StatusOr<std::vector<ScopedShapedBuffer>> ExecuteOnStreams(
tensorflow::gtl::ArraySlice<const ServiceExecutableRunOptions>
run_options,
tensorflow::gtl::ArraySlice<
@@ -98,7 +98,7 @@ class Executable {
// Convenience wrapper for calling Executable::ExecuteOnStream. Sets up a
// timer for the execution, sets up HLO profiling if enabled, and fills in the
// given ExecutionProfile if non-null.
- StatusOr<ShapedBuffer> ExecuteOnStreamWrapper(
+ StatusOr<ScopedShapedBuffer> ExecuteOnStreamWrapper(
const ServiceExecutableRunOptions* run_options, ExecutionProfile* profile,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
@@ -156,9 +156,9 @@ class Executable {
static Status DumpToDirectory(const string& directory_path, string filename,
const SessionModule& session_module);
- // Dump hlo_session to directory_path/filename.
+ // Dump hlo snapshot to directory_path/filename.
static Status DumpToDirectory(const string& directory_path, string filename,
- const HloSession& hlo_session);
+ const HloSnapshot& hlo_session);
protected:
mutable tensorflow::mutex mutex_;
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
index 62ce15bc59..980cc89fa0 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.cc
@@ -250,7 +250,7 @@ Status GpuExecutable::ExecuteThunks(
return Status::OK();
}
-StatusOr<ShapedBuffer> GpuExecutable::ExecuteOnStream(
+StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) {
@@ -297,8 +297,8 @@ StatusOr<ShapedBuffer> GpuExecutable::ExecuteOnStream(
HloInstruction* root = hlo_module_->entry_computation()->root_instruction();
auto device_ordinal = executor->device_ordinal();
- auto shaped_buffer = ShapedBuffer(root->shape(), root->shape(),
- executor->platform(), device_ordinal);
+ ScopedShapedBuffer shaped_buffer(root->shape(), root->shape(),
+ memory_allocator, device_ordinal);
// Copy DeviceMemoryBase values which contain the array(s) of the result into
// the respective location in ShapedBuffer.
@@ -335,7 +335,7 @@ StatusOr<ShapedBuffer> GpuExecutable::ExecuteOnStream(
return std::move(shaped_buffer);
}
-StatusOr<ShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
+StatusOr<ScopedShapedBuffer> GpuExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
// TODO(b/30671675): Implement asynchronous execution mode.
diff --git a/tensorflow/compiler/xla/service/gpu/gpu_executable.h b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
index 361bc30b2f..80ec38c3ac 100644
--- a/tensorflow/compiler/xla/service/gpu/gpu_executable.h
+++ b/tensorflow/compiler/xla/service/gpu/gpu_executable.h
@@ -74,12 +74,12 @@ class GpuExecutable : public Executable {
// ExecuteOnStream will fail if the compute capability of the stream doesn't
// match the compute capability passed to this object's constructor.
- StatusOr<ShapedBuffer> ExecuteOnStream(
+ StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) override;
- StatusOr<ShapedBuffer> ExecuteAsyncOnStream(
+ StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index 0c3eb7dcb4..aa6860880b 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -300,7 +300,7 @@ message HloProto {
// Encapsulates HloProto together with the arguments, result, and
// execution_platform. This message is used for purposes such as
// analysis/replay/file-storage.
-message HloSession {
+message HloSnapshot {
// The hlo graph.
HloProto hlo = 1;
diff --git a/tensorflow/compiler/xla/service/hlo_cse.cc b/tensorflow/compiler/xla/service/hlo_cse.cc
index cd7cbbdd71..3b22c93733 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.cc
+++ b/tensorflow/compiler/xla/service/hlo_cse.cc
@@ -97,6 +97,10 @@ StatusOr<bool> HloCSE::Run(HloModule* module) {
const std::function<bool(const HloComputation*, const HloComputation*)>
eq_computations = std::equal_to<const HloComputation*>();
for (auto* computation : module->computations()) {
+ if (only_fusion_computations_ && !computation->IsFusionComputation()) {
+ continue;
+ }
+
changed |= CombineConstants(computation, is_layout_sensitive_);
std::list<HloInstruction*> post_order =
diff --git a/tensorflow/compiler/xla/service/hlo_cse.h b/tensorflow/compiler/xla/service/hlo_cse.h
index 70096e07a2..5e2b348bdd 100644
--- a/tensorflow/compiler/xla/service/hlo_cse.h
+++ b/tensorflow/compiler/xla/service/hlo_cse.h
@@ -29,9 +29,11 @@ class HloCSE : public HloPassInterface {
public:
// If is_layout_sensitive is true, then the simplifier preserves layout during
// transformation. Otherwise, layout is ignored.
- explicit HloCSE(bool is_layout_sensitive)
- : is_layout_sensitive_(is_layout_sensitive) {}
- ~HloCSE() override {}
+ explicit HloCSE(bool is_layout_sensitive,
+ bool only_fusion_computations = false)
+ : is_layout_sensitive_(is_layout_sensitive),
+ only_fusion_computations_(only_fusion_computations) {}
+ ~HloCSE() override = default;
tensorflow::StringPiece name() const override { return "cse"; }
// Run CSE on the given module. Returns whether the module was changed (common
@@ -39,7 +41,8 @@ class HloCSE : public HloPassInterface {
StatusOr<bool> Run(HloModule* module) override;
private:
- bool is_layout_sensitive_;
+ const bool is_layout_sensitive_;
+ const bool only_fusion_computations_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index f3da3fc256..a5e9aecb9e 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -956,14 +956,6 @@ class HloInstruction {
void clear_sharding() { sharding_ = nullptr; }
// Return true if this operator has a sharding assigned.
bool has_sharding() const { return sharding_ != nullptr; }
- // Checks whether the instruction has compatible sharding with the other
- // instruction.
- bool has_compatible_sharding(const HloInstruction* other) const {
- if (!has_sharding()) {
- return !other->has_sharding();
- }
- return other->has_sharding() ? sharding() == other->sharding() : false;
- }
// When creating a new instruction which either replaces, or shifts up (kCopy
// insertion case), another instruction, we need to make sure the certain
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.cc b/tensorflow/compiler/xla/service/hlo_matchers.cc
index bc74c4bc10..69deac263e 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers.cc
@@ -132,6 +132,69 @@ bool HloCustomCallMatcher::MatchAndExplain(
return result;
}
+bool HloShapeMatcher::MatchAndExplain(
+ const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const {
+ if (ShapeUtil::Compatible(instruction->shape(), shape_)) {
+ return true;
+ }
+ *listener << instruction->ToString() << " has incorrect shape (expected: "
+ << ShapeUtil::HumanString(shape_) << ")";
+ return false;
+}
+
+void HloShapeMatcher::DescribeTo(std::ostream* os) const {
+ *os << ShapeUtil::HumanString(shape_);
+}
+
+bool HloShapeAndLayoutMatcher::MatchAndExplain(
+ const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const {
+ if (ShapeUtil::Equal(instruction->shape(), shape_)) {
+ return true;
+ }
+ *listener << instruction->ToString() << " has incorrect shape (expected: "
+ << ShapeUtil::HumanStringWithLayout(shape_) << ")";
+ return false;
+}
+
+void HloShapeAndLayoutMatcher::DescribeTo(std::ostream* os) const {
+ *os << ShapeUtil::HumanStringWithLayout(shape_);
+}
+
+bool HloShardingMatcher::MatchAndExplain(
+ const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const {
+ if (!sharding_.has_value()) {
+ if (!instruction->has_sharding()) {
+ return true;
+ }
+ *listener << instruction->ToString() << " expected to have no sharding.";
+ return false;
+ }
+ if (instruction->has_sharding()) {
+ if (instruction->sharding() == sharding_.value()) {
+ return true;
+ }
+ *listener << instruction->ToString()
+ << " has incorrect sharding (expected: " << sharding_->ToString()
+ << ")";
+ return false;
+ } else {
+ *listener << instruction->ToString()
+ << " has no sharding (expected: " << sharding_->ToString() << ")";
+ return false;
+ }
+}
+
+void HloShardingMatcher::DescribeTo(std::ostream* os) const {
+ if (sharding_.has_value()) {
+ *os << sharding_->ToString();
+ } else {
+ *os << "<no-sharding>";
+ }
+}
+
} // namespace testing
void PrintTo(const HloInstruction* inst, ::std::ostream* os) {
diff --git a/tensorflow/compiler/xla/service/hlo_matchers.h b/tensorflow/compiler/xla/service/hlo_matchers.h
index 103f04a2cb..f2ab9b5d9b 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers.h
+++ b/tensorflow/compiler/xla/service/hlo_matchers.h
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/test.h"
+#include "tensorflow/core/lib/gtl/optional.h"
namespace xla {
namespace testing {
@@ -86,6 +87,50 @@ class HloCustomCallMatcher : public HloMatcher {
::testing::Matcher<string> call_target_matcher_;
};
+class HloShapeMatcher
+ : public ::testing::MatcherInterface<const HloInstruction*> {
+ public:
+ explicit HloShapeMatcher(const Shape& shape) : shape_(shape) {}
+
+ bool MatchAndExplain(const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const override;
+ void DescribeTo(std::ostream* os) const override;
+
+ private:
+ Shape shape_;
+};
+
+class HloShapeAndLayoutMatcher
+ : public ::testing::MatcherInterface<const HloInstruction*> {
+ public:
+ explicit HloShapeAndLayoutMatcher(const Shape& shape) : shape_(shape) {}
+
+ bool MatchAndExplain(const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const override;
+ void DescribeTo(std::ostream* os) const override;
+
+ private:
+ Shape shape_;
+};
+
+// Verify the sharding of an instruction against the provided HloSharding. If a
+// nullopt is provided for the expected sharding then it checks that no sharding
+// is present for an instruction.
+class HloShardingMatcher
+ : public ::testing::MatcherInterface<const HloInstruction*> {
+ public:
+ explicit HloShardingMatcher(
+ const tensorflow::gtl::optional<HloSharding>& sharding)
+ : sharding_(sharding) {}
+
+ bool MatchAndExplain(const HloInstruction* instruction,
+ ::testing::MatchResultListener* listener) const override;
+ void DescribeTo(std::ostream* os) const override;
+
+ private:
+ tensorflow::gtl::optional<HloSharding> sharding_;
+};
+
// HloInstruction* matchers for opcode and operands. Example:
// namespace op = xla::opcode_matchers;
// EXPECT_THAT(instruction,
@@ -231,6 +276,30 @@ inline ::testing::Matcher<const ::xla::HloInstruction*> CustomCall() {
new ::xla::testing::HloMatcher(HloOpcode::kCustomCall, {}));
}
+// Verifies the shape or the shape and the layout of an HLO instruction against
+// the provided shape object.
+inline ::testing::Matcher<const ::xla::HloInstruction*> Shape(
+ const class Shape& shape) {
+ return ::testing::MakeMatcher(new ::xla::testing::HloShapeMatcher(shape));
+}
+inline ::testing::Matcher<const ::xla::HloInstruction*> ShapeWithLayout(
+ const class Shape& shape) {
+ return ::testing::MakeMatcher(
+ new ::xla::testing::HloShapeAndLayoutMatcher(shape));
+}
+
+// Verifies the value of the HloSharing against the provided sharding object.
+inline ::testing::Matcher<const ::xla::HloInstruction*> Sharding(
+ const HloSharding& sharding) {
+ return ::testing::MakeMatcher(
+ new ::xla::testing::HloShardingMatcher(sharding));
+}
+// Verifies that no HloSharding is set for an HLO instruction.
+inline ::testing::Matcher<const ::xla::HloInstruction*> NoSharding() {
+ return ::testing::MakeMatcher(
+ new ::xla::testing::HloShardingMatcher(tensorflow::gtl::nullopt));
+}
+
#undef HLO_MATCHER
} // namespace opcode_matchers
diff --git a/tensorflow/compiler/xla/service/hlo_matchers_test.cc b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
index 1c21703a45..c6373b2e46 100644
--- a/tensorflow/compiler/xla/service/hlo_matchers_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_matchers_test.cc
@@ -100,5 +100,63 @@ TEST(HloMatchersTest, CustomCallMatcher) {
R"(custom-call with call target that is equal to "foo_target")");
}
+TEST(HloMatchersTest, ShapeMatcher) {
+ auto p0 = HloInstruction::CreateParameter(
+ 0, ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1}), "param");
+
+ EXPECT_THAT(p0.get(), op::Shape(ShapeUtil::MakeShape(F32, {5, 7})));
+ EXPECT_THAT(
+ p0.get(),
+ ::testing::Not(op::ShapeWithLayout(ShapeUtil::MakeShape(F32, {5, 7}))));
+ EXPECT_THAT(p0.get(),
+ ::testing::Not(op::Shape(ShapeUtil::MakeShape(F32, {7, 5}))));
+ EXPECT_THAT(
+ p0.get(),
+ ::testing::Not(op::ShapeWithLayout(ShapeUtil::MakeShape(F32, {7, 5}))));
+ EXPECT_THAT(p0.get(),
+ op::Shape(ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {0, 1})));
+ EXPECT_THAT(p0.get(), op::ShapeWithLayout(ShapeUtil::MakeShapeWithLayout(
+ F32, {5, 7}, {0, 1})));
+ EXPECT_THAT(p0.get(),
+ ::testing::Not(op::ShapeWithLayout(
+ ShapeUtil::MakeShapeWithLayout(F32, {5, 7}, {1, 0}))));
+
+ EXPECT_THAT(Explain(p0.get(), op::Shape(ShapeUtil::MakeShape(F32, {7, 5}))),
+ "%param = f32[5,7]{0,1} parameter(0) has incorrect shape "
+ "(expected: f32[7,5])");
+ EXPECT_THAT(
+ Explain(p0.get(), op::ShapeWithLayout(ShapeUtil::MakeShapeWithLayout(
+ F32, {7, 5}, {1, 0}))),
+ "%param = f32[5,7]{0,1} parameter(0) has incorrect shape "
+ "(expected: f32[7,5]{1,0})");
+}
+
+TEST(HloMatchersTest, ShardingMatcher) {
+ auto p0 = HloInstruction::CreateParameter(0, ShapeUtil::MakeShape(F32, {5}),
+ "param.0");
+ p0->clear_sharding();
+ auto p1 = HloInstruction::CreateParameter(1, ShapeUtil::MakeShape(F32, {7}),
+ "param.1");
+ p1->set_sharding(HloSharding::AssignDevice(1));
+
+ EXPECT_THAT(p0.get(), op::NoSharding());
+ EXPECT_THAT(p0.get(),
+ ::testing::Not(op::Sharding(HloSharding::AssignDevice(1))));
+ EXPECT_THAT(p1.get(), ::testing::Not(op::NoSharding()));
+ EXPECT_THAT(p1.get(),
+ ::testing::Not(op::Sharding(HloSharding::AssignDevice(0))));
+ EXPECT_THAT(p1.get(), op::Sharding(HloSharding::AssignDevice(1)));
+
+ EXPECT_THAT(Explain(p0.get(), op::Sharding(HloSharding::AssignDevice(1))),
+ "%param.0 = f32[5]{0} parameter(0) has no sharding (expected: "
+ "{maximal device=1})");
+ EXPECT_THAT(Explain(p1.get(), op::NoSharding()),
+ "%param.1 = f32[7]{0} parameter(1), sharding={maximal device=1} "
+ "expected to have no sharding.");
+ EXPECT_THAT(Explain(p1.get(), op::Sharding(HloSharding::AssignDevice(0))),
+ "%param.1 = f32[7]{0} parameter(1), sharding={maximal device=1} "
+ "has incorrect sharding (expected: {maximal device=0})");
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_runner.cc b/tensorflow/compiler/xla/service/hlo_runner.cc
index df5ffd0b7d..81c43db292 100644
--- a/tensorflow/compiler/xla/service/hlo_runner.cc
+++ b/tensorflow/compiler/xla/service/hlo_runner.cc
@@ -126,16 +126,12 @@ StatusOr<std::unique_ptr<Literal>> HloRunner::Execute(
}
TF_ASSIGN_OR_RETURN(
- ShapedBuffer result,
+ ScopedShapedBuffer result,
executable->ExecuteOnStreamWrapper(
&service_run_options, /*profile=*/nullptr, argument_buffer_ptrs));
- // Create a ScopedShapedBuffer of the result to manage deallocation. This will
- // deallocate all the device memory when it goes out of scope.
- ScopedShapedBuffer scoped_result(std::move(result), run_options.allocator());
-
auto result_literal = backend().transfer_manager()->TransferLiteralFromDevice(
- stream.parent(), scoped_result);
+ stream.parent(), result);
if (result_literal.ok()) {
VLOG(4) << "Executed binary and got result: "
<< result_literal.ValueOrDie()->ToString();
@@ -248,18 +244,16 @@ StatusOr<std::vector<std::unique_ptr<Literal>>> HloRunner::ExecuteReplicated(
}
LOG(INFO) << "Replicated execution started";
- TF_ASSIGN_OR_RETURN(std::vector<ShapedBuffer> results,
+ TF_ASSIGN_OR_RETURN(std::vector<ScopedShapedBuffer> results,
executable->ExecuteOnStreams(service_run_options,
argument_buffer_slices));
LOG(INFO) << "Replicated execution terminated";
std::vector<std::unique_ptr<Literal>> exec_results;
for (int64 i = 0; i < options.num_replicas; ++i) {
- ScopedShapedBuffer result(std::move(results[i]),
- backend().memory_allocator());
TF_ASSIGN_OR_RETURN(std::unique_ptr<Literal> literal,
backend().transfer_manager()->TransferLiteralFromDevice(
- streams[i]->parent(), result));
+ streams[i]->parent(), results[i]));
exec_results.push_back(std::move(literal));
}
return std::move(exec_results);
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index 1b42349b0b..994de44123 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -256,37 +256,24 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
", input_shape=", ShapeUtil::HumanString(shape));
}
- // The tile shape must not be the same as the input shape without maximal_
- // also set. If this is the case, we're not actually sharded and the correct
- // constructor should have been used.
- if (ShapeUtil::Equal(shape, tile_shape_)) {
+ // The correct constructor have to be used to create tile maximal shardings.
+ if (tile_assignment_.num_elements() == 1) {
return tensorflow::errors::InvalidArgument(
- "Tile shape is the same as the input shape. If a replicated sharding "
- "was intended, use HloSharding::Replicated(). If a device placement "
- "was intended, use HloSharding::AssignDevice()");
+ "Tile assignment only contains a single device. If a replicated "
+ "sharding was intended, use HloSharding::Replicated(). If a device "
+ "placement was intended, use HloSharding::AssignDevice()");
}
- // The tile shape must not be greater than the input shape in any dimension.
- for (int64 i = 0, e = ShapeUtil::Rank(shape); i != e; ++i) {
- auto tile_dim = tile_shape_.dimensions(i);
- auto shape_dim = shape.dimensions(i);
- if (tile_dim > shape_dim) {
- return tensorflow::errors::InvalidArgument(
- StrCat("Tile is larger than input shape (dimension ", i, ", ",
- tile_dim, " > ", shape_dim));
- }
- }
-
- // The tile assignment tensor must be exactly dimensioned to ceil(shape[dim]
- // tile[dim]) for every dimension contained within tile.
+ // The tile assignment tensor must contain enough element to cover the full
+ // shape with tiles of the specified size.
for (int64 i = 0, e = tile_assignment_.dimensions().size(); i != e; ++i) {
- int64 expected_dim =
- CeilOfRatio(shape.dimensions(i), tile_shape_.dimensions(i));
- if (tile_assignment_.dimensions()[i] != expected_dim) {
+ int64 total_tile_size = tile_assignment_.dim(i) * tile_shape_.dimensions(i);
+ if (shape.dimensions(i) > total_tile_size) {
return tensorflow::errors::InvalidArgument(
- StrCat("Tile assignment tensor has incorrect shape. Dimension ", i,
- " expected ", expected_dim, " but got ",
- tile_assignment_.dimensions()[i]));
+ StrCat("Tile assignment tensor has too few element to cover the full "
+ "shape. Dimension ",
+ i, ", shape ", shape.dimensions(i), ", total size ",
+ total_tile_size));
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding_test.cc b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
index 69ea4233e4..3bf0d25efb 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding_test.cc
@@ -88,7 +88,7 @@ TEST_F(HloShardingTest, Tile) {
}
{
- // Test should pass.
+ // Test should fail because of more devices used then `num_device`.
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
@@ -97,17 +97,8 @@ TEST_F(HloShardingTest, Tile) {
}
{
- // Test should fail due to the tile being larger than the input space.
- Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
- HloSharding sharding =
- HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
- EXPECT_IS_NOT_OK(sharding.Validate(ShapeUtil::MakeShape(F32, {2, 2}),
- /*num_devices=*/4));
- }
-
- {
- // Test should fail due to the tile not dividing the input space into 4
- // sections (even with padding).
+ // Test should fail because the total tiled size in dimension 0 is 4 but we
+ // have 6 elements along that dimensions.
Shape tile_shape = ShapeUtil::MakeShape(U32, {2, 3});
HloSharding sharding =
HloSharding::Tile(tile_shape, MakeArray({2, 2}, {0, 1, 2, 3}));
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.cc b/tensorflow/compiler/xla/service/interpreter/executable.cc
index 6553000336..61f199bc9e 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.cc
+++ b/tensorflow/compiler/xla/service/interpreter/executable.cc
@@ -45,7 +45,7 @@ InterpreterExecutable::InterpreterExecutable(
InterpreterExecutable::~InterpreterExecutable() {}
-StatusOr<ShapedBuffer> InterpreterExecutable::ExecuteOnStream(
+StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) {
@@ -88,8 +88,8 @@ StatusOr<ShapedBuffer> InterpreterExecutable::ExecuteOnStream(
evaluator.Evaluate<std::unique_ptr<Literal>>(*computation, arg_literals));
// Transform the result literal back into a ShapedBuffer.
- TF_ASSIGN_OR_RETURN(ShapedBuffer result,
- transfer_manager->AllocateShapedBuffer(
+ TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
+ transfer_manager->AllocateScopedShapedBuffer(
result_literal->shape(), run_options->allocator(),
executor->device_ordinal()));
TF_RETURN_IF_ERROR(transfer_manager->TransferLiteralToDevice(
@@ -106,7 +106,7 @@ StatusOr<ShapedBuffer> InterpreterExecutable::ExecuteOnStream(
return std::move(result);
}
-StatusOr<ShapedBuffer> InterpreterExecutable::ExecuteAsyncOnStream(
+StatusOr<ScopedShapedBuffer> InterpreterExecutable::ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
return tensorflow::errors::Unimplemented(
diff --git a/tensorflow/compiler/xla/service/interpreter/executable.h b/tensorflow/compiler/xla/service/interpreter/executable.h
index c825a9a368..b0b797ca7d 100644
--- a/tensorflow/compiler/xla/service/interpreter/executable.h
+++ b/tensorflow/compiler/xla/service/interpreter/executable.h
@@ -43,12 +43,12 @@ class InterpreterExecutable : public Executable {
InterpreterExecutable(std::unique_ptr<const HloModule> hlo_module);
~InterpreterExecutable() override;
- StatusOr<ShapedBuffer> ExecuteOnStream(
+ StatusOr<ScopedShapedBuffer> ExecuteOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
HloExecutionProfile* hlo_execution_profile) override;
- StatusOr<ShapedBuffer> ExecuteAsyncOnStream(
+ StatusOr<ScopedShapedBuffer> ExecuteAsyncOnStream(
const ServiceExecutableRunOptions* run_options,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) override;
diff --git a/tensorflow/compiler/xla/service/interpreter/platform.cc b/tensorflow/compiler/xla/service/interpreter/platform.cc
index ce2f4d378c..92e069a8c6 100644
--- a/tensorflow/compiler/xla/service/interpreter/platform.cc
+++ b/tensorflow/compiler/xla/service/interpreter/platform.cc
@@ -71,8 +71,8 @@ port::StatusOr<StreamExecutor*> XlaInterpreterPlatform::GetExecutor(
port::StatusOr<std::unique_ptr<StreamExecutor>>
XlaInterpreterPlatform::GetUncachedExecutor(
const StreamExecutorConfig& config) {
- auto executor = port::MakeUnique<StreamExecutor>(
- this, port::MakeUnique<XlaInterpreterExecutor>(config.plugin_config));
+ auto executor = MakeUnique<StreamExecutor>(
+ this, MakeUnique<XlaInterpreterExecutor>(config.plugin_config));
auto init_status = executor->Init(config.ordinal, config.device_options);
if (!init_status.ok()) {
return port::Status{
diff --git a/tensorflow/compiler/xla/service/layout_assignment.cc b/tensorflow/compiler/xla/service/layout_assignment.cc
index 7067b6f86a..2494569db5 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/layout_assignment.cc
@@ -31,12 +31,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/computation_layout.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
-#include "tensorflow/compiler/xla/service/hlo_dce.h"
#include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/logical_buffer.h"
-#include "tensorflow/compiler/xla/service/tuple_simplifier.h"
#include "tensorflow/compiler/xla/shape_layout.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
@@ -402,9 +400,9 @@ string LayoutConstraints::ToString() const {
}
Status LayoutAssignment::AddMandatoryConstraints(
- const ComputationLayout* computation_layout,
- ChannelLayoutConstraints* channel_constraints, HloComputation* computation,
- LayoutConstraints* constraints) {
+ const ComputationLayout& computation_layout,
+ const ChannelLayoutConstraints* channel_constraints,
+ HloComputation* computation, LayoutConstraints* constraints) {
VLOG(3) << "Adding mandatory layout constraints to computation "
<< computation->name();
@@ -426,16 +424,11 @@ Status LayoutAssignment::AddMandatoryConstraints(
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
instruction->outfeed_shape(), instruction, 0));
} else if (instruction->opcode() == HloOpcode::kParameter) {
- if (computation_layout != nullptr) {
- const ShapeLayout& parameter_layout =
- computation_layout->parameter_layout(
- instruction->parameter_number());
- if (parameter_layout.LayoutIsSet()) {
- // Parameter layouts must match the respective layout in
- // ComputationLayout, if there is one.
- shape_with_layout = &parameter_layout.shape();
- }
- }
+ // Parameter layouts must match the respective layout in
+ // ComputationLayout.
+ shape_with_layout =
+ &computation_layout.parameter_layout(instruction->parameter_number())
+ .shape();
}
if (shape_with_layout != nullptr) {
TF_RETURN_IF_ERROR(
@@ -500,8 +493,9 @@ Status LayoutAssignment::AddMandatoryConstraints(
HloComputation* body = instruction->while_body();
HloComputation* condition = instruction->while_condition();
const HloInstruction* init = instruction->operand(0);
- ComputationLayout& body_layout = FindOrDie(computation_layouts_, body);
- ComputationLayout& condition_layout =
+ const ComputationLayout& body_layout =
+ FindOrDie(computation_layouts_, body);
+ const ComputationLayout& condition_layout =
FindOrDie(computation_layouts_, condition);
// Check a few invariants irrespective of layout.
@@ -514,19 +508,26 @@ Status LayoutAssignment::AddMandatoryConstraints(
condition_layout.parameter_shape(0)));
DCHECK(ShapeUtil::Compatible(body_layout.result_shape(), init->shape()));
- if (body_layout.result_layout() != body_layout.parameter_layout(0)) {
- VLOG(2) << "Reset %while body parameter layout: body=" << body->name()
- << " while=" << instruction->name()
- << " shape=" << body_layout.result_layout().ToString();
- *body_layout.mutable_parameter_layout(0) = body_layout.result_layout();
+ // Return error if earlier layout assignment of the embedded computations
+ // has produced conflicting layouts.
+ if (!ShapeUtil::Equal(body_layout.result_shape(),
+ body_layout.parameter_shape(0))) {
+ return InternalError(
+ "Parameter and result of body computation %s of while instruction "
+ "%s have different layouts: %s vs %s",
+ body->name().c_str(), instruction->name().c_str(),
+ ShapeUtil::HumanString(body_layout.result_shape()).c_str(),
+ ShapeUtil::HumanString(body_layout.parameter_shape(0)).c_str());
}
- if (condition_layout.parameter_layout(0) !=
- body_layout.parameter_layout(0)) {
- VLOG(2) << "Reset %while condition parameter layout: cond="
- << condition->name() << " while=" << instruction->name()
- << " shape=" << body_layout.parameter_layout(0).ToString();
- *condition_layout.mutable_parameter_layout(0) =
- body_layout.parameter_layout(0);
+ if (!ShapeUtil::Equal(body->root_instruction()->shape(),
+ condition->parameter_instruction(0)->shape())) {
+ return InternalError(
+ "Parameter of condition computation %s of while instruction "
+ "%s does not match body computation %s result: %s vs %s",
+ condition->name().c_str(), instruction->name().c_str(),
+ body->name().c_str(),
+ ShapeUtil::HumanString(condition_layout.parameter_shape(0)).c_str(),
+ ShapeUtil::HumanString(body_layout.result_shape()).c_str());
}
// Constrain the output and the operand of the while instruction to match
@@ -556,20 +557,7 @@ Status LayoutAssignment::AddMandatoryConstraints(
true_computation_layout.parameter_shape(0)));
DCHECK(ShapeUtil::Compatible(
false_operand->shape(), false_computation_layout.parameter_shape(0)));
- if (true_computation_layout.result_layout() !=
- false_computation_layout.result_layout()) {
- // We assign layouts in DFS fashion, so the true and false computations
- // might have negotiated a different layout. But for the conditional
- // instruction POV the layout must match, so we run again on the false
- // computation, this time with proper computation layout.
- VLOG(2) << "Reset %conditional false computation result layout: "
- "false_computation="
- << false_computation->name()
- << " conditional=" << instruction->name() << " shape="
- << true_computation_layout.result_layout().ToString();
- *false_computation_layout.mutable_result_layout() =
- true_computation_layout.result_layout();
- }
+
TF_RETURN_IF_ERROR(constraints->SetInstructionLayout(
true_computation_layout.result_shape(), instruction));
TF_RETURN_IF_ERROR(constraints->SetOperandLayout(
@@ -605,14 +593,10 @@ Status LayoutAssignment::AddMandatoryConstraints(
}
}
}
- // Finally set the result layout to match ComputationLayout, if there is one.
- if (computation_layout != nullptr) {
- const ShapeLayout& result_layout = computation_layout->result_layout();
- if (result_layout.LayoutIsSet()) {
- TF_RETURN_IF_ERROR(constraints->SetResultLayout(result_layout.shape()));
- }
- }
- return Status::OK();
+
+ // Finally set the result layout to match ComputationLayout.
+ return constraints->SetResultLayout(
+ computation_layout.result_layout().shape());
}
namespace {
@@ -776,7 +760,6 @@ StatusOr<HloInstruction*> LayoutAssignment::CreateCopyWithNewLayout(
HloInstruction* copy =
instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
instruction->shape(), HloOpcode::kCopy, instruction));
- RegisterAddedCopy(copy);
SetupCopiedInstruction(*instruction, copy, {});
LayoutUtil::ClearLayout(copy->mutable_shape());
TF_RETURN_IF_ERROR(LayoutUtil::CopyLayoutBetweenShapes(
@@ -800,19 +783,13 @@ Status LayoutAssignment::CopyOperandIfLayoutsDiffer(
TF_RET_CHECK(LayoutUtil::HasLayout(operand->shape()));
if (ShapeUtil::Equal(operand_layout.shape(), operand->shape())) {
- VLOG(5) << "Operand " << operand->ToString() << " layout matches in "
- << instruction->ToString();
// Operand layout already matches our constraint. Nothing to do.
return Status::OK();
}
- VLOG(4) << "Operand " << operand->ToString() << " layout does not match "
- << operand_layout.ToString() << " in " << instruction->ToString();
TF_ASSIGN_OR_RETURN(HloInstruction * operand_copy,
CreateCopyWithNewLayout(operand_layout.shape(), operand));
- VLOG(4) << "New copy of " << operand->ToString() << " is "
- << operand_copy->ToString();
return instruction->ReplaceOperandWith(operand_no, operand_copy);
}
@@ -919,16 +896,15 @@ Status LayoutAssignment::CheckLayouts(HloModule* module) {
}
}
}
- // Finally verify the result layout, if set, matches the layout of the entry
+
+ // Finally verify the result layout matches the layout of the entry
// computation root.
- const ShapeLayout& result_layout =
+ TF_RET_CHECK(ShapeUtil::Equal(
+ module->entry_computation()->root_instruction()->shape(),
FindOrDie(computation_layouts_, module->entry_computation())
- .result_layout();
- if (result_layout.LayoutIsSet()) {
- TF_RET_CHECK(ShapeUtil::Equal(
- module->entry_computation()->root_instruction()->shape(),
- result_layout.shape()));
- }
+ .result_layout()
+ .shape()));
+
return Status::OK();
}
@@ -937,13 +913,18 @@ LayoutAssignment::LayoutAssignment(
ChannelLayoutConstraints* channel_constraints)
: entry_computation_layout_(entry_computation_layout),
channel_layout_constraints_(channel_constraints) {
- VLOG(1) << "Entry computation layout given to layout assignment: "
+ VLOG(1) << "entry computation layout given to layout assignment: "
<< entry_computation_layout_->ToString();
// Layouts of all parameter instructions must be set.
for (const ShapeLayout& parameter_layout :
entry_computation_layout_->parameter_layouts()) {
CHECK(parameter_layout.LayoutIsSet());
}
+ // If the result layout is not set, then choose the default.
+ // TODO(b/29118294): Choose a better layout in this case.
+ if (!entry_computation_layout_->result_layout().LayoutIsSet()) {
+ entry_computation_layout_->mutable_result_layout()->SetToDefaultLayout();
+ }
}
std::unique_ptr<Layout> LayoutAssignment::ChooseOperandLayoutFromOutputLayout(
@@ -1503,60 +1484,16 @@ Status LayoutAssignment::AssignLayouts(const LayoutConstraints& constraints,
return Status::OK();
}
-Status LayoutAssignment::CalculateComputationLayout(
- HloComputation* computation) {
- ComputationLayout computation_layout(computation->ComputeProgramShape(),
- /*ignore_layouts=*/false);
- InsertOrDie(&computation_layouts_, computation, computation_layout);
- VLOG(2) << " Calculated ComputationLayout = "
- << computation_layout.ToString();
- return Status::OK();
-}
-
-Status LayoutAssignment::ClearComputationLayouts(HloComputation* computation) {
- // Clear existing layouts of the instructions. All layouts must be assigned
- // by the LayoutAssignment pass, except for those on infeeds, parameters,
- // and the computation result. The latter two are specified in
- // computation_layout, so we only need to keep the existing layouts for
- // infeeds. Clearing the layouts here avoids hiding potential bugs in the
- // layout assignment pass that may accidently use the existing layout.
- for (HloInstruction* instruction : computation->instructions()) {
- if (instruction->opcode() == HloOpcode::kBitcast) {
- // bitcasts are inherently layout sensitive and so a bitcast instruction
- // present in the IR before layout assignment is a bug.
- return InternalError(
- "Unexpected bitcast operation seen during layout assignment: %s.",
- instruction->ToString().c_str());
- }
- if (instruction->opcode() != HloOpcode::kInfeed) {
- LayoutUtil::ClearLayout(instruction->mutable_shape());
- }
- }
- return Status::OK();
-}
-
Status LayoutAssignment::RunOnComputation(
- ComputationLayout* computation_layout,
+ const ComputationLayout& computation_layout,
const TuplePointsToAnalysis& points_to_analysis,
HloComputation* computation,
ChannelLayoutConstraints* channel_constraints) {
+ DCHECK(computation_layout.LayoutIsSet());
+ InsertOrDie(&computation_layouts_, computation, computation_layout);
VLOG(2) << "LayoutAssignment::RunOnComputation(" << computation->name()
<< ")";
- TF_RETURN_IF_ERROR(ClearComputationLayouts(computation));
- if (computation_layout != nullptr) {
- auto it = computation_layouts_.find(computation);
- if (it == computation_layouts_.end()) {
- VLOG(2) << " New ComputationLayout = " << computation_layout->ToString();
- computation_layouts_.emplace(computation, *computation_layout);
- } else {
- TF_RET_CHECK(computation_layout == &it->second ||
- computation_layout == entry_computation_layout_);
- VLOG(2) << " Existing ComputationLayout = "
- << computation_layout->ToString();
- }
- } else {
- VLOG(2) << " No ComputationLayout specified (will be calculated)";
- }
+ VLOG(2) << " ComputationLayout = " << computation_layout.ToString();
// Construct LayoutConstraints with all layout constraints of the computation.
LayoutConstraints constraints(points_to_analysis, computation);
@@ -1599,19 +1536,12 @@ Status LayoutAssignment::RunOnComputation(
CHECK_LT(constraints.unconstrained_buffer_ids().size(),
unconstrained_count);
}
+
// All logical buffers should have constraints at this point. All that
// remains is assign the constraints to the buffers and infer layouts for
// aliased buffers.
TF_RETURN_IF_ERROR(AssignLayouts(constraints, computation));
- // If the computation layout wasn't specified, now it is the time to compute
- // it according to the parameters and root instruction layouts.
- // This allows the first pass through this API to record the best flowing
- // layout to parameters and root instruction.
- if (computation_layout == nullptr) {
- TF_RETURN_IF_ERROR(CalculateComputationLayout(computation));
- }
-
// Record the layouts assigned for any communication ops in
// channel_constraints so that they are constrained for future modules.
for (HloInstruction* instruction : computation->instructions()) {
@@ -1626,34 +1556,6 @@ Status LayoutAssignment::RunOnComputation(
return Status::OK();
}
-Status LayoutAssignment::PropagateComputationLayouts(
- HloComputation* computation, ComputationLayout* computation_layout) {
- ComputationLayout computed_computation_layout(
- computation->ComputeProgramShape(),
- /*ignore_layouts=*/false);
- for (int64 i = 0; i < computed_computation_layout.parameter_count(); ++i) {
- ShapeLayout* param_layout = computation_layout->mutable_parameter_layout(i);
- if (!param_layout->LayoutIsSet()) {
- VLOG(4) << "Assigning layout to parameter " << i << " of computation "
- << computation->name() << ": "
- << computed_computation_layout.parameter_layout(i).ToString();
- *param_layout = computed_computation_layout.parameter_layout(i);
- } else {
- TF_RET_CHECK(computed_computation_layout.parameter_layout(i) ==
- *param_layout);
- }
- }
- ShapeLayout* result_layout = computation_layout->mutable_result_layout();
- if (!result_layout->LayoutIsSet()) {
- VLOG(4) << "Assigning result layout of computation " << computation->name()
- << ": " << computed_computation_layout.result_layout().ToString();
- *result_layout = computed_computation_layout.result_layout();
- } else {
- TF_RET_CHECK(computed_computation_layout.result_layout() == *result_layout);
- }
- return Status::OK();
-}
-
StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
VLOG(2) << "Running layout assignment on module " << module->name();
XLA_VLOG_LINES(3, module->ToString());
@@ -1662,45 +1564,52 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
"before layout assignment",
module->config().debug_options());
}
- TF_RETURN_IF_ERROR(Init());
-
- // We do two passes. The first one we pass a nullptr ComputationLayout to
- // the RunOnComputation() calls (for non entry computations), and we register
- // the ComputationLayout which are naturally flowing in DFS fashion to the
- // parameters and root instruction.
- // Walking in DFS mode though, means that we can end up with incorrect layouts
- // when seen from an outer instruction, which has across-computation
- // constraints to impose.
- // For example, the kWhile instruction needs to enforce the same layouts for
- // the parameters and root of the bosy, as well as the condition parameters.
- // Similarly, the kConditional instruction needs to enforce the same layouts
- // for the root of the true and false computations.
- // So in the first pass, while allowing the layouts to flow to parameters and
- // root, we also fix up the eventually inconsistent ComputationLayout, which
- // will be then made mandatory by the second pass.
- for (int64 i = 0; i < 2; ++i) {
- TF_RETURN_IF_ERROR(ClearPreviousPassSideEffects(module));
- TF_ASSIGN_OR_RETURN(auto points_to_analysis,
- TuplePointsToAnalysis::Run(module));
- for (auto* computation : module->MakeComputationPostOrder()) {
- if (computation->IsFusionComputation()) {
- continue;
+
+ TF_ASSIGN_OR_RETURN(auto points_to_analysis,
+ TuplePointsToAnalysis::Run(module));
+
+ // Assign layouts to computations in an order such that a callee computation
+ // is handled before its caller computation. This ensures that the layout of
+ // all callers of a computation will agree.
+ std::list<HloComputation*> computation_post_order =
+ module->MakeComputationPostOrder();
+ for (auto* computation : module->MakeComputationPostOrder()) {
+ if (computation->IsFusionComputation()) {
+ continue;
+ }
+ // Clear existing layouts of the instructions. All layouts must be assigned
+ // by the LayoutAssignment pass, except for those on infeeds, parameters,
+ // and the computation result. The latter two are specified in
+ // computation_layout, so we only need to keep the existing layouts for
+ // infeeds. Clearing the layouts here avoids hiding potential bugs in the
+ // layout assignment pass that may accidently use the existing layout.
+ for (HloInstruction* instruction : computation->instructions()) {
+ if (instruction->opcode() == HloOpcode::kBitcast) {
+ // bitcasts are inherently layout sensitive and so a bitcast instruction
+ // present in the IR before layout assignment is a bug.
+ return InternalError(
+ "Unexpected bitcast operation seen during layout assignment: %s.",
+ instruction->ToString().c_str());
}
- if (computation == module->entry_computation()) {
- TF_RETURN_IF_ERROR(RunOnComputation(
- entry_computation_layout_, *points_to_analysis,
- module->entry_computation(), channel_layout_constraints_));
- } else {
- ComputationLayout* computation_layout =
- (i == 0) ? nullptr : &FindOrDie(computation_layouts_, computation);
- TF_RETURN_IF_ERROR(RunOnComputation(computation_layout,
- *points_to_analysis, computation,
- channel_layout_constraints_));
+ if (instruction->opcode() != HloOpcode::kInfeed) {
+ LayoutUtil::ClearLayout(instruction->mutable_shape());
}
}
+ if (computation == module->entry_computation()) {
+ TF_RETURN_IF_ERROR(RunOnComputation(
+ *entry_computation_layout_, *points_to_analysis,
+ module->entry_computation(), channel_layout_constraints_));
+ } else {
+ ComputationLayout computation_layout(computation->ComputeProgramShape());
+ // Setting all embedded computations to the default layout is potentially
+ // suboptimal.
+ computation_layout.SetToDefaultLayout();
+ TF_RETURN_IF_ERROR(RunOnComputation(computation_layout,
+ *points_to_analysis, computation,
+ channel_layout_constraints_));
+ }
}
- TF_RETURN_IF_ERROR(PropagateComputationLayouts(module->entry_computation(),
- entry_computation_layout_));
+
TF_RETURN_IF_ERROR(CheckLayouts(module));
VLOG(3) << "After layout assignment:";
@@ -1710,54 +1619,9 @@ StatusOr<bool> LayoutAssignment::Run(HloModule* module) {
"after layout assignment",
module->config().debug_options());
}
+
// All layouts are reset then reassigned by this pass.
return true;
}
-Status LayoutAssignment::Init() {
- computation_layouts_.clear();
- return Status::OK();
-}
-
-Status LayoutAssignment::ClearPreviousPassSideEffects(HloModule* module) {
- // Clear all the copies which have been added, and all the related
- // instructions (like GTE and tuples).
- int64 removed_copies = 0;
- for (HloComputation* computation : module->computations()) {
- for (HloInstruction* instruction :
- computation->MakeInstructionPostOrder()) {
- if (instruction->opcode() == HloOpcode::kCopy &&
- added_copies_.count(instruction) > 0) {
- VLOG(5) << "Removing added copy: " << instruction->ToString();
- TF_RETURN_IF_ERROR(
- instruction->ReplaceAllUsesWith(instruction->mutable_operand(0)));
- TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
- ++removed_copies;
- }
- }
- }
- added_copies_.clear();
- if (removed_copies > 0) {
- TupleSimplifier tuple_simplifier;
- HloDCE dce;
- TF_RETURN_IF_ERROR(tuple_simplifier.Run(module).status());
- TF_RETURN_IF_ERROR(dce.Run(module).status());
- }
- return Status::OK();
-}
-
-Status LayoutAssignment::AddCopyForOperand(HloInstruction* instruction,
- int64 operand_number) {
- HloInstruction* operand = instruction->mutable_operand(operand_number);
- if (operand->opcode() != HloOpcode::kCopy || operand->user_count() > 1) {
- HloInstruction* copy =
- instruction->parent()->AddInstruction(HloInstruction::CreateUnary(
- operand->shape(), HloOpcode::kCopy, operand));
- SetupCopiedInstruction(*operand, copy, {});
- LayoutUtil::ClearLayout(copy->mutable_shape());
- TF_RETURN_IF_ERROR(instruction->ReplaceOperandWith(operand_number, copy));
- }
- return Status::OK();
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/layout_assignment.h b/tensorflow/compiler/xla/service/layout_assignment.h
index 8b4e07995a..ae4986d6ad 100644
--- a/tensorflow/compiler/xla/service/layout_assignment.h
+++ b/tensorflow/compiler/xla/service/layout_assignment.h
@@ -39,7 +39,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
-#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
@@ -363,15 +362,12 @@ class LayoutAssignment : public HloPassInterface {
int64 operand_no);
private:
- // Initializes the layout assignment object for a new Run() call.
- Status Init();
-
// Adds constraints which must be satisfied for correctness on all
// backends. Called once prior to propagating constraints.
- Status AddMandatoryConstraints(const ComputationLayout* computation_layout,
- ChannelLayoutConstraints* channel_constraints,
- HloComputation* computation,
- LayoutConstraints* constraints);
+ Status AddMandatoryConstraints(
+ const ComputationLayout& computation_layout,
+ const ChannelLayoutConstraints* channel_constraints,
+ HloComputation* computation, LayoutConstraints* constraints);
// This method can be overridden to add backend-specific constraints to the
// layout of the instructions of a computation. This method is called after
@@ -382,12 +378,10 @@ class LayoutAssignment : public HloPassInterface {
}
// Construct contraints and assign layouts to all instructions in the
- // computation satisfying the given ComputationLayout, if not nullptr.
- // Otherwise the ComputationLayout will be calculated by propagating the
- // computation instruction contraints.
- // Layouts constraints are added, then propagated until all LogicalBuffers in
- // the computation are constrained.
- Status RunOnComputation(ComputationLayout* computation_layout,
+ // computation satisfying the given ComputationLayout. Layouts constraints are
+ // added, then propagated until all LogicalBuffers in the computation are
+ // constrained.
+ Status RunOnComputation(const ComputationLayout& computation_layout,
const TuplePointsToAnalysis& points_to_analysis,
HloComputation* computation,
ChannelLayoutConstraints* channel_constraints);
@@ -408,25 +402,6 @@ class LayoutAssignment : public HloPassInterface {
// necessary conditions.
Status CheckLayouts(HloModule* module);
- // Computes the ComputationLayout of the given computation based of the
- // layouts assigned to parameters and root instruction, and inserts it to the
- // computation_layouts_ map.
- Status CalculateComputationLayout(HloComputation* computation);
-
- // Clears all the layouts which can be cleared within a computation.
- Status ClearComputationLayouts(HloComputation* computation);
-
- // Clears the side effects of a previous pass, like added copy instructions.
- Status ClearPreviousPassSideEffects(HloModule* module);
-
- // Propagates the layouts computed by the layout assignment pass on the given
- // computation, to the computation layout passed in to this API.
- // This API propagates missing layout, and also checks that the caller
- // specified have been respected, by comparing those with the parameters and
- // root computation instruction.
- Status PropagateComputationLayouts(HloComputation* computation,
- ComputationLayout* computation_layout);
-
ComputationLayout* entry_computation_layout_;
protected:
@@ -443,37 +418,21 @@ class LayoutAssignment : public HloPassInterface {
// Creates and returns a copy of the given instruction with a different
// layout. Tuple-shaped instructions will be deep-copied, and the last Tuple
// instruction producing the copy is returned.
- StatusOr<HloInstruction*> CreateCopyWithNewLayout(
+ static StatusOr<HloInstruction*> CreateCopyWithNewLayout(
const Shape& shape_with_layout, HloInstruction* instruction);
// Creates a copy of the given operand if the operand's layout does not match
// the given layout. This copy replaces the use in the given instruction.
// Tuple operands will be deep-copied.
- Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
- HloInstruction* instruction,
- int64 operand_no);
-
- // Registers a copy instruction added by the layout assignment pass.
- void RegisterAddedCopy(HloInstruction* copy) {
- CHECK_EQ(copy->opcode(), HloOpcode::kCopy);
- added_copies_.insert(copy);
- }
-
- // Adds a copy for the operand of an instruction, unless such operand is
- // already a copy, and has a single user (which is forcibly the instruction
- // itself).
- Status AddCopyForOperand(HloInstruction* instruction, int64 operand_number);
+ static Status CopyOperandIfLayoutsDiffer(const ShapeLayout& operand_layout,
+ HloInstruction* instruction,
+ int64 operand_no);
// Map containing the layouts of all computations assigned so
// far. Computations are handled in a topological sort where computations are
// handled before their caller instructions so the layouts of caller
// instructions can be set to match the computation.
std::map<HloComputation*, ComputationLayout> computation_layouts_;
-
- // Every copy added to the module by the layout assignment pass is registered
- // here.
- tensorflow::gtl::FlatSet<HloInstruction*> added_copies_;
-
ChannelLayoutConstraints* channel_layout_constraints_;
};
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index a73118c68a..086bd61dd0 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -308,10 +308,7 @@ StatusOr<std::unique_ptr<HloModuleConfig>> Service::CreateModuleConfig(
computation_layout->mutable_result_layout()->CopyLayoutFromShape(
shape_with_output_layout));
} else {
- // TODO(b/78356948): We are forcing the default layout here. We should fix
- // clients which expect a default layout, to be explicit about it, by
- // passing the proper ExecutionOptions with shape_with_output_layout set.
- computation_layout->mutable_result_layout()->SetToDefaultLayout();
+ computation_layout->mutable_result_layout()->Clear();
}
config->set_replica_count(options_.number_of_replicas());
@@ -553,7 +550,7 @@ Service::ExecuteParallelAndRegisterResult(
// Stream executors for the replicas of the current computation.
TF_ASSIGN_OR_RETURN(auto replicas, Replicas(*backend, device_handles[i]));
CHECK_EQ(replicas.size(), arguments[i].size());
- std::vector<ShapedBuffer> result_buffers;
+ std::vector<ScopedShapedBuffer> result_buffers;
for (int64 replica = 0; replica < replicas.size(); ++replica) {
TF_ASSIGN_OR_RETURN(Pool<se::Stream>::SmartPtr stream,
backend->BorrowStream(replicas[replica]));
@@ -585,7 +582,7 @@ Service::ExecuteParallelAndRegisterResult(
backend->StreamBorrower());
// Asynchronously launch the computation.
- TF_ASSIGN_OR_RETURN(ShapedBuffer result,
+ TF_ASSIGN_OR_RETURN(ScopedShapedBuffer result,
executables[i]->ExecuteAsyncOnStream(
&run_options, arguments[i][replica]));
@@ -1237,7 +1234,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
streams.push_back(std::move(stream));
}
- std::vector<ShapedBuffer> result_buffers;
+ std::vector<ScopedShapedBuffer> result_buffers;
for (size_t i = 0; i < streams.size(); ++i) {
const auto& stream = streams[i];
ExecutableRunOptions options;
@@ -1250,7 +1247,7 @@ tensorflow::Status Service::ExecuteAsync(const ExecuteAsyncRequest* arg,
ServiceExecutableRunOptions service_options(
options, execute_backend_->StreamBorrower());
- TF_ASSIGN_OR_RETURN(ShapedBuffer this_result_buffer,
+ TF_ASSIGN_OR_RETURN(ScopedShapedBuffer this_result_buffer,
executable->ExecuteAsyncOnStream(
&service_options, replicated_arguments[i]));
@@ -1350,11 +1347,11 @@ tensorflow::Status Service::TransferToServer(const TransferToServerRequest* arg,
}
// Allocate memory in each replica and transfer the data to all replicas.
- std::vector<ShapedBuffer> replicated_buffers;
+ std::vector<ScopedShapedBuffer> replicated_buffers;
for (se::StreamExecutor* executor : replicas) {
TF_ASSIGN_OR_RETURN(
- ShapedBuffer shaped_buffer,
- execute_backend_->transfer_manager()->AllocateShapedBuffer(
+ ScopedShapedBuffer shaped_buffer,
+ execute_backend_->transfer_manager()->AllocateScopedShapedBuffer(
shape, execute_backend_->memory_allocator(),
executor->device_ordinal()));
TF_RETURN_IF_ERROR(
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.cc b/tensorflow/compiler/xla/service/shaped_buffer.cc
index 0b5a383f6f..fb3b5f06da 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.cc
+++ b/tensorflow/compiler/xla/service/shaped_buffer.cc
@@ -117,7 +117,7 @@ ScopedShapedBuffer::ScopedShapedBuffer(ShapedBuffer shaped_buffer,
: ShapedBuffer(std::move(shaped_buffer)), allocator_(allocator) {}
ScopedShapedBuffer::ScopedShapedBuffer(ScopedShapedBuffer&& s)
- : ShapedBuffer(std::move(s)), allocator_(s.allocator_) {
+ : ShapedBuffer(static_cast<ShapedBuffer&&>(s)), allocator_(s.allocator_) {
// Null out s.allocator_ so it doesn't try to free anything in its destructor.
s.allocator_ = nullptr;
}
@@ -151,7 +151,7 @@ ScopedShapedBuffer::~ScopedShapedBuffer() {
}
ShapedBuffer ScopedShapedBuffer::release() {
- ShapedBuffer shaped_buffer(std::move(*this));
+ ShapedBuffer shaped_buffer(static_cast<ShapedBuffer&&>(*this));
buffers_ = ShapeTree<se::DeviceMemoryBase>();
return shaped_buffer;
}
diff --git a/tensorflow/compiler/xla/service/shaped_buffer.h b/tensorflow/compiler/xla/service/shaped_buffer.h
index f1b0527474..e10fca9e94 100644
--- a/tensorflow/compiler/xla/service/shaped_buffer.h
+++ b/tensorflow/compiler/xla/service/shaped_buffer.h
@@ -30,6 +30,8 @@ limitations under the License.
namespace xla {
+class ScopedShapedBuffer;
+
// Class which encapsulates a buffer or set of buffers containing data of a
// particular XLA shape.
class ShapedBuffer {
@@ -49,6 +51,10 @@ class ShapedBuffer {
ShapedBuffer(const ShapedBuffer&) = delete;
ShapedBuffer& operator=(const ShapedBuffer&) = delete;
+ // Prevent (some forms of) accidental object slicing.
+ ShapedBuffer(const ScopedShapedBuffer&) = delete;
+ ShapedBuffer& operator=(const ScopedShapedBuffer&) = delete;
+
virtual ~ShapedBuffer();
// Returns the shape of the on-host representation of the data held by this
diff --git a/tensorflow/compiler/xla/service/transfer_manager.cc b/tensorflow/compiler/xla/service/transfer_manager.cc
index 98d0111d04..8b71a41509 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.cc
+++ b/tensorflow/compiler/xla/service/transfer_manager.cc
@@ -175,7 +175,7 @@ Status TransferManager::TransferBufferToDevice(
return Status::OK();
}
-StatusOr<ShapedBuffer> TransferManager::AllocateShapedBuffer(
+StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
int device_ordinal) {
if (!LayoutUtil::HasLayout(on_host_shape)) {
@@ -187,8 +187,8 @@ StatusOr<ShapedBuffer> TransferManager::AllocateShapedBuffer(
const Shape on_device_shape = HostShapeToDeviceShape(on_host_shape);
TF_RET_CHECK(LayoutUtil::HasLayout(on_device_shape));
- ShapedBuffer shaped_buffer(on_host_shape, on_device_shape,
- allocator->platform(), device_ordinal);
+ ScopedShapedBuffer shaped_buffer(on_host_shape, on_device_shape, allocator,
+ device_ordinal);
// Allocate an appropriate sized buffer for each element in the shape
// including the tuple pointer arrays.
@@ -204,13 +204,4 @@ StatusOr<ShapedBuffer> TransferManager::AllocateShapedBuffer(
return std::move(shaped_buffer);
}
-StatusOr<ScopedShapedBuffer> TransferManager::AllocateScopedShapedBuffer(
- const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
- int device_ordinal) {
- TF_ASSIGN_OR_RETURN(
- ShapedBuffer unscoped_buffer,
- AllocateShapedBuffer(on_host_shape, allocator, device_ordinal));
- return ScopedShapedBuffer(std::move(unscoped_buffer), allocator);
-}
-
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/transfer_manager.h b/tensorflow/compiler/xla/service/transfer_manager.h
index a6451c4bb1..d82b4f0f81 100644
--- a/tensorflow/compiler/xla/service/transfer_manager.h
+++ b/tensorflow/compiler/xla/service/transfer_manager.h
@@ -104,12 +104,9 @@ class TransferManager {
// region for a host-to-device transfer.
virtual int64 GetByteSizeRequirement(const Shape& shape) const = 0;
- // Allocate a ShapedBuffer which can hold data with the given on-host
+ // Allocates a ScopedShapedBuffer which can hold data with the given on-host
// shape. The on-device shape may be different as indicated by
// HostShapeToDeviceShape.
- StatusOr<ShapedBuffer> AllocateShapedBuffer(const Shape& on_host_shape,
- DeviceMemoryAllocator* allocator,
- int device_ordinal);
StatusOr<ScopedShapedBuffer> AllocateScopedShapedBuffer(
const Shape& on_host_shape, DeviceMemoryAllocator* allocator,
int device_ordinal);
diff --git a/tensorflow/compiler/xla/service/tuple_simplifier.cc b/tensorflow/compiler/xla/service/tuple_simplifier.cc
index d668855084..113c2e2bd9 100644
--- a/tensorflow/compiler/xla/service/tuple_simplifier.cc
+++ b/tensorflow/compiler/xla/service/tuple_simplifier.cc
@@ -69,7 +69,6 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
// Tuple
//
HloInstruction* top_tuple = nullptr;
- HloInstruction* first_gte = nullptr;
bool can_simplify = true;
for (int64 operand_number = 0;
operand_number < instruction->operand_count(); ++operand_number) {
@@ -79,17 +78,11 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
can_simplify = false;
break;
}
- if (first_gte == nullptr) {
- first_gte = operand;
- } else if (!first_gte->has_compatible_sharding(operand)) {
- can_simplify = false;
- break;
- }
+
if (top_tuple == nullptr) {
top_tuple = operand->mutable_operand(0);
if (!ShapeUtil::Compatible(top_tuple->shape(),
- instruction->shape()) ||
- !instruction->has_compatible_sharding(top_tuple)) {
+ instruction->shape())) {
can_simplify = false;
break;
}
@@ -115,17 +108,15 @@ StatusOr<bool> TupleSimplifier::Run(HloModule* module) {
// |
// GTE
if (instruction->operand(0)->opcode() == HloOpcode::kTuple) {
+ changed = true;
HloInstruction* element_source =
instruction->mutable_operand(0)->mutable_operand(
instruction->tuple_index());
- if (instruction->has_compatible_sharding(element_source)) {
- changed = true;
- TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source));
- for (HloInstruction* user : element_source->users()) {
- if (user->opcode() == HloOpcode::kTuple ||
- user->opcode() == HloOpcode::kGetTupleElement) {
- worklist.push(user);
- }
+ TF_RETURN_IF_ERROR(instruction->ReplaceAllUsesWith(element_source));
+ for (HloInstruction* user : element_source->users()) {
+ if (user->opcode() == HloOpcode::kTuple ||
+ user->opcode() == HloOpcode::kGetTupleElement) {
+ worklist.push(user);
}
}
}
diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h
index 4c83750f3e..a1dce758cd 100644
--- a/tensorflow/compiler/xla/shape_layout.h
+++ b/tensorflow/compiler/xla/shape_layout.h
@@ -48,8 +48,7 @@ class ShapeLayout {
bool MatchesLayoutInShape(const Shape& shape) const;
// Copies the layout from the given shape into this ShapeLayout. 'other_shape'
- // must be compatible with the ShapeLayout's shape, and 'other_shape' must
- // have a layout (LayoutUtil::HasLayout).
+ // must be compatible with the ShapeLayout's shape.
tensorflow::Status CopyLayoutFromShape(const Shape& other_shape);
// Clears (Layout::Clear) all the Layouts stored in this object.
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 63da9154cf..5fa728e7c2 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -31,6 +31,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/compiler/xla/statusor.h b/tensorflow/compiler/xla/statusor.h
index 641b5e9a6a..cccbce5fc8 100644
--- a/tensorflow/compiler/xla/statusor.h
+++ b/tensorflow/compiler/xla/statusor.h
@@ -113,17 +113,19 @@ class StatusOr : private internal_statusor::StatusOrData<T>,
StatusOr& operator=(StatusOr&&) = default;
// Conversion copy/move constructor, T must be convertible from U.
- // TODO(b/62186717): These should not participate in overload resolution if U
- // is not convertible to T.
- template <typename U>
+ template <typename U, typename std::enable_if<
+ std::is_convertible<U, T>::value>::type* = nullptr>
StatusOr(const StatusOr<U>& other);
- template <typename U>
+ template <typename U, typename std::enable_if<
+ std::is_convertible<U, T>::value>::type* = nullptr>
StatusOr(StatusOr<U>&& other);
// Conversion copy/move assignment operator, T must be convertible from U.
- template <typename U>
+ template <typename U, typename std::enable_if<
+ std::is_convertible<U, T>::value>::type* = nullptr>
StatusOr& operator=(const StatusOr<U>& other);
- template <typename U>
+ template <typename U, typename std::enable_if<
+ std::is_convertible<U, T>::value>::type* = nullptr>
StatusOr& operator=(StatusOr<U>&& other);
// Constructs a new StatusOr with the given value. After calling this
@@ -233,12 +235,14 @@ StatusOr<T>& StatusOr<T>::operator=(Status&& status) {
}
template <typename T>
-template <typename U>
+template <typename U,
+ typename std::enable_if<std::is_convertible<U, T>::value>::type*>
inline StatusOr<T>::StatusOr(const StatusOr<U>& other)
: Base(static_cast<const typename StatusOr<U>::Base&>(other)) {}
template <typename T>
-template <typename U>
+template <typename U,
+ typename std::enable_if<std::is_convertible<U, T>::value>::type*>
inline StatusOr<T>& StatusOr<T>::operator=(const StatusOr<U>& other) {
if (other.ok())
this->Assign(other.ValueOrDie());
@@ -248,12 +252,14 @@ inline StatusOr<T>& StatusOr<T>::operator=(const StatusOr<U>& other) {
}
template <typename T>
-template <typename U>
+template <typename U,
+ typename std::enable_if<std::is_convertible<U, T>::value>::type*>
inline StatusOr<T>::StatusOr(StatusOr<U>&& other)
: Base(static_cast<typename StatusOr<U>::Base&&>(other)) {}
template <typename T>
-template <typename U>
+template <typename U,
+ typename std::enable_if<std::is_convertible<U, T>::value>::type*>
inline StatusOr<T>& StatusOr<T>::operator=(StatusOr<U>&& other) {
if (other.ok()) {
this->Assign(std::move(other).ValueOrDie());
diff --git a/tensorflow/compiler/xla/tests/BUILD b/tensorflow/compiler/xla/tests/BUILD
index 1f90a44d8b..c28d14ba8a 100644
--- a/tensorflow/compiler/xla/tests/BUILD
+++ b/tensorflow/compiler/xla/tests/BUILD
@@ -153,6 +153,8 @@ tf_cc_binary(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:client_library",
"//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service/cpu:cpu_compiler",
"//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
"//tensorflow/core:lib",
@@ -191,6 +193,7 @@ cc_library(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:interpreter_plugin", # reference backend
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -257,8 +260,8 @@ cc_library(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:client_library",
- "//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:computation_placer",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service",
@@ -288,6 +291,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -311,6 +316,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
@@ -330,6 +337,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
@@ -371,6 +380,8 @@ xla_test(
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:platform_util",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -390,6 +401,7 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -442,6 +454,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -461,6 +475,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
@@ -478,6 +494,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -514,6 +532,8 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -535,6 +555,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -554,6 +576,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -578,6 +602,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -604,6 +630,7 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -670,6 +697,8 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -702,9 +731,6 @@ xla_test(
"cpu": [
"--xla_cpu_multi_thread_eigen=false",
],
- "cpu_parallel": [
- "--xla_cpu_multi_thread_eigen=false",
- ],
},
shard_count = 20,
tags = ["optonly"],
@@ -715,6 +741,8 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -738,6 +766,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -760,6 +790,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -801,7 +833,6 @@ xla_test(
backend_tags = {
# TODO(b/31436974): Fix msan failure. Failed on 2016-09-12.
"cpu": ["nomsan"],
- "cpu_parallel": ["nomsan"],
},
shard_count = 30,
deps = [
@@ -813,6 +844,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -836,6 +869,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client:padding",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -898,6 +933,8 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
@@ -923,6 +960,8 @@ xla_test(
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -963,6 +1002,8 @@ xla_test(
"//tensorflow/compiler/xla:array3d",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1038,6 +1079,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1196,6 +1239,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1235,6 +1280,8 @@ xla_test(
"//tensorflow/compiler/xla:reference_util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1256,6 +1303,8 @@ xla_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1294,6 +1343,8 @@ xla_test(
deps = [
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1310,6 +1361,8 @@ xla_test(
deps = [
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1335,6 +1388,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -1355,6 +1410,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:lib",
@@ -1428,6 +1485,8 @@ xla_test(
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/client/lib:arithmetic",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1472,6 +1531,8 @@ xla_test(
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1514,6 +1575,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:test_utils",
@@ -1532,6 +1595,8 @@ xla_test(
deps = [
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1595,6 +1660,8 @@ xla_test(
":client_library_test_base",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -1608,6 +1675,8 @@ xla_test(
":client_library_test_base",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
@@ -1625,11 +1694,11 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:xla_data_proto",
- "//tensorflow/compiler/xla/client:computation",
- "//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
- "//tensorflow/compiler/xla/service:session_proto",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
+ "//tensorflow/compiler/xla/service:hlo_proto",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1713,6 +1782,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/service:platform_util",
@@ -1740,6 +1811,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_runner",
"//tensorflow/compiler/xla/service:platform_util",
@@ -1777,6 +1850,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:shaped_buffer",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -1802,6 +1877,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/service:device_memory_allocator",
"//tensorflow/compiler/xla/service:local_service",
"//tensorflow/compiler/xla/service:platform_util",
@@ -1860,6 +1937,8 @@ xla_test(
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
@@ -1886,6 +1965,8 @@ xla_test(
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
@@ -1982,6 +2063,8 @@ xla_test(
":test_utils",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla/client:computation_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_builder",
+ "//tensorflow/compiler/xla/client/xla_client:xla_computation",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
"//tensorflow/core:test",
],
diff --git a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
index 4b4dc6dd9d..e8a5efe796 100644
--- a/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/array_elementwise_ops_test.cc
@@ -22,7 +22,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
@@ -214,7 +213,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantZeroElementC64s) {
}
XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
std::vector<uint64> lhs{0xFFFFFFFF,
static_cast<uint64>(-1),
@@ -255,7 +254,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, AddTwoConstantU64s) {
}
XLA_TEST_F(ArrayElementwiseOpTest, SubTwoConstantS64s) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
std::vector<int64> lhs{static_cast<int64>(0x8000000000000000LL),
static_cast<int64>(0x8000000000000000LL),
@@ -1332,7 +1331,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowZeroElementF32s) {
// Some Pow cases that can be implemented more efficiently.
XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
std::vector<float> values = {1.0f, 2.0f, 3.2f, -4.0f};
std::vector<float> exponents = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
@@ -1360,7 +1359,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowSpecialF32) {
}
XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
@@ -1385,7 +1384,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, PowOfExpF32) {
}
XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, 4.0f, 0.5f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
@@ -1410,7 +1409,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, LogOfPowerF32) {
}
XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
@@ -1435,7 +1434,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, MulOfExpF32) {
}
XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.0f, 5.7f};
std::vector<float> values1 = {0.0f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
@@ -1460,7 +1459,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfExpF32) {
}
XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
@@ -1492,7 +1491,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_lhs_F32) {
}
XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
@@ -1525,7 +1524,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Div3_rhs_F32) {
}
XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, 1.0f, 0.5f};
@@ -1558,7 +1557,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, DivOfPowerF32) {
}
XLA_TEST_F(ArrayElementwiseOpTest, Div4F32) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
std::vector<float> values0 = {1.0f, 2.0f, 3.2f, -4.0f, 0.45f, 5.7f};
std::vector<float> values1 = {0.1f, 1.0f, 2.0f, 0.5f, -1.0f, -0.5f};
@@ -2357,7 +2356,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, Add1DTo2DF32) {
XLA_TEST_F(ArrayElementwiseOpTest, Compare1DTo2DS32Eq) {
// Test broadcasting in Eq comparison.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto v = builder.ConstantR1<int32>({42, 73});
auto m = builder.ConstantR2<int32>({{42, 73}, {42, 52}});
@@ -2783,7 +2782,7 @@ XLA_TEST_F(ArrayElementwiseOpTest, NonIdentityBroadcastOfSameRankIsDisallowed) {
// Regression test for b/31927799. "slice - y" is fused and requires implicit
// broadcast.
XLA_TEST_F(ArrayElementwiseOpTest, ImplictBroadcastInFusedExpressions) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x_literal = Literal::CreateR1<float>({1, 2, 3});
auto y_literal = Literal::CreateR1<float>({4, 5});
auto x_data = client_->TransferToServer(*x_literal).ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/axpy_simple_test.cc b/tensorflow/compiler/xla/tests/axpy_simple_test.cc
index ec3b46acfe..fcd9ff55e3 100644
--- a/tensorflow/compiler/xla/tests/axpy_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/axpy_simple_test.cc
@@ -15,7 +15,6 @@ limitations under the License.
#include <vector>
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -42,7 +41,7 @@ TEST_F(AxpySimpleTest, AxTenValues) {
}
XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) {
- ComputationBuilder builder(client_, "axpy_10");
+ XlaBuilder builder("axpy_10");
auto alpha = builder.ConstantR0<float>(3.1415926535);
auto x = builder.ConstantR1<float>({});
auto y = builder.ConstantR1<float>({});
@@ -54,7 +53,7 @@ XLA_TEST_F(AxpySimpleTest, AxpyZeroValues) {
}
TEST_F(AxpySimpleTest, AxpyTenValues) {
- ComputationBuilder builder(client_, "axpy_10");
+ XlaBuilder builder("axpy_10");
auto alpha = builder.ConstantR0<float>(3.1415926535);
auto x = builder.ConstantR1<float>(
{-1.0, 1.0, 2.0, -2.0, -3.0, 3.0, 4.0, -4.0, -5.0, 5.0});
diff --git a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc
index e4bf1827ac..22c3394e6f 100644
--- a/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc
+++ b/tensorflow/compiler/xla/tests/bad_rng_shape_validation_test.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -34,13 +34,13 @@ namespace {
class BadRngShapeValidationTest : public ClientLibraryTestBase {};
TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto zero = builder.ConstantR0<float>(0.0);
auto one = builder.ConstantR0<float>(1.0);
Shape default_constructed;
builder.RngUniform(zero, one, default_constructed);
- StatusOr<Computation> computation = builder.Build();
+ StatusOr<XlaComputation> computation = builder.Build();
EXPECT_FALSE(computation.ok());
LOG(INFO) << "status received: " << computation.status();
EXPECT_THAT(computation.status().error_message(),
@@ -48,7 +48,7 @@ TEST_F(BadRngShapeValidationTest, DefaultConstructedShapeCreatesError) {
}
TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto zero = builder.ConstantR0<float>(0.0);
auto one = builder.ConstantR0<float>(1.0);
Shape sans_layout;
@@ -57,7 +57,7 @@ TEST_F(BadRngShapeValidationTest, ShapeWithoutLayoutIsOk) {
builder.RngUniform(zero, one, sans_layout);
- StatusOr<Computation> computation = builder.Build();
+ StatusOr<XlaComputation> computation = builder.Build();
ASSERT_TRUE(computation.ok());
LOG(INFO) << computation.status();
}
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index b853dfaa15..4e65cf11f3 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -19,10 +19,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@@ -52,7 +51,7 @@ class Bfloat16Test : public ClientLibraryTestBase {
};
XLA_TEST_F(Bfloat16Test, ScalarOperation) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.0f));
auto y = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(1.0f));
builder.Add(x, y);
@@ -62,7 +61,7 @@ XLA_TEST_F(Bfloat16Test, ScalarOperation) {
}
XLA_TEST_F(Bfloat16Test, LogOperation) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(4.0f));
builder.Log(x);
@@ -71,7 +70,7 @@ XLA_TEST_F(Bfloat16Test, LogOperation) {
}
XLA_TEST_F(Bfloat16Test, NegateScalarF16) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f)));
ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(-2.1f), {},
@@ -80,7 +79,7 @@ XLA_TEST_F(Bfloat16Test, NegateScalarF16) {
XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
const int kFeatureIndex = 2;
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto operand = builder.ConstantR4FromArray4D<bfloat16>(
{{{{static_cast<bfloat16>(1.f)}, {static_cast<bfloat16>(2.f)}},
@@ -117,7 +116,7 @@ XLA_TEST_F(Bfloat16Test, BatchNormTraining) {
XLA_TEST_F(Bfloat16Test, BatchNormGrad) {
const int kFeatureIndex = 2;
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto operand = builder.ConstantR4FromArray4D<bfloat16>(
Array4D<bfloat16>(2, 2, 2, 1, static_cast<bfloat16>(0.0f)));
diff --git a/tensorflow/compiler/xla/tests/binop_scaling_test.cc b/tensorflow/compiler/xla/tests/binop_scaling_test.cc
index 97fec89b63..48203b1d40 100644
--- a/tensorflow/compiler/xla/tests/binop_scaling_test.cc
+++ b/tensorflow/compiler/xla/tests/binop_scaling_test.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
@@ -32,7 +32,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_32x4) {
auto alhs = MakeLinspaceArray2D(0.0, 1.0, 32, 4);
auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 4);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
builder.Add(lhs, rhs);
@@ -48,7 +48,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixRowVector_129x129) {
auto alhs = MakeLinspaceArray2D(0.0, 1.0, 129, 129);
auto arhs = MakeLinspaceArray2D(0.0, 1.0, 1, 129);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
builder.Add(lhs, rhs);
@@ -64,7 +64,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_9x5) {
auto alhs = MakeLinspaceArray2D(0.0, 1.0, 9, 5);
auto arhs = MakeLinspaceArray2D(0.0, 1.0, 9, 1);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
builder.Add(lhs, rhs);
@@ -80,7 +80,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) {
auto alhs = MakeLinspaceArray2D(0.0, 1.0, 129, 257);
auto arhs = MakeLinspaceArray2D(0.0, 1.0, 129, 1);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto lhs = builder.ConstantR2FromArray2D<float>(*alhs);
auto rhs = builder.ConstantR2FromArray2D<float>(*arhs);
builder.Add(lhs, rhs);
@@ -93,7 +93,7 @@ TEST_F(BinopScalingTest, MatrixPlusPseudoMatrixColVector_129x257) {
}
TEST_F(BinopScalingTest, R0PlusR2F32) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto lhs = builder.ConstantR0<float>(42.0);
auto rhs = builder.ConstantR2<float>({
{1.0, 2.0}, {3.0, 4.0},
@@ -109,7 +109,7 @@ TEST_F(BinopScalingTest, R0PlusR2F32) {
}
TEST_F(BinopScalingTest, R4PlusR0S32) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// clang-format off
Array4D<int> lhs_array({
{{{1, 2},
diff --git a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
index 97095f1cc4..34c86e007b 100644
--- a/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/broadcast_simple_test.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@@ -33,10 +33,8 @@ namespace {
class BroadcastSimpleTest : public ClientLibraryTestBase {
public:
- ComputationDataHandle BuildBinOp(HloOpcode op,
- const ComputationDataHandle& lhs,
- const ComputationDataHandle& rhs,
- ComputationBuilder* builder) {
+ XlaOp BuildBinOp(HloOpcode op, const XlaOp& lhs, const XlaOp& rhs,
+ XlaBuilder* builder) {
switch (op) {
case HloOpcode::kMinimum: {
return builder->Min(lhs, rhs);
@@ -105,21 +103,21 @@ class BroadcastSimpleTest : public ClientLibraryTestBase {
using ::testing::HasSubstr;
XLA_TEST_F(BroadcastSimpleTest, ScalarNoOpBroadcast) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
b.Broadcast(b.ConstantR0<float>(1.5), {});
ComputeAndCompareR0<float>(&b, 1.5, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x3) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
b.Broadcast(b.ConstantR0<float>(2.25), {2, 3});
Array2D<float> expected(2, 3, 2.25);
ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) {
- ComputationBuilder b(client_, TestName());
- ComputationDataHandle src;
+ XlaBuilder b(TestName());
+ XlaOp src;
std::unique_ptr<GlobalData> param_data =
CreateR0Parameter<float>(2.25f, /*parameter_number=*/0, /*name=*/"src",
/*builder=*/&b, /*data_handle=*/&src);
@@ -131,21 +129,21 @@ XLA_TEST_F(BroadcastSimpleTest, ScalarParamTo2D_2x3) {
}
XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_2x0) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
b.Broadcast(b.ConstantR0<float>(2.25), {2, 0});
Array2D<float> expected(2, 0);
ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, ScalarTo2D_0x2) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
b.Broadcast(b.ConstantR0<float>(2.25), {0, 2});
Array2D<float> expected(0, 2);
ComputeAndCompareR2<float>(&b, expected, {}, ErrorSpec(0.0001));
}
XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {2});
Array2D<float> expected(2, 3);
@@ -160,7 +158,7 @@ XLA_TEST_F(BroadcastSimpleTest, 1DTo2D) {
// Tests implicit broadcasting of PREDs.
XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
Array2D<bool> x_vals(2, 1);
x_vals(0, 0) = true;
@@ -171,7 +169,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
y_vals(1, 0, 0) = true;
y_vals(1, 1, 0) = true;
- ComputationDataHandle x, y;
+ XlaOp x, y;
auto x_data = CreateR2Parameter<bool>(x_vals, 0, "x", &b, &x);
auto y_data = CreateR3Parameter<bool>(y_vals, 1, "y", &b, &y);
b.And(x, y, /*broadcast_dimensions=*/{1, 2});
@@ -186,7 +184,7 @@ XLA_TEST_F(BroadcastSimpleTest, BooleanAnd2DTo3D_Pred) {
}
XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
b.Broadcast(b.ConstantR1<float>({}), {2});
Array2D<float> expected(2, 0);
@@ -194,7 +192,7 @@ XLA_TEST_F(BroadcastSimpleTest, ZeroElement_1DTo2D) {
}
XLA_TEST_F(BroadcastSimpleTest, 1DToZeroElement2D) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
b.Broadcast(b.ConstantR1<float>({1, 2, 3}), {0});
Array2D<float> expected(0, 3);
@@ -209,7 +207,7 @@ XLA_TEST_F(BroadcastSimpleTest, InDimensionAndDegenerateBroadcasting) {
// broadcasting (broadcast_dimensions {1, 2}), then is added to the rhs shape
// [2, 3, 1]. Degenerate dimension broadcasting then broadcasts the size one
// dimensions.
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
b.Add(b.ConstantR2<float>({{1.0, 5.0}}),
b.ConstantLiteral(*Literal::CreateR3<float>(
@@ -247,7 +245,7 @@ class BroadcastR3ImplicitTest
XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
const R3ImplicitBroadcastSpec& spec = GetParam();
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Shape r3_shape, r3_implicit_shape;
Array3D<float> r3_array(spec.output_bounds[0], spec.output_bounds[1],
@@ -264,8 +262,7 @@ XLA_TEST_P(BroadcastR3ImplicitTest, Doit) {
auto r3_implicit_parameter = builder.Parameter(0, r3_implicit_shape, "input");
auto r3_parameter = builder.Parameter(1, r3_shape, "input");
- ComputationDataHandle op =
- BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder);
+ XlaOp op = BuildBinOp(spec.op, r3_implicit_parameter, r3_parameter, &builder);
Array3D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1],
spec.output_bounds[2]);
@@ -300,9 +297,9 @@ INSTANTIATE_TEST_CASE_P(BroadcastR3ImplicitTestInstances,
// r1 and r3's dim0 matches, and r1's dim1 and dim2 have size 1:
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
- ComputationBuilder b(client_, TestName());
- ComputationDataHandle r1h;
- ComputationDataHandle r3h;
+ XlaBuilder b(TestName());
+ XlaOp r1h;
+ XlaOp r3h;
Array3D<float> r1d = {{{1}}, {{2}}};
Array3D<float> r3d = {{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}};
@@ -319,7 +316,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1_2) {
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}}));
auto r3 = b.ConstantLiteral(
*Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
@@ -332,7 +329,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1) {
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}}));
auto r3 = b.ConstantLiteral(
*Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
@@ -345,7 +342,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_2) {
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}, {3, 4}}}));
auto r3 = b.ConstantLiteral(
*Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
@@ -358,7 +355,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0) {
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1, 2}}, {{3, 4}}}));
auto r3 = b.ConstantLiteral(
*Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
@@ -371,7 +368,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_1) {
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto r1 =
b.ConstantLiteral(*Literal::CreateR3<float>({{{1}, {2}}, {{3}, {4}}}));
auto r3 = b.ConstantLiteral(
@@ -385,7 +382,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_2) {
}
XLA_TEST_F(BroadcastSimpleTest, Add3DTo3DDegenerate_0_1_2) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto r1 = b.ConstantLiteral(*Literal::CreateR3<float>({{{1}}}));
auto r3 = b.ConstantLiteral(
*Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
@@ -491,7 +488,7 @@ class BroadcastR2ImplicitTest
XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
const R2ImplicitBroadcastSpec& spec = GetParam();
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Operands with degenerate dimensions require implicit broadcasting:
Shape r2_shape, r2_implicit_shape1, r2_implicit_shape2;
@@ -517,10 +514,9 @@ XLA_TEST_P(BroadcastR2ImplicitTest, Doit) {
auto r2_implicit_parameter2 =
builder.Parameter(2, r2_implicit_shape2, "input2");
- ComputationDataHandle op1 =
+ XlaOp op1 =
BuildBinOp(spec.op1, r2_implicit_parameter1, r2_parameter, &builder);
- ComputationDataHandle op2 =
- BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder);
+ XlaOp op2 = BuildBinOp(spec.op2, op1, r2_implicit_parameter2, &builder);
Array2D<float> expected_array(spec.output_bounds[0], spec.output_bounds[1]);
@@ -547,7 +543,7 @@ INSTANTIATE_TEST_CASE_P(BroadcastR2ImplicitTestInstances,
::testing::ValuesIn(kR2ImplicitBroadcastTestCases));
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}}));
auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}}));
b.Add(r2, r1);
@@ -558,7 +554,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_0) {
}
XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto r1 = b.ConstantLiteral(*Literal::CreateR2<float>({{1}, {2}}));
auto r2 = b.ConstantLiteral(*Literal::CreateR2<float>({{1, 2}, {3, 4}}));
b.Add(r2, r1);
@@ -569,7 +565,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add2DTo2DDegenerate_1) {
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto r1 = b.ConstantR1<float>({10, 20});
auto r3 = b.ConstantLiteral(
*Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
@@ -582,7 +578,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim0) {
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto r1 = b.ConstantR1<float>({10, 20});
auto r3 = b.ConstantLiteral(
*Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
@@ -595,7 +591,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim1) {
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto r1 = b.ConstantR1<float>({10, 20});
auto r3 = b.ConstantLiteral(
*Literal::CreateR3<float>({{{1, 2}, {3, 4}}, {{5, 6}, {7, 8}}}));
@@ -608,7 +604,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDim2) {
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto r1_0 = b.ConstantR1<float>({1000, 2000});
auto r1_1 = b.ConstantR1<float>({100, 200});
auto r1_2 = b.ConstantR1<float>({10, 20});
@@ -629,7 +625,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAll) {
}
XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
auto r1_0 = b.ConstantR1<float>({1000, 2000});
auto r1_1 = b.ConstantR1<float>({100, 200});
auto r1_2 = b.ConstantR1<float>({10, 20});
@@ -652,7 +648,7 @@ XLA_TEST_F(BroadcastSimpleTest, Add1DTo3DInDimAllWithScalarBroadcast) {
XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
// Binary dimension broadcasting of the smaller lhs ([2, 2] up to [2, 2, 2])
// results in a shape incompatible with the lhs [2, 3, 1].
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
b.Add(b.ConstantR2<float>({{1.0, 5.0}, {1.0, 5.0}}),
b.ConstantLiteral(*Literal::CreateR3<float>(
@@ -667,7 +663,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidBinaryAndDegenerateBroadcasting) {
XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) {
// Test invalid broadcasting with [1, 2] and [2, 3] inputs.
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
@@ -680,7 +676,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidInDimensionBroadcasting) {
XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
// Test invalid broadcasting with [1, 2] and [2, 3] inputs.
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
b.Add(b.ConstantR2<float>({{1.0, 2.0}}),
b.ConstantR2<float>({{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}}));
diff --git a/tensorflow/compiler/xla/tests/build_defs.bzl b/tensorflow/compiler/xla/tests/build_defs.bzl
index eac2eb286c..53f2c3bfbf 100644
--- a/tensorflow/compiler/xla/tests/build_defs.bzl
+++ b/tensorflow/compiler/xla/tests/build_defs.bzl
@@ -4,7 +4,7 @@ load("@local_config_cuda//cuda:build_defs.bzl", "cuda_is_configured")
load("//tensorflow/compiler/xla/tests:plugin.bzl", "plugins")
load("//tensorflow:tensorflow.bzl", "tf_cc_test")
-all_backends = ["cpu", "cpu_parallel", "gpu"] + plugins.keys()
+all_backends = ["cpu", "gpu"] + plugins.keys()
def filter_backends(backends):
"""Removes "gpu" from a backend list if CUDA is not enabled.
@@ -39,10 +39,10 @@ def xla_test(name,
**kwargs):
"""Generates cc_test targets for the given XLA backends.
- This rule generates a cc_test target for one or more XLA backends and also
- a platform-agnostic cc_library rule. The arguments are identical to cc_test
- with two additions: 'backends' and 'backend_args'. 'backends' specifies the
- backends to generate tests for ("cpu", "cpu_parallel", "gpu"), and
+ This rule generates a cc_test target for one or more XLA backends and also a
+ platform-agnostic cc_library rule. The arguments are identical to cc_test with
+ two additions: 'backends' and 'backend_args'. 'backends' specifies the
+ backends to generate tests for ("cpu", "gpu"), and
'backend_args'/'backend_tags' specifies backend-specific args parameters to
use when generating the cc_test.
@@ -90,9 +90,9 @@ def xla_test(name,
deps: Dependencies of the target.
xla_test_library_deps: If set, the generated test targets will depend on the
respective cc_libraries generated by the xla_test_library rule.
- backends: A list of backends to generate tests for. Supported
- values: "cpu", "cpu_parallel", "gpu". If this list is empty, the test will
- be generated for all supported backends.
+ backends: A list of backends to generate tests for. Supported values: "cpu",
+ "gpu". If this list is empty, the test will be generated for all supported
+ backends.
blacklisted_backends: A list of backends to NOT generate tests for.
args: Test arguments for the target.
tags: Tags for the target.
@@ -128,10 +128,6 @@ def xla_test(name,
if backend == "cpu":
backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
- elif backend == "cpu_parallel":
- backend_deps = ["//tensorflow/compiler/xla/service:cpu_plugin"]
- backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_cpu"]
- this_backend_args += ["--xla_backend_extra_options=\"xla_cpu_parallel\""]
elif backend == "gpu":
backend_deps = ["//tensorflow/compiler/xla/service:gpu_plugin"]
backend_deps += ["//tensorflow/compiler/xla/tests:test_macros_gpu"]
@@ -201,7 +197,7 @@ def xla_test_library(name,
hdrs: Headers for the target.
deps: Dependencies of the target.
backends: A list of backends to generate libraries for.
- Supported values: "cpu", "cpu_parallel", "gpu". If this list is empty, the
+ Supported values: "cpu", "gpu". If this list is empty, the
library will be generated for all supported backends.
"""
@@ -210,7 +206,7 @@ def xla_test_library(name,
for backend in filter_backends(backends):
this_backend_copts = []
- if backend in ["cpu", "cpu_parallel", "gpu"]:
+ if backend in ["cpu", "gpu"]:
backend_deps = ["//tensorflow/compiler/xla/tests:test_macros_%s" % backend]
elif backend in plugins:
backend_deps = plugins[backend]["deps"]
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.cc b/tensorflow/compiler/xla/tests/client_library_test_base.cc
index 69389dae3f..22660c35dc 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.cc
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.cc
@@ -61,6 +61,11 @@ ClientLibraryTestBase::ClientLibraryTestBase(
: client_(GetOrCreateLocalClientOrDie(client_options)),
execution_options_(CreateDefaultExecutionOptions()) {
CHECK_EQ(platform, client_options.platform());
+
+ LocalClientOptions ref_options;
+ ref_options.set_platform(GetReferencePlatform());
+ ref_client_ = GetOrCreateLocalClientOrDie(ref_options);
+
// Disabling constant_folding so that tests (usually written using Constants)
// will exercise the intended code paths, instead of being constant folded.
//
@@ -152,6 +157,7 @@ ClientLibraryTestBase::ExecuteAndTransferReference(
*execution_options.mutable_shape_with_output_layout() =
*shape_with_output_layout;
}
+ execution_options.clear_device_handles();
return ref_client_->ExecuteAndTransfer(computation, arguments,
&execution_options);
}
@@ -211,6 +217,14 @@ void ClientLibraryTestBase::ComputeAndCompareR1(
arguments);
}
+void ClientLibraryTestBase::ComputeAndCompareR1(
+ XlaBuilder* builder, const tensorflow::core::Bitmap& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
+ std::unique_ptr<Literal> expected_literal = Literal::CreateR1(expected);
+ ClientLibraryTestBase::ComputeAndCompareLiteral(builder, *expected_literal,
+ arguments);
+}
+
template <typename BuilderT>
void ClientLibraryTestBase::ComputeAndCompareLiteral(
BuilderT* builder, const Literal& expected,
@@ -452,7 +466,7 @@ tensorflow::Status ClientLibraryTestBase::ComputeAndCompareLiteralWithStatus(
}
void ClientLibraryTestBase::ComputeAndCompareR1U8(
- ComputationBuilder* builder, tensorflow::StringPiece expected,
+ XlaBuilder* builder, tensorflow::StringPiece expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
auto actual_status = ExecuteAndTransfer(builder, arguments);
EXPECT_IS_OK(actual_status.status());
@@ -613,8 +627,8 @@ ClientLibraryTestBase::ComputeValueAndReference(
return std::make_pair(std::move(reference), std::move(result));
}
-Computation ClientLibraryTestBase::CreateScalarRelu() {
- ComputationBuilder builder(client_, "relu");
+XlaComputation ClientLibraryTestBase::CreateScalarRelu() {
+ XlaBuilder builder("relu");
auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
auto z_value = builder.Parameter(0, shape, "z_value");
auto zero = use_bfloat16_
@@ -626,8 +640,8 @@ Computation ClientLibraryTestBase::CreateScalarRelu() {
return computation_status.ConsumeValueOrDie();
}
-Computation ClientLibraryTestBase::CreateScalarMax() {
- ComputationBuilder builder(client_, "max");
+XlaComputation ClientLibraryTestBase::CreateScalarMax() {
+ XlaBuilder builder("max");
auto shape = ShapeUtil::MakeShape(use_bfloat16_ ? BF16 : F32, {});
auto x = builder.Parameter(0, shape, "x");
auto y = builder.Parameter(1, shape, "y");
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 481d7c5c25..32eea7c2f3 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -165,6 +165,9 @@ class ClientLibraryTestBase : public ::testing::Test {
void ComputeAndCompareR1(ComputationBuilder* builder,
const tensorflow::core::Bitmap& expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
+ void ComputeAndCompareR1(XlaBuilder* builder,
+ const tensorflow::core::Bitmap& expected,
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments);
template <typename NativeT, typename BuilderT>
void ComputeAndCompareR2(BuilderT* builder, const Array2D<NativeT>& expected,
@@ -219,7 +222,7 @@ class ClientLibraryTestBase : public ::testing::Test {
// Compare the result of the computation to a strings. In XLA strings are
// represented using rank-1 U8 shapes.
void ComputeAndCompareR1U8(
- ComputationBuilder* builder, tensorflow::StringPiece expected,
+ XlaBuilder* builder, tensorflow::StringPiece expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
// Convenience method for running a built computation, transferring the
@@ -252,8 +255,8 @@ class ClientLibraryTestBase : public ::testing::Test {
ErrorSpec error);
// Create scalar operations for use in reductions.
- Computation CreateScalarRelu();
- Computation CreateScalarMax();
+ XlaComputation CreateScalarRelu();
+ XlaComputation CreateScalarMax();
Computation CreateScalarReluSensitivity();
// Special case convenience functions for creating filled arrays.
diff --git a/tensorflow/compiler/xla/tests/client_test.cc b/tensorflow/compiler/xla/tests/client_test.cc
index 32e2f2c084..1e54471796 100644
--- a/tensorflow/compiler/xla/tests/client_test.cc
+++ b/tensorflow/compiler/xla/tests/client_test.cc
@@ -109,8 +109,7 @@ XLA_TEST_F(ClientTest, ExecuteWithTupleLayout) {
/*minor_to_major=*/{1, 0})));
}
-XLA_TEST_F(ClientTest,
- DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(ExecuteParallel))) {
+XLA_TEST_F(ClientTest, DISABLED_ON_GPU(ExecuteParallel)) {
XlaComputation add_with_one_arg, mul_with_two_args, dot_with_one_arg;
Shape shape = ShapeUtil::MakeShape(S32, {2, 2});
diff --git a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
index 896b34fb6e..b5a42e3059 100644
--- a/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_dimension_numbers_test.cc
@@ -18,9 +18,9 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/statusor.h"
@@ -34,13 +34,35 @@ limitations under the License.
namespace xla {
namespace {
+StatusOr<ConvolutionDimensionNumbers> CreateConvDimensionNumbers(
+ int64 input_batch, int64 input_feature, int64 input_first_spatial,
+ int64 input_second_spatial, int64 output_batch, int64 output_feature,
+ int64 output_first_spatial, int64 output_second_spatial,
+ int64 kernel_output_feature, int64 kernel_input_feature,
+ int64 kernel_first_spatial, int64 kernel_second_spatial) {
+ ConvolutionDimensionNumbers dimension_numbers;
+ dimension_numbers.set_input_batch_dimension(input_batch);
+ dimension_numbers.set_input_feature_dimension(input_feature);
+ dimension_numbers.add_input_spatial_dimensions(input_first_spatial);
+ dimension_numbers.add_input_spatial_dimensions(input_second_spatial);
+ dimension_numbers.set_kernel_output_feature_dimension(kernel_output_feature);
+ dimension_numbers.set_kernel_input_feature_dimension(kernel_input_feature);
+ dimension_numbers.add_kernel_spatial_dimensions(kernel_first_spatial);
+ dimension_numbers.add_kernel_spatial_dimensions(kernel_second_spatial);
+ dimension_numbers.set_output_batch_dimension(output_batch);
+ dimension_numbers.set_output_feature_dimension(output_feature);
+ dimension_numbers.add_output_spatial_dimensions(output_first_spatial);
+ dimension_numbers.add_output_spatial_dimensions(output_second_spatial);
+ TF_RETURN_IF_ERROR(XlaBuilder::Validate(dimension_numbers));
+ return dimension_numbers;
+}
+
class ConvolutionDimensionNumbersTest : public ClientLibraryTestBase {};
// Tests the convolution operation with invalid input dimension numbers.
TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) {
auto dimension_numbers_status =
- ComputationBuilder::CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3, 0,
- 1, 2, 3);
+ CreateConvDimensionNumbers(0, 2, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3);
ASSERT_FALSE(dimension_numbers_status.ok());
ASSERT_THAT(dimension_numbers_status.status().error_message(),
::testing::HasSubstr("input are not unique"));
@@ -49,8 +71,7 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidInputDimensionNumbers) {
// Tests the convolution operation with invalid weight dimension numbers.
TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) {
auto dimension_numbers_status =
- ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 1, 2, 3, 0,
- 2, 2, 3);
+ CreateConvDimensionNumbers(0, 1, 2, 3, 0, 1, 2, 3, 0, 2, 2, 3);
ASSERT_FALSE(dimension_numbers_status.ok());
ASSERT_THAT(dimension_numbers_status.status().error_message(),
::testing::HasSubstr("weight are not unique"));
@@ -59,8 +80,7 @@ TEST_F(ConvolutionDimensionNumbersTest, InvalidWeightDimensionNumbers) {
// Tests the convolution operation with invalid output dimension numbers.
TEST_F(ConvolutionDimensionNumbersTest, InvalidOutputDimensionNumbers) {
auto dimension_numbers_status =
- ComputationBuilder::CreateConvDimensionNumbers(0, 1, 2, 3, 0, 2, 2, 3, 0,
- 1, 2, 3);
+ CreateConvDimensionNumbers(0, 1, 2, 3, 0, 2, 2, 3, 0, 1, 2, 3);
ASSERT_FALSE(dimension_numbers_status.ok());
ASSERT_THAT(dimension_numbers_status.status().error_message(),
::testing::HasSubstr("output are not unique"));
@@ -76,14 +96,14 @@ XLA_TEST_F(ConvolutionDimensionNumbersTest,
client_->TransferToServer(*Literal::CreateR4FromArray4D(*weight_array))
.ConsumeValueOrDie();
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input = builder.ConstantR4FromArray4D<float>(*input_array);
auto weight =
builder.Parameter(0, ShapeUtil::MakeShape(F32, {4, 3, 1, 1}), "weight");
auto conv1 = builder.Conv(input, weight, {1, 1}, Padding::kValid);
ConvolutionDimensionNumbers dim_nums =
- ComputationBuilder::CreateDefaultConvDimensionNumbers();
+ XlaBuilder::CreateDefaultConvDimensionNumbers();
// Swap batch_dimension and feature_dimension.
int64 old_input_batch_dim = dim_nums.input_batch_dimension();
int64 old_output_batch_dim = dim_nums.output_batch_dimension();
diff --git a/tensorflow/compiler/xla/tests/convolution_variants_test.cc b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
index 9c1145def8..50d6e25d86 100644
--- a/tensorflow/compiler/xla/tests/convolution_variants_test.cc
+++ b/tensorflow/compiler/xla/tests/convolution_variants_test.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/client/padding.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -52,7 +53,7 @@ class ConvolutionVariantsTest : public ClientLibraryTestBase {
};
XLA_TEST_F(ConvolutionVariantsTest, Minimal) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const Array4D<float> input_array(1, 1, 1, 1, {2});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -67,7 +68,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Minimal) {
}
XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const Array4D<float> input_array(5, 1, 1, 1, {1, 2, 3, 4, 5});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -82,7 +83,7 @@ XLA_TEST_F(ConvolutionVariantsTest, MinimalWithBatch) {
}
XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(2, 1, 3, 4);
input_array.FillWithMultiples(1);
@@ -99,7 +100,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Flat1x1) {
}
XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 2, 1, 1, {10, 1});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -114,7 +115,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Deep1x1) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 2, {1, 2});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -129,7 +130,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x2) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -144,7 +145,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in1x3) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -159,7 +160,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x2) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -174,7 +175,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1in2x2) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -189,7 +190,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2in2x2) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(
2, 2, 2, 3, {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 0, // plane 0
@@ -210,7 +211,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2in2x3WithDepthAndBatch) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 4, {1, 2, 3, 4});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -225,7 +226,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x4) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 5, {1, 2, 3, 4, 5});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -240,7 +241,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride1x2in1x5) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 4, {1, 2, 3, 4});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -255,7 +256,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x4) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 5, {1, 2, 3, 4, 5});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -270,7 +271,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x3stride1x2in1x5) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -285,7 +286,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1stride2x2in3x3) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 1, {1});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -300,7 +301,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x1in1x1Padded) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -315,7 +316,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter5x1in3x1Padded) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 2, 2, {1, 2, 3, 4});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -333,7 +334,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter3x3in2x2Padded) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 2, 1, 2, {1, 2, 3, 4});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -348,7 +349,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1in2x1WithPaddingAndDepth) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 3, 3, {1, 2, 3, 4, 5, 6, 7, 8, 9});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -363,7 +364,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2Stride1x1Input3x3) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(1, 1, 1, 3, {1, 2, 3});
auto input = builder.ConstantR4FromArray4D<float>(input_array);
@@ -378,7 +379,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2Stride1x1Input1x3) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(64);
std::iota(input_data.begin(), input_data.end(), 0.0);
@@ -398,7 +399,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x1x8x8Input1x1x8x8) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(16 * 1 * 1 * 1);
std::iota(input_data.begin(), input_data.end(), 1.0);
@@ -419,7 +420,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input16x1x1x1) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
constexpr int bs = 16;
constexpr int kx = 2;
@@ -450,7 +451,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input16x1x2x2) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
constexpr int kx = 2;
constexpr int ky = 2;
@@ -482,7 +483,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x2Input3x1x2x2) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(16, 1, 8, 8);
for (int i0 = 0; i0 < 16; ++i0) {
@@ -510,7 +511,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x8x8Input16x1x8x8) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(2 * 8 * 8);
std::iota(input_data.begin(), input_data.end(), 0.0);
@@ -536,7 +537,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input1x2x8x8) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(2 * 2 * 8 * 8);
std::iota(input_data.begin(), input_data.end(), 0.0);
@@ -562,7 +563,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input2x2x8x8) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(32 * 2 * 8 * 8);
std::iota(input_data.begin(), input_data.end(), 0.0);
@@ -602,7 +603,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter2x2x8x8Input32x2x8x8) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> input_array(16, 16, 1, 1);
Array4D<float> filter_array(16, 16, 1, 1);
@@ -628,7 +629,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter16x16x1x1Input16x16x1x1) {
}
XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(1 * 1 * 4 * 6);
std::iota(input_data.begin(), input_data.end(), 0.0);
@@ -640,14 +641,14 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatRhsDilation) {
builder.ConvGeneralDilated(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{},
/*lhs_dilation=*/{}, /*rhs_dilation=*/{2, 2},
- ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
Array4D<float> expected(1, 1, 2, 2, {3924, 4257, 5922, 6255});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(1 * 1 * 1 * 5);
std::iota(input_data.begin(), input_data.end(), 1.0);
@@ -659,14 +660,14 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation1D) {
builder.ConvGeneralDilated(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{}, /*padding=*/{},
/*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{},
- ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
Array4D<float> expected(1, 1, 1, 8, {10, 2, 20, 3, 30, 4, 40, 5});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(1 * 1 * 3 * 4);
std::iota(input_data.begin(), input_data.end(), 1.0);
@@ -682,8 +683,7 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) {
builder.ConvGeneralDilated(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{2, 1},
/*padding=*/{{1, 0}, {0, 0}}, /*lhs_dilation=*/{3, 2},
- /*rhs_dilation=*/{},
- ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ /*rhs_dilation=*/{}, XlaBuilder::CreateDefaultConvDimensionNumbers());
Array4D<float> expected(1, 1, 3, 5,
{204, 40, 406, 60, 608, //
@@ -693,7 +693,7 @@ XLA_TEST_F(ConvolutionVariantsTest, FlatLhsDilation) {
}
XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(1 * 1 * 1 * 5);
std::iota(input_data.begin(), input_data.end(), 1.0);
@@ -705,14 +705,14 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingOnBothEnds) {
builder.ConvGeneral(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
/*padding=*/{{0, 0}, {-1, -1}},
- ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
Array4D<float> expected(1, 1, 1, 2, {23, 34});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(1 * 1 * 1 * 5);
std::iota(input_data.begin(), input_data.end(), 1.0);
@@ -724,14 +724,14 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingLowAndPositivePaddingHigh) {
builder.ConvGeneral(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
/*padding=*/{{0, 0}, {-1, 2}},
- ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
Array4D<float> expected(1, 1, 1, 5, {23, 34, 45, 50, 0});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(1 * 1 * 1 * 5);
std::iota(input_data.begin(), input_data.end(), 1.0);
@@ -743,14 +743,14 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingLowAndNegativePaddingHigh) {
builder.ConvGeneral(
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
/*padding=*/{{0, 0}, {2, -1}},
- ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
Array4D<float> expected(1, 1, 1, 5, {0, 1, 12, 23, 34});
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(1 * 1 * 1 * 5);
std::iota(input_data.begin(), input_data.end(), 1.0);
@@ -763,7 +763,7 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) {
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
/*padding=*/{{0, 0}, {3, 2}},
/*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2},
- ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
// input:
// [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5]
@@ -775,7 +775,7 @@ XLA_TEST_F(ConvolutionVariantsTest, PositivePaddingAndDilation) {
ComputeAndCompareR4<float>(&builder, expected, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(1 * 1 * 1 * 5);
std::iota(input_data.begin(), input_data.end(), 1.0);
@@ -788,7 +788,7 @@ XLA_TEST_F(ConvolutionVariantsTest, NegativePaddingAndDilation) {
/*lhs=*/input, /*rhs=*/filter, /*window_strides=*/{},
/*padding=*/{{0, 0}, {-3, -2}},
/*lhs_dilation=*/{1, 2}, /*rhs_dilation=*/{1, 2},
- ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
// input:
// [1, 2, 3, 4, 5] --dilate-> [1, 0, 2, 0, 3, 0, 4, 0, 5]
@@ -821,7 +821,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x1x2x3_Filter2x1x1x2) {
Array4D<float> input_array(bs, iz, iy, ix, input_data);
Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input = builder.ConstantR4FromArray4D<float>(input_array);
auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
builder.Conv(input, filter, {1, 1}, Padding::kValid);
@@ -854,7 +854,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input1x16x1x1_Filter1x16x1x1) {
Array4D<float> input_array(bs, iz, iy, ix, input_data);
Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input = builder.ConstantR4FromArray4D<float>(input_array);
auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
builder.Conv(input, filter, {1, 1}, Padding::kValid);
@@ -887,7 +887,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter1x16x1x1) {
Array4D<float> input_array(bs, iz, iy, ix, input_data);
Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input = builder.ConstantR4FromArray4D<float>(input_array);
auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
builder.Conv(input, filter, {1, 1}, Padding::kValid);
@@ -920,7 +920,7 @@ XLA_TEST_F(ConvolutionVariantsTest, RandomData_Input16x16x1x1_Filter16x16x1x1) {
Array4D<float> input_array(bs, iz, iy, ix, input_data);
Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input = builder.ConstantR4FromArray4D<float>(input_array);
auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
builder.Conv(input, filter, {1, 1}, Padding::kValid);
@@ -954,7 +954,7 @@ XLA_TEST_F(ConvolutionVariantsTest,
Array4D<float> input_array(bs, iz, iy, ix, input_data);
Array4D<float> filter_array(oz, iz, ky, kx, kernel_data);
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input = builder.ConstantR4FromArray4D<float>(input_array);
auto filter = builder.ConstantR4FromArray4D<float>(filter_array);
builder.Conv(input, filter, {1, 1}, Padding::kValid);
@@ -966,7 +966,7 @@ XLA_TEST_F(ConvolutionVariantsTest,
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(1 * 2 * 3 * 1);
std::iota(input_data.begin(), input_data.end(), 1.0);
@@ -1010,7 +1010,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x2x1x1Input1x2x3x1GeneralPadding) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(1 * 2 * 3 * 1);
std::iota(input_data.begin(), input_data.end(), 1.0);
@@ -1054,7 +1054,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1GeneralPadding) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(1 * 2 * 3 * 1);
std::iota(input_data.begin(), input_data.end(), 1.0);
@@ -1095,7 +1095,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x1x1Input1x2x3x1NoPadding) {
}
XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> input_data(1 * 2 * 3 * 2);
std::iota(input_data.begin(), input_data.end(), 1.0);
@@ -1147,7 +1147,7 @@ XLA_TEST_F(ConvolutionVariantsTest, Filter1x1x2x3Input1x2x3x2NoPadding) {
// BackwardInputConv([1,2,3], [5,6], padding_low=0, padding_high=1)
XLA_TEST_F(ConvolutionVariantsTest,
BackwardInputLowPaddingLessThanHighPadding) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto gradients = builder.ConstantR4FromArray4D<float>(
Array4D<float>(1, 1, 1, 3, /*values=*/{1, 2, 3}));
@@ -1166,19 +1166,18 @@ XLA_TEST_F(ConvolutionVariantsTest,
// BackwardInputConv([1], [1,10,100], stride=3, padding=(2,1))
XLA_TEST_F(ConvolutionVariantsTest,
BackwardInputLowPaddingGreaterThanHighPadding) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto gradients = builder.ConstantR4FromArray4D<float>(
Array4D<float>(1, 1, 1, 1, /*values=*/{1}));
auto weights = builder.ConstantR4FromArray4D<float>(
Array4D<float>(1, 1, 1, 3, /*values=*/{1, 10, 100}));
auto mirrored_weights = builder.Rev(weights, {2, 3});
- builder.ConvGeneralDilated(
- gradients, mirrored_weights,
- /*window_strides=*/{1, 1},
- /*padding=*/{{0, 0}, {0, 3}},
- /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{},
- ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ builder.ConvGeneralDilated(gradients, mirrored_weights,
+ /*window_strides=*/{1, 1},
+ /*padding=*/{{0, 0}, {0, 3}},
+ /*lhs_dilation=*/{1, 3}, /*rhs_dilation=*/{},
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
ComputeAndCompareR4<float>(&builder, {{{{100, 0}}}}, {}, error_spec_);
}
@@ -1187,7 +1186,7 @@ XLA_TEST_F(ConvolutionVariantsTest,
// into
// BackwardInputConv([1], [1,10,100], padding=(1,1))
XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto gradients = builder.ConstantR4FromArray4D<float>(
Array4D<float>(1, 1, 1, 1, /*values=*/{1}));
@@ -1208,7 +1207,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding) {
// However, XLA:GPU doesn't actually fuse it because PadInsertion doesn't
// support negative padding on backward convolution yet (b/32744257).
XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto gradients = builder.ConstantR4FromArray4D<float>(
Array4D<float>(1, 1, 1, 3, /*values=*/{1, 2, 3}));
@@ -1224,7 +1223,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputWithNegativePaddingHigh) {
XLA_TEST_F(ConvolutionVariantsTest,
BackwardFilterLowPaddingLessThanHighPadding) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// activations: 1,2,3,4 ---pad--> 0,1,2,3,4,0,0
// gradients: 100,10,1 -dilate-> 100,0,10,0,1
@@ -1240,7 +1239,7 @@ XLA_TEST_F(ConvolutionVariantsTest,
/*window_strides=*/{1, 1},
/*padding=*/{{0, 0}, {1, 2}},
/*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
- ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
builder.Transpose(forward_conv, {0, 1, 2, 3});
ComputeAndCompareR4<float>(&builder, {{{{24, 130, 240}}}}, {}, error_spec_);
@@ -1248,7 +1247,7 @@ XLA_TEST_F(ConvolutionVariantsTest,
XLA_TEST_F(ConvolutionVariantsTest,
BackwardFilterLowPaddingGreaterThanHighPadding) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4
// gradients: 100,10,1 -dilate-> 100,0,10,0,1
@@ -1266,14 +1265,14 @@ XLA_TEST_F(ConvolutionVariantsTest,
/*window_strides=*/{1, 1},
/*padding=*/{{0, 0}, {2, 0}},
/*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
- ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
builder.Transpose(forward_conv, {0, 1, 2, 3});
ComputeAndCompareR4<float>(&builder, {{{{13, 24}}}}, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// activations: 1,2,3,4 ---pad--> 0,0,1,2,3,4,0
// gradients: 100,10,1 -dilate-> 100,0,10,0,1
@@ -1293,14 +1292,14 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding) {
/*window_strides=*/{1, 1},
/*padding=*/{{0, 0}, {2, 1}},
/*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 2},
- ComputationBuilder::CreateDefaultConvDimensionNumbers());
+ XlaBuilder::CreateDefaultConvDimensionNumbers());
builder.Transpose(forward_conv, {0, 1, 2, 3});
ComputeAndCompareR4<float>(&builder, {{{{13, 24, 130}}}}, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto gradients = builder.ConstantR3FromArray3D<float>(
Array3D<float>(1, 1, 1, /*value=*/1));
@@ -1314,26 +1313,26 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding1D) {
}
XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding1D) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto activations =
builder.ConstantR3FromArray3D<float>(Array3D<float>({{{1, 2, 3, 4}}}));
auto gradients =
builder.ConstantR3FromArray3D<float>(Array3D<float>({{{100, 10, 1}}}));
- auto forward_conv = builder.ConvGeneralDilated(
- activations, gradients,
- /*window_strides=*/{1},
- /*padding=*/{{2, 1}},
- /*lhs_dilation=*/{}, /*rhs_dilation=*/{2},
- ComputationBuilder::CreateDefaultConvDimensionNumbers(
- /*num_spatial_dims=*/1));
+ auto forward_conv =
+ builder.ConvGeneralDilated(activations, gradients,
+ /*window_strides=*/{1},
+ /*padding=*/{{2, 1}},
+ /*lhs_dilation=*/{}, /*rhs_dilation=*/{2},
+ XlaBuilder::CreateDefaultConvDimensionNumbers(
+ /*num_spatial_dims=*/1));
builder.Transpose(forward_conv, {0, 1, 2});
ComputeAndCompareR3<float>(&builder, {{{13, 24, 130}}}, {}, error_spec_);
}
XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto gradients_flat = Literal::CreateR1<float>({1});
auto gradients_literal =
@@ -1357,7 +1356,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardInputEvenPadding3D) {
}
XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto activations_flat = Literal::CreateR1<float>({1, 2, 3, 4});
auto activations_literal =
@@ -1378,7 +1377,7 @@ XLA_TEST_F(ConvolutionVariantsTest, BackwardFilterEvenPadding3D) {
/*window_strides=*/{1, 1, 1},
/*padding=*/{{0, 0}, {0, 0}, {2, 1}},
/*lhs_dilation=*/{}, /*rhs_dilation=*/{1, 1, 2},
- ComputationBuilder::CreateDefaultConvDimensionNumbers(
+ XlaBuilder::CreateDefaultConvDimensionNumbers(
/*num_spatial_dims=*/3));
builder.Transpose(forward_conv, {0, 1, 2, 3, 4});
ComputeAndCompareLiteral(&builder, *expected_literal, {}, error_spec_);
diff --git a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
index 021fbcedb9..ff53a84588 100644
--- a/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
+++ b/tensorflow/compiler/xla/tests/dynamic_ops_test.cc
@@ -470,13 +470,6 @@ class DynamicUpdateSliceTest : public ClientLibraryTestBase {
template <class T>
void RunR3Contiguous(std::vector<int32> operand_shape, int32 index,
int32 size) {
-#ifdef XLA_TEST_BACKEND_CPU_PARALLEL
- // TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10.
- if (std::is_same<bfloat16, T>::value) {
- return;
- }
-#endif
-
const int32 kSeq = operand_shape[0];
const int32 kBatch = operand_shape[1];
const int32 kDim = operand_shape[2];
@@ -539,30 +532,22 @@ XLA_TEST_F(DynamicUpdateSliceTest, Int64R0) { TestR0<int64, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R0) { TestR0<uint64, float>(); }
// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10.
-XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R1BF16)) {
- TestR1<int32, bfloat16>();
-}
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R1BF16) { TestR1<int32, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R1) { TestR1<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R1) { TestR1<int64, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R1) { TestR1<uint64, float>(); }
-// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10.
-XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R2BF16)) {
- TestR2<int32, bfloat16>();
-}
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R2BF16) { TestR2<int32, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R2) { TestR2<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R2) { TestR2<int64, int64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R2) { TestR2<uint64, int32>(); }
-// TODO(b/71820067): The CPU parallel backend failed for this on 2018-01-10.
-XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32R3BF16)) {
- TestR3<int32, bfloat16>();
-}
+XLA_TEST_F(DynamicUpdateSliceTest, Int32R3BF16) { TestR3<int32, bfloat16>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int32R3) { TestR3<int32, float>(); }
XLA_TEST_F(DynamicUpdateSliceTest, Int64R3) { TestR3<int64, int64>(); }
XLA_TEST_F(DynamicUpdateSliceTest, UInt64R3) { TestR3<uint64, uint64>(); }
-XLA_TEST_F(DynamicUpdateSliceTest, DISABLED_ON_CPU_PARALLEL(Int32WrapBF16)) {
+XLA_TEST_F(DynamicUpdateSliceTest, Int32WrapBF16) {
TestWrap<int32, bfloat16>();
}
XLA_TEST_F(DynamicUpdateSliceTest, Int32Wrap) { TestWrap<int32, float>(); }
diff --git a/tensorflow/compiler/xla/tests/execution_profile_test.cc b/tensorflow/compiler/xla/tests/execution_profile_test.cc
index 644cbbf40f..c8cc8e40aa 100644
--- a/tensorflow/compiler/xla/tests/execution_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/execution_profile_test.cc
@@ -24,8 +24,7 @@ namespace {
class ExecutionProfileTest : public ClientLibraryTestBase {};
-XLA_TEST_F(ExecutionProfileTest,
- DISABLED_ON_CPU_PARALLEL(ExecuteWithExecutionProfile)) {
+XLA_TEST_F(ExecutionProfileTest, ExecuteWithExecutionProfile) {
Shape shape = ShapeUtil::MakeShape(F32, {256, 256});
TF_ASSERT_OK_AND_ASSIGN(
diff --git a/tensorflow/compiler/xla/tests/fusion_test.cc b/tensorflow/compiler/xla/tests/fusion_test.cc
index c7f64d8560..6f89e9164c 100644
--- a/tensorflow/compiler/xla/tests/fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/fusion_test.cc
@@ -794,19 +794,19 @@ void BM_ParallelFusion(int num_iters) {
// Transfer literals to device.
auto param0_literal =
Literal::CreateR2F32Linspace(1.0, 2.0, param0_dim0, param0_dim1);
- ShapedBuffer buffer0 =
+ ScopedShapedBuffer buffer0 =
client->LiteralToShapedBuffer(*param0_literal, device_ordinal)
.ConsumeValueOrDie();
auto param1_literal =
Literal::CreateR2F32Linspace(1.0, 2.0, param1_dim0, param1_dim1);
- ShapedBuffer buffer1 =
+ ScopedShapedBuffer buffer1 =
client->LiteralToShapedBuffer(*param1_literal, device_ordinal)
.ConsumeValueOrDie();
auto param2_literal =
Literal::CreateR2F32Linspace(1.0, 2.0, param2_dim0, param2_dim1);
- ShapedBuffer buffer2 =
+ ScopedShapedBuffer buffer2 =
client->LiteralToShapedBuffer(*param2_literal, device_ordinal)
.ConsumeValueOrDie();
diff --git a/tensorflow/compiler/xla/tests/gather_operation_test.cc b/tensorflow/compiler/xla/tests/gather_operation_test.cc
index 90496d55e6..4dd3acd9af 100644
--- a/tensorflow/compiler/xla/tests/gather_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/gather_operation_test.cc
@@ -401,10 +401,7 @@ ENTRY main {
class GatherClientLibraryTest : public ClientLibraryTestBase {};
-// TODO(b/30671675): Asynchronous execution on stream is not yet supported on
-// GPU and CPU_PARALLEL.
-XLA_TEST_F(GatherClientLibraryTest,
- DISABLED_ON_CPU_PARALLEL(DISABLED_ON_GPU(Basic))) {
+XLA_TEST_F(GatherClientLibraryTest, DISABLED_ON_GPU(Basic)) {
// We create this HLO, but using the XlaBuilder API.
//
// ENTRY main {
diff --git a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
index 7209f91639..f21f83992f 100644
--- a/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_allocation_test.cc
@@ -15,9 +15,8 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/shaped_buffer.h"
@@ -38,7 +37,7 @@ class LocalClientAllocationTest : public LocalClientTestBase {
};
XLA_TEST_F(LocalClientAllocationTest, AddVectors) {
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>({0.0f, 1.0f, 2.0f});
auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
builder.Add(x, y);
@@ -74,7 +73,7 @@ XLA_TEST_F(LocalClientAllocationTest, AddVectors) {
XLA_TEST_F(LocalClientAllocationTest, RunOnDevices) {
// Run a computation on every device on the system. Verify that allocation
// occurs on the proper device.
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>({0.0f, 1.0f, 2.0f});
auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
builder.Add(x, y);
diff --git a/tensorflow/compiler/xla/tests/local_client_execute_test.cc b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
index 7e14e77366..44c6811df8 100644
--- a/tensorflow/compiler/xla/tests/local_client_execute_test.cc
+++ b/tensorflow/compiler/xla/tests/local_client_execute_test.cc
@@ -18,9 +18,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
@@ -54,7 +53,7 @@ class LocalClientExecuteTest : public LocalClientTestBase {
};
XLA_TEST_F(LocalClientExecuteTest, Constant) {
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto y = builder.ConstantR0<float>(123.0f);
ScopedShapedBuffer result =
@@ -64,7 +63,7 @@ XLA_TEST_F(LocalClientExecuteTest, Constant) {
}
XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x");
auto y = builder.ConstantR0<float>(123.0f);
builder.Add(x, y);
@@ -77,7 +76,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddScalars) {
}
XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {0}), "x");
auto y = builder.ConstantR1<float>({});
builder.Add(x, y);
@@ -90,7 +89,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddZeroElementVectors) {
}
XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
builder.Add(x, y);
@@ -104,7 +103,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectors) {
}
XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
builder.Add(x, y);
@@ -122,7 +121,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddVectorsWithProfile) {
}
XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
builder.Add(x, y);
@@ -155,7 +154,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentInputLayouts) {
}
XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
builder.Add(x, y);
@@ -192,7 +191,7 @@ XLA_TEST_F(LocalClientExecuteTest, AddArraysWithDifferentOutputLayouts) {
}
XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
builder.Tuple({x, y, x});
@@ -220,7 +219,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleResult) {
}
XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
auto inner_tuple = builder.Tuple({x, y, x});
@@ -254,7 +253,7 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleResult) {
XLA_TEST_F(LocalClientExecuteTest, TupleResultWithLayout) {
// Verify setting the result layout of a computation with a tuple output.
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {2, 2}), "y");
builder.Tuple({x, y});
@@ -291,7 +290,7 @@ XLA_TEST_F(LocalClientExecuteTest, TupleArguments) {
// Computation adds the respective array and vector elements from each tuple
// argument and returns the results as a tuple.
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, tuple_shape0, "x");
auto y = builder.Parameter(1, tuple_shape1, "y");
auto x_0 = builder.GetTupleElement(x, 0);
@@ -338,7 +337,7 @@ XLA_TEST_F(LocalClientExecuteTest, NestedTupleArgument) {
// Computation negates the array element and sums the two vector elements in
// the nested tuple. The resulting array and vector are returned as a tuple.
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto param = builder.Parameter(0, nested_tuple_shape, "param");
auto inner_tuple = builder.GetTupleElement(param, 0);
auto inner_array = builder.GetTupleElement(inner_tuple, 0);
@@ -376,7 +375,7 @@ XLA_TEST_F(LocalClientExecuteTest, PassingTupleResultBackIntoComputation) {
const Shape tuple_shape =
ShapeUtil::MakeTupleShape({array_shape, array_shape});
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto param = builder.Parameter(0, tuple_shape, "param");
auto element_0 = builder.GetTupleElement(param, 0);
auto element_1 = builder.GetTupleElement(param, 1);
@@ -420,11 +419,11 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
std::vector<Shape> element_shapes(kElementCount, element_shape);
const Shape tuple_shape = ShapeUtil::MakeTupleShape(element_shapes);
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto param = builder.Parameter(0, tuple_shape, "param");
// Add each element's tuple index value to every element.
- std::vector<ComputationDataHandle> result_elements;
+ std::vector<XlaOp> result_elements;
for (int i = 0; i < kElementCount; ++i) {
auto element = builder.GetTupleElement(param, i);
result_elements.push_back(
@@ -453,9 +452,7 @@ XLA_TEST_F(LocalClientExecuteTest, LargeTuple) {
}
}
-// TODO(b/66968986): Test times out on CPU parallel backend. Disabled
-// 2017-09-26.
-XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU_PARALLEL(LargeNestedTuple)) {
+XLA_TEST_F(LocalClientExecuteTest, LargeNestedTuple) {
// Construct and run a computation which takes a two-level nested tuple
// parameter with a large fanout.
const int kFanout = 40;
@@ -467,15 +464,15 @@ XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_CPU_PARALLEL(LargeNestedTuple)) {
std::vector<Shape> inner_tuple_shapes(kFanout, inner_tuple_shape);
const Shape tuple_shape = ShapeUtil::MakeTupleShape(inner_tuple_shapes);
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto param = builder.Parameter(0, tuple_shape, "param");
// The computation increments each leaf value by an amount equal to the leaf's
// ordinal position in a traversal of the tuple.
- std::vector<ComputationDataHandle> result_elements;
+ std::vector<XlaOp> result_elements;
for (int i = 0; i < kFanout; ++i) {
auto outer_element = builder.GetTupleElement(param, i);
- std::vector<ComputationDataHandle> inner_result_elements;
+ std::vector<XlaOp> inner_result_elements;
for (int j = 0; j < kFanout; ++j) {
auto inner_element = builder.GetTupleElement(outer_element, j);
inner_result_elements.push_back(builder.Add(
@@ -522,7 +519,7 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
shape = ShapeUtil::MakeTupleShape({shape});
}
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto element = builder.Parameter(0, shape, "param");
for (int i = 0; i < kTupleDepth; ++i) {
element = builder.GetTupleElement(element, 0);
@@ -556,7 +553,7 @@ XLA_TEST_F(LocalClientExecuteTest, DeepTuple) {
XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
// Test passing in an invalid number of arguments.
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(F32, {3}), "y");
builder.Add(x, y);
@@ -573,7 +570,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidNumberOfArguments) {
XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
// Test passing in an argument with the wrong shape.
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
builder.Neg(x);
@@ -590,7 +587,7 @@ XLA_TEST_F(LocalClientExecuteTest, IncorrectArgumentShape) {
XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
// Test passing in an invalid result layout parameter.
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {2, 2}), "x");
builder.Neg(x);
@@ -613,7 +610,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidResultLayout) {
XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
// Try to run a trivial computation on every device on the system. If a
// specific device is not supported, check that the right error is returned.
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
builder.ConstantR0<float>(42.0f);
auto computation = builder.Build().ConsumeValueOrDie();
for (int d = 0; d < local_client_->device_count(); ++d) {
@@ -640,7 +637,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnAllDeviceOrdinals) {
XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) {
// Try running computations on devices with device ordinal values which do not
// exist.
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
builder.ConstantR0<float>(42.0f);
auto computation = builder.Build().ConsumeValueOrDie();
@@ -657,7 +654,7 @@ XLA_TEST_F(LocalClientExecuteTest, InvalidDeviceOrdinalValues) {
XLA_TEST_F(LocalClientExecuteTest, RunOnStream) {
// Run a computation on a specific stream on each device on the system.
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
builder.ConstantR0<float>(42.0f);
auto computation = builder.Build().ConsumeValueOrDie();
@@ -693,7 +690,7 @@ XLA_TEST_F(LocalClientExecuteTest,
se::Stream wrong_stream(wrong_platform->ExecutorForDevice(0).ValueOrDie());
wrong_stream.Init();
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
builder.ConstantR0<float>(42.0f);
auto execute_status = ExecuteLocally(
builder.Build().ValueOrDie(), {}, DefaultExecutableBuildOptions(),
@@ -710,7 +707,7 @@ XLA_TEST_F(LocalClientExecuteTest,
.ValueOrDie();
TestAllocator allocator(wrong_platform);
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto y = builder.ConstantR0<float>(123.0f);
auto execute_status = ExecuteLocally(
@@ -723,7 +720,7 @@ XLA_TEST_F(LocalClientExecuteTest,
XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) {
// Try to run a computation on a stream that has not been initialized.
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
builder.ConstantR0<float>(42.0f);
LOG(INFO) << "default device = " << local_client_->default_device_ordinal();
@@ -743,7 +740,7 @@ XLA_TEST_F(LocalClientExecuteTest, RunOnUninitializedStream) {
}
XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
std::initializer_list<float> vec1 = {1.f, 2.f, 3.f};
std::initializer_list<float> vec2 = {2.f, 4.f, 6.f};
@@ -763,7 +760,7 @@ XLA_TEST_F(LocalClientExecuteTest, SelectBetweenTuples) {
}
XLA_TEST_F(LocalClientExecuteTest, CompileExecutable) {
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(F32, {3}), "x");
auto y = builder.ConstantR1<float>({2.0f, 3.0f, 4.0f});
builder.Add(x, y);
@@ -853,9 +850,8 @@ XLA_TEST_F(LocalClientExecuteTest, ShapeBufferToLiteralConversion64bit) {
// TODO(b/34359662): Support infeed/outfeed on GPU and CPU parallel.
// 2017-10-18.
-XLA_TEST_F(LocalClientExecuteTest,
- DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(InfeedOutfeedTest))) {
- ComputationBuilder builder(local_client_, TestName());
+XLA_TEST_F(LocalClientExecuteTest, DISABLED_ON_GPU(InfeedOutfeedTest)) {
+ XlaBuilder builder(TestName());
const Shape shape = ShapeUtil::MakeShape(F32, {3});
auto in = builder.Infeed(shape);
auto constant = builder.ConstantR1<float>({1.0f, 2.0f, 3.0f});
@@ -893,7 +889,7 @@ void BM_LocalClientOverhead(int num_iters) {
int device_ordinal = client->default_device_ordinal();
// Use a tiny add operation as the computation.
- ComputationBuilder builder(client, "Add");
+ XlaBuilder builder("Add");
auto shape = ShapeUtil::MakeShape(F32, {2, 3});
auto x = builder.Parameter(0, shape, "x");
builder.Add(x, x);
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.cc b/tensorflow/compiler/xla/tests/local_client_test_base.cc
index bb5aabb214..ca8e4cdbdb 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.cc
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.cc
@@ -27,7 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/test_helpers.h"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
@@ -157,7 +157,7 @@ ExecutableRunOptions LocalClientTestBase::DefaultExecutableRunOptions() const {
}
ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
- const Computation& computation,
+ const XlaComputation& computation,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
DefaultExecutableRunOptions())
@@ -165,7 +165,7 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
}
ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
- const Computation& computation,
+ const XlaComputation& computation,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options) {
@@ -174,14 +174,14 @@ ScopedShapedBuffer LocalClientTestBase::ExecuteLocallyOrDie(
}
StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
- const Computation& computation,
+ const XlaComputation& computation,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments) {
return ExecuteLocally(computation, arguments, DefaultExecutableBuildOptions(),
DefaultExecutableRunOptions());
}
StatusOr<ScopedShapedBuffer> LocalClientTestBase::ExecuteLocally(
- const Computation& computation,
+ const XlaComputation& computation,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options) {
diff --git a/tensorflow/compiler/xla/tests/local_client_test_base.h b/tensorflow/compiler/xla/tests/local_client_test_base.h
index 4ee56a05ec..3bbb760c80 100644
--- a/tensorflow/compiler/xla/tests/local_client_test_base.h
+++ b/tensorflow/compiler/xla/tests/local_client_test_base.h
@@ -21,8 +21,8 @@ limitations under the License.
#include <vector>
#include "tensorflow/compiler/xla/client/client_library.h"
-#include "tensorflow/compiler/xla/client/computation.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/service/device_memory_allocator.h"
#include "tensorflow/compiler/xla/service/local_service.h"
#include "tensorflow/compiler/xla/service/platform_util.h"
@@ -93,19 +93,19 @@ class LocalClientTestBase : public ::testing::Test {
// Execute the given computation on the local client. With and without
// options.
StatusOr<ScopedShapedBuffer> ExecuteLocally(
- const Computation& computation,
+ const XlaComputation& computation,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
StatusOr<ScopedShapedBuffer> ExecuteLocally(
- const Computation& computation,
+ const XlaComputation& computation,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options);
ScopedShapedBuffer ExecuteLocallyOrDie(
- const Computation& computation,
+ const XlaComputation& computation,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments);
ScopedShapedBuffer ExecuteLocallyOrDie(
- const Computation& computation,
+ const XlaComputation& computation,
tensorflow::gtl::ArraySlice<const ShapedBuffer*> arguments,
const ExecutableBuildOptions& build_options,
const ExecutableRunOptions& run_options);
diff --git a/tensorflow/compiler/xla/tests/map_test.cc b/tensorflow/compiler/xla/tests/map_test.cc
index 8fabcaca1b..7df45bebeb 100644
--- a/tensorflow/compiler/xla/tests/map_test.cc
+++ b/tensorflow/compiler/xla/tests/map_test.cc
@@ -16,8 +16,6 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -341,48 +339,6 @@ XLA_TEST_F(MapTest, ComplexNestedMaps) {
ComputeAndCompareR0<float>(&builder, 73.0, {}, ErrorSpec(0.01f));
}
-TEST_F(MapTest, VersionedEmbeddedComputation) {
- // Build a computation X, use it in a map, then add an additional operation to
- // computation X and use it again in a different map. Verify that the proper
- // versions of computation X are used in each of the maps.
-
- // Create a (embedded) computation which adds one to its parameter argument.
- ComputationBuilder embedded_builder(client_, "EmbeddedComputation");
- auto param_0 =
- embedded_builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "param0");
- auto constant_one = embedded_builder.ConstantR0<float>(1.0);
- auto adder_to_one = embedded_builder.Add(param_0, constant_one);
- auto computation_status = embedded_builder.Build();
- ASSERT_IS_OK(computation_status.status());
- auto embedded_computation = computation_status.ConsumeValueOrDie();
-
- ComputationBuilder builder(client_, TestName());
- auto constant_vector = builder.ConstantR1<float>({1.0, 2.0, 3.0, 4.0});
- auto map_plus_1 = builder.Map({constant_vector}, embedded_computation, {0});
-
- // Add another Add(1) operation to the existing embedded computation. This
- // requires using the stub interface because the ComputationBuilder does not
- // allow modification to the XlaComputation objects after they have been
- // built.
- BinaryOpRequest request;
- request.set_binop(BINOP_ADD);
- *request.mutable_lhs() = adder_to_one;
- *request.mutable_rhs() = constant_one;
- OpRequest op_request;
- *op_request.mutable_computation() = embedded_computation.handle();
- *op_request.mutable_binary_op_request() = request;
- OpResponse response;
- tensorflow::Status s = client_->stub()->Op(&op_request, &response);
- ASSERT_TRUE(s.ok());
-
- auto map_plus_2 = builder.Map({map_plus_1}, embedded_computation, {0});
-
- // The original vector has Add(1) applied to it with a map, followed by
- // Add(1+1) resulting in a net Add(3).
- ComputeAndCompareR1<float>(&builder, {4.0, 5.0, 6.0, 7.0}, {},
- ErrorSpec(0.01f));
-}
-
TEST_F(MapTest, MapBinaryAdder) {
// Maps (lambda (x y) (+ x y)) onto two R1F32 vectors.
XlaBuilder builder(TestName());
diff --git a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
index c42f71388b..7fa61eb33c 100644
--- a/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/matrix_ops_simple_test.cc
@@ -19,8 +19,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
@@ -60,7 +61,7 @@ TYPED_TEST_CASE(MatOpsSimpleTest_F16F32, TypesF16F32);
XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) {
using T = TypeParam;
- ComputationBuilder builder(this->client_, "exp_2x2");
+ XlaBuilder builder("exp_2x2");
auto data = builder.ConstantR2FromArray2D<T>({
{1.0f, 0.0f}, // row 0
{-1.0f, 0.5f}, // row 1
@@ -77,10 +78,10 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, ExpTwoByTwoValues) {
XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
using T = TypeParam;
- Computation add_half;
+ XlaComputation add_half;
{
// add_half(x) = x + 0.5
- ComputationBuilder builder(this->client_, "add_half");
+ XlaBuilder builder("add_half");
auto x_value =
builder.Parameter(0, ShapeUtil::MakeShapeWithType<T>({}), "x_value");
auto half = builder.ConstantR0<T>(static_cast<T>(0.5));
@@ -90,7 +91,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
add_half = computation_status.ConsumeValueOrDie();
}
- ComputationBuilder builder(this->client_, "map_2x2");
+ XlaBuilder builder("map_2x2");
auto data = builder.ConstantR2FromArray2D<T>({
{1.0f, 0.0f}, // row 0
{-1.0f, 0.5f}, // row 1
@@ -106,7 +107,7 @@ XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MapTwoByTwo) {
XLA_TYPED_TEST(MatOpsSimpleTest_F16F32, MaxTwoByTwoValues) {
using T = TypeParam;
- ComputationBuilder builder(this->client_, "max_2x2");
+ XlaBuilder builder("max_2x2");
auto lhs = builder.ConstantR2FromArray2D<T>({
{7.0f, 2.0f}, // row 0
{3.0f, -4.0f}, // row 1
@@ -143,8 +144,7 @@ class TestLinspaceMaxParametric
MakeLinspaceArray2D<T>(from, to, rows, cols);
auto arhs = MakeUnique<Array2D<T>>(rows, cols, static_cast<T>(1.0f));
- ComputationBuilder builder(
- client_,
+ XlaBuilder builder(
tensorflow::strings::Printf("max_%lldx%lld_linspace", rows, cols));
auto lhs = builder.ConstantR2FromArray2D<T>(*alhs);
auto rhs = builder.ConstantR2FromArray2D<T>(*arhs);
@@ -219,7 +219,7 @@ class MatOpsDotAddTest
client_->TransferToServer(*Literal::CreateR2FromArray2DWithLayout<T>(
rhs, LayoutUtil::MakeLayout(minor_to_major(row_major)))));
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto lhs_arg = builder.Parameter(0, lhs_shape, "lhs");
auto lhs_mat_arg = lhs_arg;
if (transpose) {
diff --git a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
index 11c0bf7a5a..0791a71aac 100644
--- a/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
+++ b/tensorflow/compiler/xla/tests/multidimensional_slice_test.cc
@@ -19,8 +19,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@@ -32,7 +32,7 @@ namespace {
class SliceTest : public ClientLibraryTestBase {};
XLA_TEST_F(SliceTest, Slice2D) {
- ComputationBuilder builder(client_, "slice_2d");
+ XlaBuilder builder("slice_2d");
auto original = builder.ConstantR2<float>(
{{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}, {10.0, 11.0, 12.0}});
builder.Slice(original, {2, 1}, {4, 3}, {1, 1});
@@ -42,7 +42,7 @@ XLA_TEST_F(SliceTest, Slice2D) {
}
XLA_TEST_F(SliceTest, Slice3D) {
- ComputationBuilder builder(client_, "slice_3d");
+ XlaBuilder builder("slice_3d");
Array3D<float> array_3d(
{{{1.0f, 2.0f}, {3.0f, 4.0f}}, {{5.0f, 6.0f}, {7.0f, 8.0f}}});
auto original = builder.ConstantR3FromArray3D<float>(array_3d);
diff --git a/tensorflow/compiler/xla/tests/params_test.cc b/tensorflow/compiler/xla/tests/params_test.cc
index bb7e800df8..97dab860c0 100644
--- a/tensorflow/compiler/xla/tests/params_test.cc
+++ b/tensorflow/compiler/xla/tests/params_test.cc
@@ -20,9 +20,10 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -41,7 +42,7 @@ namespace {
class ParamsTest : public ClientLibraryTestBase {};
XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal = Literal::CreateR0<float>(3.14159f);
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -53,7 +54,7 @@ XLA_TEST_F(ParamsTest, ConstantR0F32Param) {
}
XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal = Literal::CreateR1<float>({});
std::unique_ptr<GlobalData> param0_data =
client_->TransferToServer(*param0_literal).ConsumeValueOrDie();
@@ -65,7 +66,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S0F32Param) {
}
XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
Literal::CreateR1<float>({3.14f, -100.25f});
std::unique_ptr<GlobalData> param0_data =
@@ -78,7 +79,7 @@ XLA_TEST_F(ParamsTest, ConstantR1S2F32Param) {
}
XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
string str("hello world");
std::unique_ptr<Literal> param0_literal = Literal::CreateR1U8(str);
std::unique_ptr<GlobalData> param0_data =
@@ -91,7 +92,7 @@ XLA_TEST_F(ParamsTest, ConstantR1U8Param) {
}
XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
Literal::CreateR2FromArray2D<float>(Array2D<float>(3, 0));
std::unique_ptr<GlobalData> param0_data =
@@ -104,7 +105,7 @@ XLA_TEST_F(ParamsTest, ConstantR2_3x0_F32Param) {
}
XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal = Literal::CreateR2<float>(
{{3.14f, -100.25f}, {7e8f, 7e-9f}, {30.3f, -100.0f}});
std::unique_ptr<GlobalData> param0_data =
@@ -119,7 +120,7 @@ XLA_TEST_F(ParamsTest, ConstantR2F32Param) {
}
XLA_TEST_F(ParamsTest, TwoParameters) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
@@ -156,19 +157,15 @@ XLA_TEST_F(ParamsTest, MissingParameter) {
std::unique_ptr<GlobalData> data =
client_->TransferToServer(*literal).ConsumeValueOrDie();
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto p = builder.Parameter(2, ShapeUtil::MakeShape(F32, {}), "param2");
- auto computation = builder.Build().ConsumeValueOrDie();
+ auto computation_status = builder.Build();
- auto execute_status = client_->Execute(computation, {data.get(), data.get()},
- /*execution_options=*/nullptr,
- /*execution_profile=*/nullptr);
- ASSERT_EQ(execute_status.status().code(),
- tensorflow::error::FAILED_PRECONDITION);
+ ASSERT_NE(computation_status.status(), tensorflow::Status::OK());
}
XLA_TEST_F(ParamsTest, UnusedParameter) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
@@ -188,7 +185,7 @@ XLA_TEST_F(ParamsTest, UnusedParameter) {
XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) {
// Build a computation with a couple unused parameters which are used in an
// unused expression.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::unique_ptr<Literal> literal0 = Literal::CreateR1<float>({1, 2});
std::unique_ptr<GlobalData> param0_data =
@@ -214,12 +211,12 @@ XLA_TEST_F(ParamsTest, UnusedParametersInUnusedExpression) {
}
XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
constexpr int size = 8 * 128 * 2;
std::vector<float> init_value = {{0, 1}};
init_value.resize(size);
- ComputationDataHandle sum_handle = builder.ConstantR1<float>(init_value);
+ XlaOp sum_handle = builder.ConstantR1<float>(init_value);
std::vector<float> sum = {{0, 1}};
sum.resize(size);
@@ -237,8 +234,7 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
std::unique_ptr<Literal> literal = Literal::CreateR1<float>(sum_value);
param_data_owner.push_back(
client_->TransferToServer(*literal).ConsumeValueOrDie());
- ComputationDataHandle param =
- builder.Parameter(i, literal->shape(), "param");
+ XlaOp param = builder.Parameter(i, literal->shape(), "param");
sum_handle = builder.Add(sum_handle, param);
}
@@ -262,10 +258,10 @@ XLA_TEST_F(ParamsTest, HundredLargeR1Parameters) {
// compilation.
XLA_TEST_F(ParamsTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(ThreeThousandParameters))) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<std::unique_ptr<GlobalData>> param_data_owner;
- ComputationDataHandle sum_handle = builder.ConstantR0<float>(0.0f);
+ XlaOp sum_handle = builder.ConstantR0<float>(0.0f);
float target = 0.0;
constexpr int kParamCount = 3000;
for (int i = 0; i < kParamCount; ++i) {
@@ -273,8 +269,7 @@ XLA_TEST_F(ParamsTest,
std::unique_ptr<Literal> literal = Literal::CreateR0<float>(i);
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
- ComputationDataHandle param =
- builder.Parameter(i, literal->shape(), "param");
+ XlaOp param = builder.Parameter(i, literal->shape(), "param");
sum_handle = builder.Add(sum_handle, param);
}
@@ -294,25 +289,24 @@ XLA_TEST_F(ParamsTest,
// compilation.
XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
ThreeThousandParametersAndOutputElements))) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<std::unique_ptr<GlobalData>> param_data_owner;
- ComputationDataHandle sum_handle = builder.ConstantR1<int32>({0, 0});
+ XlaOp sum_handle = builder.ConstantR1<int32>({0, 0});
int32 target = 0;
constexpr int kParamCount = 3000;
- std::vector<ComputationDataHandle> params;
+ std::vector<XlaOp> params;
for (int i = 0; i < kParamCount; ++i) {
target += i;
std::unique_ptr<Literal> literal = Literal::CreateR1<int32>({i, i});
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
- ComputationDataHandle param =
- builder.Parameter(i, literal->shape(), "param");
+ XlaOp param = builder.Parameter(i, literal->shape(), "param");
params.push_back(param);
sum_handle = builder.Add(sum_handle, param);
}
- std::vector<ComputationDataHandle> outputs;
+ std::vector<XlaOp> outputs;
for (int i = 0; i < kParamCount; ++i) {
outputs.push_back(builder.Add(params[i], sum_handle));
}
@@ -353,18 +347,17 @@ XLA_TEST_F(ParamsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(
// 2017-12-12.
XLA_TEST_F(ParamsTest,
DISABLED_ON_CPU(DISABLED_ON_GPU(ManyParametersIntoWhileLoop))) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<std::unique_ptr<GlobalData>> param_data_owner;
constexpr int kParamCount = 1900;
- std::vector<ComputationDataHandle> params;
+ std::vector<XlaOp> params;
std::vector<Shape> parameter_shapes;
for (int i = 0; i < kParamCount; ++i) {
std::unique_ptr<Literal> literal = Literal::CreateR1<int32>({i, i});
param_data_owner.push_back(
std::move(client_->TransferToServer(*literal)).ValueOrDie());
- ComputationDataHandle param =
- builder.Parameter(i, literal->shape(), "param");
+ XlaOp param = builder.Parameter(i, literal->shape(), "param");
params.push_back(param);
parameter_shapes.push_back(literal->shape());
}
@@ -374,7 +367,7 @@ XLA_TEST_F(ParamsTest,
std::unique_ptr<Literal> bool_literal = Literal::CreateR0<bool>(false);
param_data_owner.push_back(
std::move(client_->TransferToServer(*bool_literal)).ValueOrDie());
- ComputationDataHandle bool_param =
+ XlaOp bool_param =
builder.Parameter(kParamCount, bool_literal->shape(), "bool_param");
params.push_back(bool_param);
parameter_shapes.push_back(bool_literal->shape());
@@ -383,9 +376,9 @@ XLA_TEST_F(ParamsTest,
// Create a computation for the condition: while(bool_param).
Shape while_shape = ShapeUtil::MakeTupleShape(parameter_shapes);
- Computation condition;
+ XlaComputation condition;
{
- ComputationBuilder builder(client_, "condition");
+ XlaBuilder builder("condition");
auto condition_parameter =
builder.Parameter(0, while_shape, "condition_parameter");
builder.GetTupleElement(condition_parameter, kParamCount);
@@ -394,11 +387,11 @@ XLA_TEST_F(ParamsTest,
// Create a computation for the body.
// Add {1, 1} to the each tuple element.
- Computation body;
+ XlaComputation body;
{
- ComputationBuilder builder(client_, "body");
+ XlaBuilder builder("body");
auto body_parameter = builder.Parameter(0, while_shape, "body_parameter");
- std::vector<ComputationDataHandle> updates;
+ std::vector<XlaOp> updates;
for (int i = 0; i < kParamCount; ++i) {
auto add = builder.Add(builder.GetTupleElement(body_parameter, i),
builder.ConstantR1<int32>({1, 1}));
@@ -413,7 +406,7 @@ XLA_TEST_F(ParamsTest,
auto loop = builder.While(condition, body, init);
- std::vector<ComputationDataHandle> outputs;
+ std::vector<XlaOp> outputs;
for (int i = 0; i < kParamCount; ++i) {
outputs.push_back(builder.GetTupleElement(loop, i));
}
@@ -437,7 +430,7 @@ XLA_TEST_F(ParamsTest,
#endif
XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Shape r1f32_3 = ShapeUtil::MakeShape(F32, {3});
Shape tuple_shape = ShapeUtil::MakeTupleShape({r1f32_3, r1f32_3});
@@ -464,7 +457,7 @@ XLA_TEST_F(ParamsTest, TupleOfR1ParametersAddedTogether) {
XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
{{1, 2}, {3, 4}}, LayoutUtil::MakeLayout({0, 1}));
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
builder.Parameter(0, literal->shape(), "input");
std::unique_ptr<GlobalData> data =
@@ -476,7 +469,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_Layout_01) {
XLA_TEST_F(ParamsTest, R2_2x2_Layout_10) {
std::unique_ptr<Literal> literal = Literal::CreateR2WithLayout<float>(
{{1, 3}, {2, 4}}, LayoutUtil::MakeLayout({1, 0}));
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
builder.Parameter(0, literal->shape(), "input");
std::unique_ptr<GlobalData> data =
@@ -501,7 +494,7 @@ XLA_TEST_F(ParamsTest, R2_2x2_TryToPassReverseLayoutToParameter) {
ASSERT_EQ(2, literal->Get<float>({0, 1}));
}
// Use the original shape in building the computation.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto input = builder.Parameter(0, original, "input");
// Use the slice operator to get an off-diagonal element.
builder.Slice(input, {0, 1}, {1, 2}, {1, 1});
diff --git a/tensorflow/compiler/xla/tests/pred_test.cc b/tensorflow/compiler/xla/tests/pred_test.cc
index 10e44b274a..77159efb26 100644
--- a/tensorflow/compiler/xla/tests/pred_test.cc
+++ b/tensorflow/compiler/xla/tests/pred_test.cc
@@ -17,9 +17,9 @@ limitations under the License.
#include <memory>
#include "tensorflow/compiler/xla/array2d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -29,63 +29,62 @@ namespace {
class PredTest : public ClientLibraryTestBase {
protected:
- void TestCompare(bool lhs, bool rhs, bool expected,
- ComputationDataHandle (ComputationBuilder::*op)(
- const ComputationDataHandle&,
- const ComputationDataHandle&,
- tensorflow::gtl::ArraySlice<int64>)) {
- ComputationBuilder builder(client_, TestName());
- ComputationDataHandle lhs_op = builder.ConstantR0<bool>(lhs);
- ComputationDataHandle rhs_op = builder.ConstantR0<bool>(rhs);
- ComputationDataHandle result = (builder.*op)(lhs_op, rhs_op, {});
+ void TestCompare(
+ bool lhs, bool rhs, bool expected,
+ XlaOp (XlaBuilder::*op)(const xla::XlaOp&, const xla::XlaOp&,
+ tensorflow::gtl::ArraySlice<int64>)) {
+ XlaBuilder builder(TestName());
+ XlaOp lhs_op = builder.ConstantR0<bool>(lhs);
+ XlaOp rhs_op = builder.ConstantR0<bool>(rhs);
+ XlaOp result = (builder.*op)(lhs_op, rhs_op, {});
ComputeAndCompareR0<bool>(&builder, expected, {});
}
};
TEST_F(PredTest, ConstantR0PredTrue) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR0<bool>(true);
ComputeAndCompareR0<bool>(&builder, true, {});
}
TEST_F(PredTest, ConstantR0PredFalse) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR0<bool>(false);
ComputeAndCompareR0<bool>(&builder, false, {});
}
TEST_F(PredTest, ConstantR0PredCompareEq) {
- TestCompare(true, false, false, &ComputationBuilder::Eq);
+ TestCompare(true, false, false, &XlaBuilder::Eq);
}
TEST_F(PredTest, ConstantR0PredCompareNe) {
- TestCompare(true, false, true, &ComputationBuilder::Ne);
+ TestCompare(true, false, true, &XlaBuilder::Ne);
}
TEST_F(PredTest, ConstantR0PredCompareLe) {
- TestCompare(true, false, false, &ComputationBuilder::Le);
+ TestCompare(true, false, false, &XlaBuilder::Le);
}
TEST_F(PredTest, ConstantR0PredCompareLt) {
- TestCompare(true, false, false, &ComputationBuilder::Lt);
+ TestCompare(true, false, false, &XlaBuilder::Lt);
}
TEST_F(PredTest, ConstantR0PredCompareGe) {
- TestCompare(true, false, true, &ComputationBuilder::Ge);
+ TestCompare(true, false, true, &XlaBuilder::Ge);
}
TEST_F(PredTest, ConstantR0PredCompareGt) {
- TestCompare(true, false, true, &ComputationBuilder::Gt);
+ TestCompare(true, false, true, &XlaBuilder::Gt);
}
TEST_F(PredTest, ConstantR1Pred) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<bool>({true, false, false, true});
ComputeAndCompareR1<bool>(&builder, {true, false, false, true}, {});
}
TEST_F(PredTest, ConstantR2Pred) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a =
builder.ConstantR2<bool>({{false, true, true}, {true, false, false}});
const string expected = R"(pred[2,3] {
@@ -96,28 +95,28 @@ TEST_F(PredTest, ConstantR2Pred) {
}
TEST_F(PredTest, AnyR1True) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<bool>({true, false});
TF_ASSERT_OK(Any(a, &builder).status());
ComputeAndCompareR0<bool>(&builder, true, {});
}
TEST_F(PredTest, AnyR1False) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<bool>({false, false});
TF_ASSERT_OK(Any(a, &builder).status());
ComputeAndCompareR0<bool>(&builder, false, {});
}
TEST_F(PredTest, AnyR1VacuouslyFalse) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR1<bool>({});
TF_ASSERT_OK(Any(a, &builder).status());
ComputeAndCompareR0<bool>(&builder, false, {});
}
TEST_F(PredTest, AnyR2True) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR2<bool>({
{false, false, false},
{false, false, false},
@@ -128,7 +127,7 @@ TEST_F(PredTest, AnyR2True) {
}
TEST_F(PredTest, AnyR2False) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR2<bool>({
{false, false, false},
{false, false, false},
diff --git a/tensorflow/compiler/xla/tests/prng_test.cc b/tensorflow/compiler/xla/tests/prng_test.cc
index 6aafb9fa6c..29a4f75001 100644
--- a/tensorflow/compiler/xla/tests/prng_test.cc
+++ b/tensorflow/compiler/xla/tests/prng_test.cc
@@ -16,8 +16,8 @@ limitations under the License.
#include <limits>
#include <memory>
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
@@ -52,13 +52,14 @@ class PrngTest : public ClientLibraryTestBase {
template <typename T>
std::unique_ptr<Literal> PrngTest::UniformTest(
T a, T b, tensorflow::gtl::ArraySlice<int64> dims, int64 seed) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
builder.RngUniform(
builder.ConstantR0<T>(a), builder.ConstantR0<T>(b),
ShapeUtil::MakeShape(primitive_util::NativeToPrimitiveType<T>(), dims));
SetSeed(seed);
- auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
+ auto actual =
+ ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
EXPECT_THAT(dims, ::testing::ElementsAreArray(actual->shape().dimensions()));
actual->EachCell<T>([=](tensorflow::gtl::ArraySlice<int64>, T value) {
EXPECT_LE(a, value);
@@ -81,8 +82,7 @@ XLA_TEST_F(PrngTest, LargeU01) { UniformTest<float>(0, 1, {0x100, 0x100}); }
XLA_TEST_F(PrngTest, TwelveValuesU524) { UniformTest<int32>(5, 24, {12}); }
// TODO(b/71543667): Fix Rng ops on LLVM backends.
-XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
- DISABLED_ON_CPU(ScalarBF16Tests)))) {
+XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16Tests))) {
for (int64 seed = 0; seed < 100; ++seed) {
// The largest negative number smaller than zero in bf16 that's not
// denormalized.
@@ -105,8 +105,7 @@ XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(
}
// TODO(b/71543667): Fix Rng ops on LLVM backends.
-XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(
- DISABLED_ON_CPU_PARALLEL(ScalarBF16CountTests)))) {
+XLA_TEST_F(PrngTest, DISABLED_ON_GPU(DISABLED_ON_CPU(ScalarBF16CountTests))) {
// There are 3 BF16 values in the range of [32.25, 33): 32.25, 32.5, 32.75,
// they should get similar counts.
bfloat16 low = static_cast<bfloat16>(32.25);
@@ -141,13 +140,14 @@ double PrngTest::UniformChiSquared(int32 range_size, int32 expected_count,
int64 seed) {
int32 sample_size = range_size * expected_count;
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
builder.RngUniform(builder.ConstantR0<int32>(0),
builder.ConstantR0<int32>(range_size),
ShapeUtil::MakeShape(S32, {sample_size}));
SetSeed(seed);
- auto actual = ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
+ auto actual =
+ ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
std::vector<int32> counts(range_size, 0);
actual->EachCell<int32>([&counts](tensorflow::gtl::ArraySlice<int64>,
int32 value) { ++counts[value]; });
@@ -182,16 +182,15 @@ XLA_TEST_F(PrngTest, Uniformity256) {
XLA_TEST_F(PrngTest, MapUsingRng) {
// Build a x -> (x + U[0,1)) computation.
- auto build_sum_rng = [this](ComputationBuilder& builder) {
+ auto build_sum_rng = [this](XlaBuilder& builder) {
auto b = builder.CreateSubBuilder("sum_with_rng");
auto x = b->Parameter(0, ShapeUtil::MakeShape(F32, {}), "input");
- b->Add(x,
- b->RngUniform(b->ConstantR0<float>(0), b->ConstantR0<float>(1),
- ShapeUtil::MakeShape(F32, {})));
+ b->Add(x, b->RngUniform(b->ConstantR0<float>(0), b->ConstantR0<float>(1),
+ ShapeUtil::MakeShape(F32, {})));
return b->BuildAndNoteError();
};
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::unique_ptr<Literal> param0_literal =
Literal::CreateR1<float>({2.2f, 5.3f, 4.4f, 5.5f});
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> param0_data,
@@ -226,7 +225,7 @@ XLA_TEST_F(PrngTest, MapUsingRng) {
XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
// Build a U[0,1) computation.
auto build_computation = [this]() {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
builder.RngUniform(builder.ConstantR0<float>(0),
builder.ConstantR0<float>(1),
ShapeUtil::MakeShape(F32, {10}));
@@ -282,24 +281,24 @@ XLA_TEST_F(PrngTest, PassInGlobalRngSeed) {
}
XLA_TEST_F(PrngTest, TenValuesN01) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
builder.RngNormal(builder.ConstantR0<float>(0), builder.ConstantR0<float>(1),
ShapeUtil::MakeShape(F32, {10}));
SetSeed(42);
- ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
+ ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
// TODO(b/25995601): Test that resultant values are reasonable
}
XLA_TEST_F(PrngTest, RngUniformCrash) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// This used to crash XLA during LLVM IR generation for CPUs.
auto rng_uniform = builder.RngUniform(builder.ConstantR0<int32>(0),
builder.ConstantR0<int32>(1000 * 1000),
ShapeUtil::MakeShape(S32, {}));
SetSeed(0);
- ExecuteAndTransferOrDie(&builder, /*arguments=*/{});
+ ExecuteAndTransfer(&builder, /*arguments=*/{}).ConsumeValueOrDie();
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc
index 212512207c..f95e756483 100644
--- a/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc
+++ b/tensorflow/compiler/xla/tests/query_inferred_shape_test.cc
@@ -15,8 +15,8 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -30,13 +30,13 @@ namespace {
class QueryInferredShapeTest : public ClientLibraryTestBase {};
TEST_F(QueryInferredShapeTest, OnePlusOneShape) {
- ComputationBuilder builder(client_, "one_plus_one");
+ XlaBuilder builder("one_plus_one");
auto one = builder.ConstantR0<float>(1.0);
auto result = builder.Add(one, one);
- StatusOr<std::unique_ptr<Shape>> shape_status = builder.GetShape(result);
+ StatusOr<Shape> shape_status = builder.GetShape(result);
ASSERT_IS_OK(shape_status.status());
auto shape = shape_status.ConsumeValueOrDie();
- ASSERT_TRUE(ShapeUtil::Equal(*shape, ShapeUtil::MakeShape(F32, {})));
+ ASSERT_TRUE(ShapeUtil::Equal(shape, ShapeUtil::MakeShape(F32, {})));
}
} // namespace
diff --git a/tensorflow/compiler/xla/tests/reduce_test.cc b/tensorflow/compiler/xla/tests/reduce_test.cc
index 423ccadb5b..bcc05c2d41 100644
--- a/tensorflow/compiler/xla/tests/reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_test.cc
@@ -35,7 +35,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
@@ -60,10 +59,9 @@ limitations under the License.
namespace xla {
namespace {
-using FuncGeneratorForType = Computation (*)(PrimitiveType,
- ComputationBuilder*);
+using FuncGeneratorForType = XlaComputation (*)(PrimitiveType, XlaBuilder*);
-using FuncGenerator = Computation (*)(ComputationBuilder*);
+using FuncGenerator = XlaComputation (*)(XlaBuilder*);
class ReduceTest : public ClientLibraryTestBase {
protected:
@@ -89,8 +87,8 @@ class ReduceTest : public ClientLibraryTestBase {
// Runs an R1 => R0 reduction test with the given number of elements.
void RunR1ToR0Test(int64 element_count) {
- ComputationBuilder builder(client_, TestName());
- Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ XlaBuilder builder(TestName());
+ XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {element_count});
auto input = builder.Parameter(0, input_shape, "input");
auto zero = builder.ConstantR0<float>(0.0);
@@ -119,13 +117,13 @@ class ReduceTest : public ClientLibraryTestBase {
void RunR1ToR0PredTest(bool and_reduce,
tensorflow::gtl::ArraySlice<int> input_data) {
const int element_count = input_data.size();
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const Shape input_shape = ShapeUtil::MakeShape(S32, {element_count});
auto input_par = builder.Parameter(0, input_shape, "input");
auto pred_values =
builder.Eq(input_par, builder.ConstantR1<int>(element_count, 1));
- ComputationDataHandle init_value;
- Computation reduce;
+ XlaOp init_value;
+ XlaComputation reduce;
if (and_reduce) {
init_value = builder.ConstantR0<bool>(true);
reduce = CreateScalarAndComputation(&builder);
@@ -157,13 +155,13 @@ class ReduceTest : public ClientLibraryTestBase {
template <int64 cols>
void RunR2ToR1PredTest(bool and_reduce, int64 rows, int64 minor = 1,
int64 major = 0) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const Shape input_shape = ShapeUtil::MakeShape(U8, {rows, cols});
auto input = builder.Parameter(0, input_shape, "input");
auto input_pred = builder.Eq(input, builder.ConstantR0<uint8>(1));
- ComputationDataHandle init_value;
- Computation reduce_op;
+ XlaOp init_value;
+ XlaComputation reduce_op;
if (and_reduce) {
init_value = builder.ConstantR0<bool>(true);
reduce_op = CreateScalarAndComputation(&builder);
@@ -202,8 +200,8 @@ class ReduceTest : public ClientLibraryTestBase {
// Runs an R2 => R0 reduction test with the given number of (rows, cols).
void RunR2ToR0Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) {
- ComputationBuilder builder(client_, TestName());
- Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ XlaBuilder builder(TestName());
+ XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
auto input = builder.Parameter(0, input_shape, "input");
auto zero = builder.ConstantR0<float>(0.0);
@@ -230,8 +228,8 @@ class ReduceTest : public ClientLibraryTestBase {
// Runs an R2 => R1 reduction test with the given number of (rows, cols).
void RunR2ToR1Test(int64 rows, int64 cols, int64 minor = 1, int64 major = 0) {
- ComputationBuilder builder(client_, TestName());
- Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ XlaBuilder builder(TestName());
+ XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
auto input = builder.Parameter(0, input_shape, "input");
auto zero = builder.ConstantR0<float>(0.0);
@@ -261,7 +259,7 @@ class ReduceTest : public ClientLibraryTestBase {
template <typename NativeT>
void ComputeAndCompareGeneric(
typename std::enable_if<std::is_floating_point<NativeT>::value,
- ComputationBuilder>::type* builder,
+ XlaBuilder>::type* builder,
tensorflow::gtl::ArraySlice<NativeT> expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
ComputeAndCompareR1<NativeT>(builder, expected, arguments,
@@ -271,7 +269,7 @@ class ReduceTest : public ClientLibraryTestBase {
template <typename NativeT>
void ComputeAndCompareGeneric(
typename std::enable_if<std::is_integral<NativeT>::value,
- ComputationBuilder>::type* builder,
+ XlaBuilder>::type* builder,
tensorflow::gtl::ArraySlice<NativeT> expected,
tensorflow::gtl::ArraySlice<GlobalData*> arguments) {
ComputeAndCompareR1<NativeT>(builder, expected, arguments);
@@ -279,15 +277,15 @@ class ReduceTest : public ClientLibraryTestBase {
template <typename NativeT>
void RunVectorizedReduceTestForType(
- const std::function<Computation(ComputationBuilder*)>&
+ const std::function<XlaComputation(XlaBuilder*)>&
reduction_function_generator,
const std::function<NativeT(NativeT, NativeT)>&
reference_reduction_function,
const NativeT& initial_value) {
const int rows = 64, cols = 128;
const int minor = 1, major = 0;
- ComputationBuilder builder(client_, TestName());
- Computation reduction_function = reduction_function_generator(&builder);
+ XlaBuilder builder(TestName());
+ XlaComputation reduction_function = reduction_function_generator(&builder);
const Shape input_shape = ShapeUtil::MakeShape(
xla::primitive_util::NativeToPrimitiveType<NativeT>(), {rows, cols});
auto input = builder.Parameter(0, input_shape, "input");
@@ -322,7 +320,7 @@ class ReduceTest : public ClientLibraryTestBase {
}
void RunVectorizedReduceTest(
- const std::function<Computation(PrimitiveType, ComputationBuilder*)>&
+ const std::function<XlaComputation(PrimitiveType, XlaBuilder*)>&
reduction_function_generator_for_type,
const std::function<float(float, float)>&
reference_reduction_function_for_floats,
@@ -334,21 +332,21 @@ class ReduceTest : public ClientLibraryTestBase {
uint32 unsigned_int_identity) {
// Float version
RunVectorizedReduceTestForType<float>(
- [&](ComputationBuilder* builder) {
+ [&](XlaBuilder* builder) {
return reduction_function_generator_for_type(F32, builder);
},
reference_reduction_function_for_floats, floating_point_identity);
// Signed int version
RunVectorizedReduceTestForType<int32>(
- [&](ComputationBuilder* builder) {
+ [&](XlaBuilder* builder) {
return reduction_function_generator_for_type(S32, builder);
},
reference_reduction_function_for_ints, signed_int_identity);
// Unsigned int version
RunVectorizedReduceTestForType<uint32>(
- [&](ComputationBuilder* builder) {
+ [&](XlaBuilder* builder) {
return reduction_function_generator_for_type(U32, builder);
},
reference_reduction_function_for_uints, unsigned_int_identity);
@@ -442,8 +440,8 @@ XLA_TEST_F(ReduceTest, OrReduceOnesAndZerosR1_10_Pred) {
XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
const int64 rows = 111, cols = 50;
- ComputationBuilder builder(client_, TestName());
- Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ XlaBuilder builder(TestName());
+ XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
auto input = builder.Parameter(0, input_shape, "input");
auto zero = builder.ConstantR0<float>(0.0);
@@ -473,8 +471,8 @@ XLA_TEST_F(ReduceTest, ReduceElementwiseR2_111x50_To_R1) {
XLA_TEST_F(ReduceTest, TransposeAndReduceElementwiseR2_111x50_To_R1) {
const int64 rows = 111, cols = 50;
- ComputationBuilder builder(client_, TestName());
- Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ XlaBuilder builder(TestName());
+ XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, cols});
auto input = builder.Parameter(0, input_shape, "input");
auto zero = builder.ConstantR0<float>(0.0);
@@ -522,8 +520,8 @@ XLA_TEST_F(ReduceTest, TransposeAndReduceR3_12x111x50_To_R2) {
XLA_TEST_F(ReduceTest, Reshape_111x2x25Reduce_111x50_To_R1) {
const int64 rows = 111, cols = 50;
- ComputationBuilder builder(client_, TestName());
- Computation add_f32 = CreateScalarAddComputation(F32, &builder);
+ XlaBuilder builder(TestName());
+ XlaComputation add_f32 = CreateScalarAddComputation(F32, &builder);
const Shape input_shape = ShapeUtil::MakeShape(F32, {rows, 2, cols / 2});
auto input = builder.Parameter(0, input_shape, "input");
auto zero = builder.ConstantR0<float>(0.0);
@@ -569,7 +567,7 @@ void PrintTo(const BoundsLayout& spec, std::ostream* os) {
// Add-reduces a broadcasted scalar matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto add = CreateScalarAddComputation(F32, &builder);
auto scalar = builder.ConstantR0<float>(42.0);
auto broadcasted = builder.Broadcast(scalar, {500, 500});
@@ -581,7 +579,7 @@ XLA_TEST_F(ReduceTest, AddReduce2DScalarToR0) {
// Max-reduces a broadcasted scalar matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto max = CreateScalarMaxComputation(F32, &builder);
auto scalar = builder.ConstantR0<float>(42.0);
auto broadcasted = builder.Broadcast(scalar, {500, 500});
@@ -593,7 +591,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DScalarToR0) {
// Max-reduces a matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto max = CreateScalarMaxComputation(F32, &builder);
Array2D<float> input(300, 250);
input.FillRandom(214.0f);
@@ -608,7 +606,7 @@ XLA_TEST_F(ReduceTest, MaxReduce2DToR0) {
// Min-reduces matrix among dimension 1 and 0.
XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto min = CreateScalarMinComputation(F32, &builder);
Array2D<float> input(150, 130);
input.FillRandom(214.0f);
@@ -623,7 +621,7 @@ XLA_TEST_F(ReduceTest, MinReduce2DToR0) {
}
XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array2D<uint32> input({{1}, {2}});
auto min = CreateScalarMinComputation(U32, &builder);
auto input_literal = Literal::CreateR2FromArray2D(input);
@@ -636,7 +634,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MinReduce) {
}
XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array2D<uint32> input({{1}, {2}});
auto max = CreateScalarMaxComputation(U32, &builder);
auto input_literal = Literal::CreateR2FromArray2D(input);
@@ -650,7 +648,7 @@ XLA_TEST_F(ReduceTest, UnsignedInt_MaxReduce) {
// Reduces a matrix among dimension 1.
XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto m = builder.ConstantLiteral(*literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1});
@@ -661,7 +659,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong1) {
XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
// Reduce a matrix among dimensions 0 and 1 (sum it up to a scalar).
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto m = builder.ConstantLiteral(*literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1});
@@ -671,7 +669,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmong0and1) {
// Tests 2D matrix ReduceToRow operation.
XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
- ComputationBuilder builder(client_, "reduce_among_y");
+ XlaBuilder builder("reduce_among_y");
auto m = builder.ConstantLiteral(*literal_2d_);
auto add = CreateScalarAddComputation(F32, &builder);
builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0});
@@ -681,7 +679,7 @@ XLA_TEST_F(ReduceTest, Reduce2DAmongY) {
}
XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto m = builder.ConstantLiteral(*literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1, 2});
@@ -691,7 +689,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_1_2) {
}
XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto m = builder.ConstantLiteral(*literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1});
@@ -701,7 +699,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDims_0_1) {
}
XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto m = builder.ConstantLiteral(*literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0, 1, 2});
@@ -711,7 +709,7 @@ XLA_TEST_F(ReduceTest, ReduceR3ToR0) {
}
XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto m = builder.ConstantLiteral(*literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {0});
@@ -726,7 +724,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim0) {
}
XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto m = builder.ConstantLiteral(*literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {1});
@@ -743,7 +741,7 @@ XLA_TEST_F(ReduceTest, ReduceR3AmongDim1) {
}
XLA_TEST_F(ReduceTest, ReduceR3AmongDim2) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto m = builder.ConstantLiteral(*literal_3d_);
auto add = CreateScalarAddComputation(F32, &builder);
builder.Reduce(m, builder.ConstantR0<float>(0.0f), add, {2});
@@ -817,7 +815,7 @@ class ReduceR3ToR2Test : public ReduceTest,
public ::testing::WithParamInterface<BoundsLayout> {};
XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
const auto& bounds = GetParam().bounds;
Array3D<float> input_array(bounds[0], bounds[1], bounds[2]);
// input_array.FillRandom(3.14f, 0.05);
@@ -831,7 +829,7 @@ XLA_TEST_P(ReduceR3ToR2Test, ReduceR3ToR2) {
auto input_activations =
builder.Parameter(0, input_literal->shape(), "input");
- Computation add = CreateScalarAddComputation(F32, &builder);
+ XlaComputation add = CreateScalarAddComputation(F32, &builder);
auto sum = builder.Reduce(input_activations, builder.ConstantR0<float>(0.0f),
add, GetParam().reduce_dims);
@@ -871,8 +869,8 @@ INSTANTIATE_TEST_CASE_P(
// IrEmitterUnnested::EmitInitializer() for the Reduce operator. Failed on
// 2017-07-26.
XLA_TEST_F(ReduceTest, DISABLED_ON_GPU(OperationOnConstantAsInitValue)) {
- ComputationBuilder builder(client_, TestName());
- Computation max_f32 = CreateScalarMaxComputation(F32, &builder);
+ XlaBuilder builder(TestName());
+ XlaComputation max_f32 = CreateScalarMaxComputation(F32, &builder);
auto a = builder.ConstantR0<float>(2.0f);
auto a2 = builder.Abs(a);
@@ -899,8 +897,8 @@ class ReduceInitializerTest : public ReduceTest {
protected:
template <typename T>
void DoTest(T initializer, int num_elems) {
- ComputationBuilder builder(client_, TestName());
- Computation max_fn = CreateScalarMaxComputation(
+ XlaBuilder builder(TestName());
+ XlaComputation max_fn = CreateScalarMaxComputation(
primitive_util::NativeToPrimitiveType<T>(), &builder);
auto init = builder.ConstantR0<T>(initializer);
@@ -940,7 +938,7 @@ XLA_TEST_F(ReduceInitializerTest, U64InitializerBigValue) {
// returns one of the parameters). In this case, we return the rhs, which for
// a 1D array with one element, should not be the init value.
XLA_TEST_F(ReduceTest, ReduceIdentity) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Shape single_float = ShapeUtil::MakeShape(F32, {});
builder.Parameter(0, single_float, "lhs-unused");
builder.Parameter(1, single_float, "rhs-used");
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 0a09766722..10a3da3a38 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -861,8 +861,7 @@ INSTANTIATE_TEST_CASE_P(
class R4ReduceWindowAnyDimsTest : public R4ReduceWindowTest {};
// TODO(b/72234705): Fix the test cases failed on CPU and GPU.
-XLA_TEST_P(R4ReduceWindowAnyDimsTest,
- DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) {
+XLA_TEST_P(R4ReduceWindowAnyDimsTest, DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt))) {
DoIt();
}
@@ -1151,7 +1150,7 @@ class R2ReduceWindowFailingCpuGpuBf16Test : public R2ReduceWindowTest {};
// TODO(b/72234705): Fix the test cases failed on CPU and GPU.
XLA_TEST_P(R2ReduceWindowFailingCpuGpuBf16Test,
- DISABLED_ON_CPU_PARALLEL(DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt)))) {
+ DISABLED_ON_CPU(DISABLED_ON_GPU(DoIt))) {
DoIt();
}
diff --git a/tensorflow/compiler/xla/tests/replay_test.cc b/tensorflow/compiler/xla/tests/replay_test.cc
index 6d063ffc36..36d763b0f7 100644
--- a/tensorflow/compiler/xla/tests/replay_test.cc
+++ b/tensorflow/compiler/xla/tests/replay_test.cc
@@ -15,13 +15,13 @@ limitations under the License.
#include <memory>
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/protobuf_util.h"
-#include "tensorflow/compiler/xla/service/session.pb.h"
+#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
@@ -38,17 +38,17 @@ class ReplayTest : public ClientLibraryTestBase {};
TEST_F(ReplayTest, TwoPlusTwoReplay) {
// Make 2+2 computation.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto two = builder.ConstantR0<int32>(2);
builder.Add(two, two);
- Computation computation = builder.Build().ConsumeValueOrDie();
+ XlaComputation computation = builder.Build().ConsumeValueOrDie();
// Serialize it out.
- std::unique_ptr<SessionModule> module =
+ std::unique_ptr<HloSnapshot> module =
computation.Snapshot().ConsumeValueOrDie();
// Replay it.
- Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
+ XlaComputation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
// Check signature is the same.
std::unique_ptr<ProgramShape> original_shape =
@@ -69,18 +69,18 @@ TEST_F(ReplayTest, TwoPlusTwoReplay) {
XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
// Make computation.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "x");
auto y = builder.Parameter(1, ShapeUtil::MakeShape(S32, {}), "y");
builder.Add(x, y);
- Computation computation = builder.Build().ConsumeValueOrDie();
+ XlaComputation computation = builder.Build().ConsumeValueOrDie();
// Serialize it out.
- std::unique_ptr<SessionModule> module =
+ std::unique_ptr<HloSnapshot> module =
computation.Snapshot().ConsumeValueOrDie();
// Replay it.
- Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
+ XlaComputation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
// Check signature is the same.
std::unique_ptr<ProgramShape> original_shape =
@@ -109,24 +109,24 @@ XLA_TEST_F(ReplayTest, XPlusYReplayWithParameters) {
TEST_F(ReplayTest, MapPlusTwoOverR1) {
// As above, but with map(+2) over some constant array.
- ComputationBuilder plus_two_builder(client_, "plus two");
+ XlaBuilder plus_two_builder("plus two");
auto input =
plus_two_builder.Parameter(0, ShapeUtil::MakeShape(S32, {}), "input");
plus_two_builder.Add(input, plus_two_builder.ConstantR0<int32>(2));
- Computation plus_two = plus_two_builder.Build().ConsumeValueOrDie();
+ XlaComputation plus_two = plus_two_builder.Build().ConsumeValueOrDie();
- ComputationBuilder mapper_builder(client_, TestName());
+ XlaBuilder mapper_builder(TestName());
auto original = mapper_builder.ConstantR1<int32>({1, 2, 3});
mapper_builder.Map({original}, plus_two, {0});
- Computation computation = mapper_builder.Build().ConsumeValueOrDie();
+ XlaComputation computation = mapper_builder.Build().ConsumeValueOrDie();
// Serialize it out.
- std::unique_ptr<SessionModule> module =
+ std::unique_ptr<HloSnapshot> module =
computation.Snapshot().ConsumeValueOrDie();
// Replay it.
- Computation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
+ XlaComputation replayed = client_->LoadSnapshot(*module).ConsumeValueOrDie();
// Check signature is the same.
std::unique_ptr<ProgramShape> original_shape =
@@ -135,10 +135,6 @@ TEST_F(ReplayTest, MapPlusTwoOverR1) {
client_->GetComputationShape(replayed).ConsumeValueOrDie();
ASSERT_TRUE(protobuf_util::ProtobufEquals(*original_shape, *replayed_shape));
- // Destroy the originals.
- computation.Reset();
- plus_two.Reset();
-
// Run it.
std::unique_ptr<Literal> literal =
client_
diff --git a/tensorflow/compiler/xla/tests/reshape_motion_test.cc b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
index e045e164e2..5ebd526899 100644
--- a/tensorflow/compiler/xla/tests/reshape_motion_test.cc
+++ b/tensorflow/compiler/xla/tests/reshape_motion_test.cc
@@ -20,10 +20,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array4d.h"
-#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/reference_util.h"
@@ -45,7 +44,7 @@ namespace {
using ReshapeMotionTest = ClientLibraryTestBase;
TEST_F(ReshapeMotionTest, ElementwiseOfReshapesWithNonSameInputShapes) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto a = builder.ConstantR2<int32>({{2, 3, 5}, {7, 11, 13}});
auto b = builder.ConstantR2<int32>({{17, 19}, {23, 29}, {31, 37}});
auto c = builder.Reshape(a, {6});
diff --git a/tensorflow/compiler/xla/tests/reverse_test.cc b/tensorflow/compiler/xla/tests/reverse_test.cc
index 6959c95502..e7bd142dc9 100644
--- a/tensorflow/compiler/xla/tests/reverse_test.cc
+++ b/tensorflow/compiler/xla/tests/reverse_test.cc
@@ -114,7 +114,7 @@ class ReverseTest : public ClientLibraryTestBase {};
// Tests the reverse operation on a 4D U8 array on dimension 0 and 3.
XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
// Input shape is U8[1x2x3x4].
// clang-format off
Array4D<uint8> input({{
@@ -144,7 +144,7 @@ XLA_TEST_F(ReverseTest, Reverse4DU8ArrayOnDim23) {
// Tests the reverse operation on a 4D float array on dimension 0 and 1.
TEST_F(ReverseTest, Reverse4DFloatArrayOnDim01) {
- ComputationBuilder b(client_, TestName());
+ XlaBuilder b(TestName());
// Input shape is float[4x3x2x1].
// clang-format off
Array4D<float> input({
diff --git a/tensorflow/compiler/xla/tests/test_macros.h b/tensorflow/compiler/xla/tests/test_macros.h
index e2d406f66d..7ca99a9163 100644
--- a/tensorflow/compiler/xla/tests/test_macros.h
+++ b/tensorflow/compiler/xla/tests/test_macros.h
@@ -34,7 +34,6 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
#define DISABLED_ON_CPU(X) X
-#define DISABLED_ON_CPU_PARALLEL(X) X
#define DISABLED_ON_GPU(X) X
#define DISABLED_ON_INTERPRETER(X) X
@@ -51,13 +50,6 @@ limitations under the License.
# define DISABLED_ON_CPU(X) XLA_TEST_PASTE(DISABLED_, X)
#endif // XLA_TEST_BACKEND_CPU
-#ifdef XLA_TEST_BACKEND_CPU_PARALLEL
-# undef DISABLED_ON_CPU
-# define DISABLED_ON_CPU(X) XLA_TEST_PASTE(DISABLED_, X)
-# undef DISABLED_ON_CPU_PARALLEL
-# define DISABLED_ON_CPU_PARALLEL(X) XLA_TEST_PASTE(DISABLED_, X)
-#endif // XLA_TEST_BACKEND_CPU_PARALLEL
-
#ifdef XLA_TEST_BACKEND_GPU
# undef DISABLED_ON_GPU
# define DISABLED_ON_GPU(X) XLA_TEST_PASTE(DISABLED_, X)
diff --git a/tensorflow/compiler/xla/tests/test_utils_test.cc b/tensorflow/compiler/xla/tests/test_utils_test.cc
index e8efc6e2a8..59afd28a80 100644
--- a/tensorflow/compiler/xla/tests/test_utils_test.cc
+++ b/tensorflow/compiler/xla/tests/test_utils_test.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/test_utils.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/tests/local_client_test_base.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@@ -28,7 +28,7 @@ namespace {
class TestUtilsTest : public LocalClientTestBase {};
XLA_TEST_F(TestUtilsTest, UnusedParam) {
- ComputationBuilder builder(local_client_, TestName());
+ XlaBuilder builder(TestName());
// Make the reduction lambda.
Shape single_float = ShapeUtil::MakeShape(F32, {});
builder.Parameter(0, single_float, "unused");
diff --git a/tensorflow/compiler/xla/tests/tuple_test.cc b/tensorflow/compiler/xla/tests/tuple_test.cc
index 61d0fa02ab..61be174653 100644
--- a/tensorflow/compiler/xla/tests/tuple_test.cc
+++ b/tensorflow/compiler/xla/tests/tuple_test.cc
@@ -269,7 +269,7 @@ XLA_TEST_F(TupleTest, TupleGTEToTupleToGTEAdd) {
ComputeAndCompareR2<float>(&builder, expected, {}, error_spec_);
}
-XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnFalse)) {
+XLA_TEST_F(TupleTest, SelectBetweenTuplesOnFalse) {
// Tests a selection between tuples with "false" path taken.
XlaBuilder builder(TestName());
@@ -313,7 +313,7 @@ XLA_TEST_F(TupleTest, TuplesInAMap) {
ComputeAndCompareR1<float>(&b, {-99.0f, 101.0f, 214.41f}, {}, error_spec_);
}
-XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesOnTrue)) {
+XLA_TEST_F(TupleTest, SelectBetweenTuplesOnTrue) {
// Tests a selection between tuples with "true" path taken.
XlaBuilder builder(TestName());
@@ -350,7 +350,7 @@ XLA_TEST_F(TupleTest, SelectBetweenTuplesElementResult) {
}
// Cascaded selects between tuple types.
-XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) {
+XLA_TEST_F(TupleTest, SelectBetweenTuplesCascaded) {
//
// vec1 vec2 vec2 vec1
// | | | |
@@ -390,8 +390,7 @@ XLA_TEST_F(TupleTest, DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesCascaded)) {
ComputeAndCompareR1<float>(&builder, {3.f, 6.f, 9.f}, {}, error_spec_);
}
-XLA_TEST_F(TupleTest,
- DISABLED_ON_CPU_PARALLEL(SelectBetweenTuplesReuseConstants)) {
+XLA_TEST_F(TupleTest, SelectBetweenTuplesReuseConstants) {
// Similar to SelectBetweenTuples, but the constants are shared between the
// input tuples.
XlaBuilder builder(TestName());
@@ -516,10 +515,8 @@ XLA_TEST_F(TupleTest, ComplexTuples) {
class TupleHloTest : public HloTestBase {};
-// Disabled on CPU parallel because that's broken and will be removed soon.
// Disabled on the interpreter because bitcast doesn't exist on the interpreter.
-TEST_F(TupleHloTest,
- DISABLED_ON_INTERPRETER(DISABLED_ON_CPU_PARALLEL(BitcastAfterGTE))) {
+TEST_F(TupleHloTest, DISABLED_ON_INTERPRETER(BitcastAfterGTE)) {
const char* testcase = R"(
HloModule m
diff --git a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
index 32ba067a10..82d301983f 100644
--- a/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
+++ b/tensorflow/compiler/xla/tests/vector_ops_reduce_test.cc
@@ -19,9 +19,9 @@ limitations under the License.
#include "tensorflow/compiler/xla/array2d.h"
#include "tensorflow/compiler/xla/array3d.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include "tensorflow/compiler/xla/tests/client_library_test_base.h"
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
@@ -33,9 +33,9 @@ namespace {
class VecOpsReduceTest : public ClientLibraryTestBase {
public:
- VecOpsReduceTest() : builder_(client_, TestName()) {}
+ VecOpsReduceTest() : builder_(TestName()) {}
- ComputationDataHandle BuildSampleConstantCube() {
+ XlaOp BuildSampleConstantCube() {
// clang-format off
Array3D<float> x3d({
{{1.0, 2.0, 3.0}, // | dim 1 // } plane 0 in dim 0
@@ -49,7 +49,7 @@ class VecOpsReduceTest : public ClientLibraryTestBase {
return builder_.ConstantR3FromArray3D<float>(x3d);
}
- ComputationBuilder builder_;
+ XlaBuilder builder_;
ErrorSpec errspec_{1e-3, 0};
};
diff --git a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
index 697d78fe6e..3dded3f715 100644
--- a/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
+++ b/tensorflow/compiler/xla/tests/vector_ops_simple_test.cc
@@ -19,10 +19,11 @@ limitations under the License.
#include "tensorflow/compiler/xla/array4d.h"
#include "tensorflow/compiler/xla/client/computation.h"
-#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/global_data.h"
#include "tensorflow/compiler/xla/client/lib/arithmetic.h"
#include "tensorflow/compiler/xla/client/local_client.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
+#include "tensorflow/compiler/xla/client/xla_client/xla_computation.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test_helpers.h"
@@ -49,7 +50,7 @@ class VecOpsSimpleTest : public ClientLibraryTestBase {
};
XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
auto exp = builder.Exp(x);
@@ -63,7 +64,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpTenValues) {
XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) {
for (int count : {63, 64, 65, 127, 128, 129, 17 * 4096}) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
std::vector<float> exponents;
exponents.reserve(count);
for (int i = 0; i < count; ++i) {
@@ -84,7 +85,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpManyValues) {
}
XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
Array4D<float> exponents(2, 2, 2, 2);
std::vector<float> exponents_vector;
@@ -106,7 +107,7 @@ XLA_TEST_F(VecOpsSimpleTest, ExpIn4D) {
}
XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
builder.Neg(x);
@@ -117,7 +118,7 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenFloatValues) {
}
XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<int32>({2, -2, 12, -4, 5, 20, -15, 0, -2, 1});
builder.Neg(x);
@@ -126,7 +127,7 @@ XLA_TEST_F(VecOpsSimpleTest, NegateTenInt32Values) {
}
XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<uint32>(
{0, 1, 42, static_cast<uint32>(-1), static_cast<uint32>(-12)});
builder.Neg(x);
@@ -136,7 +137,7 @@ XLA_TEST_F(VecOpsSimpleTest, NegateUint32Values) {
}
XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
builder.SquareF32(x);
@@ -147,7 +148,7 @@ XLA_TEST_F(VecOpsSimpleTest, SquareTenValues) {
}
XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
builder.ReciprocalF32(x);
@@ -159,7 +160,7 @@ XLA_TEST_F(VecOpsSimpleTest, ReciprocalTenValues) {
}
XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>({0.0, -0.0});
auto exp = builder.SqrtF32(x);
@@ -167,7 +168,7 @@ XLA_TEST_F(VecOpsSimpleTest, SqrtZeroes) {
}
XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>({16.0, 1.0, 1024.0, 0.16, 0.2, 12345});
auto exp = builder.SqrtF32(x);
@@ -176,7 +177,7 @@ XLA_TEST_F(VecOpsSimpleTest, SqrtSixValues) {
}
XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x =
builder.ConstantR1<float>({16.0, 1.0, 1024.0, 0.16, 0.2, 12345, 1.2345});
auto exp = builder.Pow(x, builder.ConstantR0<float>(-.5f));
@@ -188,7 +189,7 @@ XLA_TEST_F(VecOpsSimpleTest, InvSqrtSevenValues) {
}
XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto add = CreateScalarAddComputation(F32, &builder);
auto x = builder.ConstantR1<float>(
@@ -203,7 +204,7 @@ XLA_TEST_F(VecOpsSimpleTest, AddTenValuesViaMap) {
}
XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
auto y = builder.ConstantR1<float>(
@@ -218,8 +219,8 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValues) {
XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) {
// Similar to MaxTenValues, except that the inputs come from params rather
// than constants.
- ComputationBuilder builder(client_, TestName());
- ComputationDataHandle v1, v2;
+ XlaBuilder builder(TestName());
+ XlaOp v1, v2;
std::unique_ptr<GlobalData> param0_data = CreateR1Parameter<float>(
{41.0f, 2.0f, 3.0f, 84.0f}, /*parameter_number=*/0, /*name=*/"v1",
/*builder=*/&builder, /*data_handle=*/&v1);
@@ -236,7 +237,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesFromParams) {
XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) {
// Similar to MaxTenValuesFromParams, except that the data size passed in and
// out is large.
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
// Number of floats in the data passed into and out of the computation.
constexpr int datalen = 15 * 1000;
@@ -259,7 +260,7 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) {
expected_vec.push_back(larger);
}
- ComputationDataHandle v1, v2;
+ XlaOp v1, v2;
std::unique_ptr<GlobalData> param0_data =
CreateR1Parameter<float>(v1vec, /*parameter_number=*/0, /*name=*/"v1",
/*builder=*/&builder, /*data_handle=*/&v1);
@@ -274,7 +275,7 @@ XLA_TEST_F(VecOpsSimpleTest, Max15000ValuesFromParams) {
}
XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
auto y = builder.ConstantR0<float>(0);
@@ -286,7 +287,7 @@ XLA_TEST_F(VecOpsSimpleTest, MaxTenValuesWithScalar) {
}
XLA_TEST_F(VecOpsSimpleTest, MinTenValues) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<float>(
{2.1, -2.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
auto y = builder.ConstantR1<float>(
@@ -299,7 +300,7 @@ XLA_TEST_F(VecOpsSimpleTest, MinTenValues) {
}
XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto zero = builder.ConstantR0<float>(0);
auto one = builder.ConstantR0<float>(1);
auto x = builder.ConstantR1<float>(
@@ -312,7 +313,7 @@ XLA_TEST_F(VecOpsSimpleTest, MinMaxTenValues) {
}
XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto zero = builder.ConstantR0<float>(0);
auto one = builder.ConstantR0<float>(1);
auto x = builder.ConstantR1<float>(
@@ -325,7 +326,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstant) {
}
XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto zero = builder.ConstantR1<float>({0.0f, 0.0f});
auto one = builder.ConstantR1<float>({1.0f, 1.0f});
auto x = builder.ConstantR1<float>({2.1, -2.6});
@@ -336,7 +337,7 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTwoValuesConstant) {
}
XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto one = builder.ConstantR0<float>(1);
auto two = builder.ConstantR0<float>(2);
auto x = builder.ConstantR1<float>(
@@ -348,11 +349,22 @@ XLA_TEST_F(VecOpsSimpleTest, ClampTenValuesConstantNonzeroLower) {
ComputeAndCompareR1<float>(&builder, expected, {});
}
+XLA_TEST_F(VecOpsSimpleTest, ClampValuesConstantS64) {
+ ComputationBuilder builder(client_, TestName());
+ auto zero = builder.ConstantR0<int64>(0);
+ auto one = builder.ConstantR0<int64>(10);
+ auto x = builder.ConstantR1<int64>({-3, 3, 9, 13});
+ auto clamp = builder.Clamp(zero, x, one);
+
+ std::vector<int64> expected = {0, 3, 9, 10};
+ ComputeAndCompareR1<int64>(&builder, expected, {});
+}
+
XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
- Computation add_half;
+ XlaComputation add_half;
{
// add_half(x) = x + 0.5
- ComputationBuilder builder(client_, "add_half");
+ XlaBuilder builder("add_half");
auto x_value =
builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "x_value");
auto half = builder.ConstantR0<float>(0.5);
@@ -362,10 +374,10 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
add_half = computation_status.ConsumeValueOrDie();
}
- Computation clamp;
+ XlaComputation clamp;
{
// clamp(y) = clamp<0,5>(y)
- ComputationBuilder builder(client_, "clamp");
+ XlaBuilder builder("clamp");
auto y_value =
builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "y_value");
auto zero = builder.ConstantR0<float>(0.0);
@@ -375,10 +387,10 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
clamp = computation_status.ConsumeValueOrDie();
}
- Computation mult_relu_add;
+ XlaComputation mult_relu_add;
{
// mult_relu_add(z) = clamp(add_half(2 * max(z, 0)))
- ComputationBuilder builder(client_, "mult_relu_add");
+ XlaBuilder builder("mult_relu_add");
auto z_value =
builder.Parameter(0, ShapeUtil::MakeShape(F32, {}), "z_value");
auto zero = builder.ConstantR0<float>(0.0);
@@ -392,7 +404,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
mult_relu_add = computation_status.ConsumeValueOrDie();
}
- ComputationBuilder builder(client_, "map10");
+ XlaBuilder builder("map10");
{
auto x = builder.ConstantR1<float>(
{2.1, -21.6, 2.6, -4.0, 2.1, 2.3, -5.0, -0.9, -2.4, 1.6});
@@ -405,7 +417,7 @@ XLA_TEST_F(VecOpsSimpleTest, MapTenValues) {
}
XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<int32>({-5, -4, -3, -2, -1, 0, 1, 2, 3, 4});
auto y = builder.ConstantR0<int32>(3);
builder.Rem(x, y);
@@ -415,7 +427,7 @@ XLA_TEST_F(VecOpsSimpleTest, RemainderTenValuesS32) {
}
XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<bool>({false, true});
auto y = builder.ConstantR1<bool>({true, false});
builder.Eq(x, y);
@@ -425,7 +437,7 @@ XLA_TEST_F(VecOpsSimpleTest, VectorPredicateEqual) {
}
XLA_TEST_F(VecOpsSimpleTest, VectorPredicateNotEqual) {
- ComputationBuilder builder(client_, TestName());
+ XlaBuilder builder(TestName());
auto x = builder.ConstantR1<bool>({false, true});
auto y = builder.ConstantR1<bool>({true, false});
builder.Ne(x, y);
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 1e18b56799..336fed27c6 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -1321,10 +1321,6 @@ void BM_WhileLoop(int num_iters) {
}
}
-// TODO(b/32470510): Benchmark fails on parallel CPU backend.
-#ifndef XLA_TEST_BACKEND_CPU_PARALLEL
BENCHMARK(BM_WhileLoop);
-#endif
-
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
index 837a01e873..8354bb71cb 100644
--- a/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
+++ b/tensorflow/compiler/xla/tests/xla_hlo_profile_test.cc
@@ -175,8 +175,7 @@ void ExecuteAndFetchProfile(string* profile_output, LocalClient* client,
XLA_VLOG_LINES(4, *profile_output);
}
-// TODO(b/71364943): This test exposes a bug in the parallel CPU backend.
-XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) {
+XLA_TEST_F(HloProfileTest, ProfileSingleComputation) {
const int64 m = 256, k = 256, n = 256;
Shape lhs_shape = ShapeUtil::MakeShape(F32, {m, k});
Shape rhs_shape = ShapeUtil::MakeShape(F32, {m, k});
@@ -239,12 +238,9 @@ XLA_TEST_F(HloProfileTest, DISABLED_ON_CPU_PARALLEL(ProfileSingleComputation)) {
EXPECT_TRUE(HasTrops(tanh_profile));
}
-// TODO(b/71364943): This test exposes a bug in the parallel CPU backend.
-//
// TODO(b/71544591): The GPU backend does not record cycles spent in on Hlo
// instructions "interior" to while nodes.
-XLA_TEST_F(HloProfileTest,
- DISABLED_ON_GPU(DISABLED_ON_CPU_PARALLEL(ProfileWhileComputation))) {
+XLA_TEST_F(HloProfileTest, DISABLED_ON_GPU(ProfileWhileComputation)) {
const int64 size = 256;
Shape matrix_shape = ShapeUtil::MakeShape(F32, {size, size});
Shape while_result_shape =
diff --git a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
index 95d3fd28b3..fdbfc0210e 100644
--- a/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
+++ b/tensorflow/compiler/xla/tools/parser/hlo_parser.cc
@@ -303,18 +303,14 @@ bool HloParser::ParseComputations() {
// set the layouts to what the hlo text says.
for (int p = 0; p < computation->num_parameters(); p++) {
const Shape& param_shape = computation->parameter_instruction(p)->shape();
- if (param_shape.has_layout()) {
- module_->mutable_entry_computation_layout()
- ->mutable_parameter_layout(p)
- ->ResetLayout(param_shape.layout());
- }
+ TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ ->mutable_parameter_layout(p)
+ ->CopyLayoutFromShape(param_shape));
}
const Shape& result_shape = computation->root_instruction()->shape();
- if (result_shape.has_layout()) {
- module_->mutable_entry_computation_layout()
- ->mutable_result_layout()
- ->ResetLayout(result_shape.layout());
- }
+ TF_CHECK_OK(module_->mutable_entry_computation_layout()
+ ->mutable_result_layout()
+ ->CopyLayoutFromShape(result_shape));
}
return true;
diff --git a/tensorflow/compiler/xla/window_util.cc b/tensorflow/compiler/xla/window_util.cc
index 93284b80f9..f11123ca24 100644
--- a/tensorflow/compiler/xla/window_util.cc
+++ b/tensorflow/compiler/xla/window_util.cc
@@ -199,6 +199,9 @@ bool IsInactiveWindowDimension(const Window& window, int64 logical_dim) {
int64 DilatedBound(int64 bound, int64 dilation) {
CHECK_GE(bound, 0);
CHECK_GE(dilation, 1);
+ if (bound == 0) {
+ return 0;
+ }
// Suppose the array has three entries 123 and the dilation factor is 4. Then
// the dilated array has 9 entries 1xxx2xxx3. Here, each original entry except
@@ -212,7 +215,7 @@ int64 StridedBound(int64 bound, int64 window_size, int64 stride) {
CHECK_GE(bound, 0);
CHECK_GE(stride, 1);
- if (window_size > bound) {
+ if (bound == 0 || window_size > bound) {
return 0;
}
diff --git a/tensorflow/contrib/BUILD b/tensorflow/contrib/BUILD
index d28392a62c..abdbdb4cd2 100644
--- a/tensorflow/contrib/BUILD
+++ b/tensorflow/contrib/BUILD
@@ -29,8 +29,9 @@ py_library(
"//tensorflow/contrib/cloud:cloud_py",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_py",
- "//tensorflow/contrib/coder:coder_ops_py",
+ "//tensorflow/contrib/coder:coder_py",
"//tensorflow/contrib/compiler:compiler_py",
+ "//tensorflow/contrib/constrained_optimization",
"//tensorflow/contrib/copy_graph:copy_graph_py",
"//tensorflow/contrib/crf:crf_py",
"//tensorflow/contrib/cudnn_rnn:cudnn_rnn_py",
diff --git a/tensorflow/contrib/__init__.py b/tensorflow/contrib/__init__.py
index 0d163daa6e..7f33d460dc 100644
--- a/tensorflow/contrib/__init__.py
+++ b/tensorflow/contrib/__init__.py
@@ -29,6 +29,7 @@ from tensorflow.contrib import cloud
from tensorflow.contrib import cluster_resolver
from tensorflow.contrib import coder
from tensorflow.contrib import compiler
+from tensorflow.contrib import constrained_optimization
from tensorflow.contrib import copy_graph
from tensorflow.contrib import crf
from tensorflow.contrib import cudnn_rnn
diff --git a/tensorflow/contrib/all_reduce/python/all_reduce.py b/tensorflow/contrib/all_reduce/python/all_reduce.py
index 8add2aacff..159d985db5 100644
--- a/tensorflow/contrib/all_reduce/python/all_reduce.py
+++ b/tensorflow/contrib/all_reduce/python/all_reduce.py
@@ -18,10 +18,11 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import math
-import re
from tensorflow.contrib import nccl
+from tensorflow.python.framework import device as device_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
@@ -659,21 +660,20 @@ def _split_by_task(devices, values):
num_devices = len(devices)
if num_devices != len(values):
raise ValueError("len(devices) must equal len(values)")
- pattern = re.compile(r"/task:(\d+)/")
- per_task_devices = []
- per_task_values = []
+ per_task_devices = collections.OrderedDict()
+ per_task_values = collections.OrderedDict()
for d in range(num_devices):
- m = pattern.search(devices[d])
- if m:
- index = int(m.group(1))
- while index >= len(per_task_devices):
- per_task_devices.append([])
- per_task_values.append([])
- per_task_devices[index].append(devices[d])
- per_task_values[index].append(values[d])
- else:
+ d_spec = device_lib.DeviceSpec.from_string(devices[d])
+ if not hasattr(d_spec, "task") or d_spec.task is None:
assert False, "failed to parse device %s" % devices[d]
- return (per_task_devices, per_task_values)
+ index = (d_spec.job or "localhost", d_spec.replica or 0, d_spec.task)
+ if index not in per_task_devices:
+ per_task_devices[index] = []
+ per_task_values[index] = []
+ per_task_devices[index].append(devices[d])
+ per_task_values[index].append(values[d])
+
+ return (list(per_task_devices.values()), list(per_task_values.values()))
def build_nccl_all_reduce(input_tensors, red_op, un_op=None):
diff --git a/tensorflow/contrib/autograph/README.md b/tensorflow/contrib/autograph/README.md
index 7e84f237dc..0fcbf5dd59 100644
--- a/tensorflow/contrib/autograph/README.md
+++ b/tensorflow/contrib/autograph/README.md
@@ -1,4 +1,117 @@
-# Autograph
+# AutoGraph
-A compiler for generating TensorFlow numeric and control flow ops from Python
-code.
+IMPORTANT: AutoGraph is pre-alpha, under active development. Expect rough edges and bugs, but if you try it, we appreciate early feedback!
+
+AutoGraph is a Python to TensorFlow compiler.
+
+With AutoGraph, you can write [Eager style](https://www.tensorflow.org/programmers_guide/eager) code in a concise manner, and run it as a TensorFlow graph. AutoGraph uses source code transformation and partial evaluation to generate Python code that builds an equivalent TensorFlow subgraph. The result is code that behaves like ops and can be freely combined with other TensorFlow ops.
+
+For example, this Python function:
+
+```
+def f(x):
+ if x < 0:
+ x = -x
+ return x
+```
+
+would be converted to this:
+
+```
+def graph_mode_f(x):
+ with tf.name_scope('f'):
+
+ def if_true():
+ with tf.name_scope('if_true'):
+ x_1, = x,
+ x_1 = tf.negative(x_1)
+ return x_1,
+
+ def if_false():
+ with tf.name_scope('if_false'):
+ x_1, = x,
+ return x_1,
+ x = ag__.utils.run_cond(tf.greater(x, 0), if_true, if_false)
+ return x
+```
+
+so you can use it like an op:
+
+```
+with tf.Graph().as_default():
+ x = tf.constant(-1.0)
+
+ converted_f = autograph.to_graph(f)
+ y = converted_f(x)
+
+ with tf.Session() as sess:
+ print(sess.run(y))
+ # Output: 1
+```
+
+# Getting started
+
+Use AutoGraph in one of the following ways, described below:
+
+ 1. Annotations (simpler)
+ 2. Functional API (more flexible)
+
+NOTE: You can find more examples in this [interactive notebook](https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/autograph/examples/notebooks/dev_summit_2018_demo.ipynb).
+
+To get started, install the latest nightly TensorFlow build:
+
+```shell
+pip install -U tf-nightly
+```
+
+Then import the `autograph` module from `tf.contrib`:
+
+```
+from tensorflow.contrib import autograph as ag
+```
+
+## Using with annotations
+
+Annotating a function or class with `@convert` converts it in place:
+
+```
+@ag.convert()
+def f(x):
+ if x < 0:
+ x = -x
+ return x
+```
+
+... so that it always outputs TensorFlow code:
+
+```
+with tf.Graph().as_default():
+ x = tf.constant(-1)
+
+ y = f(x)
+
+ with tf.Session() as sess:
+ print(sess.run(y))
+ # Output: 1
+```
+
+## Using the functional API
+
+The functional API allows you to convert an existing function, class or object after it was defined:
+
+```
+converted_f = ag.to_graph(f)
+
+print(converted_f(tf.constant(-1)))
+# Output: Tensor
+
+print(f(-1))
+# Output: 1
+```
+
+You can use the functional API to inspect the generated code as well:
+
+```
+print(ag.to_code(f))
+# Output: <Python and TensorFlow code>
+```
diff --git a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
index 7df514cd20..9d6cc9245a 100644
--- a/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
+++ b/tensorflow/contrib/boosted_trees/lib/learner/batch/ordinal_split_handler.py
@@ -417,9 +417,18 @@ class SparseSplitHandler(InequalitySplitHandler):
return (are_splits_ready, partition_ids, gains, split_infos)
-@function.Defun(dtypes.bool, dtypes.bool, dtypes.float32, dtypes.float32,
- dtypes.int32, dtypes.float32, dtypes.float32, dtypes.float32,
- dtypes.float32, dtypes.float32)
+@function.Defun(
+ dtypes.bool,
+ dtypes.bool,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.int32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ noinline=True)
def dense_make_stats_update(is_active, are_buckets_ready, float_column,
quantile_buckets, example_partition_ids, gradients,
hessians, weights, empty_gradients, empty_hessians):
@@ -452,9 +461,20 @@ def dense_make_stats_update(is_active, are_buckets_ready, float_column,
gradients, hessians)
-@function.Defun(dtypes.bool, dtypes.bool, dtypes.int64, dtypes.float32,
- dtypes.int64, dtypes.float32, dtypes.int32, dtypes.float32,
- dtypes.float32, dtypes.float32, dtypes.float32, dtypes.float32)
+@function.Defun(
+ dtypes.bool,
+ dtypes.bool,
+ dtypes.int64,
+ dtypes.float32,
+ dtypes.int64,
+ dtypes.float32,
+ dtypes.int32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ dtypes.float32,
+ noinline=True)
def sparse_make_stats_update(
is_active, are_buckets_ready, sparse_column_indices, sparse_column_values,
sparse_column_shape, quantile_buckets, example_partition_ids, gradients,
diff --git a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
index 4bde7f3e33..08c1dcdd02 100644
--- a/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
+++ b/tensorflow/contrib/boosted_trees/python/training/functions/gbdt_batch.py
@@ -970,10 +970,8 @@ class GradientBoostedDecisionTreeModel(object):
# Stack all the inputs to one tensor per type.
# This is a workaround for the slowness of graph building in tf.cond.
# See (b/36554864).
- split_sizes = array_ops.stack([
- array_ops.shape(partition_id)[0]
- for partition_id in partition_ids_list
- ])
+ split_sizes = array_ops.reshape(
+ array_ops.shape_n(partition_ids_list), [-1])
partition_ids = array_ops.concat(partition_ids_list, axis=0)
gains = array_ops.concat(gains_list, axis=0)
split_infos = array_ops.concat(split_info_list, axis=0)
diff --git a/tensorflow/contrib/cmake/python_modules.txt b/tensorflow/contrib/cmake/python_modules.txt
index fbcdf7e753..6468bed497 100644
--- a/tensorflow/contrib/cmake/python_modules.txt
+++ b/tensorflow/contrib/cmake/python_modules.txt
@@ -129,7 +129,11 @@ tensorflow/contrib/boosted_trees/kernels
tensorflow/contrib/boosted_trees/ops
tensorflow/contrib/boosted_trees/proto
tensorflow/contrib/boosted_trees/python
+tensorflow/contrib/boosted_trees/python/kernel_tests
tensorflow/contrib/boosted_trees/python/ops
+tensorflow/contrib/boosted_trees/python/training
+tensorflow/contrib/boosted_trees/python/training/functions
+tensorflow/contrib/boosted_trees/python/utils
tensorflow/contrib/checkpoint
tensorflow/contrib/checkpoint/python
tensorflow/contrib/cloud
@@ -144,8 +148,11 @@ tensorflow/contrib/coder
tensorflow/contrib/coder/kernels
tensorflow/contrib/coder/ops
tensorflow/contrib/coder/python
+tensorflow/contrib/coder/python/layers
tensorflow/contrib/coder/python/ops
tensorflow/contrib/compiler
+tensorflow/contrib/constrained_optimization
+tensorflow/contrib/constrained_optimization/python
tensorflow/contrib/copy_graph
tensorflow/contrib/copy_graph/python
tensorflow/contrib/copy_graph/python/util
diff --git a/tensorflow/contrib/cmake/tf_core_kernels.cmake b/tensorflow/contrib/cmake/tf_core_kernels.cmake
index ed018b4fed..376496b33f 100644
--- a/tensorflow/contrib/cmake/tf_core_kernels.cmake
+++ b/tensorflow/contrib/cmake/tf_core_kernels.cmake
@@ -63,6 +63,7 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
"${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/split_handler_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/stats_accumulator_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/boosted_trees/ops/training_ops.cc"
+ "${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/pmf_to_cdf_op.cc"
"${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder.cc"
"${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops.cc"
"${tensorflow_source_dir}/tensorflow/contrib/coder/kernels/range_coder_ops_util.cc"
diff --git a/tensorflow/contrib/coder/BUILD b/tensorflow/contrib/coder/BUILD
index 9ca4ce8a9c..a2c6e41303 100644
--- a/tensorflow/contrib/coder/BUILD
+++ b/tensorflow/contrib/coder/BUILD
@@ -1,5 +1,5 @@
# Description:
-# Contains entropy coding related modules.
+# Contains tools related to data compression.
package(default_visibility = [
"//learning/brain:__subpackages__",
@@ -54,19 +54,27 @@ tf_gen_op_libs(
],
)
+cc_library(
+ name = "range_coder_ops_util",
+ srcs = ["kernels/range_coder_ops_util.cc"],
+ hdrs = ["kernels/range_coder_ops_util.h"],
+ visibility = ["//visibility:public"],
+ deps = [
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ ],
+)
+
tf_kernel_library(
name = "range_coder_ops",
srcs = [
"kernels/range_coder_ops.cc",
- "kernels/range_coder_ops_util.cc",
- ],
- hdrs = [
- "kernels/range_coder_ops_util.h",
],
visibility = ["//visibility:public"],
deps = [
":coder_ops_op_lib",
":range_coder",
+ ":range_coder_ops_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
],
@@ -152,10 +160,21 @@ tf_gen_op_wrapper_py(
deps = [":coder_ops_op_lib"],
)
+py_library(
+ name = "coder_py",
+ srcs = [
+ "__init__.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":coder_ops_py",
+ ":entropybottleneck_py",
+ ],
+)
+
tf_custom_op_py_library(
name = "coder_ops_py",
srcs = [
- "__init__.py",
"python/ops/coder_ops.py",
],
dso = [
@@ -186,3 +205,44 @@ tf_py_test(
],
main = "python/ops/coder_ops_test.py",
)
+
+py_library(
+ name = "entropybottleneck_py",
+ srcs = [
+ "python/layers/entropybottleneck.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":coder_ops_py",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:functional_ops",
+ "//tensorflow/python:init_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:nn",
+ "//tensorflow/python:ops",
+ "//tensorflow/python:random_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:summary_ops",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python/eager:context",
+ "//tensorflow/python/keras:engine",
+ "//third_party/py/numpy",
+ ],
+)
+
+tf_py_test(
+ name = "entropybottleneck_py_test",
+ srcs = [
+ "python/layers/entropybottleneck_test.py",
+ ],
+ additional_deps = [
+ ":entropybottleneck_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:variables",
+ "//tensorflow/python:training",
+ ],
+ main = "python/layers/entropybottleneck_test.py",
+)
diff --git a/tensorflow/contrib/coder/__init__.py b/tensorflow/contrib/coder/__init__.py
index b7e663e6f1..99b8ac7595 100644
--- a/tensorflow/contrib/coder/__init__.py
+++ b/tensorflow/contrib/coder/__init__.py
@@ -12,13 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""Entropy code operations."""
+"""Data compression tools."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# pylint: disable=wildcard-import
+from tensorflow.contrib.coder.python.layers.entropybottleneck import *
from tensorflow.contrib.coder.python.ops.coder_ops import *
# pylint: enable=wildcard-import
diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck.py b/tensorflow/contrib/coder/python/layers/entropybottleneck.py
new file mode 100644
index 0000000000..f039cb0f52
--- /dev/null
+++ b/tensorflow/contrib/coder/python/layers/entropybottleneck.py
@@ -0,0 +1,697 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Entropy bottleneck layer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.coder.python.ops import coder_ops
+
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.keras._impl.keras import engine
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import functional_ops
+from tensorflow.python.ops import init_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import nn
+from tensorflow.python.ops import random_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.summary import summary
+
+
+class EntropyBottleneck(engine.Layer):
+ """Entropy bottleneck layer.
+
+ This layer can be used to model the entropy (the amount of information
+ conveyed) of the tensor passing through it. During training, this can be used
+ to impose a (soft) entropy constraint on its activations, limiting the amount
+ of information flowing through the layer. Note that this is distinct from
+ other types of bottlenecks, which reduce the dimensionality of the space, for
+ example. Dimensionality reduction does not limit the amount of information,
+ and does not enable efficient data compression per se.
+
+ After training, this layer can be used to compress any input tensor to a
+ string, which may be written to a file, and to decompress a file which it
+ previously generated back to a reconstructed tensor (possibly on a different
+ machine having access to the same model checkpoint). The entropies estimated
+ during training or evaluation are approximately equal to the average length of
+ the strings in bits.
+
+ The layer implements a flexible probability density model to estimate entropy,
+ which is described in the appendix of the paper (please cite the paper if you
+ use this code for scientific work):
+
+ "Variational image compression with a scale hyperprior"
+
+ Johannes Ballé, David Minnen, Saurabh Singh, Sung Jin Hwang, Nick Johnston
+
+ https://arxiv.org/abs/1802.01436
+
+ The layer assumes that the input tensor is at least 2D, with a batch dimension
+ at the beginning and a channel dimension as specified by `data_format`. The
+ layer trains an independent probability density model for each channel, but
+ assumes that across all other dimensions, the inputs are i.i.d. (independent
+ and identically distributed). Because the entropy (and hence, average
+ codelength) is a function of the densities, this assumption may have a direct
+ effect on the compression performance.
+
+ Because data compression always involves discretization, the outputs of the
+ layer are generally only approximations of its inputs. During training,
+ discretization is modeled using additive uniform noise to ensure
+ differentiability. The entropies computed during training are differential
+ entropies. During evaluation, the data is actually quantized, and the
+ entropies are discrete (Shannon entropies). To make sure the approximated
+ tensor values are good enough for practical purposes, the training phase must
+ be used to balance the quality of the approximation with the entropy, by
+ adding an entropy term to the training loss, as in the following example.
+
+ Here, we use the entropy bottleneck to compress the latent representation of
+ an autoencoder. The data vectors `x` in this case are 4D tensors in
+ `'channels_last'` format (for example, 16x16 pixel grayscale images).
+
+ The layer always produces exactly one auxiliary loss and one update op which
+ are only significant for compression and decompression. To use the compression
+ feature, the auxiliary loss must be minimized during or after training. After
+ that, the update op must be executed at least once. Here, we simply attach
+ them to the main training step.
+
+ Training:
+ ```
+ # Build autoencoder.
+ x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1])
+ y = forward_transform(x)
+ entropy_bottleneck = EntropyBottleneck()
+ y_, likelihoods = entropy_bottleneck(y, training=True)
+ x_ = backward_transform(y_)
+
+ # Information content (= predicted codelength) in bits of each batch element
+ # (note that taking the natural logarithm and dividing by `log(2)` is
+ # equivalent to taking base-2 logarithms):
+ bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2)
+
+ # Squared difference of each batch element:
+ squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3))
+
+ # The loss is a weighted sum of mean squared error and entropy (average
+ # information content), where the weight controls the trade-off between
+ # approximation error and entropy.
+ main_loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits)
+
+ # Minimize loss and auxiliary loss, and execute update op.
+ main_optimizer = tf.train.AdamOptimizer(learning_rate=1e-4)
+ main_step = optimizer.minimize(main_loss)
+ # 1e-2 is a good starting point for the learning rate of the auxiliary loss,
+ # assuming Adam is used.
+ aux_optimizer = tf.train.AdamOptimizer(learning_rate=1e-2)
+ aux_step = optimizer.minimize(entropy_bottleneck.losses[0])
+ step = tf.group(main_step, aux_step, entropy_bottleneck.updates[0])
+ ```
+
+ Evaluation:
+ ```
+ # Build autoencoder.
+ x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1])
+ y = forward_transform(x)
+ y_, likelihoods = EntropyBottleneck()(y, training=False)
+ x_ = backward_transform(y_)
+
+ # Information content (= predicted codelength) in bits of each batch element:
+ bits = tf.reduce_sum(tf.log(likelihoods), axis=(1, 2, 3)) / -np.log(2)
+
+ # Squared difference of each batch element:
+ squared_error = tf.reduce_sum(tf.squared_difference(x, x_), axis=(1, 2, 3))
+
+ # The loss is a weighted sum of mean squared error and entropy (average
+ # information content), where the weight controls the trade-off between
+ # approximation error and entropy.
+ loss = 0.5 * tf.reduce_mean(squared_error) + tf.reduce_mean(bits)
+ ```
+
+ To be able to compress the bottleneck tensor and decompress it in a different
+ session, or on a different machine, you need three items:
+ - The compressed representations stored as strings.
+ - The shape of the bottleneck for these string representations as a `Tensor`,
+ as well as the number of channels of the bottleneck at graph construction
+ time.
+ - The checkpoint of the trained model that was used for compression. Note:
+ It is crucial that the auxiliary loss produced by this layer is minimized
+ during or after training, and that the update op is run after training and
+ minimization of the auxiliary loss, but *before* the checkpoint is saved.
+
+ Compression:
+ ```
+ x = tf.placeholder(tf.float32, shape=[None, 16, 16, 1])
+ y = forward_transform(x)
+ strings = EntropyBottleneck().compress(y)
+ shape = tf.shape(y)[1:]
+ ```
+
+ Decompression:
+ ```
+ strings = tf.placeholder(tf.string, shape=[None])
+ shape = tf.placeholder(tf.int32, shape=[3])
+ entropy_bottleneck = EntropyBottleneck(dtype=tf.float32)
+ y_ = entropy_bottleneck.decompress(strings, shape, channels=5)
+ x_ = backward_transform(y_)
+ ```
+ Here, we assumed that the tensor produced by the forward transform has 5
+ channels.
+
+ The above four use cases can also be implemented within the same session (i.e.
+ on the same `EntropyBottleneck` instance), for testing purposes, etc., by
+ calling the object more than once.
+
+ Arguments:
+ init_scale: Float. A scaling factor determining the initial width of the
+ probability densities. This should be chosen big enough so that the
+ range of values of the layer inputs roughly falls within the interval
+ [`-init_scale`, `init_scale`] at the beginning of training.
+ filters: An iterable of ints, giving the number of filters at each layer of
+ the density model. Generally, the more filters and layers, the more
+ expressive is the density model in terms of modeling more complicated
+ distributions of the layer inputs. For details, refer to the paper
+ referenced above. The default is `[3, 3, 3]`, which should be sufficient
+ for most practical purposes.
+ tail_mass: Float, between 0 and 1. The bottleneck layer automatically
+ determines the range of input values that should be represented based on
+ their frequency of occurrence. Values occurring in the tails of the
+ distributions will be clipped to that range during compression.
+ `tail_mass` determines the amount of probability mass in the tails which
+ is cut off in the worst case. For example, the default value of `1e-9`
+ means that at most 1 in a billion input samples will be clipped to the
+ range.
+ optimize_integer_offset: Boolean. Typically, the input values of this layer
+ are floats, which means that quantization during evaluation can be
+ performed with an arbitrary offset. By default, the layer determines that
+ offset automatically. In special situations, such as when it is known that
+ the layer will receive only full integer values during evaluation, it can
+ be desirable to set this argument to `False` instead, in order to always
+ quantize to full integer values.
+ likelihood_bound: Float. If positive, the returned likelihood values are
+ ensured to be greater than or equal to this value. This prevents very
+ large gradients with a typical entropy loss (defaults to 1e-9).
+ range_coder_precision: Integer, between 1 and 16. The precision of the range
+ coder used for compression and decompression. This trades off computation
+ speed with compression efficiency, where 16 is the slowest but most
+ efficient setting. Choosing lower values may increase the average
+ codelength slightly compared to the estimated entropies.
+ data_format: Either `'channels_first'` or `'channels_last'` (default).
+ trainable: Boolean. Whether the layer should be trained.
+ name: String. The name of the layer.
+ dtype: Default dtype of the layer's parameters (default of `None` means use
+ the type of the first input).
+
+ Read-only properties:
+ init_scale: See above.
+ filters: See above.
+ tail_mass: See above.
+ optimize_integer_offset: See above.
+ likelihood_bound: See above.
+ range_coder_precision: See above.
+ data_format: See above.
+ name: String. See above.
+ dtype: See above.
+ trainable_variables: List of trainable variables.
+ non_trainable_variables: List of non-trainable variables.
+ variables: List of all variables of this layer, trainable and non-trainable.
+ updates: List of update ops of this layer. Always contains exactly one
+ update op, which must be run once after the last training step, before
+ `compress` or `decompress` is used.
+ losses: List of losses added by this layer. Always contains exactly one
+ auxiliary loss, which must be added to the training loss.
+
+ Mutable properties:
+ trainable: Boolean. Whether the layer should be trained.
+ input_spec: Optional `InputSpec` object specifying the constraints on inputs
+ that can be accepted by the layer.
+ """
+
+ def __init__(self, init_scale=10, filters=(3, 3, 3), tail_mass=1e-9,
+ optimize_integer_offset=True, likelihood_bound=1e-9,
+ range_coder_precision=16, data_format="channels_last", **kwargs):
+ super(EntropyBottleneck, self).__init__(**kwargs)
+ self._init_scale = float(init_scale)
+ self._filters = tuple(int(f) for f in filters)
+ self._tail_mass = float(tail_mass)
+ if not 0 < self.tail_mass < 1:
+ raise ValueError(
+ "`tail_mass` must be between 0 and 1, got {}.".format(self.tail_mass))
+ self._optimize_integer_offset = bool(optimize_integer_offset)
+ self._likelihood_bound = float(likelihood_bound)
+ self._range_coder_precision = int(range_coder_precision)
+ self._data_format = data_format
+ self._channel_axis(2) # trigger ValueError early
+ self.input_spec = engine.InputSpec(min_ndim=2)
+
+ @property
+ def init_scale(self):
+ return self._init_scale
+
+ @property
+ def filters(self):
+ return self._filters
+
+ @property
+ def tail_mass(self):
+ return self._tail_mass
+
+ @property
+ def optimize_integer_offset(self):
+ return self._optimize_integer_offset
+
+ @property
+ def likelihood_bound(self):
+ return self._likelihood_bound
+
+ @property
+ def range_coder_precision(self):
+ return self._range_coder_precision
+
+ @property
+ def data_format(self):
+ return self._data_format
+
+ def _channel_axis(self, ndim):
+ try:
+ return {"channels_first": 1, "channels_last": ndim - 1}[self.data_format]
+ except KeyError:
+ raise ValueError("Unsupported `data_format` for {} layer: {}.".format(
+ self.__class__.__name__, self.data_format))
+
+ def _logits_cumulative(self, inputs, stop_gradient):
+ """Evaluate logits of the cumulative densities.
+
+ Args:
+ inputs: The values at which to evaluate the cumulative densities, expected
+ to be a `Tensor` of shape `(channels, 1, batch)`.
+ stop_gradient: Boolean. Whether to add `array_ops.stop_gradient` calls so
+ that the gradient of the output with respect to the density model
+ parameters is disconnected (the gradient with respect to `inputs` is
+ left untouched).
+
+ Returns:
+ A `Tensor` of the same shape as `inputs`, containing the logits of the
+ cumulative densities evaluated at the given inputs.
+ """
+ logits = inputs
+
+ for i in range(len(self.filters) + 1):
+ matrix = self._matrices[i]
+ if stop_gradient:
+ matrix = array_ops.stop_gradient(matrix)
+ logits = math_ops.matmul(matrix, logits)
+
+ bias = self._biases[i]
+ if stop_gradient:
+ bias = array_ops.stop_gradient(bias)
+ logits += bias
+
+ if i < len(self._factors):
+ factor = self._factors[i]
+ if stop_gradient:
+ factor = array_ops.stop_gradient(factor)
+ logits += factor * math_ops.tanh(logits)
+
+ return logits
+
+ def build(self, input_shape):
+ """Builds the layer.
+
+ Creates the variables for the network modeling the densities, creates the
+ auxiliary loss estimating the median and tail quantiles of the densities,
+ and then uses that to create the probability mass functions and the update
+ op that produces the discrete cumulative density functions used by the range
+ coder.
+
+ Args:
+ input_shape: Shape of the input tensor, used to get the number of
+ channels.
+
+ Raises:
+ ValueError: if `input_shape` doesn't specify the length of the channel
+ dimension.
+ """
+ input_shape = tensor_shape.TensorShape(input_shape)
+ channel_axis = self._channel_axis(input_shape.ndims)
+ channels = input_shape[channel_axis].value
+ if channels is None:
+ raise ValueError("The channel dimension of the inputs must be defined.")
+ self.input_spec = engine.InputSpec(
+ ndim=input_shape.ndims, axes={channel_axis: channels})
+ filters = (1,) + self.filters + (1,)
+ scale = self.init_scale ** (1 / (len(self.filters) + 1))
+
+ # Create variables.
+ self._matrices = []
+ self._biases = []
+ self._factors = []
+ for i in range(len(self.filters) + 1):
+ init = np.log(np.expm1(1 / scale / filters[i + 1]))
+ matrix = self.add_variable(
+ "matrix_{}".format(i), dtype=self.dtype,
+ shape=(channels, filters[i + 1], filters[i]),
+ initializer=init_ops.Constant(init))
+ matrix = nn.softplus(matrix)
+ self._matrices.append(matrix)
+
+ bias = self.add_variable(
+ "bias_{}".format(i), dtype=self.dtype,
+ shape=(channels, filters[i + 1], 1),
+ initializer=init_ops.RandomUniform(-.5, .5))
+ self._biases.append(bias)
+
+ if i < len(self.filters):
+ factor = self.add_variable(
+ "factor_{}".format(i), dtype=self.dtype,
+ shape=(channels, filters[i + 1], 1),
+ initializer=init_ops.Zeros())
+ factor = math_ops.tanh(factor)
+ self._factors.append(factor)
+
+ # To figure out what range of the densities to sample, we need to compute
+ # the quantiles given by `tail_mass / 2` and `1 - tail_mass / 2`. Since we
+ # can't take inverses of the cumulative directly, we make it an optimization
+ # problem:
+ # `quantiles = argmin(|logit(cumulative) - target|)`
+ # where `target` is `logit(tail_mass / 2)` or `logit(1 - tail_mass / 2)`.
+ # Taking the logit (inverse of sigmoid) of the cumulative makes the
+ # representation of the right target more numerically stable.
+
+ # Numerically stable way of computing logits of `tail_mass / 2`
+ # and `1 - tail_mass / 2`.
+ target = np.log(2 / self.tail_mass - 1)
+ # Compute lower and upper tail quantile as well as median.
+ target = constant_op.constant([-target, 0, target], dtype=self.dtype)
+
+ def quantiles_initializer(shape, dtype=None, partition_info=None):
+ del partition_info # unused
+ assert tuple(shape[1:]) == (1, 3)
+ init = constant_op.constant(
+ [[[-self.init_scale, 0, self.init_scale]]], dtype=dtype)
+ return array_ops.tile(init, (shape[0], 1, 1))
+
+ quantiles = self.add_variable(
+ "quantiles", shape=(channels, 1, 3), dtype=self.dtype,
+ initializer=quantiles_initializer)
+ logits = self._logits_cumulative(quantiles, stop_gradient=True)
+ loss = math_ops.reduce_sum(abs(logits - target))
+ self.add_loss(loss, inputs=None)
+
+ # Save medians for `call`, `compress`, and `decompress`.
+ self._medians = quantiles[:, :, 1:2]
+ if not self.optimize_integer_offset:
+ self._medians = math_ops.round(self._medians)
+
+ # Largest distance observed between lower tail quantile and median,
+ # or between median and upper tail quantile.
+ minima = math_ops.reduce_max(self._medians - quantiles[:, :, 0:1])
+ maxima = math_ops.reduce_max(quantiles[:, :, 2:3] - self._medians)
+ minmax = math_ops.maximum(minima, maxima)
+ minmax = math_ops.ceil(minmax)
+ minmax = math_ops.maximum(minmax, 1)
+
+ # Sample the density up to `minmax` around the median.
+ samples = math_ops.range(-minmax, minmax + 1, dtype=self.dtype)
+ samples += self._medians
+
+ half = constant_op.constant(.5, dtype=self.dtype)
+ # We strip the sigmoid from the end here, so we can use the special rule
+ # below to only compute differences in the left tail of the sigmoid.
+ # This increases numerical stability (see explanation in `call`).
+ lower = self._logits_cumulative(samples - half, stop_gradient=True)
+ upper = self._logits_cumulative(samples + half, stop_gradient=True)
+ # Flip signs if we can move more towards the left tail of the sigmoid.
+ sign = -math_ops.sign(math_ops.add_n([lower, upper]))
+ pmf = abs(math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower))
+ # Add tail masses to first and last bin of pmf, as we clip values for
+ # compression, meaning that out-of-range values get mapped to these bins.
+ pmf = array_ops.concat([
+ math_ops.add_n([pmf[:, 0, :1], math_ops.sigmoid(lower[:, 0, :1])]),
+ pmf[:, 0, 1:-1],
+ math_ops.add_n([pmf[:, 0, -1:], math_ops.sigmoid(-upper[:, 0, -1:])]),
+ ], axis=-1)
+ self._pmf = pmf
+
+ cdf = coder_ops.pmf_to_quantized_cdf(
+ pmf, precision=self.range_coder_precision)
+ def cdf_getter(*args, **kwargs):
+ del args, kwargs # ignored
+ return variable_scope.get_variable(
+ "quantized_cdf", dtype=dtypes.int32, initializer=cdf,
+ trainable=False, validate_shape=False, collections=())
+ # Need to provide a fake shape here since add_variable insists on it.
+ self._quantized_cdf = self.add_variable(
+ "quantized_cdf", shape=(channels, 1), dtype=dtypes.int32,
+ getter=cdf_getter, trainable=False)
+
+ update_op = state_ops.assign(
+ self._quantized_cdf, cdf, validate_shape=False)
+ self.add_update(update_op, inputs=None)
+
+ super(EntropyBottleneck, self).build(input_shape)
+
+ def call(self, inputs, training):
+ """Pass a tensor through the bottleneck.
+
+ Args:
+ inputs: The tensor to be passed through the bottleneck.
+ training: Boolean. If `True`, returns a differentiable approximation of
+ the inputs, and their likelihoods under the modeled probability
+ densities. If `False`, returns the quantized inputs and their
+ likelihoods under the corresponding probability mass function. These
+ quantities can't be used for training, as they are not differentiable,
+ but represent actual compression more closely.
+
+ Returns:
+ values: `Tensor` with the same shape as `inputs` containing the perturbed
+ or quantized input values.
+ likelihood: `Tensor` with the same shape as `inputs` containing the
+ likelihood of `values` under the modeled probability distributions.
+
+ Raises:
+ ValueError: if `inputs` has different `dtype` or number of channels than
+ a previous set of inputs the model was invoked with earlier.
+ """
+ inputs = ops.convert_to_tensor(inputs)
+ ndim = self.input_spec.ndim
+ channel_axis = self._channel_axis(ndim)
+ half = constant_op.constant(.5, dtype=self.dtype)
+
+ # Convert to (channels, 1, batch) format by commuting channels to front
+ # and then collapsing.
+ order = list(range(ndim))
+ order.pop(channel_axis)
+ order.insert(0, channel_axis)
+ values = array_ops.transpose(inputs, order)
+ shape = array_ops.shape(values)
+ values = array_ops.reshape(values, (shape[0], 1, -1))
+
+ # Add noise or quantize.
+ if training:
+ noise = random_ops.random_uniform(array_ops.shape(values), -half, half)
+ values = math_ops.add_n([values, noise])
+ elif self.optimize_integer_offset:
+ values = math_ops.round(values - self._medians) + self._medians
+ else:
+ values = math_ops.round(values)
+
+ # Evaluate densities.
+ # We can use the special rule below to only compute differences in the left
+ # tail of the sigmoid. This increases numerical stability: sigmoid(x) is 1
+ # for large x, 0 for small x. Subtracting two numbers close to 0 can be done
+ # with much higher precision than subtracting two numbers close to 1.
+ lower = self._logits_cumulative(values - half, stop_gradient=False)
+ upper = self._logits_cumulative(values + half, stop_gradient=False)
+ # Flip signs if we can move more towards the left tail of the sigmoid.
+ sign = -math_ops.sign(math_ops.add_n([lower, upper]))
+ sign = array_ops.stop_gradient(sign)
+ likelihood = abs(
+ math_ops.sigmoid(sign * upper) - math_ops.sigmoid(sign * lower))
+ if self.likelihood_bound > 0:
+ likelihood_bound = constant_op.constant(
+ self.likelihood_bound, dtype=self.dtype)
+ # TODO(jballe): Override gradients.
+ likelihood = math_ops.maximum(likelihood, likelihood_bound)
+
+ # Convert back to input tensor shape.
+ order = list(range(1, ndim))
+ order.insert(channel_axis, 0)
+ values = array_ops.reshape(values, shape)
+ values = array_ops.transpose(values, order)
+ likelihood = array_ops.reshape(likelihood, shape)
+ likelihood = array_ops.transpose(likelihood, order)
+
+ if not context.executing_eagerly():
+ values_shape, likelihood_shape = self.compute_output_shape(inputs.shape)
+ values.set_shape(values_shape)
+ likelihood.set_shape(likelihood_shape)
+
+ return values, likelihood
+
+ def compress(self, inputs):
+ """Compress inputs and store their binary representations into strings.
+
+ Args:
+ inputs: `Tensor` with values to be compressed.
+
+ Returns:
+ String `Tensor` vector containing the compressed representation of each
+ batch element of `inputs`.
+ """
+ with ops.name_scope(self._name_scope()):
+ inputs = ops.convert_to_tensor(inputs)
+ if not self.built:
+ # Check input assumptions set before layer building, e.g. input rank.
+ self._assert_input_compatibility(inputs)
+ if self.dtype is None:
+ self._dtype = inputs.dtype.base_dtype.name
+ self.build(inputs.shape)
+
+ # Check input assumptions set after layer building, e.g. input shape.
+ if not context.executing_eagerly():
+ self._assert_input_compatibility(inputs)
+
+ ndim = self.input_spec.ndim
+ channel_axis = self._channel_axis(ndim)
+ # Tuple of slices for expanding dimensions of tensors below.
+ slices = ndim * [None] + [slice(None)]
+ slices[channel_axis] = slice(None)
+ slices = tuple(slices)
+
+ # Expand dimensions of CDF to input dimensions, keeping the channels along
+ # the right dimension.
+ cdf = self._quantized_cdf[slices[1:]]
+ num_levels = array_ops.shape(cdf)[-1] - 1
+
+ # Bring inputs to the right range by centering the range on the medians.
+ half = constant_op.constant(.5, dtype=self.dtype)
+ medians = array_ops.squeeze(self._medians, [1, 2])
+ offsets = (math_ops.cast(num_levels // 2, self.dtype) + half) - medians
+ # Expand offsets to input dimensions and add to inputs.
+ values = inputs + offsets[slices[:-1]]
+
+ # Clip to range and cast to integers. Because we have added .5 above, and
+ # all values are positive, the cast effectively implements rounding.
+ values = math_ops.maximum(values, half)
+ values = math_ops.minimum(
+ values, math_ops.cast(num_levels, self.dtype) - half)
+ values = math_ops.cast(values, dtypes.int16)
+
+ def loop_body(tensor):
+ return coder_ops.range_encode(
+ tensor, cdf, precision=self.range_coder_precision)
+ strings = functional_ops.map_fn(
+ loop_body, values, dtype=dtypes.string, back_prop=False)
+
+ if not context.executing_eagerly():
+ strings.set_shape(inputs.shape[:1])
+
+ return strings
+
+ def decompress(self, strings, shape, channels=None):
+ """Decompress values from their compressed string representations.
+
+ Args:
+ strings: A string `Tensor` vector containing the compressed data.
+ shape: A `Tensor` vector of int32 type. Contains the shape of the tensor
+ to be decompressed, excluding the batch dimension.
+ channels: Integer. Specifies the number of channels statically. Needs only
+ be set if the layer hasn't been built yet (i.e., this is the first input
+ it receives).
+
+ Returns:
+ The decompressed `Tensor`. Its shape will be equal to `shape` prepended
+ with the batch dimension from `strings`.
+
+ Raises:
+ ValueError: If the length of `shape` isn't available at graph construction
+ time.
+ """
+ with ops.name_scope(self._name_scope()):
+ strings = ops.convert_to_tensor(strings)
+ shape = ops.convert_to_tensor(shape)
+ if self.built:
+ ndim = self.input_spec.ndim
+ channel_axis = self._channel_axis(ndim)
+ if channels is None:
+ channels = self.input_spec.axes[channel_axis]
+ else:
+ if not (shape.shape.is_fully_defined() and shape.shape.ndims == 1):
+ raise ValueError("`shape` must be a vector with known length.")
+ ndim = shape.shape[0].value + 1
+ channel_axis = self._channel_axis(ndim)
+ input_shape = ndim * [None]
+ input_shape[channel_axis] = channels
+ self.build(input_shape)
+
+ # Tuple of slices for expanding dimensions of tensors below.
+ slices = ndim * [None] + [slice(None)]
+ slices[channel_axis] = slice(None)
+ slices = tuple(slices)
+
+ # Expand dimensions of CDF to input dimensions, keeping the channels along
+ # the right dimension.
+ cdf = self._quantized_cdf[slices[1:]]
+ num_levels = array_ops.shape(cdf)[-1] - 1
+
+ def loop_body(string):
+ return coder_ops.range_decode(
+ string, shape, cdf, precision=self.range_coder_precision)
+ outputs = functional_ops.map_fn(
+ loop_body, strings, dtype=dtypes.int16, back_prop=False)
+ outputs = math_ops.cast(outputs, self.dtype)
+
+ medians = array_ops.squeeze(self._medians, [1, 2])
+ offsets = math_ops.cast(num_levels // 2, self.dtype) - medians
+ outputs -= offsets[slices[:-1]]
+
+ if not context.executing_eagerly():
+ outputs_shape = ndim * [None]
+ outputs_shape[0] = strings.shape[0]
+ outputs_shape[channel_axis] = channels
+ outputs.set_shape(outputs_shape)
+
+ return outputs
+
+ def visualize(self):
+ """Multi-channel visualization of densities as images.
+
+ Creates and returns an image summary visualizing the current probabilty
+ density estimates. The image contains one row for each channel. Within each
+ row, the pixel intensities are proportional to probability values, and each
+ row is centered on the median of the corresponding distribution.
+
+ Returns:
+ The created image summary.
+ """
+ with ops.name_scope(self._name_scope()):
+ image = self._pmf
+ image *= 255 / math_ops.reduce_max(image, axis=1, keepdims=True)
+ image = math_ops.cast(image + .5, dtypes.uint8)
+ image = image[None, :, :, None]
+ return summary.image("pmf", image, max_outputs=1)
+
+ def compute_output_shape(self, input_shape):
+ input_shape = tensor_shape.TensorShape(input_shape)
+ return input_shape, input_shape
diff --git a/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py b/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py
new file mode 100644
index 0000000000..798b0234eb
--- /dev/null
+++ b/tensorflow/contrib/coder/python/layers/entropybottleneck_test.py
@@ -0,0 +1,315 @@
+# -*- coding: utf-8 -*-
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests of EntropyBottleneck class."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.coder.python.layers import entropybottleneck
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+from tensorflow.python.training import gradient_descent
+
+
+class EntropyBottleneckTest(test.TestCase):
+
+ def test_noise(self):
+ # Tests that the noise added is uniform noise between -0.5 and 0.5.
+ inputs = array_ops.placeholder(dtypes.float32, (None, 1))
+ layer = entropybottleneck.EntropyBottleneck()
+ noisy, _ = layer(inputs, training=True)
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ values = np.linspace(-50, 50, 100)[:, None]
+ noisy, = sess.run([noisy], {inputs: values})
+ self.assertFalse(np.allclose(values, noisy, rtol=0, atol=.49))
+ self.assertAllClose(values, noisy, rtol=0, atol=.5)
+
+ def test_quantization(self):
+ # Tests that inputs are quantized to full integer values, even after
+ # quantiles have been updated.
+ inputs = array_ops.placeholder(dtypes.float32, (None, 1))
+ layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=False)
+ quantized, _ = layer(inputs, training=False)
+ opt = gradient_descent.GradientDescentOptimizer(learning_rate=1)
+ self.assertTrue(len(layer.losses) == 1)
+ step = opt.minimize(layer.losses[0])
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(step)
+ values = np.linspace(-50, 50, 100)[:, None]
+ quantized, = sess.run([quantized], {inputs: values})
+ self.assertAllClose(np.around(values), quantized, rtol=0, atol=1e-6)
+
+ def test_quantization_optimized_offset(self):
+ # Tests that inputs are not quantized to full integer values after quantiles
+ # have been updated. However, the difference between input and output should
+ # be between -0.5 and 0.5, and the offset must be consistent.
+ inputs = array_ops.placeholder(dtypes.float32, (None, 1))
+ layer = entropybottleneck.EntropyBottleneck(optimize_integer_offset=True)
+ quantized, _ = layer(inputs, training=False)
+ opt = gradient_descent.GradientDescentOptimizer(learning_rate=1)
+ self.assertTrue(len(layer.losses) == 1)
+ step = opt.minimize(layer.losses[0])
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(step)
+ values = np.linspace(-50, 50, 100)[:, None]
+ quantized, = sess.run([quantized], {inputs: values})
+ self.assertAllClose(values, quantized, rtol=0, atol=.5)
+ diff = np.ravel(np.around(values) - quantized) % 1
+ self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6)
+ self.assertNotEqual(diff[0], 0)
+
+ def test_codec(self):
+ # Tests that inputs are compressed and decompressed correctly, and quantized
+ # to full integer values, even after quantiles have been updated.
+ inputs = array_ops.placeholder(dtypes.float32, (1, None, 1))
+ layer = entropybottleneck.EntropyBottleneck(
+ data_format="channels_last", init_scale=60,
+ optimize_integer_offset=False)
+ bitstrings = layer.compress(inputs)
+ decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
+ opt = gradient_descent.GradientDescentOptimizer(learning_rate=1)
+ self.assertTrue(len(layer.losses) == 1)
+ step = opt.minimize(layer.losses[0])
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(step)
+ self.assertTrue(len(layer.updates) == 1)
+ sess.run(layer.updates[0])
+ values = np.linspace(-50, 50, 100)[None, :, None]
+ decoded, = sess.run([decoded], {inputs: values})
+ self.assertAllClose(np.around(values), decoded, rtol=0, atol=1e-6)
+
+ def test_codec_optimized_offset(self):
+ # Tests that inputs are compressed and decompressed correctly, and not
+ # quantized to full integer values after quantiles have been updated.
+ # However, the difference between input and output should be between -0.5
+ # and 0.5, and the offset must be consistent.
+ inputs = array_ops.placeholder(dtypes.float32, (1, None, 1))
+ layer = entropybottleneck.EntropyBottleneck(
+ data_format="channels_last", init_scale=60,
+ optimize_integer_offset=True)
+ bitstrings = layer.compress(inputs)
+ decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
+ opt = gradient_descent.GradientDescentOptimizer(learning_rate=1)
+ self.assertTrue(len(layer.losses) == 1)
+ step = opt.minimize(layer.losses[0])
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run(step)
+ self.assertTrue(len(layer.updates) == 1)
+ sess.run(layer.updates[0])
+ values = np.linspace(-50, 50, 100)[None, :, None]
+ decoded, = sess.run([decoded], {inputs: values})
+ self.assertAllClose(values, decoded, rtol=0, atol=.5)
+ diff = np.ravel(np.around(values) - decoded) % 1
+ self.assertAllClose(diff, np.full_like(diff, diff[0]), rtol=0, atol=5e-6)
+ self.assertNotEqual(diff[0], 0)
+
+ def test_codec_clipping(self):
+ # Tests that inputs are compressed and decompressed correctly, and clipped
+ # to the expected range.
+ inputs = array_ops.placeholder(dtypes.float32, (1, None, 1))
+ layer = entropybottleneck.EntropyBottleneck(
+ data_format="channels_last", init_scale=40)
+ bitstrings = layer.compress(inputs)
+ decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ self.assertTrue(len(layer.updates) == 1)
+ sess.run(layer.updates[0])
+ values = np.linspace(-50, 50, 100)[None, :, None]
+ decoded, = sess.run([decoded], {inputs: values})
+ expected = np.clip(np.around(values), -40, 40)
+ self.assertAllClose(expected, decoded, rtol=0, atol=1e-6)
+
+ def test_channels_last(self):
+ # Test the layer with more than one channel and multiple input dimensions,
+ # with the channels in the last dimension.
+ inputs = array_ops.placeholder(dtypes.float32, (None, None, None, 2))
+ layer = entropybottleneck.EntropyBottleneck(
+ data_format="channels_last", init_scale=50)
+ noisy, _ = layer(inputs, training=True)
+ quantized, _ = layer(inputs, training=False)
+ bitstrings = layer.compress(inputs)
+ decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ self.assertTrue(len(layer.updates) == 1)
+ sess.run(layer.updates[0])
+ values = 5 * np.random.normal(size=(7, 5, 3, 2))
+ noisy, quantized, decoded = sess.run(
+ [noisy, quantized, decoded], {inputs: values})
+ self.assertAllClose(values, noisy, rtol=0, atol=.5)
+ self.assertAllClose(values, quantized, rtol=0, atol=.5)
+ self.assertAllClose(values, decoded, rtol=0, atol=.5)
+
+ def test_channels_first(self):
+ # Test the layer with more than one channel and multiple input dimensions,
+ # with the channel dimension right after the batch dimension.
+ inputs = array_ops.placeholder(dtypes.float32, (None, 3, None, None))
+ layer = entropybottleneck.EntropyBottleneck(
+ data_format="channels_first", init_scale=50)
+ noisy, _ = layer(inputs, training=True)
+ quantized, _ = layer(inputs, training=False)
+ bitstrings = layer.compress(inputs)
+ decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ self.assertTrue(len(layer.updates) == 1)
+ sess.run(layer.updates[0])
+ values = 5 * np.random.normal(size=(2, 3, 5, 7))
+ noisy, quantized, decoded = sess.run(
+ [noisy, quantized, decoded], {inputs: values})
+ self.assertAllClose(values, noisy, rtol=0, atol=.5)
+ self.assertAllClose(values, quantized, rtol=0, atol=.5)
+ self.assertAllClose(values, decoded, rtol=0, atol=.5)
+
+ def test_compress(self):
+ # Test compression and decompression, and produce test data for
+ # `test_decompress`. If you set the constant at the end to `True`, this test
+ # will fail and the log will contain the new test data.
+ inputs = array_ops.placeholder(dtypes.float32, (2, 3, 10))
+ layer = entropybottleneck.EntropyBottleneck(
+ data_format="channels_first", filters=(), init_scale=2)
+ bitstrings = layer.compress(inputs)
+ decoded = layer.decompress(bitstrings, array_ops.shape(inputs)[1:])
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ self.assertTrue(len(layer.updates) == 1)
+ sess.run(layer.updates[0])
+ values = 5 * np.random.uniform(size=(2, 3, 10)) - 2.5
+ bitstrings, quantized_cdf, decoded = sess.run(
+ [bitstrings, layer._quantized_cdf, decoded], {inputs: values})
+ self.assertAllClose(values, decoded, rtol=0, atol=.5)
+ # Set this constant to `True` to log new test data for `test_decompress`.
+ if False: # pylint:disable=using-constant-test
+ assert False, (bitstrings, quantized_cdf, decoded)
+
+ # Data generated by `test_compress`.
+ # pylint:disable=g-inconsistent-quotes,bad-whitespace
+ bitstrings = np.array([
+ b'\x1e\xbag}\xc2\xdaN\x8b\xbd.',
+ b'\x8dF\xf0%\x1cv\xccllW'
+ ], dtype=object)
+
+ quantized_cdf = np.array([
+ [ 0, 15636, 22324, 30145, 38278, 65536],
+ [ 0, 19482, 26927, 35052, 42904, 65535],
+ [ 0, 21093, 28769, 36919, 44578, 65536]
+ ], dtype=np.int32)
+
+ expected = np.array([
+ [[-2., 1., 0., -2., -1., -2., -2., -2., 2., -1.],
+ [ 1., 2., 1., 0., -2., -2., 1., 2., 0., 1.],
+ [ 2., 0., -2., 2., 0., -1., -2., 0., 2., 0.]],
+ [[ 1., 2., 0., -1., 1., 2., 1., 1., 2., -2.],
+ [ 2., -1., -1., 0., -1., 2., 0., 2., -2., 2.],
+ [ 2., -2., -2., -1., -2., 1., -2., 0., 0., 0.]]
+ ], dtype=np.float32)
+ # pylint:enable=g-inconsistent-quotes,bad-whitespace
+
+ def test_decompress(self):
+ # Test that decompression of values compressed with a previous version
+ # works, i.e. that the file format doesn't change across revisions.
+ bitstrings = array_ops.placeholder(dtypes.string)
+ input_shape = array_ops.placeholder(dtypes.int32)
+ quantized_cdf = array_ops.placeholder(dtypes.int32)
+ layer = entropybottleneck.EntropyBottleneck(
+ data_format="channels_first", filters=(), dtype=dtypes.float32)
+ layer.build(self.expected.shape)
+ layer._quantized_cdf = quantized_cdf
+ decoded = layer.decompress(bitstrings, input_shape[1:])
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ decoded, = sess.run([decoded], {
+ bitstrings: self.bitstrings, input_shape: self.expected.shape,
+ quantized_cdf: self.quantized_cdf})
+ self.assertAllClose(self.expected, decoded, rtol=0, atol=1e-6)
+
+ def test_build_decompress(self):
+ # Test that layer can be built when `decompress` is the first call to it.
+ bitstrings = array_ops.placeholder(dtypes.string)
+ input_shape = array_ops.placeholder(dtypes.int32, shape=[3])
+ layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32)
+ layer.decompress(bitstrings, input_shape[1:], channels=5)
+ self.assertTrue(layer.built)
+
+ def test_pmf_normalization(self):
+ # Test that probability mass functions are normalized correctly.
+ layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32)
+ layer.build((None, 10))
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ pmf, = sess.run([layer._pmf])
+ self.assertAllClose(np.ones(10), np.sum(pmf, axis=-1), rtol=0, atol=1e-6)
+
+ def test_visualize(self):
+ # Test that summary op can be constructed.
+ layer = entropybottleneck.EntropyBottleneck(dtype=dtypes.float32)
+ layer.build((None, 10))
+ summary = layer.visualize()
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ sess.run([summary])
+
+ def test_normalization(self):
+ # Test that densities are normalized correctly.
+ inputs = array_ops.placeholder(dtypes.float32, (None, 1))
+ layer = entropybottleneck.EntropyBottleneck(filters=(2,))
+ _, likelihood = layer(inputs, training=True)
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ x = np.repeat(np.arange(-200, 201), 1000)[:, None]
+ likelihood, = sess.run([likelihood], {inputs: x})
+ self.assertEqual(x.shape, likelihood.shape)
+ integral = np.sum(likelihood) * .001
+ self.assertAllClose(1, integral, rtol=0, atol=1e-4)
+
+ def test_entropy_estimates(self):
+ # Test that entropy estimates match actual range coding.
+ inputs = array_ops.placeholder(dtypes.float32, (1, None, 1))
+ layer = entropybottleneck.EntropyBottleneck(
+ filters=(2, 3), data_format="channels_last")
+ _, likelihood = layer(inputs, training=True)
+ diff_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2)
+ _, likelihood = layer(inputs, training=False)
+ disc_entropy = math_ops.reduce_sum(math_ops.log(likelihood)) / -np.log(2)
+ bitstrings = layer.compress(inputs)
+ with self.test_session() as sess:
+ sess.run(variables.global_variables_initializer())
+ self.assertTrue(len(layer.updates) == 1)
+ sess.run(layer.updates[0])
+ diff_entropy, disc_entropy, bitstrings = sess.run(
+ [diff_entropy, disc_entropy, bitstrings],
+ {inputs: np.random.normal(size=(1, 10000, 1))})
+ codelength = 8 * sum(len(bitstring) for bitstring in bitstrings)
+ self.assertAllClose(diff_entropy, disc_entropy, rtol=5e-3, atol=0)
+ self.assertAllClose(disc_entropy, codelength, rtol=5e-3, atol=0)
+ self.assertGreater(codelength, disc_entropy)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/constrained_optimization/BUILD b/tensorflow/contrib/constrained_optimization/BUILD
new file mode 100644
index 0000000000..619153df67
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/BUILD
@@ -0,0 +1,91 @@
+package(default_visibility = ["//visibility:public"])
+
+licenses(["notice"]) # Apache 2.0
+
+exports_files(["LICENSE"])
+
+load("//tensorflow:tensorflow.bzl", "py_test")
+
+# Transitive dependencies of this target will be included in the pip package.
+py_library(
+ name = "constrained_optimization_pip",
+ deps = [
+ ":constrained_optimization",
+ ":test_util",
+ ],
+)
+
+py_library(
+ name = "constrained_optimization",
+ srcs = [
+ "__init__.py",
+ "python/candidates.py",
+ "python/constrained_minimization_problem.py",
+ "python/constrained_optimizer.py",
+ "python/external_regret_optimizer.py",
+ "python/swap_regret_optimizer.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:standard_ops",
+ "//tensorflow/python:state_ops",
+ "//tensorflow/python:training",
+ "//third_party/py/numpy",
+ "@six_archive//:six",
+ ],
+)
+
+py_test(
+ name = "candidates_test",
+ srcs = ["python/candidates_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":constrained_optimization",
+ "//tensorflow/python:client_testlib",
+ "//third_party/py/numpy",
+ ],
+)
+
+# NOTE: This library can't be "testonly" since it needs to be included in the
+# pip package.
+py_library(
+ name = "test_util",
+ srcs = ["python/test_util.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":constrained_optimization",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:standard_ops",
+ ],
+)
+
+py_test(
+ name = "external_regret_optimizer_test",
+ srcs = ["python/external_regret_optimizer_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":constrained_optimization",
+ ":test_util",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:standard_ops",
+ "//tensorflow/python:training",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
+ name = "swap_regret_optimizer_test",
+ srcs = ["python/swap_regret_optimizer_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":constrained_optimization",
+ ":test_util",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:standard_ops",
+ "//tensorflow/python:training",
+ "//third_party/py/numpy",
+ ],
+)
diff --git a/tensorflow/contrib/constrained_optimization/README.md b/tensorflow/contrib/constrained_optimization/README.md
new file mode 100644
index 0000000000..c65a150464
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/README.md
@@ -0,0 +1,345 @@
+<!-- TODO(acotter): Add usage example of non-convex optimization and stochastic classification. -->
+
+# ConstrainedOptimization (TFCO)
+
+TFCO is a library for optimizing inequality-constrained problems in TensorFlow.
+Both the objective function and the constraints are represented as Tensors,
+giving users the maximum amount of flexibility in specifying their optimization
+problems.
+
+This flexibility makes optimization considerably more difficult: on a non-convex
+problem, if one uses the "standard" approach of introducing a Lagrange
+multiplier for each constraint, and then jointly maximizing over the Lagrange
+multipliers and minimizing over the model parameters, then a stable stationary
+point might not even *exist*. Hence, in some cases, oscillation, instead of
+convergence, is inevitable.
+
+Thankfully, it turns out that even if, over the course of optimization, no
+*particular* iterate does a good job of minimizing the objective while
+satisfying the constraints, the *sequence* of iterates, on average, usually
+will. This observation suggests the following approach: at training time, we'll
+periodically snapshot the model state during optimization; then, at evaluation
+time, each time we're given a new example to evaluate, we'll sample one of the
+saved snapshots uniformly at random, and apply it to the example. This
+*stochastic model* will generally perform well, both with respect to the
+objective function, and the constraints.
+
+In fact, we can do better: it's possible to post-process the set of snapshots to
+find a distribution over at most $$m+1$$ snapshots, where $$m$$ is the number of
+constraints, that will be at least as good (and will usually be much better)
+than the (much larger) uniform distribution described above. If you're unable or
+unwilling to use a stochastic model at all, then you can instead use a heuristic
+to choose the single best snapshot.
+
+For full details, motivation, and theoretical results on the approach taken by
+this library, please refer to:
+
+> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+> Constrained Optimization".
+> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+which will be referred to as [CoJiSr18] throughout the remainder of this
+document.
+
+### Proxy Constraints
+
+Imagine that we want to constrain the recall of a binary classifier to be at
+least 90%. Since the recall is proportional to the number of true positive
+classifications, which itself is a sum of indicator functions, this constraint
+is non-differentible, and therefore cannot be used in a problem that will be
+optimized using a (stochastic) gradient-based algorithm.
+
+For this and similar problems, TFCO supports so-called *proxy constraints*,
+which are (at least semi-differentiable) approximations of the original
+constraints. For example, one could create a proxy recall function by replacing
+the indicator functions with sigmoids. During optimization, each proxy
+constraint function will be penalized, with the magnitude of the penalty being
+chosen to satisfy the corresponding *original* (non-proxy) constraint.
+
+On a problem including proxy constraints&mdash;even a convex problem&mdash;the
+Lagrangian approach discussed above isn't guaranteed to work. However, a
+different algorithm, based on minimizing *swap regret*, does work. Aside from
+this difference, the recommended procedure for optimizing a proxy-constrained
+problem remains the same: periodically snapshot the model during optimization,
+and then either find the best $$m+1$$-sized distribution, or heuristically
+choose the single best snapshot.
+
+## Components
+
+* [constrained_minimization_problem](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py):
+ contains the `ConstrainedMinimizationProblem` interface. Your own
+ constrained optimization problems should be represented using
+ implementations of this interface.
+
+* [constrained_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py):
+ contains the `ConstrainedOptimizer` interface, which is similar to (but
+ different from) `tf.train.Optimizer`, with the main difference being that
+ `ConstrainedOptimizer`s are given `ConstrainedMinimizationProblem`s to
+ optimize, and perform constrained optimization.
+
+ * [external_regret_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py):
+ contains the `AdditiveExternalRegretOptimizer` implementation, which is
+ a `ConstrainedOptimizer` implementing the Lagrangian approach discussed
+ above (with additive updates to the Lagrange multipliers). You should
+ use this optimizer for problems *without* proxy constraints. It may also
+ work for problems with proxy constraints, but we recommend using a swap
+ regret optimizer, instead.
+
+ This optimizer is most similar to Algorithm 3 in Appendix C.3 of
+ [CoJiSr18], and is discussed in Section 3. The two differences are that
+ it uses proxy constraints (if they're provided) in the update of the
+ model parameters, and uses `tf.train.Optimizer`s, instead of SGD, for
+ the "inner" updates.
+
+ * [swap_regret_optimizer](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py):
+ contains the `AdditiveSwapRegretOptimizer` and
+ `MultiplicativeSwapRegretOptimizer` implementations, which are
+ `ConstrainedOptimizer`s implementing the swap-regret minimization
+ approach mentioned above (with additive or multiplicative updates,
+ respectively, to the parameters associated with the
+ constraints&mdash;these parameters are not Lagrange multipliers, but
+ play a similar role). You should use one of these optimizers (we suggest
+ `MultiplicativeSwapRegretOptimizer`) for problems *with* proxy
+ constraints.
+
+ The `MultiplicativeSwapRegretOptimizer` is most similar to Algorithm 2
+ in Section 4 of [CoJiSr18], with the difference being that it uses
+ `tf.train.Optimizer`s, instead of SGD, for the "inner" updates. The
+ `AdditiveSwapRegretOptimizer` differs further in that it performs
+ additive (instead of multiplicative) updates of the stochastic matrix.
+
+* [candidates](https://www.tensorflow.org/code/tensorflow/contrib/constrained_optimization/python/candidates.py):
+ contains two functions, `find_best_candidate_distribution` and
+ `find_best_candidate_index`. Both of these functions are given a set of
+ candidate solutions to a constrained optimization problem, from which the
+ former finds the best distribution over at most $$m+1$$ candidates, and the
+ latter heuristically finds the single best candidate. As discussed above,
+ the set of candidates will typically be model snapshots saved periodically
+ during optimization. Both of these functions require that scipy be
+ installed.
+
+ The `find_best_candidate_distribution` function implements the approach
+ described in Lemma 3 of [CoJiSr18], while `find_best_candidate_index`
+ implements the heuristic used for hyperparameter search in the experiments
+ of Section 5.2.
+
+## Convex Example with Proxy Constraints
+
+This is a simple example of recall-constrained optimization on simulated data:
+we will try to find a classifier that minimizes the average hinge loss while
+constraining recall to be at least 90%.
+
+We'll start with the required imports&mdash;notice the definition of `tfco`:
+
+```python
+import math
+import numpy as np
+import tensorflow as tf
+
+tfco = tf.contrib.constrained_optimization
+```
+
+We'll now create an implementation of the `ConstrainedMinimizationProblem` class
+for this problem. The constructor takes three parameters: a Tensor containing
+the classification labels (0 or 1) for every training example, another Tensor
+containing the model's predictions on every training example (sometimes called
+the "logits"), and the lower bound on recall that will be enforced using a
+constraint.
+
+This implementation will contain both constraints *and* proxy constraints: the
+former represents the constraint that the true recall (defined in terms of the
+*number* of true positives) be at least `recall_lower_bound`, while the latter
+represents the same constraint, but on a hinge approximation of the recall.
+
+```python
+class ExampleProblem(tfco.ConstrainedMinimizationProblem):
+
+ def __init__(self, labels, predictions, recall_lower_bound):
+ self._labels = labels
+ self._predictions = predictions
+ self._recall_lower_bound = recall_lower_bound
+ # The number of positively-labeled examples.
+ self._positive_count = tf.reduce_sum(self._labels)
+
+ @property
+ def objective(self):
+ return tf.losses.hinge_loss(labels=self._labels, logits=self._predictions)
+
+ @property
+ def constraints(self):
+ true_positives = self._labels * tf.to_float(self._predictions > 0)
+ true_positive_count = tf.reduce_sum(true_positives)
+ recall = true_positive_count / self._positive_count
+ # The constraint is (recall >= self._recall_lower_bound), which we convert
+ # to (self._recall_lower_bound - recall <= 0) because
+ # ConstrainedMinimizationProblems must always provide their constraints in
+ # the form (tensor <= 0).
+ #
+ # The result of this function should be a tensor, with each element being
+ # a quantity that is constrained to be nonpositive. We only have one
+ # constraint, so we return a one-element tensor.
+ return self._recall_lower_bound - recall
+
+ @property
+ def proxy_constraints(self):
+ # Use 1 - hinge since we're SUBTRACTING recall in the constraint function,
+ # and we want the proxy constraint function to be convex.
+ true_positives = self._labels * tf.minimum(1.0, self._predictions)
+ true_positive_count = tf.reduce_sum(true_positives)
+ recall = true_positive_count / self._positive_count
+ # Please see the corresponding comment in the constraints property.
+ return self._recall_lower_bound - recall
+```
+
+We'll now create a simple simulated dataset by sampling 1000 random
+10-dimensional feature vectors from a Gaussian, finding their labels using a
+random "ground truth" linear model, and then adding noise by randomly flipping
+200 labels.
+
+```python
+# Create a simulated 10-dimensional training dataset consisting of 1000 labeled
+# examples, of which 800 are labeled correctly and 200 are mislabeled.
+num_examples = 1000
+num_mislabeled_examples = 200
+dimension = 10
+# We will constrain the recall to be at least 90%.
+recall_lower_bound = 0.9
+
+# Create random "ground truth" parameters to a linear model.
+ground_truth_weights = np.random.normal(size=dimension) / math.sqrt(dimension)
+ground_truth_threshold = 0
+
+# Generate a random set of features for each example.
+features = np.random.normal(size=(num_examples, dimension)).astype(
+ np.float32) / math.sqrt(dimension)
+# Compute the labels from these features given the ground truth linear model.
+labels = (np.matmul(features, ground_truth_weights) >
+ ground_truth_threshold).astype(np.float32)
+# Add noise by randomly flipping num_mislabeled_examples labels.
+mislabeled_indices = np.random.choice(
+ num_examples, num_mislabeled_examples, replace=False)
+labels[mislabeled_indices] = 1 - labels[mislabeled_indices]
+```
+
+We're now ready to construct our model, and the corresponding optimization
+problem. We'll use a linear model of the form $$f(x) = w^T x - t$$, where $$w$$
+is the `weights`, and $$t$$ is the `threshold`. The `problem` variable will hold
+an instance of the `ExampleProblem` class we created earlier.
+
+```python
+# Create variables containing the model parameters.
+weights = tf.Variable(tf.zeros(dimension), dtype=tf.float32, name="weights")
+threshold = tf.Variable(0.0, dtype=tf.float32, name="threshold")
+
+# Create the optimization problem.
+constant_labels = tf.constant(labels, dtype=tf.float32)
+constant_features = tf.constant(features, dtype=tf.float32)
+predictions = tf.tensordot(constant_features, weights, axes=(1, 0)) - threshold
+problem = ExampleProblem(
+ labels=constant_labels,
+ predictions=predictions,
+ recall_lower_bound=recall_lower_bound,
+)
+```
+
+We're almost ready to train our model, but first we'll create a couple of
+functions to measure its performance. We're interested in two quantities: the
+average hinge loss (which we seek to minimize), and the recall (which we
+constrain).
+
+```python
+def average_hinge_loss(labels, predictions):
+ num_examples, = np.shape(labels)
+ signed_labels = (labels * 2) - 1
+ total_hinge_loss = np.sum(np.maximum(0.0, 1.0 - signed_labels * predictions))
+ return total_hinge_loss / num_examples
+
+def recall(labels, predictions):
+ positive_count = np.sum(labels)
+ true_positives = labels * (predictions > 0)
+ true_positive_count = np.sum(true_positives)
+ return true_positive_count / positive_count
+```
+
+As was mentioned earlier, external regret optimizers suffice for problems
+without proxy constraints, but swap regret optimizers are recommended for
+problems *with* proxy constraints. Since this problem contains proxy
+constraints, we use the `MultiplicativeSwapRegretOptimizer`.
+
+For this problem, the constraint is fairly easy to satisfy, so we can use the
+same "inner" optimizer (an `AdagradOptimizer` with a learning rate of 1) for
+optimization of both the model parameters (`weights` and `threshold`), and the
+internal parameters associated with the constraints (these are the analogues of
+the Lagrange multipliers used by the `MultiplicativeSwapRegretOptimizer`). For
+more difficult problems, it will often be necessary to use different optimizers,
+with different learning rates (presumably found via a hyperparameter search): to
+accomplish this, pass *both* the `optimizer` and `constraint_optimizer`
+parameters to `MultiplicativeSwapRegretOptimizer`'s constructor.
+
+Since this is a convex problem (both the objective and proxy constraint
+functions are convex), we can just take the last iterate. Periodic snapshotting,
+and the use of the `find_best_candidate_distribution` or
+`find_best_candidate_index` functions, is generally only necessary for
+non-convex problems (and even then, it isn't *always* necessary).
+
+```python
+with tf.Session() as session:
+ optimizer = tfco.MultiplicativeSwapRegretOptimizer(
+ optimizer=tf.train.AdagradOptimizer(learning_rate=1.0))
+ train_op = optimizer.minimize(problem)
+
+ session.run(tf.global_variables_initializer())
+ for ii in xrange(1000):
+ session.run(train_op)
+
+ trained_weights, trained_threshold = session.run((weights, threshold))
+
+trained_predictions = np.matmul(features, trained_weights) - trained_threshold
+print("Constrained average hinge loss = %f" % average_hinge_loss(
+ labels, trained_predictions))
+print("Constrained recall = %f" % recall(labels, trained_predictions))
+```
+
+Running the above code gives the following output (due to the randomness of the
+dataset, you'll get a different result when you run it):
+
+```none
+Constrained average hinge loss = 0.710019
+Constrained recall = 0.899811
+```
+
+As we hoped, the recall is extremely close to 90%&mdash;and, thanks to the use
+of proxy constraints, this is the *true* recall, not a hinge approximation.
+
+For comparison, let's try optimizing the same problem *without* the recall
+constraint:
+
+```python
+with tf.Session() as session:
+ optimizer = tf.train.AdagradOptimizer(learning_rate=1.0)
+ # For optimizing the unconstrained problem, we just minimize the "objective"
+ # portion of the minimization problem.
+ train_op = optimizer.minimize(problem.objective)
+
+ session.run(tf.global_variables_initializer())
+ for ii in xrange(1000):
+ session.run(train_op)
+
+ trained_weights, trained_threshold = session.run((weights, threshold))
+
+trained_predictions = np.matmul(features, trained_weights) - trained_threshold
+print("Unconstrained average hinge loss = %f" % average_hinge_loss(
+ labels, trained_predictions))
+print("Unconstrained recall = %f" % recall(labels, trained_predictions))
+```
+
+This code gives the following output (again, you'll get a different answer,
+since the dataset is random):
+
+```none
+Unconstrained average hinge loss = 0.627271
+Unconstrained recall = 0.793951
+```
+
+Because there is no constraint, the unconstrained problem does a better job of
+minimizing the average hinge loss, but naturally doesn't approach 90% recall.
diff --git a/tensorflow/contrib/constrained_optimization/__init__.py b/tensorflow/contrib/constrained_optimization/__init__.py
new file mode 100644
index 0000000000..1e49ba9f17
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/__init__.py
@@ -0,0 +1,41 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""A library for performing constrained optimization in TensorFlow."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+# pylint: disable=wildcard-import
+from tensorflow.contrib.constrained_optimization.python.candidates import *
+from tensorflow.contrib.constrained_optimization.python.constrained_minimization_problem import *
+from tensorflow.contrib.constrained_optimization.python.constrained_optimizer import *
+from tensorflow.contrib.constrained_optimization.python.external_regret_optimizer import *
+from tensorflow.contrib.constrained_optimization.python.swap_regret_optimizer import *
+# pylint: enable=wildcard-import
+
+from tensorflow.python.util.all_util import remove_undocumented
+
+_allowed_symbols = [
+ "AdditiveExternalRegretOptimizer",
+ "AdditiveSwapRegretOptimizer",
+ "ConstrainedMinimizationProblem",
+ "ConstrainedOptimizer",
+ "find_best_candidate_distribution",
+ "find_best_candidate_index",
+ "MultiplicativeSwapRegretOptimizer",
+]
+
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/constrained_optimization/python/candidates.py b/tensorflow/contrib/constrained_optimization/python/candidates.py
new file mode 100644
index 0000000000..ac86a6741b
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/candidates.py
@@ -0,0 +1,319 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Code for optimizing over a set of candidate solutions.
+
+The functions in this file deal with the constrained problem:
+
+> minimize f(w)
+> s.t. g_i(w) <= 0 for all i in {0,1,...,m-1}
+
+Here, f(w) is the "objective function", and g_i(w) is the ith (of m) "constraint
+function". Given the values of the objective and constraint functions for a set
+of n "candidate solutions" {w_0,w_1,...,w_{n-1}} (for a total of n objective
+function values, and n*m constraint function values), the
+`find_best_candidate_distribution` function finds the best DISTRIBUTION over
+these candidates, while `find_best_candidate_index' heuristically finds the
+single best candidate.
+
+Both of these functions have dependencies on `scipy`, so if you want to call
+them, then you must make sure that `scipy` is available. The imports are
+performed inside the functions themselves, so if they're not actually called,
+then `scipy` is not needed.
+
+For more specifics, please refer to:
+
+> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+> Constrained Optimization".
+> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+The `find_best_candidate_distribution` function implements the approach
+described in Lemma 3, while `find_best_candidate_index` implements the heuristic
+used for hyperparameter search in the experiments of Section 5.2.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+from six.moves import xrange # pylint: disable=redefined-builtin
+
+
+def _find_best_candidate_distribution_helper(objective_vector,
+ constraints_matrix,
+ maximum_violation=0.0):
+ """Finds a distribution minimizing an objective subject to constraints.
+
+ This function deals with the constrained problem:
+
+ > minimize f(w)
+ > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1}
+
+ Here, f(w) is the "objective function", and g_i(w) is the ith (of m)
+ "constraint function". Given a set of n "candidate solutions"
+ {w_0,w_1,...,w_{n-1}}, this function finds a distribution over these n
+ candidates that, in expectation, minimizes the objective while violating
+ the constraints by no more than `maximum_violation`. If no such distribution
+ exists, it returns an error (using Go-style error reporting).
+
+ The `objective_vector` parameter should be a numpy array with shape (n,), for
+ which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a
+ numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j).
+
+ This function will return a distribution for which at most m+1 probabilities,
+ and often fewer, are nonzero.
+
+ Args:
+ objective_vector: numpy array of shape (n,), where n is the number of
+ "candidate solutions". Contains the objective function values.
+ constraints_matrix: numpy array of shape (m,n), where m is the number of
+ constraints and n is the number of "candidate solutions". Contains the
+ constraint violation magnitudes.
+ maximum_violation: nonnegative float, the maximum amount by which any
+ constraint may be violated, in expectation.
+
+ Returns:
+ A pair (`result`, `message`), exactly one of which is None. If `message` is
+ None, then the `result` contains the optimal distribution as a numpy array
+ of shape (n,). If `result` is None, then `message` contains an error
+ message.
+
+ Raises:
+ ValueError: If `objective_vector` and `constraints_matrix` have inconsistent
+ shapes, or if `maximum_violation` is negative.
+ ImportError: If we're unable to import `scipy.optimize`.
+ """
+ if maximum_violation < 0.0:
+ raise ValueError("maximum_violation must be nonnegative")
+
+ mm, nn = np.shape(constraints_matrix)
+ if (nn,) != np.shape(objective_vector):
+ raise ValueError(
+ "objective_vector must have shape (n,), and constraints_matrix (m, n),"
+ " where n is the number of candidates, and m is the number of "
+ "constraints")
+
+ # We import scipy inline, instead of at the top of the file, so that a scipy
+ # dependency is only introduced if either find_best_candidate_distribution()
+ # or find_best_candidate_index() are actually called.
+ import scipy.optimize # pylint: disable=g-import-not-at-top
+
+ # Feasibility (within maximum_violation) constraints.
+ a_ub = constraints_matrix
+ b_ub = np.full((mm, 1), maximum_violation)
+ # Sum-to-one constraint.
+ a_eq = np.ones((1, nn))
+ b_eq = np.ones((1, 1))
+ # Nonnegativity constraints.
+ bounds = (0, None)
+
+ result = scipy.optimize.linprog(
+ objective_vector,
+ A_ub=a_ub,
+ b_ub=b_ub,
+ A_eq=a_eq,
+ b_eq=b_eq,
+ bounds=bounds)
+ # Go-style error reporting. We don't raise on error, since
+ # find_best_candidate_distribution() needs to handle the failure case, and we
+ # shouldn't use exceptions as flow-control.
+ if not result.success:
+ return (None, result.message)
+ else:
+ return (result.x, None)
+
+
+def find_best_candidate_distribution(objective_vector,
+ constraints_matrix,
+ epsilon=0.0):
+ """Finds a distribution minimizing an objective subject to constraints.
+
+ This function deals with the constrained problem:
+
+ > minimize f(w)
+ > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1}
+
+ Here, f(w) is the "objective function", and g_i(w) is the ith (of m)
+ "constraint function". Given a set of n "candidate solutions"
+ {w_0,w_1,...,w_{n-1}}, this function finds a distribution over these n
+ candidates that, in expectation, minimizes the objective while violating
+ the constraints by the smallest possible amount (with the amount being found
+ via bisection search).
+
+ The `objective_vector` parameter should be a numpy array with shape (n,), for
+ which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a
+ numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j).
+
+ This function will return a distribution for which at most m+1 probabilities,
+ and often fewer, are nonzero.
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ This function implements the approach described in Lemma 3.
+
+ Args:
+ objective_vector: numpy array of shape (n,), where n is the number of
+ "candidate solutions". Contains the objective function values.
+ constraints_matrix: numpy array of shape (m,n), where m is the number of
+ constraints and n is the number of "candidate solutions". Contains the
+ constraint violation magnitudes.
+ epsilon: nonnegative float, the threshold at which to terminate the binary
+ search while searching for the minimal expected constraint violation
+ magnitude.
+
+ Returns:
+ The optimal distribution, as a numpy array of shape (n,).
+
+ Raises:
+ ValueError: If `objective_vector` and `constraints_matrix` have inconsistent
+ shapes, or if `epsilon` is negative.
+ ImportError: If we're unable to import `scipy.optimize`.
+ """
+ if epsilon < 0.0:
+ raise ValueError("epsilon must be nonnegative")
+
+ # If there is a feasible solution (i.e. with maximum_violation=0), then that's
+ # what we'll return.
+ pp, _ = _find_best_candidate_distribution_helper(objective_vector,
+ constraints_matrix)
+ if pp is not None:
+ return pp
+
+ # The bound is the minimum over all candidates, of the maximum per-candidate
+ # constraint violation.
+ lower = 0.0
+ upper = np.min(np.amax(constraints_matrix, axis=0))
+ best_pp, _ = _find_best_candidate_distribution_helper(
+ objective_vector, constraints_matrix, maximum_violation=upper)
+ assert best_pp is not None
+
+ # Throughout this loop, a maximum_violation of "lower" is not achievable,
+ # but a maximum_violation of "upper" is achiveable.
+ while True:
+ middle = 0.5 * (lower + upper)
+ if (middle - lower <= epsilon) or (upper - middle <= epsilon):
+ break
+ else:
+ pp, _ = _find_best_candidate_distribution_helper(
+ objective_vector, constraints_matrix, maximum_violation=middle)
+ if pp is None:
+ lower = middle
+ else:
+ best_pp = pp
+ upper = middle
+
+ return best_pp
+
+
+def find_best_candidate_index(objective_vector,
+ constraints_matrix,
+ rank_objectives=False):
+ """Heuristically finds the best candidate solution to a constrained problem.
+
+ This function deals with the constrained problem:
+
+ > minimize f(w)
+ > s.t. g_i(w) <= 0 for all i in {0,1,...,m-1}
+
+ Here, f(w) is the "objective function", and g_i(w) is the ith (of m)
+ "constraint function". Given a set of n "candidate solutions"
+ {w_0,w_1,...,w_{n-1}}, this function finds the "best" solution according
+ to the following heuristic:
+
+ 1. Across all models, the ith constraint violations (i.e. max{0, g_i(0)})
+ are ranked, as are the objectives (if rank_objectives=True).
+ 2. Each model is then associated its MAXIMUM rank across all m constraints
+ (and the objective, if rank_objectives=True).
+ 3. The model with the minimal maximum rank is then identified. Ties are
+ broken using the objective function value.
+ 4. The index of this "best" model is returned.
+
+ The `objective_vector` parameter should be a numpy array with shape (n,), for
+ which objective_vector[i] = f(w_i). Likewise, `constraints_matrix` should be a
+ numpy array with shape (m,n), for which constraints_matrix[i,j] = g_i(w_j).
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ This function implements the heuristic used for hyperparameter search in the
+ experiments of Section 5.2.
+
+ Args:
+ objective_vector: numpy array of shape (n,), where n is the number of
+ "candidate solutions". Contains the objective function values.
+ constraints_matrix: numpy array of shape (m,n), where m is the number of
+ constraints and n is the number of "candidate solutions". Contains the
+ constraint violation magnitudes.
+ rank_objectives: bool, whether the objective function values should be
+ included in the initial ranking step. If True, both the objective and
+ constraints will be ranked. If False, only the constraints will be ranked.
+ In either case, the objective function values will be used for
+ tiebreaking.
+
+ Returns:
+ The index (in {0,1,...,n-1}) of the "best" model according to the above
+ heuristic.
+
+ Raises:
+ ValueError: If `objective_vector` and `constraints_matrix` have inconsistent
+ shapes.
+ ImportError: If we're unable to import `scipy.stats`.
+ """
+ mm, nn = np.shape(constraints_matrix)
+ if (nn,) != np.shape(objective_vector):
+ raise ValueError(
+ "objective_vector must have shape (n,), and constraints_matrix (m, n),"
+ " where n is the number of candidates, and m is the number of "
+ "constraints")
+
+ # We import scipy inline, instead of at the top of the file, so that a scipy
+ # dependency is only introduced if either find_best_candidate_distribution()
+ # or find_best_candidate_index() are actually called.
+ import scipy.stats # pylint: disable=g-import-not-at-top
+
+ if rank_objectives:
+ maximum_ranks = scipy.stats.rankdata(objective_vector, method="min")
+ else:
+ maximum_ranks = np.zeros(nn, dtype=np.int64)
+ for ii in xrange(mm):
+ # Take the maximum of the constraint functions with zero, since we want to
+ # rank the magnitude of constraint *violations*. If the constraint is
+ # satisfied, then we don't care how much it's satisfied by (as a result, we
+ # we expect all models satisfying a constraint to be tied at rank 1).
+ ranks = scipy.stats.rankdata(
+ np.maximum(0.0, constraints_matrix[ii, :]), method="min")
+ maximum_ranks = np.maximum(maximum_ranks, ranks)
+
+ best_index = None
+ best_rank = float("Inf")
+ best_objective = float("Inf")
+ for ii in xrange(nn):
+ if maximum_ranks[ii] < best_rank:
+ best_index = ii
+ best_rank = maximum_ranks[ii]
+ best_objective = objective_vector[ii]
+ elif (maximum_ranks[ii] == best_rank) and (objective_vector[ii] <=
+ best_objective):
+ best_index = ii
+ best_objective = objective_vector[ii]
+
+ return best_index
diff --git a/tensorflow/contrib/constrained_optimization/python/candidates_test.py b/tensorflow/contrib/constrained_optimization/python/candidates_test.py
new file mode 100644
index 0000000000..a4c49d48bc
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/candidates_test.py
@@ -0,0 +1,95 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for constrained_optimization.python.candidates."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.constrained_optimization.python import candidates
+from tensorflow.python.platform import test
+
+
+class CandidatesTest(test.TestCase):
+
+ def test_inconsistent_shapes_for_best_distribution(self):
+ """An error is raised when parameters have inconsistent shapes."""
+ objective_vector = np.array([1, 2, 3])
+ constraints_matrix = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
+ with self.assertRaises(ValueError):
+ _ = candidates.find_best_candidate_distribution(objective_vector,
+ constraints_matrix)
+
+ def test_inconsistent_shapes_for_best_index(self):
+ """An error is raised when parameters have inconsistent shapes."""
+ objective_vector = np.array([1, 2, 3])
+ constraints_matrix = np.array([[1, 2, 3, 4], [5, 6, 7, 8]])
+ with self.assertRaises(ValueError):
+ _ = candidates.find_best_candidate_index(objective_vector,
+ constraints_matrix)
+
+ def test_best_distribution(self):
+ """Distribution should match known solution."""
+ objective_vector = np.array(
+ [0.03053309, -0.06667082, 0.88355145, 0.46529806])
+ constraints_matrix = np.array(
+ [[-0.60164551, 0.36676229, 0.7856454, -0.8441711],
+ [0.00371592, -0.16392108, -0.59778071, -0.56908492]])
+ distribution = candidates.find_best_candidate_distribution(
+ objective_vector, constraints_matrix)
+ # Verify that the solution is a probability distribution.
+ self.assertTrue(np.all(distribution >= 0))
+ self.assertAlmostEqual(np.sum(distribution), 1.0)
+ # Verify that the solution satisfies the constraints.
+ maximum_constraint_violation = np.amax(
+ np.dot(constraints_matrix, distribution))
+ self.assertLessEqual(maximum_constraint_violation, 0)
+ # Verify that the solution matches that which we expect.
+ expected_distribution = np.array([0.37872711, 0.62127289, 0, 0])
+ self.assertAllClose(expected_distribution, distribution, rtol=0, atol=1e-6)
+
+ def test_best_index_rank_objectives_true(self):
+ """Index should match known solution."""
+ # Objective ranks = [2, 1, 4, 3].
+ objective_vector = np.array(
+ [0.03053309, -0.06667082, 0.88355145, 0.46529806])
+ # Constraint ranks = [[1, 3, 4, 1], [4, 1, 1, 1]].
+ constraints_matrix = np.array(
+ [[-0.60164551, 0.36676229, 0.7856454, -0.8441711],
+ [0.00371592, -0.16392108, -0.59778071, -0.56908492]])
+ # Maximum ranks = [4, 3, 4, 3].
+ index = candidates.find_best_candidate_index(
+ objective_vector, constraints_matrix, rank_objectives=True)
+ self.assertEqual(1, index)
+
+ def test_best_index_rank_objectives_false(self):
+ """Index should match known solution."""
+ # Objective ranks = [2, 1, 4, 3].
+ objective_vector = np.array(
+ [0.03053309, -0.06667082, 0.88355145, 0.46529806])
+ # Constraint ranks = [[1, 3, 4, 1], [4, 1, 1, 1]].
+ constraints_matrix = np.array(
+ [[-0.60164551, 0.36676229, 0.7856454, -0.8441711],
+ [0.00371592, -0.16392108, -0.59778071, -0.56908492]])
+ # Maximum ranks = [4, 3, 4, 1].
+ index = candidates.find_best_candidate_index(
+ objective_vector, constraints_matrix, rank_objectives=False)
+ self.assertEqual(3, index)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
new file mode 100644
index 0000000000..70813fb217
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/constrained_minimization_problem.py
@@ -0,0 +1,123 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines abstract class for `ConstrainedMinimizationProblem`s.
+
+A ConstrainedMinimizationProblem consists of an objective function to minimize,
+and a set of constraint functions that are constrained to be nonpositive.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import six
+
+
+@six.add_metaclass(abc.ABCMeta)
+class ConstrainedMinimizationProblem(object):
+ """Abstract class representing a `ConstrainedMinimizationProblem`.
+
+ A ConstrainedMinimizationProblem consists of an objective function to
+ minimize, and a set of constraint functions that are constrained to be
+ nonpositive.
+
+ In addition to the constraint functions, there may (optionally) be proxy
+ constraint functions: a ConstrainedOptimizer will attempt to penalize these
+ proxy constraint functions so as to satisfy the (non-proxy) constraints. Proxy
+ constraints could be used if the constraints functions are difficult or
+ impossible to optimize (e.g. if they're piecewise constant), in which case the
+ proxy constraints should be some approximation of the original constraints
+ that is well-enough behaved to permit successful optimization.
+ """
+
+ @abc.abstractproperty
+ def objective(self):
+ """Returns the objective function.
+
+ Returns:
+ A 0d tensor that should be minimized.
+ """
+ pass
+
+ @property
+ def num_constraints(self):
+ """Returns the number of constraints.
+
+ Returns:
+ An int containing the number of constraints.
+
+ Raises:
+ ValueError: If the constraints (or proxy_constraints, if present) do not
+ have fully-known shapes, OR if proxy_constraints are present, and the
+ shapes of constraints and proxy_constraints are fully-known, but they're
+ different.
+ """
+ constraints_shape = self.constraints.get_shape()
+ if self.proxy_constraints is None:
+ proxy_constraints_shape = constraints_shape
+ else:
+ proxy_constraints_shape = self.proxy_constraints.get_shape()
+
+ if (constraints_shape is None or proxy_constraints_shape is None or
+ any([ii is None for ii in constraints_shape.as_list()]) or
+ any([ii is None for ii in proxy_constraints_shape.as_list()])):
+ raise ValueError(
+ "constraints and proxy_constraints must have fully-known shapes")
+ if constraints_shape != proxy_constraints_shape:
+ raise ValueError(
+ "constraints and proxy_constraints must have the same shape")
+
+ size = 1
+ for ii in constraints_shape.as_list():
+ size *= ii
+ return int(size)
+
+ @abc.abstractproperty
+ def constraints(self):
+ """Returns the vector of constraint functions.
+
+ Letting g_i be the ith element of the constraints vector, the ith constraint
+ will be g_i <= 0.
+
+ Returns:
+ A tensor of constraint functions.
+ """
+ pass
+
+ # This is a property, instead of an abstract property, since it doesn't need
+ # to be overridden: if proxy_constraints returns None, then there are no
+ # proxy constraints.
+ @property
+ def proxy_constraints(self):
+ """Returns the optional vector of proxy constraint functions.
+
+ The difference between `constraints` and `proxy_constraints` is that, when
+ proxy constraints are present, the `constraints` are merely EVALUATED during
+ optimization, whereas the `proxy_constraints` are DIFFERENTIATED. If there
+ are no proxy constraints, then the `constraints` are both evaluated and
+ differentiated.
+
+ For example, if we want to impose constraints on step functions, then we
+ could use these functions for `constraints`. However, because a step
+ function has zero gradient almost everywhere, we can't differentiate these
+ functions, so we would take `proxy_constraints` to be some differentiable
+ approximation of `constraints`.
+
+ Returns:
+ A tensor of proxy constraint functions.
+ """
+ return None
diff --git a/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py
new file mode 100644
index 0000000000..8055545366
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/constrained_optimizer.py
@@ -0,0 +1,208 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines base class for `ConstrainedOptimizer`s."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import six
+
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import standard_ops
+from tensorflow.python.training import optimizer as train_optimizer
+
+
+@six.add_metaclass(abc.ABCMeta)
+class ConstrainedOptimizer(object):
+ """Base class representing a constrained optimizer.
+
+ A ConstrainedOptimizer wraps a tf.train.Optimizer (or more than one), and
+ applies it to a ConstrainedMinimizationProblem. Unlike a tf.train.Optimizer,
+ which takes a tensor to minimize as a parameter to its minimize() method, a
+ constrained optimizer instead takes a ConstrainedMinimizationProblem.
+ """
+
+ def __init__(self, optimizer):
+ """Constructs a new `ConstrainedOptimizer`.
+
+ Args:
+ optimizer: tf.train.Optimizer, used to optimize the
+ ConstraintedMinimizationProblem.
+
+ Returns:
+ A new `ConstrainedOptimizer`.
+ """
+ self._optimizer = optimizer
+
+ @property
+ def optimizer(self):
+ """Returns the `tf.train.Optimizer` used for optimization."""
+ return self._optimizer
+
+ def minimize_unconstrained(self,
+ minimization_problem,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Returns an `Op` for minimizing the unconstrained problem.
+
+ Unlike `minimize_constrained`, this function ignores the `constraints` (and
+ `proxy_constraints`) portion of the minimization problem entirely, and only
+ minimizes `objective`.
+
+ Args:
+ minimization_problem: ConstrainedMinimizationProblem, the problem to
+ optimize.
+ global_step: as in `tf.train.Optimizer`'s `minimize` method.
+ var_list: as in `tf.train.Optimizer`'s `minimize` method.
+ gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
+ aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
+ colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
+ method.
+ name: as in `tf.train.Optimizer`'s `minimize` method.
+ grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
+
+ Returns:
+ TensorFlow Op.
+ """
+ return self.optimizer.minimize(
+ minimization_problem.objective,
+ global_step=global_step,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ name=name,
+ grad_loss=grad_loss)
+
+ @abc.abstractmethod
+ def minimize_constrained(self,
+ minimization_problem,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Returns an `Op` for minimizing the constrained problem.
+
+ Unlike `minimize_unconstrained`, this function attempts to find a solution
+ that minimizes the `objective` portion of the minimization problem while
+ satisfying the `constraints` portion.
+
+ Args:
+ minimization_problem: ConstrainedMinimizationProblem, the problem to
+ optimize.
+ global_step: as in `tf.train.Optimizer`'s `minimize` method.
+ var_list: as in `tf.train.Optimizer`'s `minimize` method.
+ gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
+ aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
+ colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
+ method.
+ name: as in `tf.train.Optimizer`'s `minimize` method.
+ grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
+
+ Returns:
+ TensorFlow Op.
+ """
+ pass
+
+ def minimize(self,
+ minimization_problem,
+ unconstrained_steps=None,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Returns an `Op` for minimizing the constrained problem.
+
+ This method combines the functionality of `minimize_unconstrained` and
+ `minimize_constrained`. If global_step < unconstrained_steps, it will
+ perform an unconstrained update, and if global_step >= unconstrained_steps,
+ it will perform a constrained update.
+
+ The reason for this functionality is that it may be best to initialize the
+ constrained optimizer with an approximate optimum of the unconstrained
+ problem.
+
+ Args:
+ minimization_problem: ConstrainedMinimizationProblem, the problem to
+ optimize.
+ unconstrained_steps: int, number of steps for which we should perform
+ unconstrained updates, before transitioning to constrained updates.
+ global_step: as in `tf.train.Optimizer`'s `minimize` method.
+ var_list: as in `tf.train.Optimizer`'s `minimize` method.
+ gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
+ aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
+ colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
+ method.
+ name: as in `tf.train.Optimizer`'s `minimize` method.
+ grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
+
+ Returns:
+ TensorFlow Op.
+
+ Raises:
+ ValueError: If unconstrained_steps is provided, but global_step is not.
+ """
+
+ def unconstrained_fn():
+ """Returns an `Op` for minimizing the unconstrained problem."""
+ return self.minimize_unconstrained(
+ minimization_problem=minimization_problem,
+ global_step=global_step,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ name=name,
+ grad_loss=grad_loss)
+
+ def constrained_fn():
+ """Returns an `Op` for minimizing the constrained problem."""
+ return self.minimize_constrained(
+ minimization_problem=minimization_problem,
+ global_step=global_step,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ name=name,
+ grad_loss=grad_loss)
+
+ if unconstrained_steps is not None:
+ if global_step is None:
+ raise ValueError(
+ "global_step cannot be None if unconstrained_steps is provided")
+ unconstrained_steps_tensor = ops.convert_to_tensor(unconstrained_steps)
+ dtype = unconstrained_steps_tensor.dtype
+ return control_flow_ops.cond(
+ standard_ops.cast(global_step, dtype) < unconstrained_steps_tensor,
+ true_fn=unconstrained_fn,
+ false_fn=constrained_fn)
+ else:
+ return constrained_fn()
diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py
new file mode 100644
index 0000000000..01c6e4f08a
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer.py
@@ -0,0 +1,375 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines `AdditiveExternalRegretOptimizer`.
+
+This optimizer minimizes a `ConstrainedMinimizationProblem` by introducing
+Lagrange multipliers, and using `tf.train.Optimizer`s to jointly optimize over
+the model parameters and Lagrange multipliers.
+
+For the purposes of constrained optimization, at least in theory,
+external-regret minimization suffices if the `ConstrainedMinimizationProblem`
+we're optimizing doesn't have any `proxy_constraints`, while swap-regret
+minimization should be used if `proxy_constraints` are present.
+
+For more specifics, please refer to:
+
+> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+> Constrained Optimization".
+> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+The formulation used by the AdditiveExternalRegretOptimizer--which is simply the
+usual Lagrangian formulation--can be found in Definition 1, and is discussed in
+Section 3. This optimizer is most similar to Algorithm 3 in Appendix C.3, with
+the two differences being that it uses proxy constraints (if they're provided)
+in the update of the model parameters, and uses `tf.train.Optimizer`s, instead
+of SGD, for the "inner" updates.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+
+import six
+
+from tensorflow.contrib.constrained_optimization.python import constrained_optimizer
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import standard_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.training import optimizer as train_optimizer
+
+
+def _project_multipliers_wrt_euclidean_norm(multipliers, radius):
+ """Projects its argument onto the feasible region.
+
+ The feasible region is the set of all vectors with nonnegative elements that
+ sum to at most `radius`.
+
+ Args:
+ multipliers: 1d tensor, the Lagrange multipliers to project.
+ radius: float, the radius of the feasible region.
+
+ Returns:
+ The 1d tensor that results from projecting `multipliers` onto the feasible
+ region w.r.t. the Euclidean norm.
+
+ Raises:
+ ValueError: if the `multipliers` tensor does not have a fully-known shape,
+ or is not one-dimensional.
+ """
+ multipliers_shape = multipliers.get_shape()
+ if multipliers_shape is None:
+ raise ValueError("multipliers must have known shape")
+ if multipliers_shape.ndims != 1:
+ raise ValueError(
+ "multipliers must be one dimensional (instead is %d-dimensional)" %
+ multipliers_shape.ndims)
+ dimension = multipliers_shape[0].value
+ if dimension is None:
+ raise ValueError("multipliers must have fully-known shape")
+
+ def while_loop_condition(iteration, multipliers, inactive, old_inactive):
+ """Returns false if the while loop should terminate."""
+ del multipliers # Needed by the body, but not the condition.
+ not_done = (iteration < dimension)
+ not_converged = standard_ops.reduce_any(
+ standard_ops.not_equal(inactive, old_inactive))
+ return standard_ops.logical_and(not_done, not_converged)
+
+ def while_loop_body(iteration, multipliers, inactive, old_inactive):
+ """Performs one iteration of the projection."""
+ del old_inactive # Needed by the condition, but not the body.
+ iteration += 1
+ scale = standard_ops.minimum(
+ 0.0,
+ (radius - standard_ops.reduce_sum(multipliers)) / standard_ops.maximum(
+ 1.0, standard_ops.reduce_sum(inactive)))
+ multipliers += scale * inactive
+ new_inactive = standard_ops.to_float(multipliers > 0)
+ multipliers *= new_inactive
+ return (iteration, multipliers, new_inactive, inactive)
+
+ iteration = standard_ops.constant(0)
+ inactive = standard_ops.ones_like(multipliers)
+
+ # We actually want a do-while loop, so we explicitly call while_loop_body()
+ # once before tf.while_loop().
+ iteration, multipliers, inactive, old_inactive = while_loop_body(
+ iteration, multipliers, inactive, inactive)
+ iteration, multipliers, inactive, old_inactive = control_flow_ops.while_loop(
+ while_loop_condition,
+ while_loop_body,
+ loop_vars=(iteration, multipliers, inactive, old_inactive),
+ name="euclidean_projection")
+
+ return multipliers
+
+
+@six.add_metaclass(abc.ABCMeta)
+class _ExternalRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
+ """Base class representing an `_ExternalRegretOptimizer`.
+
+ This class contains most of the logic for performing constrained
+ optimization, minimizing external regret for the constraints player. What it
+ *doesn't* do is keep track of the internal state (the Lagrange multipliers).
+ Instead, the state is accessed via the _initial_state(),
+ _lagrange_multipliers(), _constraint_grad_and_var() and _projection_op()
+ methods.
+
+ The reason for this is that we want to make it easy to implement different
+ representations of the internal state.
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ The formulation used by `_ExternalRegretOptimizer`s--which is simply the usual
+ Lagrangian formulation--can be found in Definition 1, and is discussed in
+ Section 3. Such optimizers are most similar to Algorithm 3 in Appendix C.3.
+ """
+
+ def __init__(self, optimizer, constraint_optimizer=None):
+ """Constructs a new `_ExternalRegretOptimizer`.
+
+ The difference between `optimizer` and `constraint_optimizer` (if the latter
+ is provided) is that the former is used for learning the model parameters,
+ while the latter us used for the Lagrange multipliers. If no
+ `constraint_optimizer` is provided, then `optimizer` is used for both.
+
+ Args:
+ optimizer: tf.train.Optimizer, used to optimize the objective and
+ proxy_constraints portion of the ConstrainedMinimizationProblem. If
+ constraint_optimizer is not provided, this will also be used to optimize
+ the Lagrange multipliers.
+ constraint_optimizer: optional tf.train.Optimizer, used to optimize the
+ Lagrange multipliers.
+
+ Returns:
+ A new `_ExternalRegretOptimizer`.
+ """
+ super(_ExternalRegretOptimizer, self).__init__(optimizer=optimizer)
+ self._constraint_optimizer = constraint_optimizer
+
+ @property
+ def constraint_optimizer(self):
+ """Returns the `tf.train.Optimizer` used for the Lagrange multipliers."""
+ return self._constraint_optimizer
+
+ @abc.abstractmethod
+ def _initial_state(self, num_constraints):
+ pass
+
+ @abc.abstractmethod
+ def _lagrange_multipliers(self, state):
+ pass
+
+ @abc.abstractmethod
+ def _constraint_grad_and_var(self, state, gradient):
+ pass
+
+ @abc.abstractmethod
+ def _projection_op(self, state, name=None):
+ pass
+
+ def minimize_constrained(self,
+ minimization_problem,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Returns an `Op` for minimizing the constrained problem.
+
+ The `optimizer` constructor parameter will be used to update the model
+ parameters, while the Lagrange multipliers will be updated using
+ `constrained_optimizer` (if provided) or `optimizer` (if not).
+
+ Args:
+ minimization_problem: ConstrainedMinimizationProblem, the problem to
+ optimize.
+ global_step: as in `tf.train.Optimizer`'s `minimize` method.
+ var_list: as in `tf.train.Optimizer`'s `minimize` method.
+ gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
+ aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
+ colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
+ method.
+ name: as in `tf.train.Optimizer`'s `minimize` method.
+ grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
+
+ Returns:
+ TensorFlow Op.
+ """
+ objective = minimization_problem.objective
+
+ constraints = minimization_problem.constraints
+ proxy_constraints = minimization_problem.proxy_constraints
+ if proxy_constraints is None:
+ proxy_constraints = constraints
+ # Flatten both constraints tensors to 1d.
+ num_constraints = minimization_problem.num_constraints
+ constraints = standard_ops.reshape(constraints, shape=(num_constraints,))
+ proxy_constraints = standard_ops.reshape(
+ proxy_constraints, shape=(num_constraints,))
+
+ # We use a lambda to initialize the state so that, if this function call is
+ # inside the scope of a tf.control_dependencies() block, the dependencies
+ # will not be applied to the initializer.
+ state = standard_ops.Variable(
+ lambda: self._initial_state(num_constraints),
+ trainable=False,
+ name="external_regret_optimizer_state")
+
+ multipliers = self._lagrange_multipliers(state)
+ loss = (
+ objective + standard_ops.tensordot(multipliers, proxy_constraints, 1))
+ multipliers_gradient = constraints
+
+ update_ops = []
+ if self.constraint_optimizer is None:
+ # If we don't have a separate constraint_optimizer, then we use
+ # self._optimizer for both the update of the model parameters, and that of
+ # the internal state.
+ grads_and_vars = self.optimizer.compute_gradients(
+ loss,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss)
+ grads_and_vars.append(
+ self._constraint_grad_and_var(state, multipliers_gradient))
+ update_ops.append(
+ self.optimizer.apply_gradients(grads_and_vars, name="update"))
+ else:
+ # If we have a separate constraint_optimizer, then we use self._optimizer
+ # for the update of the model parameters, and self._constraint_optimizer
+ # for that of the internal state.
+ grads_and_vars = self.optimizer.compute_gradients(
+ loss,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss)
+ multiplier_grads_and_vars = [
+ self._constraint_grad_and_var(state, multipliers_gradient)
+ ]
+
+ gradients = [
+ gradient for gradient, _ in grads_and_vars + multiplier_grads_and_vars
+ if gradient is not None
+ ]
+ with ops.control_dependencies(gradients):
+ update_ops.append(
+ self.optimizer.apply_gradients(grads_and_vars, name="update"))
+ update_ops.append(
+ self.constraint_optimizer.apply_gradients(
+ multiplier_grads_and_vars, name="optimizer_state_update"))
+
+ with ops.control_dependencies(update_ops):
+ if global_step is None:
+ # If we don't have a global step, just project, and we're done.
+ return self._projection_op(state, name=name)
+ else:
+ # If we have a global step, then we need to increment it in addition to
+ # projecting.
+ projection_op = self._projection_op(state, name="project")
+ with ops.colocate_with(global_step):
+ global_step_op = state_ops.assign_add(
+ global_step, 1, name="global_step_increment")
+ return control_flow_ops.group(projection_op, global_step_op, name=name)
+
+
+class AdditiveExternalRegretOptimizer(_ExternalRegretOptimizer):
+ """A `ConstrainedOptimizer` based on external-regret minimization.
+
+ This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly
+ minimize over the model parameters, and maximize over Lagrange multipliers,
+ with the latter maximization using additive updates and an algorithm that
+ minimizes external regret.
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ The formulation used by this optimizer--which is simply the usual Lagrangian
+ formulation--can be found in Definition 1, and is discussed in Section 3. It
+ is most similar to Algorithm 3 in Appendix C.3, with the two differences being
+ that it uses proxy constraints (if they're provided) in the update of the
+ model parameters, and uses `tf.train.Optimizer`s, instead of SGD, for the
+ "inner" updates.
+ """
+
+ def __init__(self,
+ optimizer,
+ constraint_optimizer=None,
+ maximum_multiplier_radius=None):
+ """Constructs a new `AdditiveExternalRegretOptimizer`.
+
+ Args:
+ optimizer: tf.train.Optimizer, used to optimize the objective and
+ proxy_constraints portion of ConstrainedMinimizationProblem. If
+ constraint_optimizer is not provided, this will also be used to optimize
+ the Lagrange multipliers.
+ constraint_optimizer: optional tf.train.Optimizer, used to optimize the
+ Lagrange multipliers.
+ maximum_multiplier_radius: float, an optional upper bound to impose on the
+ sum of the Lagrange multipliers.
+
+ Returns:
+ A new `AdditiveExternalRegretOptimizer`.
+
+ Raises:
+ ValueError: If the maximum_multiplier_radius parameter is nonpositive.
+ """
+ super(AdditiveExternalRegretOptimizer, self).__init__(
+ optimizer=optimizer, constraint_optimizer=constraint_optimizer)
+
+ if maximum_multiplier_radius and (maximum_multiplier_radius <= 0.0):
+ raise ValueError("maximum_multiplier_radius must be strictly positive")
+
+ self._maximum_multiplier_radius = maximum_multiplier_radius
+
+ def _initial_state(self, num_constraints):
+ # For an AdditiveExternalRegretOptimizer, the internal state is simply a
+ # tensor of Lagrange multipliers with shape (m,), where m is the number of
+ # constraints.
+ return standard_ops.zeros((num_constraints,), dtype=dtypes.float32)
+
+ def _lagrange_multipliers(self, state):
+ return state
+
+ def _constraint_grad_and_var(self, state, gradient):
+ # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True?
+ return (-gradient, state)
+
+ def _projection_op(self, state, name=None):
+ with ops.colocate_with(state):
+ if self._maximum_multiplier_radius:
+ projected_multipliers = _project_multipliers_wrt_euclidean_norm(
+ state, self._maximum_multiplier_radius)
+ else:
+ projected_multipliers = standard_ops.maximum(state, 0.0)
+ return state_ops.assign(state, projected_multipliers, name=name)
diff --git a/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
new file mode 100644
index 0000000000..9b4bf62710
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/external_regret_optimizer_test.py
@@ -0,0 +1,136 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for constrained_optimization.python.external_regret_optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.constrained_optimization.python import external_regret_optimizer
+from tensorflow.contrib.constrained_optimization.python import test_util
+
+from tensorflow.python.ops import standard_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import gradient_descent
+
+
+class AdditiveExternalRegretOptimizerWrapper(
+ external_regret_optimizer.AdditiveExternalRegretOptimizer):
+ """Testing wrapper class around AdditiveExternalRegretOptimizer.
+
+ This class is identical to AdditiveExternalRegretOptimizer, except that it
+ caches the internal optimization state when _lagrange_multipliers() is called,
+ so that we can test that the Lagrange multipliers take on their expected
+ values.
+ """
+
+ def __init__(self,
+ optimizer,
+ constraint_optimizer=None,
+ maximum_multiplier_radius=None):
+ """Same as AdditiveExternalRegretOptimizer.__init__."""
+ super(AdditiveExternalRegretOptimizerWrapper, self).__init__(
+ optimizer=optimizer,
+ constraint_optimizer=constraint_optimizer,
+ maximum_multiplier_radius=maximum_multiplier_radius)
+ self._cached_lagrange_multipliers = None
+
+ @property
+ def lagrange_multipliers(self):
+ """Returns the cached Lagrange multipliers."""
+ return self._cached_lagrange_multipliers
+
+ def _lagrange_multipliers(self, state):
+ """Caches the internal state for testing."""
+ self._cached_lagrange_multipliers = super(
+ AdditiveExternalRegretOptimizerWrapper,
+ self)._lagrange_multipliers(state)
+ return self._cached_lagrange_multipliers
+
+
+class ExternalRegretOptimizerTest(test.TestCase):
+
+ def test_project_multipliers_wrt_euclidean_norm(self):
+ """Tests Euclidean projection routine on some known values."""
+ multipliers1 = standard_ops.constant([-0.1, -0.6, -0.3])
+ expected_projected_multipliers1 = np.array([0.0, 0.0, 0.0])
+
+ multipliers2 = standard_ops.constant([-0.1, 0.6, 0.3])
+ expected_projected_multipliers2 = np.array([0.0, 0.6, 0.3])
+
+ multipliers3 = standard_ops.constant([0.4, 0.7, -0.2, 0.5, 0.1])
+ expected_projected_multipliers3 = np.array([0.2, 0.5, 0.0, 0.3, 0.0])
+
+ with self.test_session() as session:
+ projected_multipliers1 = session.run(
+ external_regret_optimizer._project_multipliers_wrt_euclidean_norm(
+ multipliers1, 1.0))
+ projected_multipliers2 = session.run(
+ external_regret_optimizer._project_multipliers_wrt_euclidean_norm(
+ multipliers2, 1.0))
+ projected_multipliers3 = session.run(
+ external_regret_optimizer._project_multipliers_wrt_euclidean_norm(
+ multipliers3, 1.0))
+
+ self.assertAllClose(
+ expected_projected_multipliers1,
+ projected_multipliers1,
+ rtol=0,
+ atol=1e-6)
+ self.assertAllClose(
+ expected_projected_multipliers2,
+ projected_multipliers2,
+ rtol=0,
+ atol=1e-6)
+ self.assertAllClose(
+ expected_projected_multipliers3,
+ projected_multipliers3,
+ rtol=0,
+ atol=1e-6)
+
+ def test_additive_external_regret_optimizer(self):
+ """Tests that the Lagrange multipliers update as expected."""
+ minimization_problem = test_util.ConstantMinimizationProblem(
+ np.array([0.6, -0.1, 0.4]))
+ optimizer = AdditiveExternalRegretOptimizerWrapper(
+ gradient_descent.GradientDescentOptimizer(1.0),
+ maximum_multiplier_radius=1.0)
+ train_op = optimizer.minimize_constrained(minimization_problem)
+
+ expected_multipliers = [
+ np.array([0.0, 0.0, 0.0]),
+ np.array([0.6, 0.0, 0.4]),
+ np.array([0.7, 0.0, 0.3]),
+ np.array([0.8, 0.0, 0.2]),
+ np.array([0.9, 0.0, 0.1]),
+ np.array([1.0, 0.0, 0.0]),
+ np.array([1.0, 0.0, 0.0]),
+ ]
+
+ multipliers = []
+ with self.test_session() as session:
+ session.run(standard_ops.global_variables_initializer())
+ while len(multipliers) < len(expected_multipliers):
+ multipliers.append(session.run(optimizer.lagrange_multipliers))
+ session.run(train_op)
+
+ for expected, actual in zip(expected_multipliers, multipliers):
+ self.assertAllClose(expected, actual, rtol=0, atol=1e-6)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
new file mode 100644
index 0000000000..04014ab4ae
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer.py
@@ -0,0 +1,595 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Defines `{Additive,Multiplicative}SwapRegretOptimizer`s.
+
+These optimizers minimize a `ConstrainedMinimizationProblem` by using a
+swap-regret minimizing algorithm (either SGD or multiplicative weights) to learn
+what weights should be associated with the objective function and constraints.
+These algorithms do *not* use Lagrange multipliers, but the idea is similar.
+The main differences between the formulation used here, and the standard
+Lagrangian formulation, are that (i) the objective function is weighted, in
+addition to the constraints, and (ii) we learn a matrix of weights, instead of a
+vector.
+
+For the purposes of constrained optimization, at least in theory,
+external-regret minimization suffices if the `ConstrainedMinimizationProblem`
+we're optimizing doesn't have any `proxy_constraints`, while swap-regret
+minimization should be used if `proxy_constraints` are present.
+
+For more specifics, please refer to:
+
+> Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+> Constrained Optimization".
+> [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+The formulation used by both of the SwapRegretOptimizers can be found in
+Definition 2, and is discussed in Section 4. The
+`MultiplicativeSwapRegretOptimizer` is most similar to Algorithm 2 in Section 4,
+with the difference being that it uses `tf.train.Optimizer`s, instead of SGD,
+for the "inner" updates. The `AdditiveSwapRegretOptimizer` differs further in
+that it performs additive (instead of multiplicative) updates of the stochastic
+matrix.
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import abc
+import math
+
+import six
+
+from tensorflow.contrib.constrained_optimization.python import constrained_optimizer
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import standard_ops
+from tensorflow.python.ops import state_ops
+from tensorflow.python.training import optimizer as train_optimizer
+
+
+def _maximal_eigenvector_power_method(matrix,
+ epsilon=1e-6,
+ maximum_iterations=100):
+ """Returns the maximal right-eigenvector of `matrix` using the power method.
+
+ Args:
+ matrix: 2D Tensor, the matrix of which we will find the maximal
+ right-eigenvector.
+ epsilon: nonnegative float, if two iterations of the power method differ (in
+ L2 norm) by no more than epsilon, we will terminate.
+ maximum_iterations: nonnegative int, if we perform this many iterations, we
+ will terminate.
+
+ Result:
+ The maximal right-eigenvector of `matrix`.
+
+ Raises:
+ ValueError: If the epsilon or maximum_iterations parameters violate their
+ bounds.
+ """
+ if epsilon <= 0.0:
+ raise ValueError("epsilon must be strictly positive")
+ if maximum_iterations <= 0:
+ raise ValueError("maximum_iterations must be strictly positive")
+
+ def while_loop_condition(iteration, eigenvector, old_eigenvector):
+ """Returns false if the while loop should terminate."""
+ not_done = (iteration < maximum_iterations)
+ not_converged = (standard_ops.norm(eigenvector - old_eigenvector) > epsilon)
+ return standard_ops.logical_and(not_done, not_converged)
+
+ def while_loop_body(iteration, eigenvector, old_eigenvector):
+ """Performs one iteration of the power method."""
+ del old_eigenvector # Needed by the condition, but not the body.
+ iteration += 1
+ # We need to use tf.matmul() and tf.expand_dims(), instead of
+ # tf.tensordot(), since the former will infer the shape of the result, while
+ # the latter will not (tf.while_loop() needs the shapes).
+ new_eigenvector = standard_ops.matmul(
+ matrix, standard_ops.expand_dims(eigenvector, 1))[:, 0]
+ new_eigenvector /= standard_ops.norm(new_eigenvector)
+ return (iteration, new_eigenvector, eigenvector)
+
+ iteration = standard_ops.constant(0)
+ eigenvector = standard_ops.ones_like(matrix[:, 0])
+ eigenvector /= standard_ops.norm(eigenvector)
+
+ # We actually want a do-while loop, so we explicitly call while_loop_body()
+ # once before tf.while_loop().
+ iteration, eigenvector, old_eigenvector = while_loop_body(
+ iteration, eigenvector, eigenvector)
+ iteration, eigenvector, old_eigenvector = control_flow_ops.while_loop(
+ while_loop_condition,
+ while_loop_body,
+ loop_vars=(iteration, eigenvector, old_eigenvector),
+ name="power_method")
+
+ return eigenvector
+
+
+def _project_stochastic_matrix_wrt_euclidean_norm(matrix):
+ """Projects its argument onto the set of left-stochastic matrices.
+
+ This algorithm is O(n^3) at worst, where `matrix` is n*n. It can be done in
+ O(n^2 * log(n)) time by sorting each column (and maybe better with a different
+ algorithm), but the algorithm implemented here is easier to implement in
+ TensorFlow.
+
+ Args:
+ matrix: 2d square tensor, the matrix to project.
+
+ Returns:
+ The 2d square tensor that results from projecting `matrix` onto the set of
+ left-stochastic matrices w.r.t. the Euclidean norm applied column-wise
+ (i.e. the Frobenius norm).
+
+ Raises:
+ ValueError: if the `matrix` tensor does not have a fully-known shape, or is
+ not two-dimensional and square.
+ """
+ matrix_shape = matrix.get_shape()
+ if matrix_shape is None:
+ raise ValueError("matrix must have known shape")
+ if matrix_shape.ndims != 2:
+ raise ValueError(
+ "matrix must be two dimensional (instead is %d-dimensional)" %
+ matrix_shape.ndims)
+ if matrix_shape[0] != matrix_shape[1]:
+ raise ValueError("matrix must be be square (instead has shape (%d,%d))" %
+ (matrix_shape[0], matrix_shape[1]))
+ dimension = matrix_shape[0].value
+ if dimension is None:
+ raise ValueError("matrix must have fully-known shape")
+
+ def while_loop_condition(iteration, matrix, inactive, old_inactive):
+ """Returns false if the while loop should terminate."""
+ del matrix # Needed by the body, but not the condition.
+ not_done = (iteration < dimension)
+ not_converged = standard_ops.reduce_any(
+ standard_ops.not_equal(inactive, old_inactive))
+ return standard_ops.logical_and(not_done, not_converged)
+
+ def while_loop_body(iteration, matrix, inactive, old_inactive):
+ """Performs one iteration of the projection."""
+ del old_inactive # Needed by the condition, but not the body.
+ iteration += 1
+ scale = (1.0 - standard_ops.reduce_sum(
+ matrix, axis=0, keep_dims=True)) / standard_ops.maximum(
+ 1.0, standard_ops.reduce_sum(inactive, axis=0, keep_dims=True))
+ matrix += scale * inactive
+ new_inactive = standard_ops.to_float(matrix > 0)
+ matrix *= new_inactive
+ return (iteration, matrix, new_inactive, inactive)
+
+ iteration = standard_ops.constant(0)
+ inactive = standard_ops.ones_like(matrix)
+
+ # We actually want a do-while loop, so we explicitly call while_loop_body()
+ # once before tf.while_loop().
+ iteration, matrix, inactive, old_inactive = while_loop_body(
+ iteration, matrix, inactive, inactive)
+ iteration, matrix, inactive, old_inactive = control_flow_ops.while_loop(
+ while_loop_condition,
+ while_loop_body,
+ loop_vars=(iteration, matrix, inactive, old_inactive),
+ name="euclidean_projection")
+
+ return matrix
+
+
+def _project_log_stochastic_matrix_wrt_kl_divergence(log_matrix):
+ """Projects its argument onto the set of log-left-stochastic matrices.
+
+ Args:
+ log_matrix: 2d square tensor, the element-wise logarithm of the matrix to
+ project.
+
+ Returns:
+ The 2d square tensor that results from projecting exp(`matrix`) onto the set
+ of left-stochastic matrices w.r.t. the KL-divergence applied column-wise.
+ """
+
+ # For numerical reasons, make sure that the largest matrix element is zero
+ # before exponentiating.
+ log_matrix -= standard_ops.reduce_max(log_matrix, axis=0, keep_dims=True)
+ log_matrix -= standard_ops.log(
+ standard_ops.reduce_sum(
+ standard_ops.exp(log_matrix), axis=0, keep_dims=True))
+ return log_matrix
+
+
+@six.add_metaclass(abc.ABCMeta)
+class _SwapRegretOptimizer(constrained_optimizer.ConstrainedOptimizer):
+ """Base class representing a `_SwapRegretOptimizer`.
+
+ This class contains most of the logic for performing constrained optimization,
+ minimizing external regret for the constraints player. What it *doesn't* do is
+ keep track of the internal state (the stochastic matrix). Instead, the state
+ is accessed via the _initial_state(), _stochastic_matrix(),
+ _constraint_grad_and_var() and _projection_op() methods.
+
+ The reason for this is that we want to make it easy to implement different
+ representations of the internal state. For example, for additive updates, it's
+ most natural to store the stochastic matrix directly, whereas for
+ multiplicative updates, it's most natural to store its element-wise logarithm.
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ The formulation used by `_SwapRegretOptimizer`s can be found in Definition 2,
+ and is discussed in Section 4. Such optimizers are most similar to Algorithm
+ 2 in Section 4. Most notably, the internal state is a left-stochastic matrix
+ of shape (m+1,m+1), where m is the number of constraints.
+ """
+
+ def __init__(self, optimizer, constraint_optimizer=None):
+ """Constructs a new `_SwapRegretOptimizer`.
+
+ The difference between `optimizer` and `constraint_optimizer` (if the latter
+ is provided) is that the former is used for learning the model parameters,
+ while the latter us used for the update to the constraint/objective weight
+ matrix (the analogue of Lagrange multipliers). If no `constraint_optimizer`
+ is provided, then `optimizer` is used for both.
+
+ Args:
+ optimizer: tf.train.Optimizer, used to optimize the objective and
+ proxy_constraints portion of ConstrainedMinimizationProblem. If
+ constraint_optimizer is not provided, this will also be used to optimize
+ the Lagrange multiplier analogues.
+ constraint_optimizer: optional tf.train.Optimizer, used to optimize the
+ Lagrange multiplier analogues.
+
+ Returns:
+ A new `_SwapRegretOptimizer`.
+ """
+ super(_SwapRegretOptimizer, self).__init__(optimizer=optimizer)
+ self._constraint_optimizer = constraint_optimizer
+
+ @property
+ def constraint_optimizer(self):
+ """Returns the `tf.train.Optimizer` used for the matrix."""
+ return self._constraint_optimizer
+
+ @abc.abstractmethod
+ def _initial_state(self, num_constraints):
+ pass
+
+ @abc.abstractmethod
+ def _stochastic_matrix(self, state):
+ pass
+
+ def _distribution(self, state):
+ distribution = _maximal_eigenvector_power_method(
+ self._stochastic_matrix(state))
+ distribution = standard_ops.abs(distribution)
+ distribution /= standard_ops.reduce_sum(distribution)
+ return distribution
+
+ @abc.abstractmethod
+ def _constraint_grad_and_var(self, state, gradient):
+ pass
+
+ @abc.abstractmethod
+ def _projection_op(self, state, name=None):
+ pass
+
+ def minimize_constrained(self,
+ minimization_problem,
+ global_step=None,
+ var_list=None,
+ gate_gradients=train_optimizer.Optimizer.GATE_OP,
+ aggregation_method=None,
+ colocate_gradients_with_ops=False,
+ name=None,
+ grad_loss=None):
+ """Returns an `Op` for minimizing the constrained problem.
+
+ The `optimizer` constructor parameter will be used to update the model
+ parameters, while the constraint/objective weight matrix (the analogue of
+ Lagrange multipliers) will be updated using `constrained_optimizer` (if
+ provided) or `optimizer` (if not). Whether the matrix updates are additive
+ or multiplicative depends on the derived class.
+
+ Args:
+ minimization_problem: ConstrainedMinimizationProblem, the problem to
+ optimize.
+ global_step: as in `tf.train.Optimizer`'s `minimize` method.
+ var_list: as in `tf.train.Optimizer`'s `minimize` method.
+ gate_gradients: as in `tf.train.Optimizer`'s `minimize` method.
+ aggregation_method: as in `tf.train.Optimizer`'s `minimize` method.
+ colocate_gradients_with_ops: as in `tf.train.Optimizer`'s `minimize`
+ method.
+ name: as in `tf.train.Optimizer`'s `minimize` method.
+ grad_loss: as in `tf.train.Optimizer`'s `minimize` method.
+
+ Returns:
+ TensorFlow Op.
+ """
+ objective = minimization_problem.objective
+
+ constraints = minimization_problem.constraints
+ proxy_constraints = minimization_problem.proxy_constraints
+ if proxy_constraints is None:
+ proxy_constraints = constraints
+ # Flatten both constraints tensors to 1d.
+ num_constraints = minimization_problem.num_constraints
+ constraints = standard_ops.reshape(constraints, shape=(num_constraints,))
+ proxy_constraints = standard_ops.reshape(
+ proxy_constraints, shape=(num_constraints,))
+
+ # We use a lambda to initialize the state so that, if this function call is
+ # inside the scope of a tf.control_dependencies() block, the dependencies
+ # will not be applied to the initializer.
+ state = standard_ops.Variable(
+ lambda: self._initial_state(num_constraints),
+ trainable=False,
+ name="swap_regret_optimizer_state")
+
+ zero_and_constraints = standard_ops.concat(
+ (standard_ops.zeros((1,)), constraints), axis=0)
+ objective_and_proxy_constraints = standard_ops.concat(
+ (standard_ops.expand_dims(objective, 0), proxy_constraints), axis=0)
+
+ distribution = self._distribution(state)
+ loss = standard_ops.tensordot(distribution, objective_and_proxy_constraints,
+ 1)
+ matrix_gradient = standard_ops.matmul(
+ standard_ops.expand_dims(zero_and_constraints, 1),
+ standard_ops.expand_dims(distribution, 0))
+
+ update_ops = []
+ if self.constraint_optimizer is None:
+ # If we don't have a separate constraint_optimizer, then we use
+ # self._optimizer for both the update of the model parameters, and that of
+ # the internal state.
+ grads_and_vars = self.optimizer.compute_gradients(
+ loss,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss)
+ grads_and_vars.append(
+ self._constraint_grad_and_var(state, matrix_gradient))
+ update_ops.append(
+ self.optimizer.apply_gradients(grads_and_vars, name="update"))
+ else:
+ # If we have a separate constraint_optimizer, then we use self._optimizer
+ # for the update of the model parameters, and self._constraint_optimizer
+ # for that of the internal state.
+ grads_and_vars = self.optimizer.compute_gradients(
+ loss,
+ var_list=var_list,
+ gate_gradients=gate_gradients,
+ aggregation_method=aggregation_method,
+ colocate_gradients_with_ops=colocate_gradients_with_ops,
+ grad_loss=grad_loss)
+ matrix_grads_and_vars = [
+ self._constraint_grad_and_var(state, matrix_gradient)
+ ]
+
+ gradients = [
+ gradient for gradient, _ in grads_and_vars + matrix_grads_and_vars
+ if gradient is not None
+ ]
+ with ops.control_dependencies(gradients):
+ update_ops.append(
+ self.optimizer.apply_gradients(grads_and_vars, name="update"))
+ update_ops.append(
+ self.constraint_optimizer.apply_gradients(
+ matrix_grads_and_vars, name="optimizer_state_update"))
+
+ with ops.control_dependencies(update_ops):
+ if global_step is None:
+ # If we don't have a global step, just project, and we're done.
+ return self._projection_op(state, name=name)
+ else:
+ # If we have a global step, then we need to increment it in addition to
+ # projecting.
+ projection_op = self._projection_op(state, name="project")
+ with ops.colocate_with(global_step):
+ global_step_op = state_ops.assign_add(
+ global_step, 1, name="global_step_increment")
+ return control_flow_ops.group(projection_op, global_step_op, name=name)
+
+
+class AdditiveSwapRegretOptimizer(_SwapRegretOptimizer):
+ """A `ConstrainedOptimizer` based on swap-regret minimization.
+
+ This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly
+ minimize over the model parameters, and maximize over constraint/objective
+ weight matrix (the analogue of Lagrange multipliers), with the latter
+ maximization using additive updates and an algorithm that minimizes swap
+ regret.
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ The formulation used by this optimizer can be found in Definition 2, and is
+ discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with
+ the differences being that it uses `tf.train.Optimizer`s, instead of SGD, for
+ the "inner" updates, and performs additive (instead of multiplicative) updates
+ of the stochastic matrix.
+ """
+
+ def __init__(self, optimizer, constraint_optimizer=None):
+ """Constructs a new `AdditiveSwapRegretOptimizer`.
+
+ Args:
+ optimizer: tf.train.Optimizer, used to optimize the objective and
+ proxy_constraints portion of ConstrainedMinimizationProblem. If
+ constraint_optimizer is not provided, this will also be used to optimize
+ the Lagrange multiplier analogues.
+ constraint_optimizer: optional tf.train.Optimizer, used to optimize the
+ Lagrange multiplier analogues.
+
+ Returns:
+ A new `AdditiveSwapRegretOptimizer`.
+ """
+ # TODO(acotter): add a parameter determining the initial values of the
+ # matrix elements (like initial_multiplier_radius in
+ # MultiplicativeSwapRegretOptimizer).
+ super(AdditiveSwapRegretOptimizer, self).__init__(
+ optimizer=optimizer, constraint_optimizer=constraint_optimizer)
+
+ def _initial_state(self, num_constraints):
+ # For an AdditiveSwapRegretOptimizer, the internal state is a tensor of
+ # shape (m+1,m+1), where m is the number of constraints, representing a
+ # left-stochastic matrix.
+ dimension = num_constraints + 1
+ # Initialize by putting all weight on the objective, and none on the
+ # constraints.
+ return standard_ops.concat(
+ (standard_ops.ones(
+ (1, dimension)), standard_ops.zeros((dimension - 1, dimension))),
+ axis=0)
+
+ def _stochastic_matrix(self, state):
+ return state
+
+ def _constraint_grad_and_var(self, state, gradient):
+ # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True?
+ return (-gradient, state)
+
+ def _projection_op(self, state, name=None):
+ with ops.colocate_with(state):
+ return state_ops.assign(
+ state,
+ _project_stochastic_matrix_wrt_euclidean_norm(state),
+ name=name)
+
+
+class MultiplicativeSwapRegretOptimizer(_SwapRegretOptimizer):
+ """A `ConstrainedOptimizer` based on swap-regret minimization.
+
+ This `ConstrainedOptimizer` uses the given `tf.train.Optimizer`s to jointly
+ minimize over the model parameters, and maximize over constraint/objective
+ weight matrix (the analogue of Lagrange multipliers), with the latter
+ maximization using multiplicative updates and an algorithm that minimizes swap
+ regret.
+
+ For more specifics, please refer to:
+
+ > Cotter, Jiang and Sridharan. "Two-Player Games for Efficient Non-Convex
+ > Constrained Optimization".
+ > [https://arxiv.org/abs/1804.06500](https://arxiv.org/abs/1804.06500)
+
+ The formulation used by this optimizer can be found in Definition 2, and is
+ discussed in Section 4. It is most similar to Algorithm 2 in Section 4, with
+ the difference being that it uses `tf.train.Optimizer`s, instead of SGD, for
+ the "inner" updates.
+ """
+
+ def __init__(self,
+ optimizer,
+ constraint_optimizer=None,
+ minimum_multiplier_radius=1e-3,
+ initial_multiplier_radius=None):
+ """Constructs a new `MultiplicativeSwapRegretOptimizer`.
+
+ Args:
+ optimizer: tf.train.Optimizer, used to optimize the objective and
+ proxy_constraints portion of ConstrainedMinimizationProblem. If
+ constraint_optimizer is not provided, this will also be used to optimize
+ the Lagrange multiplier analogues.
+ constraint_optimizer: optional tf.train.Optimizer, used to optimize the
+ Lagrange multiplier analogues.
+ minimum_multiplier_radius: float, each element of the matrix will be lower
+ bounded by `minimum_multiplier_radius` divided by one plus the number of
+ constraints.
+ initial_multiplier_radius: float, the initial value of each element of the
+ matrix associated with a constraint (i.e. excluding those elements
+ associated with the objective) will be `initial_multiplier_radius`
+ divided by one plus the number of constraints. Defaults to the value of
+ `minimum_multiplier_radius`.
+
+ Returns:
+ A new `MultiplicativeSwapRegretOptimizer`.
+
+ Raises:
+ ValueError: If the two radius parameters are inconsistent.
+ """
+ super(MultiplicativeSwapRegretOptimizer, self).__init__(
+ optimizer=optimizer, constraint_optimizer=constraint_optimizer)
+
+ if (minimum_multiplier_radius <= 0.0) or (minimum_multiplier_radius >= 1.0):
+ raise ValueError("minimum_multiplier_radius must be in the range (0,1)")
+ if initial_multiplier_radius is None:
+ initial_multiplier_radius = minimum_multiplier_radius
+ elif (initial_multiplier_radius <
+ minimum_multiplier_radius) or (minimum_multiplier_radius > 1.0):
+ raise ValueError("initial_multiplier_radius must be in the range "
+ "[minimum_multiplier_radius,1]")
+
+ self._minimum_multiplier_radius = minimum_multiplier_radius
+ self._initial_multiplier_radius = initial_multiplier_radius
+
+ def _initial_state(self, num_constraints):
+ # For a MultiplicativeSwapRegretOptimizer, the internal state is a tensor of
+ # shape (m+1,m+1), where m is the number of constraints, representing the
+ # element-wise logarithm of a left-stochastic matrix.
+ dimension = num_constraints + 1
+ # Initialize by putting as much weight as possible on the objective, and as
+ # little as possible on the constraints.
+ log_initial_one = math.log(1.0 - (self._initial_multiplier_radius *
+ (dimension - 1) / (dimension)))
+ log_initial_zero = math.log(self._initial_multiplier_radius / dimension)
+ return standard_ops.concat(
+ (standard_ops.constant(
+ log_initial_one, dtype=dtypes.float32, shape=(1, dimension)),
+ standard_ops.constant(
+ log_initial_zero,
+ dtype=dtypes.float32,
+ shape=(dimension - 1, dimension))),
+ axis=0)
+
+ def _stochastic_matrix(self, state):
+ return standard_ops.exp(state)
+
+ def _constraint_grad_and_var(self, state, gradient):
+ # TODO(acotter): tf.colocate_with(), if colocate_gradients_with_ops is True?
+ return (-gradient, state)
+
+ def _projection_op(self, state, name=None):
+ with ops.colocate_with(state):
+ # Gets the dimension of the state (num_constraints + 1)--all of these
+ # assertions are of things that should be impossible, since the state
+ # passed into this method will have the same shape as that returned by
+ # _initial_state().
+ state_shape = state.get_shape()
+ assert state_shape is not None
+ assert state_shape.ndims == 2
+ assert state_shape[0] == state_shape[1]
+ dimension = state_shape[0].value
+ assert dimension is not None
+
+ minimum_log_multiplier = standard_ops.log(
+ self._minimum_multiplier_radius / standard_ops.to_float(dimension))
+
+ return state_ops.assign(
+ state,
+ standard_ops.maximum(
+ _project_log_stochastic_matrix_wrt_kl_divergence(state),
+ minimum_log_multiplier),
+ name=name)
diff --git a/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
new file mode 100644
index 0000000000..34c4543dca
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/swap_regret_optimizer_test.py
@@ -0,0 +1,212 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Tests for constrained_optimization.python.swap_regret_optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.constrained_optimization.python import swap_regret_optimizer
+from tensorflow.contrib.constrained_optimization.python import test_util
+
+from tensorflow.python.ops import standard_ops
+from tensorflow.python.platform import test
+from tensorflow.python.training import gradient_descent
+
+
+class AdditiveSwapRegretOptimizerWrapper(
+ swap_regret_optimizer.AdditiveSwapRegretOptimizer):
+ """Testing wrapper class around AdditiveSwapRegretOptimizer.
+
+ This class is identical to AdditiveSwapRegretOptimizer, except that it caches
+ the internal optimization state when _stochastic_matrix() is called, so that
+ we can test that the stochastic matrices take on their expected values.
+ """
+
+ def __init__(self, optimizer, constraint_optimizer=None):
+ """Same as AdditiveSwapRegretOptimizer.__init__()."""
+ super(AdditiveSwapRegretOptimizerWrapper, self).__init__(
+ optimizer=optimizer, constraint_optimizer=constraint_optimizer)
+ self._cached_stochastic_matrix = None
+
+ @property
+ def stochastic_matrix(self):
+ """Returns the cached stochastic matrix."""
+ return self._cached_stochastic_matrix
+
+ def _stochastic_matrix(self, state):
+ """Caches the internal state for testing."""
+ self._cached_stochastic_matrix = super(AdditiveSwapRegretOptimizerWrapper,
+ self)._stochastic_matrix(state)
+ return self._cached_stochastic_matrix
+
+
+class MultiplicativeSwapRegretOptimizerWrapper(
+ swap_regret_optimizer.MultiplicativeSwapRegretOptimizer):
+ """Testing wrapper class around MultiplicativeSwapRegretOptimizer.
+
+ This class is identical to MultiplicativeSwapRegretOptimizer, except that it
+ caches the internal optimization state when _stochastic_matrix() is called, so
+ that we can test that the stochastic matrices take on their expected values.
+ """
+
+ def __init__(self,
+ optimizer,
+ constraint_optimizer=None,
+ minimum_multiplier_radius=None,
+ initial_multiplier_radius=None):
+ """Same as MultiplicativeSwapRegretOptimizer.__init__()."""
+ super(MultiplicativeSwapRegretOptimizerWrapper, self).__init__(
+ optimizer=optimizer,
+ constraint_optimizer=constraint_optimizer,
+ minimum_multiplier_radius=1e-3,
+ initial_multiplier_radius=initial_multiplier_radius)
+ self._cached_stochastic_matrix = None
+
+ @property
+ def stochastic_matrix(self):
+ """Returns the cached stochastic matrix."""
+ return self._cached_stochastic_matrix
+
+ def _stochastic_matrix(self, state):
+ """Caches the internal state for testing."""
+ self._cached_stochastic_matrix = super(
+ MultiplicativeSwapRegretOptimizerWrapper,
+ self)._stochastic_matrix(state)
+ return self._cached_stochastic_matrix
+
+
+class SwapRegretOptimizerTest(test.TestCase):
+
+ def test_maximum_eigenvector_power_method(self):
+ """Tests power method routine on some known left-stochastic matrices."""
+ matrix1 = np.matrix([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9], [0.4, 0.3, 0.0]])
+ matrix2 = np.matrix([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5], [0.4, 0.5, 0.3]])
+
+ with self.test_session() as session:
+ eigenvector1 = session.run(
+ swap_regret_optimizer._maximal_eigenvector_power_method(
+ standard_ops.constant(matrix1)))
+ eigenvector2 = session.run(
+ swap_regret_optimizer._maximal_eigenvector_power_method(
+ standard_ops.constant(matrix2)))
+
+ # Check that eigenvector1 and eigenvector2 are eigenvectors of matrix1 and
+ # matrix2 (respectively) with associated eigenvalue 1.
+ matrix_eigenvector1 = np.tensordot(matrix1, eigenvector1, axes=1)
+ matrix_eigenvector2 = np.tensordot(matrix2, eigenvector2, axes=1)
+ self.assertAllClose(eigenvector1, matrix_eigenvector1, rtol=0, atol=1e-6)
+ self.assertAllClose(eigenvector2, matrix_eigenvector2, rtol=0, atol=1e-6)
+
+ def test_project_stochastic_matrix_wrt_euclidean_norm(self):
+ """Tests Euclidean projection routine on some known values."""
+ matrix = standard_ops.constant([[-0.1, -0.1, 0.4], [-0.8, 0.4, 1.2],
+ [-0.3, 0.1, 0.2]])
+ expected_projected_matrix = np.array([[0.6, 0.1, 0.1], [0.0, 0.6, 0.9],
+ [0.4, 0.3, 0.0]])
+
+ with self.test_session() as session:
+ projected_matrix = session.run(
+ swap_regret_optimizer._project_stochastic_matrix_wrt_euclidean_norm(
+ matrix))
+
+ self.assertAllClose(
+ expected_projected_matrix, projected_matrix, rtol=0, atol=1e-6)
+
+ def test_project_log_stochastic_matrix_wrt_kl_divergence(self):
+ """Tests KL-divergence projection routine on some known values."""
+ matrix = standard_ops.constant([[0.2, 0.8, 0.6], [0.1, 0.2, 1.5],
+ [0.2, 1.0, 0.9]])
+ expected_projected_matrix = np.array([[0.4, 0.4, 0.2], [0.2, 0.1, 0.5],
+ [0.4, 0.5, 0.3]])
+
+ with self.test_session() as session:
+ projected_matrix = session.run(
+ standard_ops.exp(
+ swap_regret_optimizer.
+ _project_log_stochastic_matrix_wrt_kl_divergence(
+ standard_ops.log(matrix))))
+
+ self.assertAllClose(
+ expected_projected_matrix, projected_matrix, rtol=0, atol=1e-6)
+
+ def test_additive_swap_regret_optimizer(self):
+ """Tests that the stochastic matrices update as expected."""
+ minimization_problem = test_util.ConstantMinimizationProblem(
+ np.array([0.6, -0.1, 0.4]))
+ optimizer = AdditiveSwapRegretOptimizerWrapper(
+ gradient_descent.GradientDescentOptimizer(1.0))
+ train_op = optimizer.minimize_constrained(minimization_problem)
+
+ # Calculated using a numpy+python implementation of the algorithm.
+ expected_matrices = [
+ np.array([[1.0, 1.0, 1.0, 1.0], [0.0, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]),
+ np.array([[0.66666667, 1.0, 1.0, 1.0], [0.26666667, 0.0, 0.0, 0.0],
+ [0.0, 0.0, 0.0, 0.0], [0.06666667, 0.0, 0.0, 0.0]]),
+ np.array([[0.41666667, 0.93333333, 1.0,
+ 0.98333333], [0.46666667, 0.05333333, 0.0,
+ 0.01333333], [0.0, 0.0, 0.0, 0.0],
+ [0.11666667, 0.01333333, 0.0, 0.00333333]]),
+ ]
+
+ matrices = []
+ with self.test_session() as session:
+ session.run(standard_ops.global_variables_initializer())
+ while len(matrices) < len(expected_matrices):
+ matrices.append(session.run(optimizer.stochastic_matrix))
+ session.run(train_op)
+
+ for expected, actual in zip(expected_matrices, matrices):
+ self.assertAllClose(expected, actual, rtol=0, atol=1e-6)
+
+ def test_multiplicative_swap_regret_optimizer(self):
+ """Tests that the stochastic matrices update as expected."""
+ minimization_problem = test_util.ConstantMinimizationProblem(
+ np.array([0.6, -0.1, 0.4]))
+ optimizer = MultiplicativeSwapRegretOptimizerWrapper(
+ gradient_descent.GradientDescentOptimizer(1.0),
+ initial_multiplier_radius=0.8)
+ train_op = optimizer.minimize_constrained(minimization_problem)
+
+ # Calculated using a numpy+python implementation of the algorithm.
+ expected_matrices = [
+ np.array([[0.4, 0.4, 0.4, 0.4], [0.2, 0.2, 0.2, 0.2],
+ [0.2, 0.2, 0.2, 0.2], [0.2, 0.2, 0.2, 0.2]]),
+ np.array([[0.36999014, 0.38528351, 0.38528351, 0.38528351], [
+ 0.23517483, 0.21720297, 0.21720297, 0.21720297
+ ], [0.17774131, 0.18882719, 0.18882719, 0.18882719],
+ [0.21709373, 0.20868632, 0.20868632, 0.20868632]]),
+ np.array([[0.33972109, 0.36811863, 0.37118462, 0.36906575], [
+ 0.27114826, 0.23738228, 0.23376693, 0.23626491
+ ], [0.15712313, 0.17641793, 0.17858959, 0.17708679],
+ [0.23200752, 0.21808115, 0.21645886, 0.21758255]]),
+ ]
+
+ matrices = []
+ with self.test_session() as session:
+ session.run(standard_ops.global_variables_initializer())
+ while len(matrices) < len(expected_matrices):
+ matrices.append(session.run(optimizer.stochastic_matrix))
+ session.run(train_op)
+
+ for expected, actual in zip(expected_matrices, matrices):
+ self.assertAllClose(expected, actual, rtol=0, atol=1e-6)
+
+
+if __name__ == '__main__':
+ test.main()
diff --git a/tensorflow/contrib/constrained_optimization/python/test_util.py b/tensorflow/contrib/constrained_optimization/python/test_util.py
new file mode 100644
index 0000000000..704b36ca4c
--- /dev/null
+++ b/tensorflow/contrib/constrained_optimization/python/test_util.py
@@ -0,0 +1,58 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Contains helpers used by tests."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.constrained_optimization.python import constrained_minimization_problem
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import standard_ops
+
+
+class ConstantMinimizationProblem(
+ constrained_minimization_problem.ConstrainedMinimizationProblem):
+ """A `ConstrainedMinimizationProblem` with constant constraint violations.
+
+ This minimization problem is intended for use in performing simple tests of
+ the Lagrange multiplier (or equivalent) update in the optimizers. There is a
+ one-element "dummy" model parameter, but it should be ignored.
+ """
+
+ def __init__(self, constraints):
+ """Constructs a new `ConstantMinimizationProblem'.
+
+ Args:
+ constraints: 1d numpy array, the constant constraint violations.
+
+ Returns:
+ A new `ConstantMinimizationProblem'.
+ """
+ # We make an fake 1-parameter linear objective so that we don't get a "no
+ # variables to optimize" error.
+ self._objective = standard_ops.Variable(0.0, dtype=dtypes.float32)
+ self._constraints = standard_ops.constant(constraints, dtype=dtypes.float32)
+
+ @property
+ def objective(self):
+ """Returns the objective function."""
+ return self._objective
+
+ @property
+ def constraints(self):
+ """Returns the constant constraint violations."""
+ return self._constraints
diff --git a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
index 6fb56b0858..012b17cee8 100644
--- a/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
+++ b/tensorflow/contrib/cudnn_rnn/python/kernel_tests/cudnn_rnn_test.py
@@ -1072,6 +1072,17 @@ class CudnnRNNTestParamsSize(test_util.TensorFlowTestCase):
class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
+ def setUp(self):
+ super(CudnnRNNTestTraining, self).setUp()
+ self._reset_rnd_gen_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE",
+ str(False))
+ self._rnn_use_v2 = os.environ.get("TF_CUDNN_RNN_USE_V2", "0")
+
+ def tearDown(self):
+ super(CudnnRNNTestTraining, self).tearDown()
+ os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = self._reset_rnd_gen_state
+ os.environ["TF_CUDNN_RNN_USE_V2"] = self._rnn_use_v2
+
def _ComputeNumericGrad(self, sess, y, x, delta=1e-4, step=1):
"""Compute the numeric gradient of y wrt to x.
@@ -1184,11 +1195,10 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
def _TestOneSimpleTraining(self, rnn_mode, num_layers, num_units, input_size,
batch_size, seq_length, dir_count, dropout, dtype,
- delta, tolerance):
+ use_v2, delta, tolerance):
# Gradient checking runs two forward ops with almost the same input. Need to
# make sure the drop patterns across the two runs are the same.
logging.info("Training test with config: %s", locals())
- old_env_state = os.environ.get("TF_CUDNN_RESET_RND_GEN_STATE", str(False))
os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = str(True)
np.random.seed(1234)
@@ -1196,6 +1206,10 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
has_input_c = (rnn_mode == CUDNN_LSTM)
direction = (CUDNN_RNN_UNIDIRECTION
if dir_count == 1 else CUDNN_RNN_BIDIRECTION)
+ if use_v2:
+ os.environ["TF_CUDNN_RNN_USE_V2"] = "1"
+ else:
+ os.environ["TF_CUDNN_RNN_USE_V2"] = "0"
model = CudnnTestModel(
rnn_mode,
num_layers,
@@ -1245,22 +1259,22 @@ class CudnnRNNTestTraining(test_util.TensorFlowTestCase):
self._GradientCheck(
sess, total_sum, all_inputs,
tolerance=tolerance, delta=delta)
- os.environ["TF_CUDNN_RESET_RND_GEN_STATE"] = old_env_state
def _TestSimpleTrainingHelper(self, rnn_mode, test_configs):
dropouts = [0, 0.5, 1.]
- for config, dropout in itertools.product(test_configs, dropouts):
+ v2_options = [str(False), str(True)]
+ for config, dropout, use_v2 in itertools.product(test_configs, dropouts,
+ v2_options):
dtype = config.get("dtype", dtypes.float32)
delta = config.get("delta", 1e-4)
tolerance = config.get("tolerance", 1e-6)
dir_count = config.get("dir_count", 1)
shape = config["shape"]
with ops.Graph().as_default():
- self._TestOneSimpleTraining(rnn_mode, shape["num_layers"],
- shape["num_units"], shape["input_size"],
- shape["batch_size"], shape["seq_length"],
- dir_count, dropout, dtype, delta,
- tolerance)
+ self._TestOneSimpleTraining(
+ rnn_mode, shape["num_layers"], shape["num_units"],
+ shape["input_size"], shape["batch_size"], shape["seq_length"],
+ dir_count, dropout, dtype, use_v2, delta, tolerance)
@unittest.skipUnless(test.is_built_with_cuda(),
"Test only applicable when running on GPUs")
diff --git a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
index a1ede4471e..73a961992e 100644
--- a/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
+++ b/tensorflow/contrib/cudnn_rnn/python/ops/cudnn_rnn_ops.py
@@ -17,6 +17,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import os
from tensorflow.contrib.checkpoint.python import split_dependency
from tensorflow.contrib.rnn.python.ops import lstm_ops
from tensorflow.python.framework import common_shapes
@@ -901,19 +902,27 @@ def _cudnn_rnn(inputs,
check_direction(direction)
check_input_mode(input_mode)
seed, seed2 = random_seed.get_seed(seed)
- outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(
- input=inputs,
- input_h=input_h,
- input_c=input_c,
- params=params,
- is_training=is_training,
- rnn_mode=rnn_mode,
- input_mode=input_mode,
- direction=direction,
- dropout=dropout,
- seed=seed,
- seed2=seed2,
- name=name)
+ # TODO(jamesqin): switch default value to "1" on May 25th 2018, and get rid
+ # of V1 ops.
+ use_cudnn_v2 = os.environ.get("TF_CUDNN_RNN_USE_V2", "0")
+ args = {
+ "input": inputs,
+ "input_h": input_h,
+ "input_c": input_c,
+ "params": params,
+ "is_training": is_training,
+ "rnn_mode": rnn_mode,
+ "input_mode": input_mode,
+ "direction": direction,
+ "dropout": dropout,
+ "seed": seed,
+ "seed2": seed2,
+ "name": name
+ }
+ if use_cudnn_v2 is not "1":
+ outputs, output_h, output_c, _ = gen_cudnn_rnn_ops.cudnn_rnn(**args)
+ else:
+ outputs, output_h, output_c, _, _ = gen_cudnn_rnn_ops.cudnn_rnnv2(**args)
return (outputs, output_h, output_c)
diff --git a/tensorflow/contrib/distribute/python/estimator_integration_test.py b/tensorflow/contrib/distribute/python/estimator_integration_test.py
index c5a520ab5a..34410a6470 100644
--- a/tensorflow/contrib/distribute/python/estimator_integration_test.py
+++ b/tensorflow/contrib/distribute/python/estimator_integration_test.py
@@ -61,7 +61,8 @@ class DNNLinearCombinedClassifierIntegrationTest(test.TestCase,
mode=['graph'],
distribution=[
combinations.one_device_strategy,
- combinations.mirrored_strategy_with_gpu_and_cpu
+ combinations.mirrored_strategy_with_gpu_and_cpu,
+ combinations.mirrored_strategy_with_two_gpus
]))
def test_complete_flow_with_mode(self, distribution):
label_dimension = 2
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/BUILD b/tensorflow/contrib/eager/python/examples/resnet50/BUILD
index 536cad998d..0c0e28dd95 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/BUILD
+++ b/tensorflow/contrib/eager/python/examples/resnet50/BUILD
@@ -14,6 +14,17 @@ py_library(
],
)
+py_library(
+ name = "resnet50_test_lib",
+ srcs = ["resnet50_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":resnet50",
+ "//tensorflow:tensorflow_py",
+ "//tensorflow/contrib/eager/python:tfe",
+ ],
+)
+
cuda_py_test(
name = "resnet50_test",
size = "large",
diff --git a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
index d6923293a3..8517a3bf7b 100644
--- a/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
+++ b/tensorflow/contrib/eager/python/examples/resnet50/resnet50_test.py
@@ -36,8 +36,8 @@ def device_and_data_format():
'channels_last')
-def random_batch(batch_size):
- _, data_format = device_and_data_format()
+def random_batch(batch_size, device_and_format=None):
+ _, data_format = device_and_format or device_and_data_format()
shape = (3, 224, 224) if data_format == 'channels_first' else (224, 224, 3)
shape = (batch_size,) + shape
@@ -169,7 +169,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
def _train_batch_sizes(self):
"""Choose batch sizes based on GPU capability."""
for device in device_lib.list_local_devices():
- if 'GPU:0' in device.name:
+ if tf.DeviceSpec.from_string(device.name).device_type == 'GPU':
# Avoid OOM errors with larger batch sizes, which seem to cause errors
# later on even if caught.
#
@@ -180,26 +180,32 @@ class ResNet50Benchmarks(tf.test.Benchmark):
return (16,)
if 'P100' in device.physical_device_desc:
return (16, 32, 64)
+
+ if tf.DeviceSpec.from_string(device.name).device_type == 'TPU':
+ # TODO(iga): Training fails with batch size of 16, probably because of
+ # no layout optimizations with op-by-op mode. Investigate more.
+ return (8,)
return (16, 32)
def _report(self, label, start, num_iters, device, batch_size, data_format):
avg_time = (time.time() - start) / num_iters
- dev = 'cpu' if 'cpu' in device else 'gpu'
+ dev = tf.DeviceSpec.from_string(device).device_type.lower()
name = '%s_%s_batch_%d_%s' % (label, dev, batch_size, data_format)
extras = {'examples_per_sec': batch_size / avg_time}
self.report_benchmark(
iters=num_iters, wall_time=avg_time, name=name, extras=extras)
- def _force_gpu_sync(self):
- # If this function is called in the context of a GPU device
+ def _force_device_sync(self):
+ # If this function is called in the context of a non-CPU device
# (e.g., inside a 'with tf.device("/gpu:0")' block)
- # then this will force a copy from CPU->GPU->CPU, which forces
- # a sync. This is a roundabout way, yes.
+ # then this will force a copy from CPU->NON_CPU_DEVICE->CPU,
+ # which forces a sync. This is a roundabout way, yes.
tf.constant(1.).cpu()
- def _benchmark_eager_apply(self, label, defun=False, execution_mode=None):
+ def _benchmark_eager_apply(self, label, defun=False, execution_mode=None,
+ device_and_format=None):
with tfe.execution_mode(execution_mode):
- device, data_format = device_and_data_format()
+ device, data_format = device_and_format or device_and_data_format()
model = resnet50.ResNet50(data_format)
if defun:
model.call = tfe.defun(model.call)
@@ -207,7 +213,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
num_burn = 5
num_iters = 30
with tf.device(device):
- images, _ = random_batch(batch_size)
+ images, _ = random_batch(batch_size, device_and_format)
for _ in xrange(num_burn):
model(images, training=False).cpu()
if execution_mode:
@@ -220,7 +226,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
tfe.async_wait()
self._report(label, start, num_iters, device, batch_size, data_format)
- def benchmark_eager_apply(self):
+ def benchmark_eager_apply_sync(self):
self._benchmark_eager_apply('eager_apply', defun=False)
def benchmark_eager_apply_async(self):
@@ -234,11 +240,12 @@ class ResNet50Benchmarks(tf.test.Benchmark):
label,
make_iterator,
defun=False,
- execution_mode=None):
+ execution_mode=None,
+ device_and_format=None):
with tfe.execution_mode(execution_mode):
- device, data_format = device_and_data_format()
+ device, data_format = device_and_format or device_and_data_format()
for batch_size in self._train_batch_sizes():
- (images, labels) = random_batch(batch_size)
+ (images, labels) = random_batch(batch_size, device_and_format)
num_burn = 3
num_iters = 10
model = resnet50.ResNet50(data_format)
@@ -253,7 +260,7 @@ class ResNet50Benchmarks(tf.test.Benchmark):
train_one_step(model, images, labels, optimizer)
if execution_mode:
tfe.async_wait()
- self._force_gpu_sync()
+ self._force_device_sync()
gc.collect()
start = time.time()
@@ -262,10 +269,10 @@ class ResNet50Benchmarks(tf.test.Benchmark):
train_one_step(model, images, labels, optimizer)
if execution_mode:
tfe.async_wait()
- self._force_gpu_sync()
+ self._force_device_sync()
self._report(label, start, num_iters, device, batch_size, data_format)
- def benchmark_eager_train(self):
+ def benchmark_eager_train_sync(self):
self._benchmark_eager_train('eager_train', MockIterator, defun=False)
def benchmark_eager_train_async(self):
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 62ddb3d290..b473de86ee 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -367,6 +367,7 @@ py_library(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:state_ops",
"//tensorflow/python:training",
+ "//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/estimator:export_output",
"//tensorflow/python/estimator:model_fn",
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
index a8774d6dab..f8564446e5 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -47,8 +47,12 @@ from tensorflow.python.ops.losses import losses
from tensorflow.python.platform import tf_logging
from tensorflow.python.training import device_setter as device_setter_lib
from tensorflow.python.training import optimizer as optimizer_lib
+from tensorflow.python.util import deprecation
+@deprecation.deprecated(
+ '2018-05-31',
+ 'Please use `tf.contrib.distribute.MirroredStrategy` instead.')
def replicate_model_fn(model_fn,
loss_reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS,
devices=None):
@@ -255,6 +259,9 @@ class TowerOptimizer(optimizer_lib.Optimizer):
COLLECTION_FOR_GRAPH_STATES = 'replicate_model_fn_graph_states'
+ @deprecation.deprecated(
+ '2018-05-31',
+ 'Please use `tf.contrib.distribute.MirroredStrategy` instead.')
def __init__(self, optimizer_or_optimizer_fn):
"""Wrap an existing optimizer for gathering gradients across towers.
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index 144b45982c..dd8a3a95f1 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -540,59 +540,6 @@ class ReplicateAcrossASingleDeviceWithoutTowerOptimizer(
self.assertEqual(7.0, session.run(c))
-class UseTowerEstimatorWithoutReplication(test_util.TensorFlowTestCase):
-
- def model_fn(self, mode, features, labels, params):
- c = variable_scope.get_variable(
- 'c',
- initializer=constant_op.constant(10, dtype=dtypes.float64),
- dtype=dtypes.float64)
-
- features = features['features']
- predictions = math_ops.multiply(features, c)
-
- loss = losses.absolute_difference(
- labels=labels, predictions=predictions, reduction=losses.Reduction.SUM)
- loss = math_ops.reduce_sum(loss)
-
- metrics = {
- 'accuracy': metrics_lib.accuracy(labels, predictions),
- 'auc': metrics_lib.auc(labels, predictions)
- }
-
- optimizer = replicate_model_fn.TowerOptimizer(
- gradient_descent.GradientDescentOptimizer(params['learning_rate']))
-
- return model_fn_lib.EstimatorSpec(
- mode=mode,
- loss=loss,
- eval_metric_ops=metrics,
- predictions={'probabilities': predictions},
- train_op=optimizer.minimize(loss))
-
- @property
- def params(self):
- params = {}
- params['learning_rate'] = 1.0
- return params
-
- def test_train_single_tower(self):
- features = np.array([[1.0], [2.0]])
- labels = np.array([[1.0], [2.0]])
-
- train_input_fn = numpy_io.numpy_input_fn(
- x={'features': features}, y=labels, batch_size=2, shuffle=False)
-
- with self.test_session():
- estimator = estimator_lib.Estimator(
- model_fn=self.model_fn,
- model_dir=tempfile.mkdtemp(),
- params=self.params)
- estimator.train(train_input_fn, steps=1)
-
- self.assertEqual(7.0, estimator.get_variable_value('c'))
-
-
class MakeSureSyncReplicasOptimizerWorks(test_util.TensorFlowTestCase):
def model_fn(self, mode, features, labels, params):
diff --git a/tensorflow/contrib/factorization/BUILD b/tensorflow/contrib/factorization/BUILD
index 0a648d5d40..f28d95401c 100644
--- a/tensorflow/contrib/factorization/BUILD
+++ b/tensorflow/contrib/factorization/BUILD
@@ -215,6 +215,7 @@ tf_py_test(
"//tensorflow/python:platform_test",
"//tensorflow/python:sparse_tensor",
],
+ tags = ["noasan"], # times out b/78588193
)
# Estimators tests
diff --git a/tensorflow/contrib/factorization/kernels/clustering_ops.cc b/tensorflow/contrib/factorization/kernels/clustering_ops.cc
index 2a6c97e8b9..025534d540 100644
--- a/tensorflow/contrib/factorization/kernels/clustering_ops.cc
+++ b/tensorflow/contrib/factorization/kernels/clustering_ops.cc
@@ -32,6 +32,7 @@
#include "tensorflow/core/lib/gtl/top_n.h"
#include "tensorflow/core/lib/random/philox_random.h"
#include "tensorflow/core/lib/random/simple_philox.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
index 35341406a0..cca1a05419 100644
--- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
+++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
@@ -28,7 +28,7 @@
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
-#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/env.h"
using tensorflow::strings::StrCat;
diff --git a/tensorflow/contrib/framework/BUILD b/tensorflow/contrib/framework/BUILD
index b1c8ad49ea..249debbdf6 100644
--- a/tensorflow/contrib/framework/BUILD
+++ b/tensorflow/contrib/framework/BUILD
@@ -93,6 +93,7 @@ tf_kernel_library(
],
deps = [
"//tensorflow/core:framework",
+ "//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
],
alwayslink = 1,
@@ -177,6 +178,8 @@ cuda_py_test(
"//tensorflow/python:platform_test",
"//tensorflow/python:resource_variable_ops",
"//tensorflow/python:tensor_array_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/eager:context",
],
)
diff --git a/tensorflow/contrib/framework/__init__.py b/tensorflow/contrib/framework/__init__.py
index 11397e86bd..10d1ecc738 100644
--- a/tensorflow/contrib/framework/__init__.py
+++ b/tensorflow/contrib/framework/__init__.py
@@ -108,6 +108,7 @@ from __future__ import print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.framework.python.framework import *
+from tensorflow.contrib.framework.python.framework import nest
from tensorflow.contrib.framework.python.ops import *
# pylint: enable=unused-import,wildcard-import
@@ -126,5 +127,20 @@ from tensorflow.python.ops.init_ops import convolutional_orthogonal_3d
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['nest', 'broadcast_to']
-
+_nest_allowed_symbols = [
+ 'assert_same_structure',
+ 'is_sequence',
+ 'flatten',
+ 'flatten_dict_items',
+ 'pack_sequence_as',
+ 'map_structure',
+ 'assert_shallow_structure',
+ 'flatten_up_to',
+ 'map_structure_up_to',
+ 'get_traverse_shallow_structure',
+ 'yield_flat_paths',
+ 'flatten_with_joined_string_paths',
+]
+
+remove_undocumented(nest.__name__, allowed_exception_list=_nest_allowed_symbols)
remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc
index 5bf6b67529..6ab3f460b3 100644
--- a/tensorflow/contrib/framework/kernels/zero_initializer_op.cc
+++ b/tensorflow/contrib/framework/kernels/zero_initializer_op.cc
@@ -23,6 +23,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/resource_var.h"
namespace tensorflow {
@@ -85,4 +86,74 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
#undef REGISTER_KERNELS
+template <typename Device, typename T>
+class ZeroVarInitializer : public OpKernel {
+ public:
+ explicit ZeroVarInitializer(OpKernelConstruction* ctx) : OpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dtype", &dtype_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("shape", &shape_));
+ }
+
+ void Compute(OpKernelContext* ctx) override {
+ Var* variable = nullptr;
+ OP_REQUIRES_OK(ctx, LookupOrCreateResource<Var>(
+ ctx, HandleFromInput(ctx, 0), &variable,
+ [this, ctx](Var** var_ptr) {
+ *var_ptr = new Var(dtype_);
+ PersistentTensor unused;
+ Tensor* var_tensor = nullptr;
+ AllocatorAttributes attr;
+ attr.set_gpu_compatible(true);
+ attr.set_nic_compatible(true);
+ TF_RETURN_IF_ERROR(ctx->allocate_persistent(
+ dtype_, shape_, &unused, &var_tensor, attr));
+
+ functor::TensorSetZero<Device, T>()(
+ ctx->eigen_device<Device>(),
+ var_tensor->flat<T>());
+
+ *(*var_ptr)->tensor() = *var_tensor;
+
+ return Status::OK();
+ }));
+
+ core::ScopedUnref scoped(variable);
+ mutex_lock ml(*variable->mu());
+
+ OP_REQUIRES(ctx, !variable->is_initialized,
+ errors::InvalidArgument("input is already initialized"));
+
+ variable->is_initialized = true;
+
+ Tensor* output = nullptr;
+ OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &output));
+ output->scalar<ResourceHandle>()() = HandleFromInput(ctx, 0);
+ }
+
+ private:
+ DataType dtype_;
+ TensorShape shape_;
+};
+
+#define REGISTER_CPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("ZeroVarInitializer") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<type>("dtype"), \
+ ZeroVarInitializer<Eigen::ThreadPoolDevice, type>);
+
+TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
+#undef REGISTER_CPU_KERNELS
+
+#if GOOGLE_CUDA
+#define REGISTER_GPU_KERNELS(type) \
+ REGISTER_KERNEL_BUILDER(Name("ZeroVarInitializer") \
+ .Device(DEVICE_GPU) \
+ .TypeConstraint<type>("dtype") \
+ .HostMemory("var"), \
+ ZeroVarInitializer<GPUDevice, type>);
+
+TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_KERNELS);
+#undef REGISTER_GPU_KERNELS
+#endif // GOOGLE_CUDA
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/framework/ops/variable_ops.cc b/tensorflow/contrib/framework/ops/variable_ops.cc
index 706134ba9a..f6ee6cdb57 100644
--- a/tensorflow/contrib/framework/ops/variable_ops.cc
+++ b/tensorflow/contrib/framework/ops/variable_ops.cc
@@ -39,4 +39,33 @@ ref: Should be from a `Variable` node.
output_ref:= Same as "ref".
)doc");
+REGISTER_OP("ZeroVarInitializer")
+ .Input("var: resource")
+ .Output("output_var: resource")
+ .Attr("dtype: type")
+ .Attr("shape: shape")
+ .SetAllowsUninitializedInput()
+ .SetShapeFn([](InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ DataType t;
+ TF_RETURN_IF_ERROR(c->GetAttr("dtype", &t));
+ PartialTensorShape p;
+ TF_RETURN_IF_ERROR(c->GetAttr("shape", &p));
+ shape_inference::ShapeHandle s;
+ TF_RETURN_IF_ERROR(c->MakeShapeFromPartialTensorShape(p, &s));
+ c->set_output_handle_shapes_and_types(
+ 0, std::vector<shape_inference::ShapeAndType>{{s, t}});
+
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Initialize 'var' with all zeros. This op requires that the resource var is not
+initialized. The var will first be allocated memory, then be filled with all
+zeros. This op is intended to save memory during initialization,
+if you use this op, you should not run initializer of the var.
+
+var: Should be a ResourceVariable.
+output_var:= Same as "var".
+)doc");
+
} // namespace tensorflow
diff --git a/tensorflow/contrib/framework/python/ops/critical_section_test.py b/tensorflow/contrib/framework/python/ops/critical_section_test.py
index ba660295cb..df7d7e9dae 100644
--- a/tensorflow/contrib/framework/python/ops/critical_section_test.py
+++ b/tensorflow/contrib/framework/python/ops/critical_section_test.py
@@ -19,6 +19,8 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import critical_section_ops
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.eager import context
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
@@ -330,6 +332,25 @@ class CriticalSectionTest(test.TestCase):
self.evaluate(v.initializer)
self.assertEqual(10, self.evaluate(out))
+ @test_util.run_in_graph_and_eager_modes()
+ def testInsideFunction(self):
+ cs = critical_section_ops.CriticalSection()
+ v = resource_variable_ops.ResourceVariable(1)
+ def fn():
+ return v.read_value()
+
+ # map() creates a TensorFlow function.
+ ds = dataset_ops.Dataset.range(1).map(lambda _: cs.execute(fn))
+
+ def get_first():
+ if context.executing_eagerly():
+ return self.evaluate(ds.make_one_shot_iterator().get_next())
+ itr = ds.make_initializable_iterator()
+ self.evaluate([v.initializer, itr.initializer])
+ return self.evaluate(itr.get_next())
+
+ self.assertEqual(1, get_first())
+
# TODO(ebrevdo): Re-enable once CriticalSection is in core.
#
# def testCriticalSectionAndExecuteOpSaverRoundTrip(self):
diff --git a/tensorflow/contrib/framework/python/ops/variables.py b/tensorflow/contrib/framework/python/ops/variables.py
index 0754c3e0e3..40ae01bfcc 100644
--- a/tensorflow/contrib/framework/python/ops/variables.py
+++ b/tensorflow/contrib/framework/python/ops/variables.py
@@ -32,6 +32,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import tf_logging as logging
@@ -82,7 +83,12 @@ def zero_initializer(ref, use_locking=True, name="zero_initializer"):
"""
loader.load_op_library(
resource_loader.get_path_to_datafile("_variable_ops.so"))
- return gen_variable_ops.zero_initializer(ref, name=name)
+ if resource_variable_ops.is_resource_variable(ref):
+ return gen_variable_ops.zero_var_initializer(
+ ref.handle, shape=ref.shape, dtype=ref.dtype, name=name)
+ else:
+ return gen_variable_ops.zero_initializer(ref, name=name)
+
@deprecated(None, "Please switch to tf.train.assert_global_step")
def assert_global_step(global_step_tensor):
diff --git a/tensorflow/contrib/framework/python/ops/variables_test.py b/tensorflow/contrib/framework/python/ops/variables_test.py
index 2f06df93ac..37ea6eb12a 100644
--- a/tensorflow/contrib/framework/python/ops/variables_test.py
+++ b/tensorflow/contrib/framework/python/ops/variables_test.py
@@ -1284,6 +1284,32 @@ class ZeroInitializerOpTest(test.TestCase):
[10, 20], dtype=dtype), use_init)
+class ZeroVarInitializerOpTest(test.TestCase):
+
+ def _testZeroVarInitializer(self, shape, initializer, use_init):
+ var = resource_variable_ops.ResourceVariable(initializer)
+ var_zero = variables_lib2.zero_initializer(var)
+
+ with self.test_session() as sess:
+ with self.assertRaisesOpError('Error while reading resource variable'):
+ var.eval()
+ if use_init:
+ sess.run(var.initializer)
+ with self.assertRaisesOpError('input is already initialized'):
+ var_zero.eval()
+ self.assertAllClose(np.ones(shape), var.eval())
+ else:
+ var_zero.eval()
+ self.assertAllClose(np.zeros(shape), var.eval())
+
+ def testZeroVarInitializer(self):
+ for dtype in (dtypes.int32, dtypes.int64, dtypes.float32, dtypes.float64):
+ for use_init in (False, True):
+ self._testZeroVarInitializer([10, 20],
+ array_ops.ones([10, 20], dtype=dtype),
+ use_init)
+
+
class FilterVariablesTest(test.TestCase):
def setUp(self):
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 1e8f011b5d..2458f7554a 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -247,7 +247,7 @@ class FusedConv2DBiasActivationOp : public OpKernel {
};
#if GOOGLE_CUDA
-namespace dnn = ::perftools::gputools::dnn;
+namespace dnn = se::dnn;
// A dummy type to group forward convolution autotune results together.
struct ConvBiasActivationAutoTuneGroup {
diff --git a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
index 47e51415fd..d914f54945 100644
--- a/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
+++ b/tensorflow/contrib/gan/python/eval/python/classifier_metrics_impl.py
@@ -488,25 +488,25 @@ def frechet_classifier_distance(real_images,
The Frechet Inception distance. A floating-point scalar of the same type
as the output of `classifier_fn`.
"""
-
real_images_list = array_ops.split(
real_images, num_or_size_splits=num_batches)
generated_images_list = array_ops.split(
generated_images, num_or_size_splits=num_batches)
- imgs = array_ops.stack(real_images_list + generated_images_list)
+ real_imgs = array_ops.stack(real_images_list)
+ generated_imgs = array_ops.stack(generated_images_list)
# Compute the activations using the memory-efficient `map_fn`.
- activations = functional_ops.map_fn(
- fn=classifier_fn,
- elems=imgs,
- parallel_iterations=1,
- back_prop=False,
- swap_memory=True,
- name='RunClassifier')
+ def compute_activations(elems):
+ return functional_ops.map_fn(fn=classifier_fn,
+ elems=elems,
+ parallel_iterations=1,
+ back_prop=False,
+ swap_memory=True,
+ name='RunClassifier')
- # Split the activations by the real and generated images.
- real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)
+ real_a = compute_activations(real_imgs)
+ gen_a = compute_activations(generated_imgs)
# Ensure the activations have the right shapes.
real_a = array_ops.concat(array_ops.unstack(real_a), 0)
@@ -697,18 +697,20 @@ def frechet_classifier_distance_from_activations(real_activations,
# Compute mean and covariance matrices of activations.
m = math_ops.reduce_mean(real_activations, 0)
m_w = math_ops.reduce_mean(generated_activations, 0)
- num_examples = math_ops.to_double(array_ops.shape(real_activations)[0])
+ num_examples_real = math_ops.to_double(array_ops.shape(real_activations)[0])
+ num_examples_generated = math_ops.to_double(
+ array_ops.shape(generated_activations)[0])
# sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
real_centered = real_activations - m
sigma = math_ops.matmul(
real_centered, real_centered, transpose_a=True) / (
- num_examples - 1)
+ num_examples_real - 1)
gen_centered = generated_activations - m_w
sigma_w = math_ops.matmul(
gen_centered, gen_centered, transpose_a=True) / (
- num_examples - 1)
+ num_examples_generated - 1)
# Find the Tr(sqrt(sigma sigma_w)) component of FID
sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)
diff --git a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc
index 645abbf0b0..bbb3a3b18f 100644
--- a/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc
+++ b/tensorflow/contrib/image/kernels/adjust_hsv_in_yiq_op_gpu.cu.cc
@@ -59,7 +59,7 @@ void AdjustHsvInYiqGPU::operator()(OpKernelContext* ctx, int channel_count,
delta_h, scale_s, scale_v, tranformation_matrix.flat<float>().data(),
tranformation_matrix.flat<float>().size());
// Call cuBlas C = A * B directly.
- auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+ auto no_transpose = se::blas::Transpose::kNoTranspose;
auto a_ptr =
AsDeviceMemory(input->flat<float>().data(), input->flat<float>().size());
auto b_ptr = AsDeviceMemory(tranformation_matrix.flat<float>().data(),
diff --git a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
index 2a3592c53f..432b67e569 100644
--- a/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
+++ b/tensorflow/contrib/kfac/python/kernel_tests/fisher_factors_test.py
@@ -814,6 +814,21 @@ class ConvInputKroneckerFactorTest(ConvFactorTestCase):
new_cov = sess.run(factor.make_covariance_update_op(0.))
self.assertAllClose([[(1. + 4.) / 2.]], new_cov)
+ def testSubSample(self):
+ with tf_ops.Graph().as_default():
+ patches_1 = array_ops.constant(1, shape=(10, 2))
+ patches_2 = array_ops.constant(1, shape=(10, 8))
+ patches_3 = array_ops.constant(1, shape=(3, 3))
+ patches_1_sub = ff._subsample_for_cov_computation(patches_1)
+ patches_2_sub = ff._subsample_for_cov_computation(patches_2)
+ patches_3_sub = ff._subsample_for_cov_computation(patches_3)
+ patches_1_sub_batch_size = patches_1_sub.shape.as_list()[0]
+ patches_2_sub_batch_size = patches_2_sub.shape.as_list()[0]
+ patches_3_sub_batch_size = patches_3_sub.shape.as_list()[0]
+ self.assertEqual(2, patches_1_sub_batch_size)
+ self.assertEqual(8, patches_2_sub_batch_size)
+ self.assertEqual(3, patches_3_sub_batch_size)
+
class ConvOutputKroneckerFactorTest(ConvFactorTestCase):
diff --git a/tensorflow/contrib/kfac/python/ops/BUILD b/tensorflow/contrib/kfac/python/ops/BUILD
index b897fd68a0..cb0917bb85 100644
--- a/tensorflow/contrib/kfac/python/ops/BUILD
+++ b/tensorflow/contrib/kfac/python/ops/BUILD
@@ -37,10 +37,13 @@ py_library(
deps = [
":utils",
"//tensorflow/python:array_ops",
+ "//tensorflow/python:control_flow_ops",
+ "//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:linalg_ops",
"//tensorflow/python:math_ops",
+ "//tensorflow/python:random_ops",
"//tensorflow/python:special_math_ops",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index 0d40d265a1..7988a3b92b 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -32,6 +32,7 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import linalg_ops
from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@@ -55,6 +56,22 @@ EIGENVALUE_DECOMPOSITION_THRESHOLD = 2
# matrix powers. Must be nonnegative.
EIGENVALUE_CLIPPING_THRESHOLD = 0.0
+# Used to subsample the flattened extracted image patches. The number of
+# outer products per row of the covariance matrix should not exceed this
+# value. This parameter is used only if `_SUB_SAMPLE_OUTER_PRODUCTS` is True.
+_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = 1
+
+# Used to subsample the inputs passed to the extract image patches. The batch
+# size of number of inputs to extract image patches is multiplied by this
+# factor. This parameter is used only if `_SUB_SAMPLE_INPUTS` is True.
+_INPUTS_TO_EXTRACT_PATCHES_FACTOR = 0.5
+
+# If True, then subsamples the tensor passed to compute the covaraince matrix.
+_SUB_SAMPLE_OUTER_PRODUCTS = False
+
+# If True, then subsamples the tensor passed to compute the covaraince matrix.
+_SUB_SAMPLE_INPUTS = False
+
# TOWER_STRATEGY can be one of "concat" or "separate". If "concat", the data
# passed to the factors from the blocks will be concatenated across towers
# (lazilly via PartitionedTensor objects). Otherwise a tuple of tensors over
@@ -67,12 +84,20 @@ def set_global_constants(init_covariances_at_zero=None,
zero_debias=None,
eigenvalue_decomposition_threshold=None,
eigenvalue_clipping_threshold=None,
+ max_num_outer_products_per_cov_row=None,
+ sub_sample_outer_products=None,
+ inputs_to_extract_patches_factor=None,
+ sub_sample_inputs=None,
tower_strategy=None):
"""Sets various global constants used by the classes in this module."""
global INIT_COVARIANCES_AT_ZERO
global ZERO_DEBIAS
global EIGENVALUE_DECOMPOSITION_THRESHOLD
global EIGENVALUE_CLIPPING_THRESHOLD
+ global _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW
+ global _SUB_SAMPLE_OUTER_PRODUCTS
+ global _INPUTS_TO_EXTRACT_PATCHES_FACTOR
+ global _SUB_SAMPLE_INPUTS
global TOWER_STRATEGY
if init_covariances_at_zero is not None:
@@ -83,6 +108,14 @@ def set_global_constants(init_covariances_at_zero=None,
EIGENVALUE_DECOMPOSITION_THRESHOLD = eigenvalue_decomposition_threshold
if eigenvalue_clipping_threshold is not None:
EIGENVALUE_CLIPPING_THRESHOLD = eigenvalue_clipping_threshold
+ if max_num_outer_products_per_cov_row is not None:
+ _MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW = max_num_outer_products_per_cov_row
+ if sub_sample_outer_products is not None:
+ _SUB_SAMPLE_OUTER_PRODUCTS = sub_sample_outer_products
+ if inputs_to_extract_patches_factor is not None:
+ _INPUTS_TO_EXTRACT_PATCHES_FACTOR = inputs_to_extract_patches_factor
+ if sub_sample_inputs is not None:
+ _SUB_SAMPLE_INPUTS = sub_sample_inputs
if tower_strategy is not None:
TOWER_STRATEGY = tower_strategy
@@ -227,6 +260,58 @@ def graph_func_to_string(func):
return list_to_string(func.func_id)
+def _subsample_for_cov_computation(array, name=None):
+ """Subsamples the first dimension of the array.
+
+ `array`(A) is a tensor of shape `[batch_size, dim_2]`. Then the covariance
+ matrix(A^TA) is of shape `dim_2 ** 2`. Subsample only if the number of outer
+ products per row of the covariance matrix is greater than
+ `_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW`.
+
+ Args:
+ array: Tensor, of shape `[batch_size, dim_2]`.
+ name: `string`, Default(None)
+
+ Returns:
+ A tensor of shape `[max_samples, dim_2]`.
+
+ Raises:
+ ValueError: If array's is not matrix-shaped.
+ ValueError: If array's batch_size cannot be inferred.
+
+ """
+ with tf_ops.name_scope(name, "subsample", [array]):
+ array = tf_ops.convert_to_tensor(array)
+ if len(array.shape) != 2:
+ raise ValueError("Input param array must be a matrix.")
+
+ batch_size = array.shape.as_list()[0]
+ if batch_size is None:
+ raise ValueError("Unable to get batch_size from input param array.")
+
+ num_cov_rows = array.shape.as_list()[-1]
+ max_batch_size = int(_MAX_NUM_OUTER_PRODUCTS_PER_COV_ROW * num_cov_rows)
+ if batch_size <= max_batch_size:
+ return array
+
+ return _random_tensor_gather(array, max_batch_size)
+
+
+def _random_tensor_gather(array, max_size):
+ """Generates a random set of indices and gathers the value at the indcices.
+
+ Args:
+ array: Tensor, of shape `[batch_size, dim_2]`.
+ max_size: int, Number of indices to sample.
+
+ Returns:
+ A tensor of shape `[max_size, ...]`.
+ """
+ batch_size = array.shape.as_list()[0]
+ indices = random_ops.random_shuffle(math_ops.range(0, batch_size))[:max_size]
+ return array_ops.gather(array, indices)
+
+
@six.add_metaclass(abc.ABCMeta)
class FisherFactor(object):
"""Base class for objects modeling factors of approximate Fisher blocks.
@@ -1153,7 +1238,9 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
dilation_rate=None,
data_format=None,
extract_patches_fn=None,
- has_bias=False):
+ has_bias=False,
+ sub_sample_inputs=None,
+ sub_sample_patches=None):
"""Initializes ConvInputKroneckerFactor.
Args:
@@ -1173,6 +1260,10 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
patches. One of "extract_convolution_patches", "extract_image_patches",
"extract_pointwise_conv2d_patches".
has_bias: bool. If True, append 1 to in_channel.
+ sub_sample_inputs: `bool`. If True, then subsample the inputs from which
+ the image patches are extracted. (Default: None)
+ sub_sample_patches: `bool`, If `True` then subsample the extracted
+ patches.(Default: None)
"""
self._inputs = inputs
self._filter_shape = filter_shape
@@ -1182,7 +1273,15 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
self._data_format = data_format
self._extract_patches_fn = extract_patches_fn
self._has_bias = has_bias
+ if sub_sample_inputs is None:
+ self._sub_sample_inputs = _SUB_SAMPLE_INPUTS
+ else:
+ self._sub_sample_inputs = sub_sample_inputs
+ if sub_sample_patches is None:
+ self._sub_sample_patches = _SUB_SAMPLE_OUTER_PRODUCTS
+ else:
+ self._sub_sample_patches = sub_sample_patches
super(ConvInputKroneckerFactor, self).__init__()
@property
@@ -1215,6 +1314,10 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
assert source == 0
inputs = self._inputs[tower]
+ if self._sub_sample_inputs:
+ batch_size = inputs.shape.as_list()[0]
+ max_size = int(batch_size * _INPUTS_TO_EXTRACT_PATCHES_FACTOR)
+ inputs = _random_tensor_gather(inputs, max_size)
# TODO(b/64144716): there is potential here for a big savings in terms of
# memory use.
@@ -1260,8 +1363,12 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
# |Delta| = number of spatial offsets, and J = number of input maps
# for convolutional layer l.
patches_flat = array_ops.reshape(patches, [-1, flatten_size])
+
# We append a homogenous coordinate to patches_flat if the layer has
# bias parameters. This gives us [[A_l]]_H from the paper.
+ if self._sub_sample_patches:
+ patches_flat = _subsample_for_cov_computation(patches_flat)
+
if self._has_bias:
patches_flat = append_homog(patches_flat)
# We call compute_cov without passing in a normalizer. compute_cov uses
diff --git a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
index d81a534b79..9e5aaf3118 100644
--- a/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
+++ b/tensorflow/contrib/learn/python/learn/estimators/estimator_test.py
@@ -715,7 +715,9 @@ class EstimatorTest(test.TestCase):
ckpt = checkpoint_state_pb2.CheckpointState()
text_format.Merge(checkpoint_file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
- self.assertAllEqual(['model.ckpt-1', 'model.ckpt-5'],
+ # TODO(b/78461127): Please modify tests to not directly rely on names of
+ # checkpoints.
+ self.assertAllEqual(['model.ckpt-0', 'model.ckpt-5'],
ckpt.all_model_checkpoint_paths)
def test_train_save_copy_reload(self):
diff --git a/tensorflow/contrib/linalg/__init__.py b/tensorflow/contrib/linalg/__init__.py
index 38bd66b13f..554854da84 100644
--- a/tensorflow/contrib/linalg/__init__.py
+++ b/tensorflow/contrib/linalg/__init__.py
@@ -18,6 +18,9 @@ See the @{$python/contrib.linalg} guide.
@@LinearOperator
@@LinearOperatorBlockDiag
+@@LinearOperatorCirculant
+@@LinearOperatorCirculant2D
+@@LinearOperatorCirculant3D
@@LinearOperatorDiag
@@LinearOperatorIdentity
@@LinearOperatorScaledIdentity
@@ -39,6 +42,7 @@ from tensorflow.contrib.linalg.python.ops.linear_operator_addition import *
from tensorflow.contrib.linalg.python.ops.linear_operator_block_diag import *
from tensorflow.contrib.linalg.python.ops.linear_operator_kronecker import *
from tensorflow.python.ops.linalg.linear_operator import *
+from tensorflow.python.ops.linalg.linear_operator_circulant import *
from tensorflow.python.ops.linalg.linear_operator_composition import *
from tensorflow.python.ops.linalg.linear_operator_diag import *
from tensorflow.python.ops.linalg.linear_operator_full_matrix import *
diff --git a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
index 6e6c812adc..b5741967ab 100644
--- a/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
+++ b/tensorflow/contrib/linear_optimizer/python/kernel_tests/sdca_ops_test.py
@@ -39,8 +39,8 @@ from tensorflow.python.ops import variables as variables_lib
from tensorflow.python.platform import googletest
_MAX_ITERATIONS = 100
-_SHARD_NUMBERS = [None, 1, 3, 10]
-_NUM_LOSS_PARTITIONS = [2, 4]
+_SHARD_NUMBERS = [None, 1, 3]
+_NUM_LOSS_PARTITIONS = [4]
def make_example_proto(feature_dict, target, value=1.0):
diff --git a/tensorflow/contrib/lite/context.h b/tensorflow/contrib/lite/context.h
index 0b38f43cd3..12841d233c 100644
--- a/tensorflow/contrib/lite/context.h
+++ b/tensorflow/contrib/lite/context.h
@@ -275,7 +275,7 @@ typedef struct {
typedef struct TfLiteContext {
// Number of tensors in the context.
- int tensors_size;
+ size_t tensors_size;
// The execution plan contains a list of the node indices in execution
// order. execution_plan->size is the current number of nodes. And,
@@ -397,13 +397,13 @@ typedef struct _TfLiteDelegate {
// This can be null if the delegate doesn't use its own buffer.
TfLiteStatus (*CopyFromBufferHandle)(TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
- void* data, int size);
+ void* data, size_t size);
// Copy the data from raw memory to delegate buffer handle.
// This can be null if the delegate doesn't use its own buffer.
TfLiteStatus (*CopyToBufferHandle)(TfLiteDelegate* delegate,
TfLiteBufferHandle buffer_handle,
- void* data, int size);
+ void* data, size_t size);
// Free the Delegate Buffer Handle. Note: This only frees the handle, but
// this doesn't release the underlying resource (e.g. textures). The
diff --git a/tensorflow/contrib/lite/interpreter.cc b/tensorflow/contrib/lite/interpreter.cc
index 91b6c414bf..9d8ea55fd1 100644
--- a/tensorflow/contrib/lite/interpreter.cc
+++ b/tensorflow/contrib/lite/interpreter.cc
@@ -308,7 +308,12 @@ TfLiteStatus Interpreter::CheckTensorIndices(const char* label,
for (int i = 0; i < length; i++) {
int index = indices[i];
- if (index < kOptionalTensor || index >= context_.tensors_size) {
+ // Continue if index == kOptionalTensor before additional comparisons below,
+ // size_t(-1) is always >= context_tensors_size.
+ if (index == kOptionalTensor) {
+ continue;
+ }
+ if (index < 0 || static_cast<size_t>(index) >= context_.tensors_size) {
ReportError(&context_, "Invalid tensor index %d in %s\n", index, label);
consistent_ = false;
return kTfLiteError;
@@ -318,7 +323,7 @@ TfLiteStatus Interpreter::CheckTensorIndices(const char* label,
}
TfLiteStatus Interpreter::BytesRequired(TfLiteType type, const int* dims,
- int dims_size, size_t* bytes) {
+ size_t dims_size, size_t* bytes) {
// TODO(aselle): Check for overflow here using overflow.h in TensorFlow
// MultiplyWithoutOverflow.
TF_LITE_ENSURE(&context_, bytes != nullptr);
@@ -645,7 +650,7 @@ TfLiteStatus Interpreter::GetNodeAndRegistration(
}
TfLiteStatus Interpreter::SetTensorParametersReadOnly(
- int tensor_index, TfLiteType type, const char* name, const int rank,
+ int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization, const char* buffer,
size_t bytes, const Allocation* allocation) {
if (state_ == kStateInvokableAndImmutable) {
@@ -691,7 +696,7 @@ TfLiteStatus Interpreter::SetTensorParametersReadOnly(
// bytes. The lifetime of buffer must be ensured to be greater or equal
// to Interpreter.
TfLiteStatus Interpreter::SetTensorParametersReadWrite(
- int tensor_index, TfLiteType type, const char* name, const int rank,
+ int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization) {
if (state_ == kStateInvokableAndImmutable) {
ReportError(
diff --git a/tensorflow/contrib/lite/interpreter.h b/tensorflow/contrib/lite/interpreter.h
index a49134b95e..6f3433abcf 100644
--- a/tensorflow/contrib/lite/interpreter.h
+++ b/tensorflow/contrib/lite/interpreter.h
@@ -150,7 +150,7 @@ class Interpreter {
};
TfLiteStatus SetTensorParametersReadOnly(
- int tensor_index, TfLiteType type, const char* name, const int rank,
+ int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization,
const char* buffer, size_t bytes, const Allocation* allocation = nullptr);
@@ -165,7 +165,7 @@ class Interpreter {
dims.data(), quantization);
}
TfLiteStatus SetTensorParametersReadWrite(
- int tensor_index, TfLiteType type, const char* name, const int rank,
+ int tensor_index, TfLiteType type, const char* name, const size_t rank,
const int* dims, TfLiteQuantizationParams quantization);
// Functions to access tensor data
@@ -189,10 +189,10 @@ class Interpreter {
}
// Return the number of tensors in the model.
- int tensors_size() const { return context_.tensors_size; }
+ size_t tensors_size() const { return context_.tensors_size; }
// Return the number of ops in the model.
- int nodes_size() const { return nodes_and_registration_.size(); }
+ size_t nodes_size() const { return nodes_and_registration_.size(); }
// WARNING: Experimental interface, subject to change
const std::vector<int>& execution_plan() const { return execution_plan_; }
@@ -406,7 +406,7 @@ class Interpreter {
// Compute the number of bytes required to represent a tensor with dimensions
// specified by the array dims (of length dims_size). Returns the status code
// and bytes.
- TfLiteStatus BytesRequired(TfLiteType type, const int* dims, int dims_size,
+ TfLiteStatus BytesRequired(TfLiteType type, const int* dims, size_t dims_size,
size_t* bytes);
// Request an tensor be resized implementation. If the given tensor is of
@@ -467,7 +467,7 @@ class Interpreter {
// tensors. After calling this function, adding `kTensorsCapacityHeadroom`
// more tensors won't invalidate the pointer to existing tensors.
void EnsureTensorsVectorCapacity() {
- const int required_capacity = tensors_size() + kTensorsCapacityHeadroom;
+ const size_t required_capacity = tensors_size() + kTensorsCapacityHeadroom;
if (required_capacity > tensors_.capacity()) {
tensors_.reserve(required_capacity);
context_.tensors = tensors_.data();
diff --git a/tensorflow/contrib/lite/interpreter_test.cc b/tensorflow/contrib/lite/interpreter_test.cc
index 131e088079..453c1ada1c 100644
--- a/tensorflow/contrib/lite/interpreter_test.cc
+++ b/tensorflow/contrib/lite/interpreter_test.cc
@@ -887,15 +887,15 @@ class TestDelegate : public ::testing::Test {
TfLiteIntArrayFree(nodes_to_separate);
return kTfLiteOk;
};
- delegate_.CopyToBufferHandle = [](TfLiteDelegate* delegate,
- TfLiteBufferHandle buffer_handle,
- void* data, int size) -> TfLiteStatus {
+ delegate_.CopyToBufferHandle =
+ [](TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle,
+ void* data, size_t size) -> TfLiteStatus {
// TODO(ycling): Implement tests to test buffer copying logic.
return kTfLiteOk;
};
delegate_.CopyFromBufferHandle =
[](TfLiteDelegate* delegate, TfLiteBufferHandle buffer_handle,
- void* data, int size) -> TfLiteStatus {
+ void* data, size_t size) -> TfLiteStatus {
// TODO(ycling): Implement tests to test buffer copying logic.
return kTfLiteOk;
};
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
index 18f6465188..4f5662bc2d 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/Camera2BasicFragment.java
@@ -47,6 +47,8 @@ import android.os.HandlerThread;
import android.support.annotation.NonNull;
import android.support.v13.app.FragmentCompat;
import android.support.v4.content.ContextCompat;
+import android.text.SpannableString;
+import android.text.SpannableStringBuilder;
import android.util.Log;
import android.util.Size;
import android.view.LayoutInflater;
@@ -207,14 +209,21 @@ public class Camera2BasicFragment extends Fragment
*
* @param text The message to show
*/
- private void showToast(final String text) {
+ private void showToast(String s) {
+ SpannableStringBuilder builder = new SpannableStringBuilder();
+ SpannableString str1 = new SpannableString(s);
+ builder.append(str1);
+ showToast(builder);
+ }
+
+ private void showToast(SpannableStringBuilder builder) {
final Activity activity = getActivity();
if (activity != null) {
activity.runOnUiThread(
new Runnable() {
@Override
public void run() {
- textView.setText(text);
+ textView.setText(builder, TextView.BufferType.SPANNABLE);
}
});
}
@@ -682,8 +691,9 @@ public class Camera2BasicFragment extends Fragment
showToast("Uninitialized Classifier or invalid context.");
return;
}
+ SpannableStringBuilder textToShow = new SpannableStringBuilder();
Bitmap bitmap = textureView.getBitmap(classifier.getImageSizeX(), classifier.getImageSizeY());
- String textToShow = classifier.classifyFrame(bitmap);
+ classifier.classifyFrame(bitmap, textToShow);
bitmap.recycle();
showToast(textToShow);
}
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
index d32c077910..7bb6afd9d8 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/java/com/example/android/tflitecamerademo/ImageClassifier.java
@@ -19,10 +19,11 @@ import android.app.Activity;
import android.content.res.AssetFileDescriptor;
import android.graphics.Bitmap;
import android.os.SystemClock;
+import android.text.SpannableString;
+import android.text.SpannableStringBuilder;
+import android.text.style.ForegroundColorSpan;
+import android.text.style.RelativeSizeSpan;
import android.util.Log;
-
-import org.tensorflow.lite.Interpreter;
-
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
@@ -37,11 +38,15 @@ import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
+import org.tensorflow.lite.Interpreter;
/**
* Classifies images with Tensorflow Lite.
*/
public abstract class ImageClassifier {
+ // Display preferences
+ private static final float GOOD_PROB_THRESHOLD = 0.3f;
+ private static final int SMALL_COLOR = 0xffddaa88;
/** Tag for the {@link Log}. */
private static final String TAG = "TfLiteCameraDemo";
@@ -99,10 +104,12 @@ public abstract class ImageClassifier {
}
/** Classifies a frame from the preview stream. */
- String classifyFrame(Bitmap bitmap) {
+ void classifyFrame(Bitmap bitmap, SpannableStringBuilder builder) {
+ printTopKLabels(builder);
+
if (tflite == null) {
Log.e(TAG, "Image classifier has not been initialized; Skipped.");
- return "Uninitialized Classifier.";
+ builder.append(new SpannableString("Uninitialized Classifier."));
}
convertBitmapToByteBuffer(bitmap);
// Here's where the magic happens!!!
@@ -115,9 +122,10 @@ public abstract class ImageClassifier {
applyFilter();
// Print the results.
- String textToShow = printTopKLabels();
- textToShow = Long.toString(endTime - startTime) + "ms" + textToShow;
- return textToShow;
+ long duration = endTime - startTime;
+ SpannableString span = new SpannableString(duration + " ms");
+ span.setSpan(new ForegroundColorSpan(android.graphics.Color.LTGRAY), 0, span.length(), 0);
+ builder.append(span);
}
void applyFilter() {
@@ -202,7 +210,7 @@ public abstract class ImageClassifier {
}
/** Prints top-K labels, to be shown in UI as the results. */
- private String printTopKLabels() {
+ private void printTopKLabels(SpannableStringBuilder builder) {
for (int i = 0; i < getNumLabels(); ++i) {
sortedLabels.add(
new AbstractMap.SimpleEntry<>(labelList.get(i), getNormalizedProbability(i)));
@@ -210,13 +218,27 @@ public abstract class ImageClassifier {
sortedLabels.poll();
}
}
- String textToShow = "";
+
final int size = sortedLabels.size();
- for (int i = 0; i < size; ++i) {
+ for (int i = 0; i < size; i++) {
Map.Entry<String, Float> label = sortedLabels.poll();
- textToShow = String.format("\n%s: %4.2f", label.getKey(), label.getValue()) + textToShow;
+ SpannableString span =
+ new SpannableString(String.format("%s: %4.2f\n", label.getKey(), label.getValue()));
+ int color;
+ // Make it white when probability larger than threshold.
+ if (label.getValue() > GOOD_PROB_THRESHOLD) {
+ color = android.graphics.Color.WHITE;
+ } else {
+ color = SMALL_COLOR;
+ }
+ // Make first item bigger.
+ if (i == size - 1) {
+ float sizeScale = (i == size - 1) ? 1.25f : 0.8f;
+ span.setSpan(new RelativeSizeSpan(sizeScale), 0, span.length(), 0);
+ }
+ span.setSpan(new ForegroundColorSpan(color), 0, span.length(), 0);
+ builder.insert(0, span);
}
- return textToShow;
}
/**
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png
index c22509d8df..52cf2ab952 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-hdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png
index d68af39186..b75f892c46 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-mdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png
index 15e419b7cc..36e14c48d1 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xhdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png
index 342ce34e16..06dd2a740e 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/ic_launcher.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/logo.png b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/logo.png
new file mode 100644
index 0000000000..b94bcfc081
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/drawable-xxhdpi/logo.png
Binary files differ
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml
index a84f1bbfa0..20f520814d 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-land/fragment_camera2_basic.xml
@@ -14,37 +14,50 @@
limitations under the License.
-->
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
+ xmlns:app="http://schemas.android.com/apk/res-auto"
android:layout_width="match_parent"
android:layout_height="match_parent">
- <com.example.android.tflitecamerademo.AutoFitTextureView
- android:id="@+id/texture"
- android:layout_width="wrap_content"
- android:layout_height="wrap_content"
- android:layout_alignParentBottom="true"
- android:layout_alignParentStart="true"
- android:layout_alignParentTop="true" />
-
- <FrameLayout
- android:id="@+id/control"
+ <LinearLayout
android:layout_width="match_parent"
- android:layout_height="wrap_content"
- android:layout_alignParentBottom="true"
- android:layout_alignParentEnd="true"
- android:layout_alignParentTop="true"
- android:layout_toRightOf="@id/texture"
- android:background="@color/control_background"
- android:orientation="horizontal">
-
- <TextView android:id="@+id/text"
- android:layout_width="wrap_content"
- android:layout_height="wrap_content"
+ android:layout_height="match_parent"
+ android:background="#bb7700"
+ android:orientation="horizontal"
+ android:weightSum="100">
+
+ <LinearLayout
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:layout_weight="30"
+ android:orientation="vertical">
+
+ <com.example.android.tflitecamerademo.AutoFitTextureView
+ android:id="@+id/texture"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:layout_weight="100" />
+
+ <ImageView
+ android:id="@+id/logoview"
+ android:layout_width="match_parent"
+ android:layout_height="wrap_content"
+ android:layout_weight="100"
+ android:scaleType="centerCrop"
+ android:src="@drawable/logo" />
+
+ </LinearLayout>
+
+ <TextView
+ android:id="@+id/text"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:layout_weight="70"
+ android:paddingLeft="5dp"
android:paddingTop="20dp"
android:textColor="#FFF"
android:textSize="20sp"
android:textStyle="bold" />
-
- </FrameLayout>
+ </LinearLayout>
</RelativeLayout>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml
new file mode 100644
index 0000000000..72a229ecdb
--- /dev/null
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout-v26/fragment_camera2_basic.xml
@@ -0,0 +1,88 @@
+<?xml version="1.0" encoding="utf-8"?><!--
+ Copyright 2014 The Android Open Source Project
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
+ xmlns:app="http://schemas.android.com/apk/res-auto"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:background="#bb7700">
+
+ <com.example.android.tflitecamerademo.AutoFitTextureView
+ android:id="@+id/texture"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:layout_weight="1" />
+
+ <LinearLayout
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentEnd="false"
+ android:layout_alignParentStart="true"
+ android:layout_alignParentTop="false"
+ android:background="#bb7700"
+ android:orientation="vertical"
+ android:weightSum="100">
+
+ <ImageView
+ android:id="@+id/logoview2"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:layout_weight="30"
+ android:scaleType="fitStart"
+ android:src="@drawable/logo" />
+
+ <TextView
+ android:id="@+id/text"
+ android:layout_width="match_parent"
+ android:layout_height="wrap_content"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentEnd="true"
+ android:layout_alignParentRight="true"
+ android:layout_weight="30"
+ android:textColor="#FFF"
+ android:textSize="20sp"
+ android:textStyle="bold" />
+
+ </LinearLayout>
+
+ <RelativeLayout
+ android:id="@+id/control2"
+ android:layout_width="match_parent"
+ android:layout_height="135dp"
+ android:layout_alignParentLeft="true"
+ android:layout_alignParentStart="true"
+ android:layout_alignTop="@+id/control"
+ android:layout_marginLeft="300dp"
+ android:layout_marginStart="300dp"
+ android:background="#bb7700">
+
+ <ToggleButton
+ android:id="@+id/button"
+ android:textOff="@string/tflite"
+ android:textOn="@string/nnapi"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:layout_alignParentLeft="true"
+ android:layout_alignParentStart="true" />
+
+ <NumberPicker
+ android:id="@+id/np"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:layout_below="@+id/button"
+ android:visibility="visible" />
+ </RelativeLayout>
+</RelativeLayout>
diff --git a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
index db557ad62f..d12435d5ab 100644
--- a/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
+++ b/tensorflow/contrib/lite/java/demo/app/src/main/res/layout/fragment_camera2_basic.xml
@@ -14,9 +14,30 @@
limitations under the License.
-->
<RelativeLayout xmlns:android="http://schemas.android.com/apk/res/android"
+ xmlns:app="http://schemas.android.com/apk/res-auto"
+ xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent">
+ <LinearLayout
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:orientation="vertical"
+ android:weightSum="60">
+
+ <FrameLayout
+ android:id="@+id/control"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:layout_alignParentBottom="true"
+ android:layout_alignParentStart="true"
+ android:layout_weight="60"
+ android:background="#cc7700"
+ android:paddingLeft="20dp"
+ android:paddingStart="20dp">
+
+ </FrameLayout>
+
<com.example.android.tflitecamerademo.AutoFitTextureView
android:id="@+id/texture"
android:layout_width="wrap_content"
@@ -25,29 +46,43 @@
android:layout_alignParentLeft="true"
android:layout_alignParentTop="true" />
- <FrameLayout
- android:id="@+id/control"
+ <TextView
+ android:id="@+id/text"
+ android:layout_width="match_parent"
+ android:layout_height="match_parent"
+ android:layout_weight="20"
+ android:textColor="#FFF"
+ android:textSize="20sp"
+ android:textStyle="bold" />
+ </LinearLayout>
+
+ <RelativeLayout
+ android:id="@+id/control2"
android:layout_width="match_parent"
android:layout_height="135dp"
- android:layout_alignParentBottom="true"
- android:layout_alignParentStart="true"
android:layout_alignParentLeft="true"
- android:layout_alignParentEnd="true"
- android:layout_alignParentRight="true"
- android:layout_marginEnd="150dp"
- android:layout_marginRight="150dp"
- android:background="@color/control_background">
+ android:layout_alignParentStart="true"
+ android:layout_alignTop="@+id/control"
+ android:layout_marginLeft="300dp"
+ android:layout_marginStart="300dp"
+ android:background="#bb7700">
- <TextView
- android:id="@+id/text"
+ <ToggleButton
+ android:id="@+id/button"
+ android:textOff="@string/tflite"
+ android:textOn="@string/nnapi"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
- android:paddingLeft="20dp"
- android:textColor="#FFF"
- android:textSize="20sp"
- android:textStyle="bold" />
+ android:layout_alignParentLeft="true"
+ android:layout_alignParentStart="true" />
- </FrameLayout>
+ <NumberPicker
+ android:id="@+id/np"
+ android:layout_width="wrap_content"
+ android:layout_height="wrap_content"
+ android:layout_below="@+id/button"
+ android:visibility="visible" />
+ </RelativeLayout>
<RelativeLayout
android:id="@+id/control2"
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
index d27c6ccf3d..9e9aba0169 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/optimized_ops.h
@@ -1203,13 +1203,330 @@ void FullyConnected(const uint8* input_data, const Dims<4>& input_dims,
output_activation_max, output_data, output_dims, gemm_context);
}
+// Internal function doing the actual arithmetic work for
+// ExperimentalShuffledFullyConnected.
+// May be called either directly by it (single-threaded case) or may be used
+// as the 'task' for worker threads to run (multi-threaded case, see
+// ExperimentalShuffledFullyConnectedWorkerTask below).
+inline void ExperimentalShuffledFullyConnectedWorkerImpl(
+ const uint8* shuffled_input_workspace_data,
+ const int8* shuffled_weights_data, int batches, int output_depth,
+ int output_stride, int accum_depth, const int32* bias_data,
+ int32 output_multiplier, int output_shift, int16* output_data) {
+#if defined USE_NEON
+ const int8* shuffled_weights_ptr = shuffled_weights_data;
+ if (batches == 1) {
+ const int right_shift = output_shift > 0 ? output_shift : 0;
+ const int left_shift = output_shift > 0 ? 0 : -output_shift;
+ for (int c = 0; c < output_depth; c += 4) {
+ // Accumulation loop.
+ int32x4_t row_accum0 = vdupq_n_s32(0);
+ int32x4_t row_accum1 = vdupq_n_s32(0);
+ int32x4_t row_accum2 = vdupq_n_s32(0);
+ int32x4_t row_accum3 = vdupq_n_s32(0);
+ for (int d = 0; d < accum_depth; d += 16) {
+ int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
+ int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
+ int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
+ int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
+ shuffled_weights_ptr += 64;
+ int8x16_t input =
+ vreinterpretq_s8_u8(vld1q_u8(shuffled_input_workspace_data + d));
+ int16x8_t local_accum0 =
+ vmull_s8(vget_low_s8(weights0), vget_low_s8(input));
+ int16x8_t local_accum1 =
+ vmull_s8(vget_low_s8(weights1), vget_low_s8(input));
+ int16x8_t local_accum2 =
+ vmull_s8(vget_low_s8(weights2), vget_low_s8(input));
+ int16x8_t local_accum3 =
+ vmull_s8(vget_low_s8(weights3), vget_low_s8(input));
+ local_accum0 =
+ vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input));
+ local_accum1 =
+ vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input));
+ local_accum2 =
+ vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input));
+ local_accum3 =
+ vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input));
+ row_accum0 = vpadalq_s16(row_accum0, local_accum0);
+ row_accum1 = vpadalq_s16(row_accum1, local_accum1);
+ row_accum2 = vpadalq_s16(row_accum2, local_accum2);
+ row_accum3 = vpadalq_s16(row_accum3, local_accum3);
+ }
+ // Horizontally reduce accumulators
+ int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
+ pairwise_reduced_acc_2, pairwise_reduced_acc_3;
+ pairwise_reduced_acc_0 =
+ vpadd_s32(vget_low_s32(row_accum0), vget_high_s32(row_accum0));
+ pairwise_reduced_acc_1 =
+ vpadd_s32(vget_low_s32(row_accum1), vget_high_s32(row_accum1));
+ pairwise_reduced_acc_2 =
+ vpadd_s32(vget_low_s32(row_accum2), vget_high_s32(row_accum2));
+ pairwise_reduced_acc_3 =
+ vpadd_s32(vget_low_s32(row_accum3), vget_high_s32(row_accum3));
+ const int32x2_t reduced_lo =
+ vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
+ const int32x2_t reduced_hi =
+ vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
+ int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
+ // Add bias values.
+ int32x4_t bias_vec = vld1q_s32(bias_data + c);
+ reduced = vaddq_s32(reduced, bias_vec);
+ reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
+ // Multiply by the fixed-point multiplier.
+ reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
+ // Rounding-shift-right.
+ using gemmlowp::RoundingDivideByPOT;
+ reduced = RoundingDivideByPOT(reduced, right_shift);
+ // Narrow values down to 16 bit signed.
+ const int16x4_t res16 = vqmovn_s32(reduced);
+ vst1_s16(output_data + c, res16);
+ }
+ } else if (batches == 4) {
+ const int right_shift = output_shift > 0 ? output_shift : 0;
+ const int left_shift = output_shift > 0 ? 0 : -output_shift;
+ for (int c = 0; c < output_depth; c += 4) {
+ const int8* shuffled_input_ptr =
+ reinterpret_cast<const int8*>(shuffled_input_workspace_data);
+ // Accumulation loop.
+ int32x4_t row_accum00 = vdupq_n_s32(0);
+ int32x4_t row_accum10 = vdupq_n_s32(0);
+ int32x4_t row_accum20 = vdupq_n_s32(0);
+ int32x4_t row_accum30 = vdupq_n_s32(0);
+ int32x4_t row_accum01 = vdupq_n_s32(0);
+ int32x4_t row_accum11 = vdupq_n_s32(0);
+ int32x4_t row_accum21 = vdupq_n_s32(0);
+ int32x4_t row_accum31 = vdupq_n_s32(0);
+ int32x4_t row_accum02 = vdupq_n_s32(0);
+ int32x4_t row_accum12 = vdupq_n_s32(0);
+ int32x4_t row_accum22 = vdupq_n_s32(0);
+ int32x4_t row_accum32 = vdupq_n_s32(0);
+ int32x4_t row_accum03 = vdupq_n_s32(0);
+ int32x4_t row_accum13 = vdupq_n_s32(0);
+ int32x4_t row_accum23 = vdupq_n_s32(0);
+ int32x4_t row_accum33 = vdupq_n_s32(0);
+ for (int d = 0; d < accum_depth; d += 16) {
+ int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
+ int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
+ int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
+ int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
+ shuffled_weights_ptr += 64;
+ int8x16_t input0 = vld1q_s8(shuffled_input_ptr + 0);
+ int8x16_t input1 = vld1q_s8(shuffled_input_ptr + 16);
+ int8x16_t input2 = vld1q_s8(shuffled_input_ptr + 32);
+ int8x16_t input3 = vld1q_s8(shuffled_input_ptr + 48);
+ shuffled_input_ptr += 64;
+ int16x8_t local_accum0, local_accum1, local_accum2, local_accum3;
+#define TFLITE_SHUFFLED_FC_ACCUM(B) \
+ local_accum0 = vmull_s8(vget_low_s8(weights0), vget_low_s8(input##B)); \
+ local_accum1 = vmull_s8(vget_low_s8(weights1), vget_low_s8(input##B)); \
+ local_accum2 = vmull_s8(vget_low_s8(weights2), vget_low_s8(input##B)); \
+ local_accum3 = vmull_s8(vget_low_s8(weights3), vget_low_s8(input##B)); \
+ local_accum0 = \
+ vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input##B)); \
+ local_accum1 = \
+ vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input##B)); \
+ local_accum2 = \
+ vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input##B)); \
+ local_accum3 = \
+ vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input##B)); \
+ row_accum0##B = vpadalq_s16(row_accum0##B, local_accum0); \
+ row_accum1##B = vpadalq_s16(row_accum1##B, local_accum1); \
+ row_accum2##B = vpadalq_s16(row_accum2##B, local_accum2); \
+ row_accum3##B = vpadalq_s16(row_accum3##B, local_accum3);
+
+ TFLITE_SHUFFLED_FC_ACCUM(0)
+ TFLITE_SHUFFLED_FC_ACCUM(1)
+ TFLITE_SHUFFLED_FC_ACCUM(2)
+ TFLITE_SHUFFLED_FC_ACCUM(3)
+
+#undef TFLITE_SHUFFLED_FC_ACCUM
+ }
+ // Horizontally reduce accumulators
+
+#define TFLITE_SHUFFLED_FC_STORE(B) \
+ { \
+ int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1, \
+ pairwise_reduced_acc_2, pairwise_reduced_acc_3; \
+ pairwise_reduced_acc_0 = \
+ vpadd_s32(vget_low_s32(row_accum0##B), vget_high_s32(row_accum0##B)); \
+ pairwise_reduced_acc_1 = \
+ vpadd_s32(vget_low_s32(row_accum1##B), vget_high_s32(row_accum1##B)); \
+ pairwise_reduced_acc_2 = \
+ vpadd_s32(vget_low_s32(row_accum2##B), vget_high_s32(row_accum2##B)); \
+ pairwise_reduced_acc_3 = \
+ vpadd_s32(vget_low_s32(row_accum3##B), vget_high_s32(row_accum3##B)); \
+ const int32x2_t reduced_lo = \
+ vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1); \
+ const int32x2_t reduced_hi = \
+ vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3); \
+ int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi); \
+ int32x4_t bias_vec = vld1q_s32(bias_data + c); \
+ reduced = vaddq_s32(reduced, bias_vec); \
+ reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift)); \
+ reduced = vqrdmulhq_n_s32(reduced, output_multiplier); \
+ using gemmlowp::RoundingDivideByPOT; \
+ reduced = RoundingDivideByPOT(reduced, right_shift); \
+ const int16x4_t res16 = vqmovn_s32(reduced); \
+ vst1_s16(output_data + c + B * output_stride, res16); \
+ }
+
+ TFLITE_SHUFFLED_FC_STORE(0);
+ TFLITE_SHUFFLED_FC_STORE(1);
+ TFLITE_SHUFFLED_FC_STORE(2);
+ TFLITE_SHUFFLED_FC_STORE(3);
+
+#undef TFLITE_SHUFFLED_FC_STORE
+ }
+ } else {
+ TFLITE_DCHECK(false);
+ return;
+ }
+#else
+ if (batches == 1) {
+ int16* output_ptr = output_data;
+ // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
+ // so that just reinterpreting them as int8 values is equivalent to
+ // subtracting 128 from them, thus implementing for free the subtraction of
+ // the zero_point value 128.
+ const int8* shuffled_weights_ptr =
+ reinterpret_cast<const int8*>(shuffled_weights_data);
+ // Likewise, we preshuffled and pre-xored the input data above.
+ const int8* shuffled_input_data =
+ reinterpret_cast<const int8*>(shuffled_input_workspace_data);
+ for (int c = 0; c < output_depth; c += 4) {
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32 accum[4] = {0};
+ // Accumulation loop.
+ for (int d = 0; d < accum_depth; d += 16) {
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 16; j++) {
+ int8 input_val = shuffled_input_data[d + j];
+ int8 weights_val = *shuffled_weights_ptr++;
+ accum[i] += weights_val * input_val;
+ }
+ }
+ }
+ for (int i = 0; i < 4; i++) {
+ // Add bias value
+ int acc = accum[i] + bias_data[c + i];
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, typically 3 integer bits) fixed-point format. The quantized
+ // multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
+ -output_shift);
+ // Saturate, cast to int16, and store to output array.
+ acc = std::max(acc, -32768);
+ acc = std::min(acc, 32767);
+ output_ptr[c + i] = acc;
+ }
+ }
+ } else if (batches == 4) {
+ int16* output_ptr = output_data;
+ // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
+ // so that just reinterpreting them as int8 values is equivalent to
+ // subtracting 128 from them, thus implementing for free the subtraction of
+ // the zero_point value 128.
+ const int8* shuffled_weights_ptr =
+ reinterpret_cast<const int8*>(shuffled_weights_data);
+ // Likewise, we preshuffled and pre-xored the input data above.
+ const int8* shuffled_input_data =
+ reinterpret_cast<const int8*>(shuffled_input_workspace_data);
+ for (int c = 0; c < output_depth; c += 4) {
+ const int8* shuffled_input_ptr = shuffled_input_data;
+ // Accumulation loop.
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32 accum[4][4];
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ accum[i][b] = 0;
+ }
+ }
+ for (int d = 0; d < accum_depth; d += 16) {
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ for (int j = 0; j < 16; j++) {
+ int8 input_val = shuffled_input_ptr[16 * b + j];
+ int8 weights_val = shuffled_weights_ptr[16 * i + j];
+ accum[i][b] += weights_val * input_val;
+ }
+ }
+ }
+ shuffled_input_ptr += 64;
+ shuffled_weights_ptr += 64;
+ }
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ // Add bias value
+ int acc = accum[i][b] + bias_data[c + i];
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, typically 3 integer bits) fixed-point format. The
+ // quantized multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
+ -output_shift);
+ // Saturate, cast to int16, and store to output array.
+ acc = std::max(acc, -32768);
+ acc = std::min(acc, 32767);
+ output_ptr[b * output_stride + c + i] = acc;
+ }
+ }
+ }
+ } else {
+ TFLITE_DCHECK(false);
+ return;
+ }
+#endif
+}
+
+// Wraps ExperimentalShuffledFullyConnectedWorkerImpl into a Task class
+// to allow using gemmlowp's threadpool.
+struct ExperimentalShuffledFullyConnectedWorkerTask : gemmlowp::Task {
+ ExperimentalShuffledFullyConnectedWorkerTask(
+ const uint8* input_data, const int8* shuffled_weights_data, int batches,
+ int output_depth, int output_stride, int accum_depth,
+ const int32* bias_data, int32 output_multiplier, int output_shift,
+ int16* output_data)
+ : input_data_(input_data),
+ shuffled_weights_data_(shuffled_weights_data),
+ batches_(batches),
+ output_depth_(output_depth),
+ output_stride_(output_stride),
+ accum_depth_(accum_depth),
+ bias_data_(bias_data),
+ output_multiplier_(output_multiplier),
+ output_shift_(output_shift),
+ output_data_(output_data) {}
+
+ void Run() override {
+ ExperimentalShuffledFullyConnectedWorkerImpl(
+ input_data_, shuffled_weights_data_, batches_, output_depth_,
+ output_stride_, accum_depth_, bias_data_, output_multiplier_,
+ output_shift_, output_data_);
+ }
+
+ const uint8* input_data_;
+ const int8* shuffled_weights_data_;
+ int batches_;
+ int output_depth_;
+ int output_stride_;
+ int accum_depth_;
+ const int32* bias_data_;
+ int32 output_multiplier_;
+ int output_shift_;
+ int16* output_data_;
+};
+
inline void ExperimentalShuffledFullyConnected(
const uint8* input_data, const Dims<4>& input_dims,
const uint8* shuffled_weights_data, const Dims<4>& weights_dims,
const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
int output_shift, int32 output_activation_min, int32 output_activation_max,
int16* output_data, const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
+ uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
gemmlowp::ScopedProfilingLabel label(
"ExperimentalShuffledFullyConnected/8bit");
(void)gemm_context; // only used in optimized code.
@@ -1226,117 +1543,100 @@ inline void ExperimentalShuffledFullyConnected(
const int accum_depth = ArraySize(weights_dims, 0);
TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- // The experimental shuffling is an optimization for matrix*vector product.
- // We aren't interested in supporting non-matrix*vector-product cases, i.e.
- // batches>1.
- TFLITE_DCHECK_EQ(batches, 1);
+ TFLITE_DCHECK((accum_depth % 16) == 0);
+ TFLITE_DCHECK((output_depth % 4) == 0);
// Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
// so that just reinterpreting them as int8 values is equivalent to
// subtracting 128 from them, thus implementing for free the subtraction of
// the zero_point value 128.
- const int8* shuffled_weights_ptr =
+ const int8* int8_shuffled_weights_data =
reinterpret_cast<const int8*>(shuffled_weights_data);
-#if defined USE_NEON
- // We'll only need to xor signbit to the input activation values, as
- // that xor-ing is pre-built into the shuffled weights values.
- const uint8x16_t signbit = vdupq_n_u8(0x80);
- const int right_shift = output_shift > 0 ? output_shift : 0;
- const int left_shift = output_shift > 0 ? 0 : -output_shift;
- for (int c = 0; c < output_depth; c += 4) {
- // Accumulation loop.
- int32x4_t row_accum0 = vdupq_n_s32(0);
- int32x4_t row_accum1 = vdupq_n_s32(0);
- int32x4_t row_accum2 = vdupq_n_s32(0);
- int32x4_t row_accum3 = vdupq_n_s32(0);
- for (int d = 0; d < accum_depth; d += 16) {
- int8x16_t weights0 = vld1q_s8(shuffled_weights_ptr + 0);
- int8x16_t weights1 = vld1q_s8(shuffled_weights_ptr + 16);
- int8x16_t weights2 = vld1q_s8(shuffled_weights_ptr + 32);
- int8x16_t weights3 = vld1q_s8(shuffled_weights_ptr + 48);
- shuffled_weights_ptr += 64;
- int8x16_t input =
- vreinterpretq_s8_u8(veorq_u8(signbit, vld1q_u8(input_data + d)));
- int16x8_t local_accum0 =
- vmull_s8(vget_low_s8(weights0), vget_low_s8(input));
- int16x8_t local_accum1 =
- vmull_s8(vget_low_s8(weights1), vget_low_s8(input));
- int16x8_t local_accum2 =
- vmull_s8(vget_low_s8(weights2), vget_low_s8(input));
- int16x8_t local_accum3 =
- vmull_s8(vget_low_s8(weights3), vget_low_s8(input));
- local_accum0 =
- vmlal_s8(local_accum0, vget_high_s8(weights0), vget_high_s8(input));
- local_accum1 =
- vmlal_s8(local_accum1, vget_high_s8(weights1), vget_high_s8(input));
- local_accum2 =
- vmlal_s8(local_accum2, vget_high_s8(weights2), vget_high_s8(input));
- local_accum3 =
- vmlal_s8(local_accum3, vget_high_s8(weights3), vget_high_s8(input));
- row_accum0 = vpadalq_s16(row_accum0, local_accum0);
- row_accum1 = vpadalq_s16(row_accum1, local_accum1);
- row_accum2 = vpadalq_s16(row_accum2, local_accum2);
- row_accum3 = vpadalq_s16(row_accum3, local_accum3);
+
+ // Shuffling and xoring of input activations into the workspace buffer
+ if (batches == 1) {
+#ifdef USE_NEON
+ const uint8x16_t signbit = vdupq_n_u8(0x80);
+ for (int i = 0; i < accum_depth; i += 16) {
+ uint8x16_t val = vld1q_u8(input_data + i);
+ val = veorq_u8(val, signbit);
+ vst1q_u8(shuffled_input_workspace_data + i, val);
}
- // Horizontally reduce accumulators
- int32x2_t pairwise_reduced_acc_0, pairwise_reduced_acc_1,
- pairwise_reduced_acc_2, pairwise_reduced_acc_3;
- pairwise_reduced_acc_0 =
- vpadd_s32(vget_low_s32(row_accum0), vget_high_s32(row_accum0));
- pairwise_reduced_acc_1 =
- vpadd_s32(vget_low_s32(row_accum1), vget_high_s32(row_accum1));
- pairwise_reduced_acc_2 =
- vpadd_s32(vget_low_s32(row_accum2), vget_high_s32(row_accum2));
- pairwise_reduced_acc_3 =
- vpadd_s32(vget_low_s32(row_accum3), vget_high_s32(row_accum3));
- const int32x2_t reduced_lo =
- vpadd_s32(pairwise_reduced_acc_0, pairwise_reduced_acc_1);
- const int32x2_t reduced_hi =
- vpadd_s32(pairwise_reduced_acc_2, pairwise_reduced_acc_3);
- int32x4_t reduced = vcombine_s32(reduced_lo, reduced_hi);
- // Add bias values.
- int32x4_t bias_vec = vld1q_s32(bias_data + c);
- reduced = vaddq_s32(reduced, bias_vec);
- reduced = vshlq_s32(reduced, vdupq_n_s32(left_shift));
- // Multiply by the fixed-point multiplier.
- reduced = vqrdmulhq_n_s32(reduced, output_multiplier);
- // Rounding-shift-right.
- using gemmlowp::RoundingDivideByPOT;
- reduced = RoundingDivideByPOT(reduced, right_shift);
- // Narrow values down to 16 bit signed.
- const int16x4_t res16 = vqmovn_s32(reduced);
- vst1_s16(output_data + c, res16);
- }
#else
- for (int c = 0; c < output_depth; c += 4) {
- // Internal accumulation.
- // Initialize accumulator with the bias-value.
- int32 accum[4] = {0};
- // Accumulation loop.
- for (int d = 0; d < accum_depth; d += 16) {
- for (int i = 0; i < 4; i++) {
+ for (int i = 0; i < accum_depth; i++) {
+ shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
+ }
+#endif
+ } else if (batches == 4) {
+ uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
+ int c = 0;
+#ifdef USE_NEON
+ const uint8x16_t signbit = vdupq_n_u8(0x80);
+ for (c = 0; c < accum_depth; c += 16) {
+ const uint8* src_data_ptr = input_data + c;
+ uint8x16_t val0 = vld1q_u8(src_data_ptr + 0 * accum_depth);
+ uint8x16_t val1 = vld1q_u8(src_data_ptr + 1 * accum_depth);
+ uint8x16_t val2 = vld1q_u8(src_data_ptr + 2 * accum_depth);
+ uint8x16_t val3 = vld1q_u8(src_data_ptr + 3 * accum_depth);
+ val0 = veorq_u8(val0, signbit);
+ val1 = veorq_u8(val1, signbit);
+ val2 = veorq_u8(val2, signbit);
+ val3 = veorq_u8(val3, signbit);
+ vst1q_u8(shuffled_input_workspace_ptr + 0, val0);
+ vst1q_u8(shuffled_input_workspace_ptr + 16, val1);
+ vst1q_u8(shuffled_input_workspace_ptr + 32, val2);
+ vst1q_u8(shuffled_input_workspace_ptr + 48, val3);
+ shuffled_input_workspace_ptr += 64;
+ }
+#else
+ for (c = 0; c < accum_depth; c += 16) {
+ for (int b = 0; b < 4; b++) {
+ const uint8* src_data_ptr = input_data + b * accum_depth + c;
for (int j = 0; j < 16; j++) {
- int8 input_val = input_data[d + j] - 128;
- int8 weights_val = *shuffled_weights_ptr++;
- accum[i] += weights_val * input_val;
+ uint8 src_val = *src_data_ptr++;
+ // Flip the sign bit, so that the kernel will only need to
+ // reinterpret these uint8 values as int8, getting for free the
+ // subtraction of the zero_point value 128.
+ uint8 dst_val = src_val ^ 0x80;
+ *shuffled_input_workspace_ptr++ = dst_val;
}
}
}
- for (int i = 0; i < 4; i++) {
- // Add bias value
- int acc = accum[i] + bias_data[c + i];
- // Down-scale the final int32 accumulator to the scale used by our
- // (16-bit, typically 3 integer bits) fixed-point format. The quantized
- // multiplier and shift here have been pre-computed offline
- // (e.g. by toco).
- acc =
- MultiplyByQuantizedMultiplier(acc, output_multiplier, -output_shift);
- // Saturate, cast to int16, and store to output array.
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_data[c + i] = acc;
- }
- }
#endif
+ } else {
+ TFLITE_DCHECK(false);
+ return;
+ }
+
+ static constexpr int kKernelRows = 4;
+ const int thread_count = gemmlowp::HowManyThreads<kKernelRows>(
+ gemm_context->max_num_threads(), output_depth, batches, accum_depth);
+ if (thread_count == 1) {
+ // Single-thread case: do the computation on the current thread, don't
+ // use a threadpool
+ ExperimentalShuffledFullyConnectedWorkerImpl(
+ shuffled_input_workspace_data, int8_shuffled_weights_data, batches,
+ output_depth, output_depth, accum_depth, bias_data, output_multiplier,
+ output_shift, output_data);
+ return;
+ }
+
+ // Multi-threaded case: use the gemmlowp context's threadpool.
+ TFLITE_DCHECK_GT(thread_count, 1);
+ std::vector<gemmlowp::Task*> tasks(thread_count);
+ const int kRowsPerWorker =
+ gemmlowp::RoundUp<kKernelRows>(output_depth / thread_count);
+ int row_start = 0;
+ for (int i = 0; i < thread_count; i++) {
+ int row_end = std::min(output_depth, row_start + kRowsPerWorker);
+ tasks[i] = new ExperimentalShuffledFullyConnectedWorkerTask(
+ shuffled_input_workspace_data,
+ int8_shuffled_weights_data + row_start * accum_depth, batches,
+ row_end - row_start, output_depth, accum_depth, bias_data + row_start,
+ output_multiplier, output_shift, output_data + row_start);
+ row_start = row_end;
+ }
+ TFLITE_DCHECK_EQ(row_start, output_depth);
+ gemm_context->workers_pool()->Execute(tasks);
}
template <typename T>
@@ -5474,6 +5774,9 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& right_paddings, T* output_data,
const Dims<4>& output_dims, const int32_t pad_value) {
gemmlowp::ScopedProfilingLabel label("Pad");
+ TFLITE_DCHECK_EQ(left_paddings.size(), 4);
+ TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+
const int output_batch = ArraySize(output_dims, 3);
const int output_height = ArraySize(output_dims, 2);
const int output_width = ArraySize(output_dims, 1);
diff --git a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
index 4e324a5e10..ff15f3e3b1 100644
--- a/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
+++ b/tensorflow/contrib/lite/kernels/internal/optimized/tensor_utils_impl.h
@@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_
-#define TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_
+#ifndef TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_
+#define TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_
// TODO(ghodrat): Remove this header file and the dependency to internal data
// structure.
@@ -135,4 +135,4 @@ void NeonReductionSumVector(const float* input_vector, float* output_vector,
} // namespace tensor_utils
} // namespace tflite
-#endif // TF_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_
+#endif // TENSORFLOW_CONTRIB_LITE_KERNELS_INTERNAL_OPTIMIZED_TENSOR_UTILS_IMPL_H_
diff --git a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
index 9ad125b8eb..4c8cbe4275 100644
--- a/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
+++ b/tensorflow/contrib/lite/kernels/internal/reference/reference_ops.h
@@ -608,8 +608,9 @@ inline void ExperimentalShuffledFullyConnected(
const int32* bias_data, const Dims<4>& bias_dims, int32 output_multiplier,
int output_shift, int32 output_activation_min, int32 output_activation_max,
int16* output_data, const Dims<4>& output_dims,
- gemmlowp::GemmContext* gemm_context) {
+ uint8* shuffled_input_workspace_data, gemmlowp::GemmContext* gemm_context) {
(void)gemm_context; // only used in optimized code.
+
TFLITE_DCHECK_LE(output_activation_min, output_activation_max);
// TODO(benoitjacob): This really should be:
// const int batches = ArraySize(output_dims, 1);
@@ -622,44 +623,130 @@ inline void ExperimentalShuffledFullyConnected(
const int accum_depth = ArraySize(weights_dims, 0);
TFLITE_DCHECK(IsPackedWithoutStrides(input_dims));
TFLITE_DCHECK(IsPackedWithoutStrides(weights_dims));
- // The experimental shuffling is an optimization for matrix*vector product.
- // We aren't interested in supporting non-matrix*vector-product cases, i.e.
- // batches>1.
- TFLITE_DCHECK_EQ(batches, 1);
- // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
- // so that just reinterpreting them as int8 values is equivalent to
- // subtracting 128 from them, thus implementing for free the subtraction of
- // the zero_point value 128.
- const int8* shuffled_weights_ptr =
- reinterpret_cast<const int8*>(shuffled_weights_data);
- for (int c = 0; c < output_depth; c += 4) {
- // Internal accumulation.
- // Initialize accumulator with the bias-value.
- int32 accum[4] = {0};
- // Accumulation loop.
- for (int d = 0; d < accum_depth; d += 16) {
- for (int i = 0; i < 4; i++) {
+ TFLITE_DCHECK((accum_depth % 16) == 0);
+ TFLITE_DCHECK((output_depth % 4) == 0);
+
+ // Shuffling and xoring of input activations into the workspace buffer
+ uint8* shuffled_input_workspace_ptr = shuffled_input_workspace_data;
+ if (batches == 1) {
+ for (int i = 0; i < accum_depth; i++) {
+ shuffled_input_workspace_data[i] = input_data[i] ^ 0x80;
+ }
+ } else if (batches == 4) {
+ for (int c = 0; c < accum_depth; c += 16) {
+ for (int b = 0; b < 4; b++) {
+ const uint8* src_data_ptr = input_data + b * accum_depth + c;
for (int j = 0; j < 16; j++) {
- int8 input_val = input_data[d + j] - 128;
- int8 weights_val = *shuffled_weights_ptr++;
- accum[i] += weights_val * input_val;
+ uint8 src_val = *src_data_ptr++;
+ // Flip the sign bit, so that the kernel will only need to
+ // reinterpret these uint8 values as int8, getting for free the
+ // subtraction of the zero_point value 128.
+ uint8 dst_val = src_val ^ 0x80;
+ *shuffled_input_workspace_ptr++ = dst_val;
}
}
}
- for (int i = 0; i < 4; i++) {
- // Add bias value
- int acc = accum[i] + bias_data[c + i];
- // Down-scale the final int32 accumulator to the scale used by our
- // (16-bit, typically 3 integer bits) fixed-point format. The quantized
- // multiplier and shift here have been pre-computed offline
- // (e.g. by toco).
- acc =
- MultiplyByQuantizedMultiplier(acc, output_multiplier, -output_shift);
- // Saturate, cast to int16, and store to output array.
- acc = std::max(acc, output_activation_min);
- acc = std::min(acc, output_activation_max);
- output_data[c + i] = acc;
+ } else {
+ TFLITE_DCHECK(false);
+ return;
+ }
+
+ // Actual computation
+ if (batches == 1) {
+ int16* output_ptr = output_data;
+ // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
+ // so that just reinterpreting them as int8 values is equivalent to
+ // subtracting 128 from them, thus implementing for free the subtraction of
+ // the zero_point value 128.
+ const int8* shuffled_weights_ptr =
+ reinterpret_cast<const int8*>(shuffled_weights_data);
+ // Likewise, we preshuffled and pre-xored the input data above.
+ const int8* shuffled_input_data =
+ reinterpret_cast<const int8*>(shuffled_input_workspace_data);
+ for (int c = 0; c < output_depth; c += 4) {
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32 accum[4] = {0};
+ // Accumulation loop.
+ for (int d = 0; d < accum_depth; d += 16) {
+ for (int i = 0; i < 4; i++) {
+ for (int j = 0; j < 16; j++) {
+ int8 input_val = shuffled_input_data[d + j];
+ int8 weights_val = *shuffled_weights_ptr++;
+ accum[i] += weights_val * input_val;
+ }
+ }
+ }
+ for (int i = 0; i < 4; i++) {
+ // Add bias value
+ int acc = accum[i] + bias_data[c + i];
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, typically 3 integer bits) fixed-point format. The quantized
+ // multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
+ -output_shift);
+ // Saturate, cast to int16, and store to output array.
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_ptr[c + i] = acc;
+ }
}
+ } else if (batches == 4) {
+ int16* output_ptr = output_data;
+ // Shuffled weights have had their sign bit (0x80) pre-flipped (xor'd)
+ // so that just reinterpreting them as int8 values is equivalent to
+ // subtracting 128 from them, thus implementing for free the subtraction of
+ // the zero_point value 128.
+ const int8* shuffled_weights_ptr =
+ reinterpret_cast<const int8*>(shuffled_weights_data);
+ // Likewise, we preshuffled and pre-xored the input data above.
+ const int8* shuffled_input_data =
+ reinterpret_cast<const int8*>(shuffled_input_workspace_data);
+ for (int c = 0; c < output_depth; c += 4) {
+ const int8* shuffled_input_ptr = shuffled_input_data;
+ // Accumulation loop.
+ // Internal accumulation.
+ // Initialize accumulator with the bias-value.
+ int32 accum[4][4];
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ accum[i][b] = 0;
+ }
+ }
+ for (int d = 0; d < accum_depth; d += 16) {
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ for (int j = 0; j < 16; j++) {
+ int8 input_val = shuffled_input_ptr[16 * b + j];
+ int8 weights_val = shuffled_weights_ptr[16 * i + j];
+ accum[i][b] += weights_val * input_val;
+ }
+ }
+ }
+ shuffled_input_ptr += 64;
+ shuffled_weights_ptr += 64;
+ }
+ for (int i = 0; i < 4; i++) {
+ for (int b = 0; b < 4; b++) {
+ // Add bias value
+ int acc = accum[i][b] + bias_data[c + i];
+ // Down-scale the final int32 accumulator to the scale used by our
+ // (16-bit, typically 3 integer bits) fixed-point format. The
+ // quantized multiplier and shift here have been pre-computed offline
+ // (e.g. by toco).
+ acc = MultiplyByQuantizedMultiplier(acc, output_multiplier,
+ -output_shift);
+ // Saturate, cast to int16, and store to output array.
+ acc = std::max(acc, output_activation_min);
+ acc = std::min(acc, output_activation_max);
+ output_ptr[b * output_depth + c + i] = acc;
+ }
+ }
+ }
+ } else {
+ TFLITE_DCHECK(false);
+ return;
}
}
@@ -2993,6 +3080,9 @@ inline void Pad(const T* input_data, const Dims<4>& input_dims,
const std::vector<int>& left_paddings,
const std::vector<int>& right_paddings, T* output_data,
const Dims<4>& output_dims, const int32_t pad_value) {
+ TFLITE_DCHECK_EQ(left_paddings.size(), 4);
+ TFLITE_DCHECK_EQ(right_paddings.size(), 4);
+
const int output_batch = ArraySize(output_dims, 3);
const int output_height = ArraySize(output_dims, 2);
const int output_width = ArraySize(output_dims, 1);
diff --git a/tensorflow/contrib/lite/kernels/test_util.h b/tensorflow/contrib/lite/kernels/test_util.h
index a9064d54e7..a5f345e98a 100644
--- a/tensorflow/contrib/lite/kernels/test_util.h
+++ b/tensorflow/contrib/lite/kernels/test_util.h
@@ -88,7 +88,9 @@ struct TensorData {
class SingleOpResolver : public OpResolver {
public:
SingleOpResolver(const BuiltinOperator op, TfLiteRegistration* registration)
- : op_(op), registration_(registration) {}
+ : op_(op), registration_(registration) {
+ registration_->builtin_code = op;
+ }
TfLiteRegistration* FindOp(BuiltinOperator op) const override {
if (op == op_) {
return registration_;
diff --git a/tensorflow/contrib/lite/model.cc b/tensorflow/contrib/lite/model.cc
index 2dd6d67e07..f45f39d1e6 100644
--- a/tensorflow/contrib/lite/model.cc
+++ b/tensorflow/contrib/lite/model.cc
@@ -194,7 +194,6 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
builtin_code);
status = kTfLiteError;
} else if (builtin_code != BuiltinOperator_CUSTOM) {
- flatbuffer_op_index_to_registration_types_.push_back(builtin_code);
registration = op_resolver_.FindOp(builtin_code);
if (registration == nullptr) {
error_reporter_->Report("Didn't find op for builtin opcode '%s'\n",
@@ -208,8 +207,6 @@ TfLiteStatus InterpreterBuilder::BuildLocalIndexToRegistrationMapping() {
} else {
const char* name = opcode->custom_code()->c_str();
registration = op_resolver_.FindOp(name);
- flatbuffer_op_index_to_registration_types_.push_back(
- BuiltinOperator_CUSTOM);
if (registration == nullptr) {
error_reporter_->Report("Didn't find custom op for name '%s'\n", name);
status = kTfLiteError;
@@ -702,8 +699,7 @@ TfLiteStatus InterpreterBuilder::ParseNodes(
continue;
}
- auto op_type =
- flatbuffer_op_index_to_registration_types_[op->opcode_index()];
+ BuiltinOperator op_type = static_cast<BuiltinOperator>(reg->builtin_code);
if (op_type != BuiltinOperator_CUSTOM && op->custom_options()) {
error_reporter_->Report(
"Found builtin operator %s with custom options.\n",
diff --git a/tensorflow/contrib/lite/model.h b/tensorflow/contrib/lite/model.h
index 5a55b031a8..a7d7f3ea10 100644
--- a/tensorflow/contrib/lite/model.h
+++ b/tensorflow/contrib/lite/model.h
@@ -188,7 +188,6 @@ class InterpreterBuilder {
ErrorReporter* error_reporter_;
std::vector<TfLiteRegistration*> flatbuffer_op_index_to_registration_;
- std::vector<BuiltinOperator> flatbuffer_op_index_to_registration_types_;
const Allocation* allocation_ = nullptr;
};
diff --git a/tensorflow/contrib/lite/optional_debug_tools.cc b/tensorflow/contrib/lite/optional_debug_tools.cc
index e1366639c7..dfdd80ea8a 100644
--- a/tensorflow/contrib/lite/optional_debug_tools.cc
+++ b/tensorflow/contrib/lite/optional_debug_tools.cc
@@ -72,7 +72,7 @@ const char* AllocTypeName(TfLiteAllocationType type) {
// Prints a dump of what tensors and what nodes are in the interpreter.
void PrintInterpreterState(Interpreter* interpreter) {
- printf("Interpreter has %d tensors and %d nodes\n",
+ printf("Interpreter has %zu tensors and %zu nodes\n",
interpreter->tensors_size(), interpreter->nodes_size());
printf("Inputs:");
PrintIntVector(interpreter->inputs());
diff --git a/tensorflow/contrib/lite/profiling/profile_buffer.h b/tensorflow/contrib/lite/profiling/profile_buffer.h
index 3bfe02571b..b2f565376c 100644
--- a/tensorflow/contrib/lite/profiling/profile_buffer.h
+++ b/tensorflow/contrib/lite/profiling/profile_buffer.h
@@ -37,9 +37,9 @@ struct ProfileEvent {
// Label of the event. This usually describes the event.
const char* tag;
// Timestamp in microseconds when the event began.
- int64_t begin_timestamp_ms;
+ int64_t begin_timestamp_us;
// Timestamp in microseconds when the event ended.
- int64_t end_timestamp_ms;
+ int64_t end_timestamp_us;
// The field containing the type of event. This must be one of the event types
// in EventType.
EventType event_type;
@@ -79,8 +79,8 @@ class ProfileBuffer {
event_buffer_[index].tag = tag;
event_buffer_[index].event_type = event_type;
event_buffer_[index].event_metadata = event_metadata;
- event_buffer_[index].begin_timestamp_ms = timestamp;
- event_buffer_[index].end_timestamp_ms = 0;
+ event_buffer_[index].begin_timestamp_us = timestamp;
+ event_buffer_[index].end_timestamp_us = 0;
current_index_++;
return index;
}
@@ -103,7 +103,7 @@ class ProfileBuffer {
}
int event_index = event_handle % max_size;
- event_buffer_[event_index].end_timestamp_ms = NowMicros();
+ event_buffer_[event_index].end_timestamp_us = NowMicros();
}
// Returns the size of the buffer.
diff --git a/tensorflow/contrib/lite/profiling/profile_buffer_test.cc b/tensorflow/contrib/lite/profiling/profile_buffer_test.cc
index 0c5f0cd314..b8784cca45 100644
--- a/tensorflow/contrib/lite/profiling/profile_buffer_test.cc
+++ b/tensorflow/contrib/lite/profiling/profile_buffer_test.cc
@@ -49,13 +49,13 @@ TEST(ProfileBufferTest, AddEvent) {
auto event = GetProfileEvents(buffer)[0];
EXPECT_EQ(event->tag, "hello");
- EXPECT_GT(event->begin_timestamp_ms, 0);
+ EXPECT_GT(event->begin_timestamp_us, 0);
EXPECT_EQ(event->event_type, ProfileEvent::EventType::DEFAULT);
EXPECT_EQ(event->event_metadata, 42);
buffer.EndEvent(event_handle);
EXPECT_EQ(1, buffer.Size());
- EXPECT_GE(event->end_timestamp_ms, event->begin_timestamp_ms);
+ EXPECT_GE(event->end_timestamp_us, event->begin_timestamp_us);
}
TEST(ProfileBufferTest, OverFlow) {
diff --git a/tensorflow/contrib/lite/profiling/profiler_test.cc b/tensorflow/contrib/lite/profiling/profiler_test.cc
index 994523a8fb..7914f36a31 100644
--- a/tensorflow/contrib/lite/profiling/profiler_test.cc
+++ b/tensorflow/contrib/lite/profiling/profiler_test.cc
@@ -30,7 +30,7 @@ namespace {
void AssertDurationOfEventAroundMs(const ProfileEvent* event,
double expected_ms, double eps_ms) {
double duration_ms =
- (event->end_timestamp_ms - event->begin_timestamp_ms) / 1e3;
+ (event->end_timestamp_us - event->begin_timestamp_us) / 1e3;
EXPECT_NEAR(expected_ms, duration_ms, eps_ms);
}
diff --git a/tensorflow/contrib/lite/python/BUILD b/tensorflow/contrib/lite/python/BUILD
index 926896d609..e6dcc7aa09 100644
--- a/tensorflow/contrib/lite/python/BUILD
+++ b/tensorflow/contrib/lite/python/BUILD
@@ -39,16 +39,35 @@ py_test(
py_library(
name = "lite",
srcs = ["lite.py"],
- # data = [
- # "//tensorflow/contrib/lite/toco/python:toco_from_protos",
- # ],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
+ ":convert",
+ ":convert_saved_model",
":op_hint",
+ ],
+)
+
+py_library(
+ name = "lite_constants",
+ srcs = ["lite_constants.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/contrib/lite/toco:toco_flags_proto_py",
+ ],
+)
+
+py_library(
+ name = "convert",
+ srcs = ["convert.py"],
+ srcs_version = "PY2AND3",
+ visibility = ["//visibility:public"],
+ deps = [
+ ":lite_constants",
"//tensorflow/contrib/lite/toco:model_flags_proto_py",
"//tensorflow/contrib/lite/toco:toco_flags_proto_py",
"//tensorflow/contrib/lite/toco/python:tensorflow_wrap_toco",
+ "//tensorflow/contrib/lite/toco/python:toco_from_protos",
"//tensorflow/python:platform",
],
)
@@ -66,15 +85,15 @@ py_library(
)
py_test(
- name = "lite_test",
- srcs = ["lite_test.py"],
+ name = "convert_test",
+ srcs = ["convert_test.py"],
srcs_version = "PY2AND3",
tags = [
"no-internal-py3",
"no_oss",
],
deps = [
- ":lite",
+ ":convert",
":op_hint",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
@@ -84,13 +103,14 @@ py_test(
],
)
-py_binary(
+py_library(
name = "convert_saved_model",
srcs = ["convert_saved_model.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:public"],
deps = [
- ":lite",
+ ":convert",
+ ":lite_constants",
"//tensorflow/contrib/saved_model:saved_model_py",
"//tensorflow/python:graph_util",
"//tensorflow/python/tools:freeze_graph_lib",
@@ -130,6 +150,15 @@ py_test(
],
)
+py_binary(
+ name = "convert_saved_model_to_frozen_graph",
+ srcs = ["convert_saved_model_to_frozen_graph.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":convert_saved_model",
+ ],
+)
+
# Transitive dependencies of this target will be included in the pip package.
py_library(
name = "tf_lite_py_pip",
diff --git a/tensorflow/contrib/lite/python/convert.py b/tensorflow/contrib/lite/python/convert.py
new file mode 100644
index 0000000000..c4200c879b
--- /dev/null
+++ b/tensorflow/contrib/lite/python/convert.py
@@ -0,0 +1,187 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Converts a frozen graph into a TFLite FlatBuffer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os as _os
+import subprocess as _subprocess
+import tempfile as _tempfile
+
+from tensorflow.contrib.lite.python import lite_constants
+from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
+from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
+from tensorflow.python.framework import dtypes as _dtypes
+from tensorflow.python.platform import resource_loader as _resource_loader
+from tensorflow.python.util.lazy_loader import LazyLoader
+
+
+# Lazy load since some of the performance benchmark skylark rules
+# break dependencies.
+_toco_python = LazyLoader(
+ "tensorflow_wrap_toco", globals(),
+ "tensorflow.contrib.lite.toco.python."
+ "tensorflow_wrap_toco")
+del LazyLoader
+
+# Find the toco_from_protos binary using the resource loader if using from
+# bazel, otherwise we are in a pip where console_scripts already has
+# the toco_from_protos tool.
+if lite_constants.EXPERIMENTAL_USE_TOCO_API_DIRECTLY:
+ _toco_from_proto_bin = ""
+else:
+ _toco_from_proto_bin = _resource_loader.get_path_to_datafile(
+ "../toco/python/toco_from_protos")
+
+if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin):
+ _toco_from_proto_bin = "toco_from_protos"
+
+
+def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
+ """Convert `input_data_str` according to model and toco parameters.
+
+ Unless you know what you are doing consider using
+ the more friendly @{tf.contrib.lite.toco_convert}}.
+
+ Args:
+ model_flags_str: Serialized proto describing model properties, see
+ `toco/model_flags.proto`.
+ toco_flags_str: Serialized proto describing conversion properties, see
+ `toco/toco_flags.proto`.
+ input_data_str: Input data in serialized form (e.g. a graphdef is common)
+ Returns:
+ Converted model in serialized form (e.g. a TFLITE model is common).
+ Raises:
+ RuntimeError: When conversion fails, an exception is raised with the error
+ message embedded.
+ """
+ # TODO(aselle): When toco does not use fatal errors for failure, we can
+ # switch this on.
+ if not _toco_from_proto_bin:
+ return _toco_python.TocoConvert(
+ model_flags_str, toco_flags_str, input_data_str)
+
+ with _tempfile.NamedTemporaryFile() as fp_toco, \
+ _tempfile.NamedTemporaryFile() as fp_model, \
+ _tempfile.NamedTemporaryFile() as fp_input, \
+ _tempfile.NamedTemporaryFile() as fp_output:
+ fp_model.write(model_flags_str)
+ fp_toco.write(toco_flags_str)
+ fp_input.write(input_data_str)
+ fp_model.flush()
+ fp_toco.flush()
+ fp_input.flush()
+
+ cmd = [
+ _toco_from_proto_bin, fp_model.name, fp_toco.name, fp_input.name,
+ fp_output.name
+ ]
+ cmdline = " ".join(cmd)
+ proc = _subprocess.Popen(
+ cmdline,
+ shell=True,
+ stdout=_subprocess.PIPE,
+ stderr=_subprocess.STDOUT,
+ close_fds=True)
+ stdout, stderr = proc.communicate()
+ exitcode = proc.returncode
+ if exitcode == 0:
+ stuff = fp_output.read()
+ return stuff
+ else:
+ raise RuntimeError("TOCO failed see console for info.\n%s\n%s\n" %
+ (stdout, stderr))
+
+
+def tensor_name(x):
+ return x.name.split(":")[0]
+
+
+def toco_convert(input_data,
+ input_tensors,
+ output_tensors,
+ inference_type=lite_constants.FLOAT,
+ input_format=lite_constants.TENSORFLOW_GRAPHDEF,
+ output_format=lite_constants.TFLITE,
+ quantized_input_stats=None,
+ drop_control_dependency=True):
+ """Convert a model using TOCO from `input_format` to `output_format`.
+
+ Typically this is to convert from TensorFlow GraphDef to TFLite, in which
+ case the default `input_format` and `output_format` are sufficient.
+
+ Args:
+ input_data: Input data (i.e. often `sess.graph_def`).
+ input_tensors: List of input tensors. Type and shape are computed using
+ `foo.get_shape()` and `foo.dtype`.
+ output_tensors: List of output tensors (only .name is used from this).
+ inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
+ input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
+ output_format: Type of data to write (currently must be TFLITE or
+ GRAPHVIZ_DOT)
+ quantized_input_stats: For each member of input_tensors the mean and
+ std deviation of training data. Only needed if `inference_type` is
+ `QUANTIZED_UINT8`.
+ drop_control_dependency: Drops control dependencies silently. This is due
+ to tf lite not supporting control dependencies.
+
+ Returns:
+ The converted data. For example if tflite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ ValueError: If the input tensor type is unknown
+ RuntimeError: If TOCO fails to convert (in which case the runtime error's
+ error text will contain the TOCO error log)
+ """
+ toco = _toco_flags_pb2.TocoFlags()
+ toco.input_format = input_format
+ toco.output_format = output_format
+ toco.drop_control_dependency = drop_control_dependency
+ model = _model_flags_pb2.ModelFlags()
+ toco.inference_type = inference_type
+ for idx, input_tensor in enumerate(input_tensors):
+ if input_tensor.dtype == _dtypes.float32:
+ tflite_input_type = lite_constants.FLOAT
+ elif input_tensor.dtype == _dtypes.int32:
+ tflite_input_type = lite_constants.INT32
+ elif input_tensor.dtype == _dtypes.int64:
+ tflite_input_type = lite_constants.INT64
+ # TODO(aselle): Insert strings when they are available
+ else:
+ raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
+ input_tensor.dtype))
+
+ input_array = model.input_arrays.add()
+
+ if inference_type == lite_constants.QUANTIZED_UINT8:
+ if tflite_input_type == lite_constants.FLOAT:
+ tflite_input_type = lite_constants.QUANTIZED_UINT8
+ input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
+
+ input_array.name = tensor_name(input_tensor)
+ input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
+
+ for output_tensor in output_tensors:
+ model.output_arrays.append(tensor_name(output_tensor))
+
+ # TODO(aselle): Consider handling the case of allowing quantized
+ # inputs to be converted to float (via the toco.inference_input_type field).
+ data = toco_convert_protos(model.SerializeToString(),
+ toco.SerializeToString(),
+ input_data.SerializeToString())
+ return data
diff --git a/tensorflow/contrib/lite/python/convert_saved_model.py b/tensorflow/contrib/lite/python/convert_saved_model.py
index a2b5ef488e..a7eddf3408 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model.py
@@ -12,52 +12,43 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-r"""TensorFlow Lite flatbuffer generation from saved_models.
+"""Functions to convert SavedModel to frozen GraphDefs."""
-Example:
-
-bazel run third_party/tensorflow/contrib/lite/python:convert_saved_model -- \
- --saved_model_dir=/tmp/test_saved_model/1519865537 \
- --output_tflite=/tmp/test.lite
-
-"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.lite.python import lite
+from tensorflow.contrib.lite.python import convert
+from tensorflow.contrib.lite.python import lite_constants
+from tensorflow.contrib.lite.toco import model_flags_pb2
from tensorflow.contrib.saved_model.python.saved_model import reader
from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
from tensorflow.core.framework import types_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import graph_util as tf_graph_util
from tensorflow.python.framework import ops
-from tensorflow.python.platform import app
-from tensorflow.python.platform import flags
from tensorflow.python.platform import gfile
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import loader
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
-flags.DEFINE_string("saved_model_dir", "", "Saved model directory to convert.")
-flags.DEFINE_string("output_tflite", None, "File path to write flatbuffer.")
-flags.DEFINE_string("output_arrays", None,
- "List of output tensor names, the default value is None, "
- "which means the conversion will keep all outputs.")
-flags.DEFINE_integer("batch_size", 1,
- "If input tensor shape has None at first dimension, "
- "e.g. (None,224,224,3), replace None with batch_size.")
-flags.DEFINE_string("tag_set", tag_constants.SERVING,
- "Group of tag(s) of the MetaGraphDef in the saved_model, "
- "in string format, separated by ','. For tag-set contains "
- "multiple tags, all tags must be passed in.")
-flags.DEFINE_string("signature_key",
- signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
- "This is signature key to extract inputs, outputs.")
-
-
-def log_tensor_details(tensor_info):
+
+def _write_and_flush_file(file_path, data_str):
+ """Writes data to file path.
+
+ Args:
+ file_path: Full path of the file to store data in.
+ data_str: Data represented as a string.
+
+ Returns: None.
+ """
+ with gfile.Open(file_path, "wb") as data_file:
+ data_file.write(data_str)
+ data_file.flush()
+
+
+def _log_tensor_details(tensor_info):
"""Log tensor details: name, shape, and type."""
for key in tensor_info:
val = tensor_info[key]
@@ -73,7 +64,7 @@ def log_tensor_details(tensor_info):
dtype)
-def get_meta_graph_def(saved_model_dir, tag_set):
+def _get_meta_graph_def(saved_model_dir, tag_set):
"""Validate saved_model and extract MetaGraphDef.
Args:
@@ -103,7 +94,7 @@ def get_meta_graph_def(saved_model_dir, tag_set):
"values are '{}'. ".format(tag_set, tag_sets))
-def get_signature_def(meta_graph, signature_key):
+def _get_signature_def(meta_graph, signature_key):
"""Get the signature def from meta_graph with given signature_key.
Args:
@@ -130,11 +121,11 @@ def get_signature_def(meta_graph, signature_key):
return signature_def
-def get_inputs_outputs(signature_def):
- """Get inputs and outputs from signature def.
+def _get_inputs_outputs(signature_def):
+ """Get inputs and outputs from SignatureDef.
Args:
- signature_def: signatuer def in the meta_graph_def for conversion.
+ signature_def: SignatureDef in the meta_graph_def for conversion.
Returns:
The inputs and outputs in the graph for conversion.
@@ -142,9 +133,9 @@ def get_inputs_outputs(signature_def):
inputs_tensor_info = signature_def.inputs
outputs_tensor_info = signature_def.outputs
logging.info("input tensors info: ")
- log_tensor_details(inputs_tensor_info)
+ _log_tensor_details(inputs_tensor_info)
logging.info("output tensors info: ")
- log_tensor_details(outputs_tensor_info)
+ _log_tensor_details(outputs_tensor_info)
def gather_names(tensor_info):
return [tensor_info[key].name for key in tensor_info]
@@ -154,109 +145,277 @@ def get_inputs_outputs(signature_def):
return inputs, outputs
-def convert(saved_model_dir,
- output_tflite=None,
- output_arrays=None,
- tag_set=None,
- signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
- batch_size=1):
- """Convert a saved_model to tflite flatbuffer.
+def _get_tensors(graph, signature_def_tensor_names=None,
+ user_tensor_names=None):
+ """Gets the tensors associated with the tensor names.
+
+ Either signature_def_tensor_names or user_tensor_names should be provided. If
+ the user provides tensors, the tensors associated with the user provided
+ tensor names are provided. Otherwise, the tensors associated with the names in
+ the SignatureDef are provided.
Args:
- saved_model_dir: Saved model directory to convert.
- output_tflite: File path to write result flatbuffer.
- output_arrays: List of output tensor names, the default value is None, which
- means conversion keeps all output tensors. This is also used to filter
- tensors that are from Op currently not supported in tflite, e.g., Argmax).
- tag_set: This is the set of tags to get meta_graph_def in saved_model.
- signature_key: This is the signature key to extract inputs, outputs.
- batch_size: If input tensor shape has None at first dimension,
- e.g. (None,224,224,3), replace None with batch_size.
+ graph: GraphDef representing graph.
+ signature_def_tensor_names: Tensor names stored in either the inputs or
+ outputs of a SignatureDef. (default None)
+ user_tensor_names: Tensor names provided by the user. (default None)
Returns:
- The converted data. For example if tflite was the destination, then
- this will be a tflite flatbuffer in a bytes array.
+ List of tensors.
+
+ Raises:
+ ValueError:
+ signature_def_tensors and user_tensor_names are undefined or empty.
+ user_tensor_names are not valid.
+ """
+ tensors = []
+ if user_tensor_names:
+ # Get the list of all of the tensors with and without the tensor index.
+ all_tensor_names = [
+ tensor.name for op in graph.get_operations() for tensor in op.outputs
+ ]
+ all_tensor_names_only = [name.split(":")[0] for name in all_tensor_names]
+
+ # Sort the tensor names.
+ user_tensor_names = sorted(user_tensor_names)
+
+ # Get the tensors associated with the tensor names.
+ tensors = []
+ invalid_tensors = []
+ for name in user_tensor_names:
+ if name not in all_tensor_names_only:
+ invalid_tensors.append(name)
+ else:
+ idx = all_tensor_names_only.index(name)
+ tensors.append(graph.get_tensor_by_name(all_tensor_names[idx]))
+
+ # Throw ValueError if any user input names are not valid tensors.
+ if invalid_tensors:
+ raise ValueError("Invalid tensors '{}' were found.".format(
+ ",".join(invalid_tensors)))
+ elif signature_def_tensor_names:
+ tensors = [
+ graph.get_tensor_by_name(name)
+ for name in sorted(signature_def_tensor_names)
+ ]
+ else:
+ # Throw ValueError if signature_def_tensors and user_tensor_names are both
+ # either undefined or empty.
+ raise ValueError(
+ "Specify either signature_def_tensor_names or user_tensor_names")
+
+ return tensors
+
+
+def _freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
+ output_arrays, tag_set, signature_key, batch_size):
+ """Converts a SavedModel to a frozen graph.
+
+ Args:
+ saved_model_dir: SavedModel directory to convert.
+ input_arrays: List of input tensors to freeze graph with. Uses input arrays
+ from SignatureDef when none are provided. (default None)
+ input_shapes: Map of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" : None}).
+ (default None)
+ output_arrays: List of output tensors to freeze graph with. Uses output
+ arrays from SignatureDef when none are provided. (default None)
+ tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze. All tags in the tag set must be present. (default "serve")
+ signature_key: Key identifying SignatureDef containing inputs and outputs.
+ batch_size: Batch size for the model. Replaces the first dimension of an
+ input size array if undefined. (default 1)
+
+ Returns:
+ frozen_graph_def: Frozen GraphDef.
+ in_tensors: List of input tensors for the graph.
+ out_tensors: List of output tensors for the graph.
Raises:
- ValueError: If tag_set does not indicate any meta_graph_def in saved_model,
- or signature_key is not in relevant meta_graph_def,
- or input shape has None beyond 1st dimension, e.g., (1,None, None, 3),
- or given output_arrays are not valid causing empty outputs.
+ ValueError:
+ SavedModel doesn't contain a MetaGraphDef identified by tag_set.
+ signature_key is not in the MetaGraphDef.
+ input_shapes does not match the length of input_arrays.
+ input_shapes has a None value after the 1st dimension.
+ input_arrays or output_arrays are not valid.
+ Unable to load Session.
"""
+ # Set default values for inputs if they are set to None.
+ if signature_key is None:
+ signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
if tag_set is None:
tag_set = set([tag_constants.SERVING])
+ if batch_size is None:
+ batch_size = 1
- meta_graph = get_meta_graph_def(saved_model_dir, tag_set)
- signature_def = get_signature_def(meta_graph, signature_key)
- inputs, outputs = get_inputs_outputs(signature_def)
+ # Read SignatureDef.
+ meta_graph = _get_meta_graph_def(saved_model_dir, tag_set)
+ signature_def = _get_signature_def(meta_graph, signature_key)
+ inputs, outputs = _get_inputs_outputs(signature_def)
graph = ops.Graph()
with session.Session(graph=graph) as sess:
-
+ # TODO(nupurgarg): Throw ValueError if SavedModel has assets/ directory.
loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)
- in_tensors = [graph.get_tensor_by_name(input_) for input_ in inputs]
-
- # Users can use output_arrays to filter output tensors for conversion.
- # If output_arrays is None, we keep all output tensors. In future, we may
- # use tflite supported Op list and check whether op is custom Op to
- # automatically filter output arrays.
- # TODO(zhixianyan): Use tflite supported Op list to filter outputs.
- if output_arrays is not None:
- output_arrays = output_arrays.split(",")
- out_tensors = [
- graph.get_tensor_by_name(output)
- for output in outputs
- if output.split(":")[0] in output_arrays
- ]
- else:
- out_tensors = [graph.get_tensor_by_name(output) for output in outputs]
+ # Gets input and output tensors.
+ # TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
+ in_tensors = _get_tensors(graph, inputs, input_arrays)
+ out_tensors = _get_tensors(graph, outputs, output_arrays)
- output_names = [node.split(":")[0] for node in outputs]
+ # Gets fully defined tensor shape. An input tensor with None in the first
+ # dimension, e.g. (None, 224, 224, 3), is replaced with the batch_size.
+ # Shapes with None after the first dimension result in a ValueError.
+ # TODO(zhixianyan): Add supports for input tensor with more None in shape.
+ for tensor in in_tensors:
+ if (input_shapes and tensor.name in input_shapes and
+ input_shapes[tensor.name] is not None):
+ shape = input_shapes[tensor.name]
+ else:
+ shape = tensor.get_shape().as_list()
- if not out_tensors:
- raise ValueError(
- "No valid output tensors for '{}', possible values are '{}'".format(
- output_arrays, output_names))
+ if None in shape[1:]:
+ raise ValueError(
+ "None is only supported in the 1st dimension. Tensor '{0}' has "
+ "invalid shape '{1}'.".format(tensor.name, shape))
+ elif shape[0] is None:
+ shape[0] = batch_size
+ tensor.set_shape(shape)
+ output_names = [node.split(":")[0] for node in outputs]
frozen_graph_def = tf_graph_util.convert_variables_to_constants(
sess, graph.as_graph_def(), output_names)
- # Toco requires fully defined tensor shape, for input tensor with None in
- # their shape, e.g., (None, 224, 224, 3), we need to replace first None with
- # a given batch size. For shape with more None, e.g. (None, None, None, 3),
- # still be able to replace and convert, but require further investigation.
- # TODO(zhixianyan): Add supports for input tensor with more None in shape.
- for i in range(len(in_tensors)):
- shape = in_tensors[i].get_shape().as_list()
- if shape[0] is None:
- shape[0] = batch_size
- if None in shape[1:]:
- raise ValueError(
- "Only support None shape at 1st dim as batch_size. But tensor "
- "'{}' 's shape '{}' has None at other dimension. ".format(
- inputs[i], shape))
- in_tensors[i].set_shape(shape)
+ return frozen_graph_def, in_tensors, out_tensors
+ raise ValueError("Unable to load Session.")
- result = lite.toco_convert(frozen_graph_def, in_tensors, out_tensors)
- if output_tflite is not None:
- with gfile.Open(output_tflite, "wb") as f:
- f.write(result)
- logging.info("Successfully converted to: %s", output_tflite)
+def saved_model_to_frozen_graphdef(
+ saved_model_dir,
+ output_file_model,
+ output_file_flags,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None,
+ tag_set=None,
+ signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
+ batch_size=1):
+ """Converts a SavedModel to a frozen graph. Writes graph to tmp directory.
- return result
+ Stores frozen graph and command line flags in the tmp directory.
+ Args:
+ saved_model_dir: SavedModel directory to convert.
+ output_file_model: Full file path to save frozen graph.
+ output_file_flags: Full file path to save ModelFlags.
+ input_arrays: List of input tensors to freeze graph with. Uses input arrays
+ from SignatureDef when none are provided. (default None)
+ input_shapes: Map of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" : None}).
+ (default None)
+ output_arrays: List of output tensors to freeze graph with. Uses output
+ arrays from SignatureDef when none are provided. (default None)
+ tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze. All tags in the tag set must be present. (default "serve")
+ signature_key: Key identifying SignatureDef containing inputs and outputs.
+ batch_size: Batch size for the model. Replaces the first dimension of an
+ input size array if undefined. (default 1)
+
+ Returns: None.
-def main(_):
- convert(
- saved_model_dir=flags.FLAGS.saved_model_dir,
- output_tflite=flags.FLAGS.output_tflite,
- output_arrays=flags.FLAGS.output_arrays,
- batch_size=flags.FLAGS.batch_size,
- tag_set=set(flags.FLAGS.tag_set.split(",")),
- signature_key=flags.FLAGS.signature_key)
+ Raises:
+ ValueError: Unable to convert to frozen graph.
+ """
+ frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model(
+ saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set,
+ signature_key, batch_size)
+
+ # Initialize model flags.
+ model = model_flags_pb2.ModelFlags()
+
+ for input_tensor in in_tensors:
+ input_array = model.input_arrays.add()
+ input_array.name = convert.tensor_name(input_tensor)
+ input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
+
+ for output_tensor in out_tensors:
+ model.output_arrays.append(convert.tensor_name(output_tensor))
+
+ # Write model and ModelFlags to file. ModelFlags contain input array and
+ # output array information that is parsed from the SignatureDef and used for
+ # analysis by TOCO.
+ _write_and_flush_file(output_file_model, frozen_graph_def.SerializeToString())
+ _write_and_flush_file(output_file_flags, model.SerializeToString())
+
+
+def tflite_from_saved_model(
+ saved_model_dir,
+ output_file=None,
+ input_arrays=None,
+ input_shapes=None,
+ output_arrays=None,
+ tag_set=None,
+ signature_key=signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY,
+ batch_size=1,
+ inference_type=lite_constants.FLOAT,
+ input_format=lite_constants.TENSORFLOW_GRAPHDEF,
+ output_format=lite_constants.TFLITE,
+ quantized_input_stats=None,
+ drop_control_dependency=True):
+ """Converts a SavedModel to TFLite FlatBuffer.
+ Args:
+ saved_model_dir: SavedModel directory to convert.
+ output_file: File path to write result TFLite FlatBuffer.
+ input_arrays: List of input tensors to freeze graph with. Uses input arrays
+ from SignatureDef when none are provided. (default None)
+ input_shapes: Map of strings representing input tensor names to list of
+ integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
+ Automatically determined when input shapes is None (e.g., {"foo" : None}).
+ (default None)
+ output_arrays: List of output tensors to freeze graph with. Uses output
+ arrays from SignatureDef when none are provided. (default None)
+ tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
+ analyze. All tags in the tag set must be present. (default "serve")
+ signature_key: Key identifying SignatureDef containing inputs and outputs.
+ batch_size: Batch size for the model. Replaces the first dimension of an
+ input size array if undefined. (default 1)
+ inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
+ input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
+ output_format: Type of data to write (currently must be TFLITE or
+ GRAPHVIZ_DOT)
+ quantized_input_stats: For each member of input_tensors the mean and
+ std deviation of training data. Only needed if `inference_type` is
+ `QUANTIZED_UINT8`.
+ drop_control_dependency: Drops control dependencies silently. This is due
+ to tf lite not supporting control dependencies.
-if __name__ == "__main__":
- app.run(main)
+ Returns:
+ The converted data. For example if tflite was the destination, then
+ this will be a tflite flatbuffer in a bytes array.
+
+ Raises:
+ ValueError: Unable to convert to frozen graph.
+ """
+ frozen_graph_def, in_tensors, out_tensors = _freeze_saved_model(
+ saved_model_dir, input_arrays, input_shapes, output_arrays, tag_set,
+ signature_key, batch_size)
+
+ result = convert.toco_convert(
+ input_data=frozen_graph_def,
+ input_tensors=in_tensors,
+ output_tensors=out_tensors,
+ inference_type=inference_type,
+ input_format=input_format,
+ output_format=output_format,
+ quantized_input_stats=quantized_input_stats,
+ drop_control_dependency=drop_control_dependency)
+
+ if output_file is not None:
+ with gfile.Open(output_file, "wb") as f:
+ f.write(result)
+ logging.info("Successfully converted to: %s", output_file)
+
+ return result
diff --git a/tensorflow/contrib/lite/python/convert_saved_model_test.py b/tensorflow/contrib/lite/python/convert_saved_model_test.py
index 734e42d619..db95fc8ad7 100644
--- a/tensorflow/contrib/lite/python/convert_saved_model_test.py
+++ b/tensorflow/contrib/lite/python/convert_saved_model_test.py
@@ -12,11 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
-"""TF Lite SavedModel Conversion test cases.
-
- - test on generated saved_models from simple graphs (sanity check)
- - test mnist savedmodel generated on-the-fly
+"""TFLite SavedModel conversion test cases.
+ - Tests converting simple SavedModel graph to TFLite FlatBuffer.
+ - Tests converting simple SavedModel graph to frozen graph.
+ - Tests converting MNIST SavedModel to TFLite FlatBuffer.
"""
from __future__ import absolute_import
@@ -25,6 +25,7 @@ from __future__ import print_function
import os
from tensorflow.contrib.lite.python import convert_saved_model
+from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
from tensorflow.python import keras
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator_lib as estimator
@@ -37,6 +38,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import random_ops
from tensorflow.python.ops.losses import losses
+from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.saved_model import saved_model
from tensorflow.python.training import training as train
@@ -45,7 +47,7 @@ from tensorflow.python.training import training as train
class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase):
def _createSimpleSavedModel(self, shape):
- """Create a simple savedmodel on the fly."""
+ """Create a simple SavedModel on the fly."""
saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel")
with session.Session() as sess:
in_tensor = array_ops.placeholder(shape=shape, dtype=dtypes.float32)
@@ -56,44 +58,78 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase):
return saved_model_dir
def testSimpleSavedModel(self):
- """Test a simple savedmodel created on the fly."""
- # Create a simple savedmodel
+ """Test a simple SavedModel created on the fly."""
+ # Create a simple SavedModel
saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
# Convert to tflite
- result = convert_saved_model.convert(saved_model_dir=saved_model_dir)
+ result = convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir)
self.assertTrue(result)
def testSimpleSavedModelWithNoneBatchSizeInShape(self):
- """Test a simple savedmodel, with None in input tensor's shape."""
+ """Test a simple SavedModel, with None in input tensor's shape."""
saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3])
- result = convert_saved_model.convert(saved_model_dir=saved_model_dir)
+ result = convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir)
self.assertTrue(result)
def testSimpleSavedModelWithMoreNoneInShape(self):
- """Test a simple savedmodel, fail as more None in input shape."""
+ """Test a simple SavedModel, fail as more None in input shape."""
saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, None, 3])
# Convert to tflite: this should raise ValueError, as 3rd dim is None.
with self.assertRaises(ValueError):
- convert_saved_model.convert(saved_model_dir=saved_model_dir)
+ convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir)
def testSimpleSavedModelWithWrongSignatureKey(self):
- """Test a simple savedmodel, fail as given signature is invalid."""
+ """Test a simple SavedModel, fail as given signature is invalid."""
saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
# Convert to tflite: this should raise ValueError, as
# signature_key does not exit in the saved_model.
with self.assertRaises(ValueError):
- convert_saved_model.convert(
+ convert_saved_model.tflite_from_saved_model(
saved_model_dir=saved_model_dir, signature_key="wrong-key")
def testSimpleSavedModelWithWrongOutputArray(self):
- """Test a simple savedmodel, fail as given output_arrays is invalid."""
- # Create a simple savedmodel
+ """Test a simple SavedModel, fail as given output_arrays is invalid."""
+ # Create a simple SavedModel
saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
# Convert to tflite: this should raise ValueError, as
# output_arrays is not valid for the saved_model.
with self.assertRaises(ValueError):
- convert_saved_model.convert(
- saved_model_dir=saved_model_dir, output_arrays="wrong-output")
+ convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir, output_arrays=["wrong-output"])
+
+ def testSimpleSavedModelWithWrongInputArrays(self):
+ """Test a simple SavedModel, fail as given input_arrays is invalid."""
+ saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
+ # Checks invalid input_arrays.
+ with self.assertRaises(ValueError):
+ convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir, input_arrays=["wrong-input"])
+ # Checks valid and invalid input_arrays.
+ with self.assertRaises(ValueError):
+ convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir,
+ input_arrays=["Placeholder", "wrong-input"])
+
+ def testSimpleSavedModelWithCorrectArrays(self):
+ """Test a simple SavedModel, with correct input_arrays and output_arrays."""
+ saved_model_dir = self._createSimpleSavedModel(shape=[None, 16, 16, 3])
+ result = convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir,
+ input_arrays=["Placeholder"],
+ output_arrays=["add"])
+ self.assertTrue(result)
+
+ def testSimpleSavedModelWithCorrectInputArrays(self):
+ """Test a simple SavedModel, with correct input_arrays and input_shapes."""
+ saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
+ result = convert_saved_model.tflite_from_saved_model(
+ saved_model_dir=saved_model_dir,
+ input_arrays=["Placeholder"],
+ input_shapes={"Placeholder": [1, 16, 16, 3]})
+ self.assertTrue(result)
def testMultipleMetaGraphDef(self):
"""Test saved model with multiple MetaGraphDef."""
@@ -119,20 +155,103 @@ class ConvertSavedModelTestBasicGraph(test_util.TensorFlowTestCase):
sess,
tags=[saved_model.tag_constants.SERVING, "additional_test_tag"],
signature_def_map=signature_def_map)
+
# MetaGraphDef 2
builder.add_meta_graph(tags=["tflite"])
builder.save(True)
# Convert to tflite
- convert_saved_model.convert(
+ convert_saved_model.tflite_from_saved_model(
saved_model_dir=saved_model_dir,
tag_set=set([saved_model.tag_constants.SERVING, "additional_test_tag"]))
+class ConvertSavedModelTestBasicGraphToText(test_util.TensorFlowTestCase):
+
+ def _createSimpleSavedModel(self, shape):
+ """Create a simple SavedModel."""
+ saved_model_dir = os.path.join(self.get_temp_dir(), "simple_savedmodel")
+ with session.Session() as sess:
+ in_tensor_1 = array_ops.placeholder(
+ shape=shape, dtype=dtypes.float32, name="inputB")
+ in_tensor_2 = array_ops.placeholder(
+ shape=shape, dtype=dtypes.float32, name="inputA")
+ out_tensor = in_tensor_1 + in_tensor_2
+ inputs = {"x": in_tensor_1, "y": in_tensor_2}
+ outputs = {"z": out_tensor}
+ saved_model.simple_save(sess, saved_model_dir, inputs, outputs)
+ return saved_model_dir
+
+ def _getInputArrayNames(self, model_proto):
+ return [data.name for data in model_proto.input_arrays]
+
+ def _getInputArrayShapes(self, model_proto):
+ return [
+ [dim for dim in data.shape.dims] for data in model_proto.input_arrays
+ ]
+
+ def _get_model_flags_proto_from_file(self, filename):
+ proto = _model_flags_pb2.ModelFlags()
+ with gfile.Open(filename, "rb") as output_file:
+ proto.ParseFromString(output_file.read())
+ output_file.close()
+ return proto
+
+ def testSimpleSavedModel(self):
+ """Test a simple SavedModel."""
+ saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
+ output_file_model = os.path.join(self.get_temp_dir(), "model.pb")
+ output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt")
+
+ convert_saved_model.saved_model_to_frozen_graphdef(
+ saved_model_dir=saved_model_dir,
+ output_file_model=output_file_model,
+ output_file_flags=output_file_flags,
+ input_arrays=["inputB", "inputA"])
+
+ proto = self._get_model_flags_proto_from_file(output_file_flags)
+ self.assertEqual(proto.output_arrays, ["add"])
+ self.assertEqual(self._getInputArrayNames(proto), ["inputA", "inputB"])
+ self.assertEqual(
+ self._getInputArrayShapes(proto), [[1, 16, 16, 3], [1, 16, 16, 3]])
+
+ def testSimpleSavedModelWithDifferentInputNames(self):
+ """Test a simple SavedModel."""
+ saved_model_dir = self._createSimpleSavedModel(shape=[1, 16, 16, 3])
+ output_file_model = os.path.join(self.get_temp_dir(), "model.pb")
+ output_file_flags = os.path.join(self.get_temp_dir(), "model.pbtxt")
+
+ # Check case where input shape is given.
+ convert_saved_model.saved_model_to_frozen_graphdef(
+ saved_model_dir=saved_model_dir,
+ output_file_model=output_file_model,
+ output_file_flags=output_file_flags,
+ input_arrays=["inputA"],
+ input_shapes={"inputA": [1, 16, 16, 3]})
+
+ proto = self._get_model_flags_proto_from_file(output_file_flags)
+ self.assertEqual(proto.output_arrays, ["add"])
+ self.assertEqual(self._getInputArrayNames(proto), ["inputA"])
+ self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]])
+
+ # Check case where input shape is None.
+ convert_saved_model.saved_model_to_frozen_graphdef(
+ saved_model_dir=saved_model_dir,
+ output_file_model=output_file_model,
+ output_file_flags=output_file_flags,
+ input_arrays=["inputA"],
+ input_shapes={"inputA": None})
+
+ proto = self._get_model_flags_proto_from_file(output_file_flags)
+ self.assertEqual(proto.output_arrays, ["add"])
+ self.assertEqual(self._getInputArrayNames(proto), ["inputA"])
+ self.assertEqual(self._getInputArrayShapes(proto), [[1, 16, 16, 3]])
+
+
class Model(keras.Model):
"""Model to recognize digits in the MNIST dataset.
- Train and export savedmodel, used for testOnflyTrainMnistSavedModel
+ Train and export SavedModel, used for testOnflyTrainMnistSavedModel
Network structure is equivalent to:
https://github.com/tensorflow/tensorflow/blob/r1.5/tensorflow/examples/tutorials/mnist/mnist_deep.py
@@ -238,7 +357,7 @@ def dummy_input_fn():
class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase):
def testTrainedMnistSavedModel(self):
- """Test mnist savedmodel, trained with dummy data and small steps."""
+ """Test mnist SavedModel, trained with dummy data and small steps."""
# Build classifier
classifier = estimator.Estimator(
model_fn=model_fn,
@@ -253,21 +372,20 @@ class ConvertSavedModelTestTrainGraph(test_util.TensorFlowTestCase):
"image": image,
})
- # Export savedmodel
+ # Export SavedModel
saved_model_dir = os.path.join(self.get_temp_dir(), "mnist_savedmodel")
classifier.export_savedmodel(saved_model_dir, pred_input_fn)
# Convert to tflite and test output
saved_model_name = os.listdir(saved_model_dir)[0]
saved_model_final_dir = os.path.join(saved_model_dir, saved_model_name)
- output_tflite = os.path.join(saved_model_dir,
- saved_model_final_dir + ".lite")
+ output_file = os.path.join(saved_model_dir, saved_model_final_dir + ".lite")
# TODO(zhixianyan): no need to limit output_arrays to `Softmax'
# once b/74205001 fixed and argmax implemented in tflite.
- result = convert_saved_model.convert(
+ result = convert_saved_model.tflite_from_saved_model(
saved_model_dir=saved_model_final_dir,
- output_arrays="Softmax",
- output_tflite=output_tflite)
+ output_arrays=["Softmax"],
+ output_file=output_file)
self.assertTrue(result)
diff --git a/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py b/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py
new file mode 100644
index 0000000000..4d9782f4a6
--- /dev/null
+++ b/tensorflow/contrib/lite/python/convert_saved_model_to_frozen_graph.py
@@ -0,0 +1,106 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Python console command for generating frozen models from SavedModels.
+
+This exists to add SavedModel compatibility to TOCO.
+"""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import argparse
+import sys
+from tensorflow.contrib.lite.python.convert_saved_model import saved_model_to_frozen_graphdef
+from tensorflow.python.platform import app
+
+FLAGS = None
+
+
+def execute(unused_args):
+ """Calls function to convert the SavedModel to a frozen graph."""
+ # Error handling.
+ if FLAGS.input_shapes and not FLAGS.input_arrays:
+ raise ValueError("Input shapes requires input arrays to be specified.")
+
+ # Calls saved_model_to_frozen_graphdef function to generate frozen graph.
+ input_arrays = (FLAGS.input_arrays.split(",") if FLAGS.input_arrays else None)
+ input_shapes = None
+ if FLAGS.input_shapes:
+ input_shapes = {
+ input_arrays[idx]: shape.split(",")
+ for idx, shape in enumerate(FLAGS.input_shapes.split(":"))
+ }
+ output_arrays = (
+ FLAGS.output_arrays.split(",") if FLAGS.output_arrays else None)
+ tag_set = set(FLAGS.tag_set.split(",")) if FLAGS.tag_set else None
+
+ saved_model_to_frozen_graphdef(
+ saved_model_dir=FLAGS.saved_model_directory,
+ output_file_model=FLAGS.output_file_model,
+ output_file_flags=FLAGS.output_file_flags,
+ input_arrays=input_arrays,
+ input_shapes=input_shapes,
+ output_arrays=output_arrays,
+ tag_set=tag_set,
+ signature_key=FLAGS.signature_key,
+ batch_size=FLAGS.batch_size)
+
+
+def main():
+ global FLAGS
+ # Parses flags.
+ parser = argparse.ArgumentParser(
+ description="Invoke SavedModel to frozen model converter.")
+ parser.add_argument(
+ "saved_model_directory",
+ type=str,
+ help="Full path to directory containing the SavedModel.")
+ parser.add_argument(
+ "output_file_model",
+ type=str,
+ help="Full file path to save frozen graph.")
+ parser.add_argument(
+ "output_file_flags", type=str, help="Full file path to save ModelFlags.")
+ parser.add_argument(
+ "--input_arrays",
+ type=str,
+ help="Name of the input arrays, comma-separated.")
+ parser.add_argument(
+ "--input_shapes",
+ type=str,
+ help="Shapes corresponding to --input_arrays, colon-separated.")
+ parser.add_argument(
+ "--output_arrays",
+ type=str,
+ help="Name of the output arrays, comma-separated.")
+ parser.add_argument(
+ "--tag_set", type=str, help="Name of output arrays, comma-separated.")
+ parser.add_argument(
+ "--signature_key",
+ type=str,
+ help="Key identifying SignatureDef containing inputs and outputs.")
+ parser.add_argument(
+ "--batch_size",
+ type=int,
+ help="Batch size for the model. Replaces the first dimension of an "
+ "input size array if undefined.")
+
+ FLAGS, unparsed = parser.parse_known_args()
+
+ app.run(main=execute, argv=[sys.argv[0]] + unparsed)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tensorflow/contrib/lite/python/lite_test.py b/tensorflow/contrib/lite/python/convert_test.py
index b8b4510188..dc21a9b669 100644
--- a/tensorflow/contrib/lite/python/lite_test.py
+++ b/tensorflow/contrib/lite/python/convert_test.py
@@ -17,8 +17,9 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.contrib.lite.python import lite
-from tensorflow.contrib.lite.python.op_hint import _tensor_name_base as _tensor_name_base
+from tensorflow.contrib.lite.python import convert
+from tensorflow.contrib.lite.python import lite_constants
+from tensorflow.contrib.lite.python import op_hint
from tensorflow.python.client import session
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import test_util
@@ -29,7 +30,7 @@ from tensorflow.python.ops import math_ops
from tensorflow.python.platform import test
-class LiteTest(test_util.TensorFlowTestCase):
+class ConvertTest(test_util.TensorFlowTestCase):
def testBasic(self):
in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
@@ -37,13 +38,13 @@ class LiteTest(test_util.TensorFlowTestCase):
out_tensor = in_tensor + in_tensor
sess = session.Session()
# Try running on valid graph
- result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor])
+ result = convert.toco_convert(sess.graph_def, [in_tensor], [out_tensor])
self.assertTrue(result)
# TODO(aselle): remove tests that fail (we must get TOCO to not fatal
# all the time).
# Try running on identity graph (known fail)
# with self.assertRaisesRegexp(RuntimeError, "!model->operators.empty()"):
- # result = lite.toco_convert(sess.graph_def, [in_tensor], [in_tensor])
+ # result = convert.toco_convert(sess.graph_def, [in_tensor], [in_tensor])
def testQuantization(self):
in_tensor = array_ops.placeholder(shape=[1, 16, 16, 3],
@@ -51,13 +52,14 @@ class LiteTest(test_util.TensorFlowTestCase):
out_tensor = array_ops.fake_quant_with_min_max_args(in_tensor + in_tensor,
min=0., max=1.)
sess = session.Session()
- result = lite.toco_convert(sess.graph_def, [in_tensor], [out_tensor],
- inference_type=lite.QUANTIZED_UINT8,
- quantized_input_stats=[(0., 1.)])
+ result = convert.toco_convert(
+ sess.graph_def, [in_tensor], [out_tensor],
+ inference_type=lite_constants.QUANTIZED_UINT8,
+ quantized_input_stats=[(0., 1.)])
self.assertTrue(result)
-class LiteTestOpHint(test_util.TensorFlowTestCase):
+class ConvertTestOpHint(test_util.TensorFlowTestCase):
"""Test the hint to stub functionality."""
def _getGraphOpTypes(self, graphdef, output_nodes):
@@ -99,7 +101,7 @@ class LiteTestOpHint(test_util.TensorFlowTestCase):
swish_scale = array_ops.constant(1.0)
def _swish(input_tensor, scale):
- custom = lite.OpHint("cool_activation")
+ custom = op_hint.OpHint("cool_activation")
input_tensor, scale = custom.add_inputs(input_tensor, scale)
output = math_ops.sigmoid(input_tensor) * input_tensor * scale
output, = custom.add_outputs(output)
@@ -111,11 +113,12 @@ class LiteTestOpHint(test_util.TensorFlowTestCase):
# and 1 final output).
self.assertEqual(self._countIdentities(sess.graph_def.node), 4)
- stubbed_graphdef = lite.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
self.assertCountEqual(
self._getGraphOpTypes(
- stubbed_graphdef, output_nodes=[_tensor_name_base(output)]),
+ stubbed_graphdef,
+ output_nodes=[op_hint._tensor_name_base(output)]),
["cool_activation", "Const", "Identity"])
def testScaleAndBiasAndIdentity(self):
@@ -125,7 +128,7 @@ class LiteTestOpHint(test_util.TensorFlowTestCase):
b = array_ops.constant([4., 5.])
def _scaled_and_bias_and_identity(a, x, b):
- custom = lite.OpHint("scale_and_bias_and_identity")
+ custom = op_hint.OpHint("scale_and_bias_and_identity")
a, x, b = custom.add_inputs(a, x, b)
return custom.add_outputs(a * x + b, x)
output = array_ops.identity(_scaled_and_bias_and_identity(a, x, b),
@@ -136,11 +139,12 @@ class LiteTestOpHint(test_util.TensorFlowTestCase):
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 6)
- stubbed_graphdef = lite.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
self.assertCountEqual(
self._getGraphOpTypes(
- stubbed_graphdef, output_nodes=[_tensor_name_base(output)]),
+ stubbed_graphdef,
+ output_nodes=[op_hint._tensor_name_base(output)]),
["scale_and_bias_and_identity", "Const", "Identity", "Pack"])
def testTwoFunctions(self):
@@ -148,7 +152,7 @@ class LiteTestOpHint(test_util.TensorFlowTestCase):
a = array_ops.constant([1.])
b = array_ops.constant([1.])
def _double_values(x):
- custom = lite.OpHint("add_test")
+ custom = op_hint.OpHint("add_test")
x = custom.add_inputs(x)
output = math_ops.multiply(x, x)
output, = custom.add_outputs(output)
@@ -160,10 +164,11 @@ class LiteTestOpHint(test_util.TensorFlowTestCase):
# make sure one identity for each input (2) and output (2) => 2 + 2
# +1 for the final output
self.assertEqual(self._countIdentities(sess.graph_def.node), 5)
- stubbed_graphdef = lite.convert_op_hints_to_stubs(sess)
+ stubbed_graphdef = op_hint.convert_op_hints_to_stubs(sess)
self.assertCountEqual(
self._getGraphOpTypes(
- stubbed_graphdef, output_nodes=[_tensor_name_base(output)]),
+ stubbed_graphdef,
+ output_nodes=[op_hint._tensor_name_base(output)]),
["add_test", "Const", "Identity", "Add"])
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index cf50f9d4d6..4ea40201f7 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -18,6 +18,7 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice.
@@toco_convert
@@toco_convert_protos
+@@tflite_from_saved_model
@@OpHint
@@convert_op_hints_to_stubs
@@ -25,208 +26,11 @@ EXPERIMENTAL: APIs here are unstable and likely to change without notice.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import os as _os
-import subprocess as _subprocess
-import tempfile as _tempfile
# pylint: disable=unused-import
+from tensorflow.contrib.lite.python.convert import toco_convert
+from tensorflow.contrib.lite.python.convert import toco_convert_protos
+from tensorflow.contrib.lite.python.convert_saved_model import tflite_from_saved_model
from tensorflow.contrib.lite.python.op_hint import convert_op_hints_to_stubs
from tensorflow.contrib.lite.python.op_hint import OpHint
# pylint: enable=unused-import
-from tensorflow.contrib.lite.toco import model_flags_pb2 as _model_flags_pb2
-from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
-from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2
-from tensorflow.python.framework import dtypes as _dtypes
-from tensorflow.python.platform import resource_loader as _resource_loader
-from tensorflow.python.util.all_util import remove_undocumented
-from tensorflow.python.util.lazy_loader import LazyLoader
-
-# Lazy load since some of the performance benchmark skylark rules
-# break dependencies.
-_toco_python = LazyLoader(
- "tensorflow_wrap_toco", globals(),
- "tensorflow.contrib.lite.toco.python."
- "tensorflow_wrap_toco")
-del LazyLoader
-
-# Enum types from the protobuf promoted to the API
-FLOAT = _types_pb2.FLOAT
-INT32 = _types_pb2.INT32
-INT64 = _types_pb2.INT64
-STRING = _types_pb2.STRING
-QUANTIZED_UINT8 = _types_pb2.QUANTIZED_UINT8
-TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF
-TFLITE = _toco_flags_pb2.TFLITE
-GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT
-
-# Currently the default mode of operation is to shell to another python process
-# to protect against crashes. However, it breaks some dependent targets because
-# it forces us to depend on an external py_binary. The experimental API doesn't
-# have that drawback.
-EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False
-
-# Find the toco_from_protos binary using the resource loader if using from
-# bazel, otherwise we are in a pip where console_scripts already has
-# the toco_from_protos tool.
-if EXPERIMENTAL_USE_TOCO_API_DIRECTLY:
- _toco_from_proto_bin = ""
-else:
- _toco_from_proto_bin = _resource_loader.get_path_to_datafile(
- "../toco/python/toco_from_protos")
-
-if _toco_from_proto_bin and not _os.path.exists(_toco_from_proto_bin):
- _toco_from_proto_bin = "toco_from_protos"
-
-
-def toco_convert_protos(model_flags_str, toco_flags_str, input_data_str):
- """Convert `input_data_str` according to model and toco parameters.
-
- Unless you know what you are doing consider using
- the more friendly @{tf.contrib.lite.toco_convert}}.
-
- Args:
- model_flags_str: Serialized proto describing model properties, see
- `toco/model_flags.proto`.
- toco_flags_str: Serialized proto describing conversion properties, see
- `toco/toco_flags.proto`.
- input_data_str: Input data in serialized form (e.g. a graphdef is common)
- Returns:
- Converted model in serialized form (e.g. a TFLITE model is common).
- Raises:
- RuntimeError: When conversion fails, an exception is raised with the error
- message embedded.
- """
- # TODO(aselle): When toco does not use fatal errors for failure, we can
- # switch this on.
- if not _toco_from_proto_bin:
- return _toco_python.TocoConvert(
- model_flags_str, toco_flags_str, input_data_str)
-
- with _tempfile.NamedTemporaryFile() as fp_toco, \
- _tempfile.NamedTemporaryFile() as fp_model, \
- _tempfile.NamedTemporaryFile() as fp_input, \
- _tempfile.NamedTemporaryFile() as fp_output:
- fp_model.write(model_flags_str)
- fp_toco.write(toco_flags_str)
- fp_input.write(input_data_str)
- fp_model.flush()
- fp_toco.flush()
- fp_input.flush()
-
- cmd = [
- _toco_from_proto_bin, fp_model.name, fp_toco.name, fp_input.name,
- fp_output.name
- ]
- cmdline = " ".join(cmd)
- proc = _subprocess.Popen(
- cmdline,
- shell=True,
- stdout=_subprocess.PIPE,
- stderr=_subprocess.STDOUT,
- close_fds=True)
- stdout, stderr = proc.communicate()
- exitcode = proc.returncode
- if exitcode == 0:
- stuff = fp_output.read()
- return stuff
- else:
- raise RuntimeError("TOCO failed see console for info.\n%s\n%s\n" %
- (stdout, stderr))
-
-
-def _tensor_name(x):
- return x.name.split(":")[0]
-
-
-def toco_convert(input_data,
- input_tensors,
- output_tensors,
- inference_type=FLOAT,
- input_format=TENSORFLOW_GRAPHDEF,
- output_format=TFLITE,
- quantized_input_stats=None,
- drop_control_dependency=True,
- allow_custom_ops=None):
- """Convert a model using TOCO from `input_format` to `output_format`.
-
- Typically this is to convert from TensorFlow GraphDef to TFLite, in which
- case the default `input_format` and `output_format` are sufficient.
-
- Args:
- input_data: Input data (i.e. often `sess.graph_def`).
- input_tensors: List of input tensors. Type and shape are computed using
- `foo.get_shape()` and `foo.dtype`.
- output_tensors: List of output tensors (only .name is used from this).
- inference_type: Currently must be `{FLOAT, QUANTIZED_UINT8}`.
- input_format: Type of data to read (currently must be TENSORFLOW_GRAPHDEF).
- output_format: Type of data to write (currently must be TFLITE or
- GRAPHVIZ_DOT)
- quantized_input_stats: For each member of input_tensors the mean and
- std deviation of training data. Only needed if `inference_type` is
- `QUANTIZED_UINT8`.
- drop_control_dependency: Drops control dependencies silently. This is due
- to tf lite not supporting control dependencies.
-
- Returns:
- The converted data. For example if tflite was the destination, then
- this will be a tflite flatbuffer in a bytes array.
-
- Raises:
- ValueError: If the input tensor type is unknown
- RuntimeError: If TOCO fails to convert (in which case the runtime error's
- error text will contain the TOCO error log)
- """
- toco = _toco_flags_pb2.TocoFlags()
- toco.input_format = input_format
- toco.output_format = output_format
- toco.inference_type = inference_type
- toco.drop_control_dependency = drop_control_dependency
- if allow_custom_ops is not None:
- toco.allow_custom_ops = allow_custom_ops
-
- model = _model_flags_pb2.ModelFlags()
- for idx, input_tensor in enumerate(input_tensors):
- if input_tensor.dtype == _dtypes.float32:
- tflite_input_type = FLOAT
- elif input_tensor.dtype == _dtypes.int32:
- tflite_input_type = INT32
- elif input_tensor.dtype == _dtypes.int64:
- tflite_input_type = INT64
- # TODO(aselle): Insert strings when they are available
- else:
- raise ValueError("Tensors %s not known type %r" % (input_tensor.name,
- input_tensor.dtype))
-
- input_array = model.input_arrays.add()
-
- if inference_type == QUANTIZED_UINT8:
- if tflite_input_type == FLOAT:
- tflite_input_type = QUANTIZED_UINT8
- input_array.mean_value, input_array.std_value = quantized_input_stats[idx]
-
- input_array.name = _tensor_name(input_tensor)
- input_array.shape.dims.extend(map(int, input_tensor.get_shape()))
-
- for output_tensor in output_tensors:
- model.output_arrays.append(_tensor_name(output_tensor))
-
- # TODO(aselle): Consider handling the case of allowing quantized
- # inputs to be converted to float (via the toco.inference_input_type field).
- data = toco_convert_protos(model.SerializeToString(),
- toco.SerializeToString(),
- input_data.SerializeToString())
- return data
-
-
-_allowed_symbols = [
- "FLOAT",
- "INT32",
- "INT64",
- "STRING",
- "QUANTIZED_UINT8",
- "TENSORFLOW_GRAPHDEF",
- "TFLITE",
- "GRAPHVIZ_DOT",
- "EXPERIMENTAL_USE_TOCO_API_DIRECTLY",
-]
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/lite/python/lite_constants.py b/tensorflow/contrib/lite/python/lite_constants.py
new file mode 100644
index 0000000000..195d7a732f
--- /dev/null
+++ b/tensorflow/contrib/lite/python/lite_constants.py
@@ -0,0 +1,53 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Constants for TFLite."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.lite.toco import toco_flags_pb2 as _toco_flags_pb2
+from tensorflow.contrib.lite.toco import types_pb2 as _types_pb2
+from tensorflow.python.util.all_util import remove_undocumented
+
+# Enum types from the protobuf promoted to the API
+FLOAT = _types_pb2.FLOAT
+INT32 = _types_pb2.INT32
+INT64 = _types_pb2.INT64
+STRING = _types_pb2.STRING
+QUANTIZED_UINT8 = _types_pb2.QUANTIZED_UINT8
+TENSORFLOW_GRAPHDEF = _toco_flags_pb2.TENSORFLOW_GRAPHDEF
+TFLITE = _toco_flags_pb2.TFLITE
+GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT
+
+# Currently the default mode of operation is to shell to another python process
+# to protect against crashes. However, it breaks some dependent targets because
+# it forces us to depend on an external py_binary. The experimental API doesn't
+# have that drawback.
+EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False
+
+
+_allowed_symbols = [
+ "FLOAT",
+ "INT32",
+ "INT64",
+ "STRING",
+ "QUANTIZED_UINT8",
+ "TENSORFLOW_GRAPHDEF",
+ "TFLITE",
+ "GRAPHVIZ_DOT",
+ "EXPERIMENTAL_USE_TOCO_API_DIRECTLY",
+]
+remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/contrib/lite/schema/schema.fbs b/tensorflow/contrib/lite/schema/schema.fbs
index 2b62c257d8..a65c2e0c70 100644
--- a/tensorflow/contrib/lite/schema/schema.fbs
+++ b/tensorflow/contrib/lite/schema/schema.fbs
@@ -435,21 +435,25 @@ table Operator {
custom_options_format:CustomOptionsFormat;
}
-// The root type, defining a model.
+// The root type, defining a subgraph, which typically represents an entire
+// model.
table SubGraph {
- // A list of all tensors used in this model.
+ // A list of all tensors used in this subgraph.
tensors:[Tensor];
- // Indices of the input tensors.
+ // Indices of the tensors that are inputs into this subgraph. Note this is
+ // the list of non-static tensors that feed into the subgraph for inference.
inputs:[int];
- // Indices of the output tensors.
+ // Indices of the tensors that are outputs out of this subgraph. Note this is
+ // the list of output tensors that are considered the product of the
+ // subgraph's inference.
outputs:[int];
// All operators, in execution order.
operators:[Operator];
- // Name of subgraph (used for debugging).
+ // Name of this subgraph (used for debugging).
name:string;
}
@@ -475,7 +479,10 @@ table Model {
// A description of the model.
description:string;
- // Buffers of the model
+ // Buffers of the model.
+ // Note the 0th entry of this array must be an empty buffer (sentinel).
+ // This is a convention so that tensors without a buffer can provide 0 as
+ // their buffer.
buffers:[Buffer];
}
diff --git a/tensorflow/contrib/lite/string_util.cc b/tensorflow/contrib/lite/string_util.cc
index cd41299d38..a89776b29f 100644
--- a/tensorflow/contrib/lite/string_util.cc
+++ b/tensorflow/contrib/lite/string_util.cc
@@ -24,7 +24,10 @@ namespace tflite {
namespace {
// Convenient method to get pointer to int32_t.
-int32_t* GetIntPtr(char* ptr) { return reinterpret_cast<int32_t*>(ptr); }
+const int32_t* GetIntPtr(const char* ptr) {
+ return reinterpret_cast<const int32_t*>(ptr);
+}
+
} // namespace
void DynamicBuffer::AddString(const char* str, size_t len) {
@@ -64,7 +67,7 @@ void DynamicBuffer::AddJoinedString(const std::vector<StringRef>& strings,
offset_.push_back(offset_.back() + total_len);
}
-void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) {
+int DynamicBuffer::WriteToBuffer(char** buffer) {
// Allocate sufficient memory to tensor buffer.
int32_t num_strings = offset_.size() - 1;
// Total bytes include:
@@ -75,43 +78,57 @@ void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) {
int32_t bytes = data_.size() // size of content
+ sizeof(int32_t) * (num_strings + 2); // size of header
- // Output tensor will take over the ownership of tensor_buffer, and free it
- // during Interpreter destruction.
- char* tensor_buffer = static_cast<char*>(malloc(bytes));
+ // Caller will take ownership of buffer.
+ *buffer = reinterpret_cast<char*>(malloc(bytes));
// Set num of string
- memcpy(tensor_buffer, &num_strings, sizeof(int32_t));
+ memcpy(*buffer, &num_strings, sizeof(int32_t));
// Set offset of strings.
int32_t start = sizeof(int32_t) * (num_strings + 2);
for (int i = 0; i < offset_.size(); i++) {
int32_t offset = start + offset_[i];
- memcpy(tensor_buffer + sizeof(int32_t) * (i + 1), &offset, sizeof(int32_t));
+ memcpy(*buffer + sizeof(int32_t) * (i + 1), &offset, sizeof(int32_t));
}
// Copy data of strings.
- memcpy(tensor_buffer + start, data_.data(), data_.size());
+ memcpy(*buffer + start, data_.data(), data_.size());
+ return bytes;
+}
+
+void DynamicBuffer::WriteToTensor(TfLiteTensor* tensor) {
+ char* tensor_buffer;
+ int bytes = WriteToBuffer(&tensor_buffer);
// Set tensor content pointer to tensor_buffer, and release original data.
auto dims = TfLiteIntArrayCreate(1);
- dims->data[0] = num_strings;
+ dims->data[0] = offset_.size() - 1; // Store number of strings.
TfLiteTensorReset(tensor->type, tensor->name, dims, tensor->params,
tensor_buffer, bytes, kTfLiteDynamic, tensor->allocation,
tensor);
}
+int GetStringCount(const char* raw_buffer) {
+ // The first integers in the raw buffer is the number of strings.
+ return *GetIntPtr(raw_buffer);
+}
+
int GetStringCount(const TfLiteTensor* tensor) {
// The first integers in the raw buffer is the number of strings.
- return *GetIntPtr(tensor->data.raw);
+ return GetStringCount(tensor->data.raw);
}
-StringRef GetString(const TfLiteTensor* tensor, int string_index) {
- int32_t* offset =
- GetIntPtr(tensor->data.raw + sizeof(int32_t) * (string_index + 1));
+StringRef GetString(const char* raw_buffer, int string_index) {
+ const int32_t* offset =
+ GetIntPtr(raw_buffer + sizeof(int32_t) * (string_index + 1));
return {
- tensor->data.raw + (*offset),
+ raw_buffer + (*offset),
(*(offset + 1)) - (*offset),
};
}
+StringRef GetString(const TfLiteTensor* tensor, int string_index) {
+ return GetString(tensor->data.raw, string_index);
+}
+
} // namespace tflite
diff --git a/tensorflow/contrib/lite/string_util.h b/tensorflow/contrib/lite/string_util.h
index c35a2fff3c..57f129bf5e 100644
--- a/tensorflow/contrib/lite/string_util.h
+++ b/tensorflow/contrib/lite/string_util.h
@@ -49,7 +49,7 @@ namespace tflite {
// Convenient structure to store string pointer and length.
typedef struct {
- char* str;
+ const char* str;
int len;
} StringRef;
@@ -70,6 +70,10 @@ class DynamicBuffer {
// buffer.
void AddJoinedString(const std::vector<StringRef>& strings, char separator);
+ // Fill content into a buffer and returns the number of bytes stored.
+ // The function allocates space for the buffer but does NOT take ownership.
+ int WriteToBuffer(char** buffer);
+
// Fill content into a string tensor.
void WriteToTensor(TfLiteTensor* tensor);
@@ -81,10 +85,12 @@ class DynamicBuffer {
};
// Return num of strings in a String tensor.
+int GetStringCount(const char* raw_buffer);
int GetStringCount(const TfLiteTensor* tensor);
// Get String pointer and length of index-th string in tensor.
// NOTE: This will not create a copy of string data.
+StringRef GetString(const char* raw_buffer, int string_index);
StringRef GetString(const TfLiteTensor* tensor, int string_index);
} // namespace tflite
diff --git a/tensorflow/contrib/lite/toco/dump_graphviz.cc b/tensorflow/contrib/lite/toco/dump_graphviz.cc
index c289ddcd92..5bb0e3ba4d 100644
--- a/tensorflow/contrib/lite/toco/dump_graphviz.cc
+++ b/tensorflow/contrib/lite/toco/dump_graphviz.cc
@@ -259,6 +259,19 @@ NodeProperties GetPropertiesForOperator(const Operator& op) {
node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
break;
}
+ case OperatorType::kFakeQuant: {
+ const auto& fakequant_op = static_cast<const FakeQuantOperator&>(op);
+ node_properties.color = Color(0xC5, 0x39, 0x29); // Bolder color
+ if (fakequant_op.minmax) {
+ AppendF(&node_properties.label, "\\n%dbit [%g,%g]",
+ fakequant_op.num_bits, fakequant_op.minmax->min,
+ fakequant_op.minmax->max);
+ } else {
+ AppendF(&node_properties.label, "\\n%dbit [?,?]",
+ fakequant_op.num_bits);
+ }
+ break;
+ }
default:
node_properties.color = Color(0xDB, 0x44, 0x37);
break;
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc b/tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc
index f098981a5c..c00cdcb944 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/experimental_shuffle_fc_weights.cc
@@ -55,17 +55,26 @@ bool ExperimentalShuffleFCWeights::Run(Model* model, std::size_t op_index) {
// Exit if, based on the known shapes, this FC op is not a GEMV.
// The shuffling of FC weights is only useful to enable fast GEMV paths.
const Shape& input_shape = input_array.shape();
- for (int i = 0; i < input_shape.dimensions_count() - 1; i++) {
+ for (int i = 1; i < input_shape.dimensions_count() - 1; i++) {
if (input_shape.dims(i) != 1) {
// The input activations, shaped as a matrix, have multiple columns.
// This FC op isn't a matrix*vector multiplication.
AddMessageF(
"Not applying experimental shuffling to the weights of %s because "
- "it's not a matrix*vector product",
+ "the input shape is not 1D or 2D (possibly with additional inner "
+ "dimensions of size 1)",
LogName(*op));
return false;
}
}
+ if (input_shape.dims(0) != 1 && input_shape.dims(0) != 4) {
+ AddMessageF(
+ "Not applying experimental shuffling to the weights of %s because "
+ "the input shape's leading dimension, i.e. the 'batch size', is not "
+ "equal to 1 or 4",
+ LogName(*op));
+ return false;
+ }
// Exit if the weights shape isn't an integral multiple of the shuffled
// block shape, 4x16. We don't want to have to write code dealing with
// odd sizes, that would go un-exercised at the moment as the models
@@ -129,6 +138,20 @@ bool ExperimentalShuffleFCWeights::Run(Model* model, std::size_t op_index) {
fc_op->experimental_shuffled_weights = true;
AddMessageF("Applied experimental shuffling to the weights of %s",
LogName(*op));
+ // Add a second output array to this FC op, serving as a workspace to perform
+ // runtime shuffling/xoring of its input activations.
+ CHECK_EQ(fc_op->outputs.size(), 1);
+ const string& shuffled_input_workspace_array_name =
+ AvailableArrayName(*model, fc_op->inputs[0] + "_shuffled");
+ fc_op->outputs.push_back(shuffled_input_workspace_array_name);
+ auto& shuffled_input_workspace_array =
+ model->GetOrCreateArray(shuffled_input_workspace_array_name);
+ shuffled_input_workspace_array.data_type = input_array.data_type;
+ *shuffled_input_workspace_array.mutable_shape() = input_array.shape();
+ shuffled_input_workspace_array.GetOrCreateMinMax() = input_array.GetMinMax();
+ shuffled_input_workspace_array.GetOrCreateQuantizationParams() =
+ input_array.GetQuantizationParams();
+
return true;
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
index ba244cf5ef..be6e0e07dd 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/propagate_fixed_sizes.cc
@@ -168,7 +168,9 @@ void ProcessConvOperator(Model* model, ConvOperator* op) {
return;
}
const auto& input_shape = input_array.shape();
- CHECK_EQ(input_shape.dimensions_count(), 4);
+ CHECK(input_shape.dimensions_count() == 4)
+ << "Conv ops require 4D inputs. Input array \"" << op->inputs[0]
+ << "\" is " << input_shape.dimensions_count() << "D.";
const auto& weights_array = model->GetArray(op->inputs[1]);
// Yield until weights dims have been resolved.
@@ -249,12 +251,6 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
<< op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
<< toco::ShapeToString(weights_shape) << ".";
- CHECK(weights_shape.dims(0) == 1 && weights_shape.dims(3) == 1)
- << "TransposeConv weights dimensions must begin and end with 1. Input "
- "weights \""
- << op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
- << toco::ShapeToString(weights_shape) << ".";
-
// Compute padding
const int kheight = weights_shape.dims(1);
const int kwidth = weights_shape.dims(2);
@@ -269,9 +265,7 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
LOG(FATAL) << "TransposeConv only supports SAME or VALID padding";
}
- // VALIDATE OUTPUT SHAPE
- // Compute the output shape from the input and weights shapes to verify it
- // agrees with the specified output shape.
+ // VALIDATE some dimensions and set the output shape.
const auto& input_array =
model->GetArray(op->inputs[TransposeConvOperator::DATA_INPUT]);
if (!input_array.has_shape()) {
@@ -283,31 +277,13 @@ void ProcessTransposeConvOperator(Model* model, TransposeConvOperator* op) {
<< "TransposeConv input shape must have 4 dimensions. Input \""
<< op->inputs[TransposeConvOperator::WEIGHTS] << "\" had shape "
<< toco::ShapeToString(weights_shape) << ".";
+ CHECK_EQ(input_shape.dims(3), weights_shape.dims(0))
+ << "Input shape depth and weight depth do not agree";
- // Compute output shape
- const int input_width = input_shape.dims(2);
- const int input_height = input_shape.dims(1);
- int output_height = op->stride_height * (input_height - 1);
- int output_width = op->stride_width * (input_width - 1);
- if (op->padding.type == PaddingType::kValid) {
- output_height += kheight;
- output_width += kwidth;
- } else if (op->padding.type == PaddingType::kSame) {
- output_height += 1;
- output_width += 1;
- }
-
- CHECK(specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data ==
- std::vector<int32>({input_shape.dims(0), output_height, output_width,
- weights_shape.dims(3)}))
- << "Specified output shape: " << ShapeToString(output_array.shape())
- << ", does not agree with shape computed from input data and weights: ["
- << input_shape.dims(0) << ", " << output_height << ", " << output_width
- << ", " << weights_shape.dims(3) << "].";
-
- // SUCCESS: Set the op's output shape according to the specified output shape.
- *(output_array.mutable_shape()->mutable_dims()) =
+ // Set the output shape according to the specified output shape.
+ std::vector<int32> const& specified_output_shape =
specified_output_shape_array.GetBuffer<ArrayDataType::kInt32>().data;
+ *(output_array.mutable_shape()->mutable_dims()) = specified_output_shape;
}
void ProcessDepthwiseConvOperator(Model* model, DepthwiseConvOperator* op) {
@@ -1179,6 +1155,11 @@ void ProcessRankOperator(Model* model, RankOperator* op) {
return;
}
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return;
+ }
+
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
@@ -1200,6 +1181,11 @@ void ProcessShapeOperator(Model* model, TensorFlowShapeOperator* op) {
return;
}
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return;
+ }
+
const auto& input_array = model->GetArray(op->inputs[0]);
if (!input_array.has_shape()) {
// Yield until input dims have been resolved.
@@ -1230,10 +1216,6 @@ void ProcessStackOperator(Model* model, StackOperator* op) {
}
Shape shape = input_array.shape();
- if (shape.dimensions_count() == 0) {
- // Convert 0D scalars to 1D scalars of shape {1}.
- shape.mutable_dims()->push_back(1);
- }
if (!stacked_shape) {
stacked_shape.reset(new Shape(shape));
} else {
@@ -1519,7 +1501,7 @@ void ProcessArgMaxOperator(Model* model, ArgMaxOperator* op) {
const std::vector<int>& input_dims = input_array.shape().dims();
std::vector<int> output_dims;
- output_dims.reserve(input_dims.size() - 1);
+ output_dims.reserve(input_dims.size());
for (int i = 0; i < input_dims.size() - 1; ++i) {
output_dims.push_back(input_dims[i]);
}
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
index 5e779f6765..6e78653fad 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_binary.cc
@@ -233,7 +233,12 @@ bool ResolveConstantBinaryOperator::Run(Model* model, std::size_t op_index) {
}
// Check that input data types agree.
- CHECK(input0_array.data_type == input1_array.data_type);
+ CHECK(input0_array.data_type == input1_array.data_type)
+ << "Dissimilar data types given to op outputting \""
+ << binary_op->outputs[0] << "\". 0:\"" << binary_op->inputs[0] << "\"("
+ << static_cast<int>(input0_array.data_type) << ") 1:\""
+ << binary_op->inputs[1] << "\"("
+ << static_cast<int>(input1_array.data_type) << ").";
// Do the actual constants propagation
EvaluateBinaryOperatorOnConstantInputs(model, binary_op);
diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
index 37beb41dfc..4bb1217828 100644
--- a/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
+++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_multiply_by_zero.cc
@@ -60,6 +60,11 @@ bool ResolveMultiplyByZero::Run(Model* model, std::size_t op_index) {
const auto& output_array_name = mul_op->outputs[0];
auto& output_array = model->GetArray(output_array_name);
+ if (output_array.data_type == ArrayDataType::kNone) {
+ // Yield until the output type has been set by PropagateArrayDataTypes
+ return false;
+ }
+
// Yield if the output shape is not known yet.
if (!output_array.has_shape()) {
return false;
diff --git a/tensorflow/contrib/lite/toco/import_tensorflow.cc b/tensorflow/contrib/lite/toco/import_tensorflow.cc
index 155d890c9f..2ed05cb372 100644
--- a/tensorflow/contrib/lite/toco/import_tensorflow.cc
+++ b/tensorflow/contrib/lite/toco/import_tensorflow.cc
@@ -1093,8 +1093,10 @@ void ConvertMatMulOperator(const NodeDef& node,
// Transpose flags should be easy to support, but we don't have a
// GraphDef with them to test on at the moment.
- CHECK_EQ(GetBoolAttr(node, "transpose_a"), false);
- CHECK_EQ(GetBoolAttr(node, "transpose_b"), false);
+ CHECK_EQ(HasAttr(node, "transpose_a") && GetBoolAttr(node, "transpose_a"),
+ false);
+ CHECK_EQ(HasAttr(node, "transpose_b") && GetBoolAttr(node, "transpose_b"),
+ false);
CHECK(!HasAttr(node, "adjoint_a") ||
(GetBoolAttr(node, "adjoint_a") == false));
CHECK(!HasAttr(node, "adjoint_b") ||
@@ -1300,11 +1302,17 @@ void ConvertStridedSliceOperator(const NodeDef& node,
}
op->outputs.push_back(node.name());
- op->begin_mask = GetIntAttr(node, "begin_mask");
- op->ellipsis_mask = GetIntAttr(node, "ellipsis_mask");
- op->end_mask = GetIntAttr(node, "end_mask");
- op->new_axis_mask = GetIntAttr(node, "new_axis_mask");
- op->shrink_axis_mask = GetIntAttr(node, "shrink_axis_mask");
+ op->begin_mask =
+ HasAttr(node, "begin_mask") ? GetIntAttr(node, "begin_mask") : 0;
+ op->ellipsis_mask =
+ HasAttr(node, "ellipsis_mask") ? GetIntAttr(node, "ellipsis_mask") : 0;
+ op->end_mask = HasAttr(node, "end_mask") ? GetIntAttr(node, "end_mask") : 0;
+ op->new_axis_mask =
+ HasAttr(node, "new_axis_mask") ? GetIntAttr(node, "new_axis_mask") : 0;
+ op->shrink_axis_mask = HasAttr(node, "shrink_axis_mask")
+ ? GetIntAttr(node, "shrink_axis_mask")
+ : 0;
+
model->operators.emplace_back(op);
}
@@ -1394,8 +1402,11 @@ void ConvertArgMaxOperator(const NodeDef& node,
Model* model) {
CHECK_EQ(node.op(), "ArgMax");
CheckInputsCount(node, tf_import_flags, 2);
- const auto axis_data_type = GetDataTypeAttr(node, "Tidx");
- const auto output_type = GetDataTypeAttr(node, "output_type");
+ const auto axis_data_type =
+ HasAttr(node, "Tidx") ? GetDataTypeAttr(node, "Tidx") : DT_INT32;
+ const auto output_type = HasAttr(node, "output_type")
+ ? GetDataTypeAttr(node, "output_type")
+ : DT_INT64;
CHECK(axis_data_type == DT_INT64 || axis_data_type == DT_INT32);
CHECK(output_type == DT_INT64 || output_type == DT_INT32);
auto* op = new ArgMaxOperator;
@@ -1772,7 +1783,7 @@ void ConvertStackOperator(const NodeDef& node,
op->inputs.push_back(node.input(i));
}
// Both "Stack" and "Pack" have the "axis" attribute.
- op->axis = GetIntAttr(node, "axis");
+ op->axis = HasAttr(node, "axis") ? GetIntAttr(node, "axis") : 0;
op->outputs.push_back(node.name());
model->operators.emplace_back(op);
}
diff --git a/tensorflow/contrib/lite/toco/tflite/BUILD b/tensorflow/contrib/lite/toco/tflite/BUILD
index e0191801a0..e1025c6664 100644
--- a/tensorflow/contrib/lite/toco/tflite/BUILD
+++ b/tensorflow/contrib/lite/toco/tflite/BUILD
@@ -54,6 +54,7 @@ cc_library(
"types.h",
],
deps = [
+ "//tensorflow/contrib/lite:string_util",
"//tensorflow/contrib/lite/schema:schema_fbs",
"//tensorflow/contrib/lite/toco:model",
],
diff --git a/tensorflow/contrib/lite/toco/tflite/types.cc b/tensorflow/contrib/lite/toco/tflite/types.cc
index 0afd2f3df5..c9c2e9ba01 100644
--- a/tensorflow/contrib/lite/toco/tflite/types.cc
+++ b/tensorflow/contrib/lite/toco/tflite/types.cc
@@ -13,12 +13,29 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/contrib/lite/toco/tflite/types.h"
+#include "tensorflow/contrib/lite/string_util.h"
namespace toco {
namespace tflite {
namespace {
+
+DataBuffer::FlatBufferOffset CopyStringToBuffer(
+ const Array& array, flatbuffers::FlatBufferBuilder* builder) {
+ const auto& src_data = array.GetBuffer<ArrayDataType::kString>().data;
+ ::tflite::DynamicBuffer dyn_buffer;
+ for (const string& str : src_data) {
+ dyn_buffer.AddString(str.c_str(), str.length());
+ }
+ char* tensor_buffer;
+ int bytes = dyn_buffer.WriteToBuffer(&tensor_buffer);
+ std::vector<uint8_t> dst_data(bytes);
+ memcpy(dst_data.data(), tensor_buffer, bytes);
+ free(tensor_buffer);
+ return builder->CreateVector(dst_data.data(), bytes);
+}
+
template <ArrayDataType T>
DataBuffer::FlatBufferOffset CopyBuffer(
const Array& array, flatbuffers::FlatBufferBuilder* builder) {
@@ -29,6 +46,18 @@ DataBuffer::FlatBufferOffset CopyBuffer(
return builder->CreateVector(dst_data, size);
}
+void CopyStringFromBuffer(const ::tflite::Buffer& buffer, Array* array) {
+ auto* src_data = reinterpret_cast<const char*>(buffer.data()->data());
+ std::vector<string>* dst_data =
+ &array->GetMutableBuffer<ArrayDataType::kString>().data;
+ int32_t num_strings = ::tflite::GetStringCount(src_data);
+ for (int i = 0; i < num_strings; i++) {
+ ::tflite::StringRef str_ref = ::tflite::GetString(src_data, i);
+ string this_str(str_ref.str, str_ref.len);
+ dst_data->push_back(this_str);
+ }
+}
+
template <ArrayDataType T>
void CopyBuffer(const ::tflite::Buffer& buffer, Array* array) {
using NativeT = ::toco::DataType<T>;
@@ -93,7 +122,7 @@ flatbuffers::Offset<flatbuffers::Vector<uint8_t>> DataBuffer::Serialize(
case ArrayDataType::kInt64:
return CopyBuffer<ArrayDataType::kInt64>(array, builder);
case ArrayDataType::kString:
- return CopyBuffer<ArrayDataType::kString>(array, builder);
+ return CopyStringToBuffer(array, builder);
case ArrayDataType::kUint8:
return CopyBuffer<ArrayDataType::kUint8>(array, builder);
default:
@@ -114,7 +143,7 @@ void DataBuffer::Deserialize(const ::tflite::Tensor& tensor,
case ::tflite::TensorType_INT64:
return CopyBuffer<ArrayDataType::kInt64>(buffer, array);
case ::tflite::TensorType_STRING:
- return CopyBuffer<ArrayDataType::kString>(buffer, array);
+ return CopyStringFromBuffer(buffer, array);
case ::tflite::TensorType_UINT8:
return CopyBuffer<ArrayDataType::kUint8>(buffer, array);
default:
diff --git a/tensorflow/contrib/lite/toco/tflite/types_test.cc b/tensorflow/contrib/lite/toco/tflite/types_test.cc
index a040fe1358..29fb0b2af2 100644
--- a/tensorflow/contrib/lite/toco/tflite/types_test.cc
+++ b/tensorflow/contrib/lite/toco/tflite/types_test.cc
@@ -151,6 +151,13 @@ TEST(DataBuffer, Int32) {
::testing::ElementsAre(1, 1 << 30));
}
+TEST(DataBuffer, String) {
+ Array recovered = ToFlatBufferAndBack<ArrayDataType::kString>(
+ {"AA", "BBB", "Best. String. Ever."});
+ EXPECT_THAT(recovered.GetBuffer<ArrayDataType::kString>().data,
+ ::testing::ElementsAre("AA", "BBB", "Best. String. Ever."));
+}
+
TEST(Padding, All) {
EXPECT_EQ(::tflite::Padding_SAME, Padding::Serialize(PaddingType::kSame));
EXPECT_EQ(PaddingType::kSame, Padding::Deserialize(::tflite::Padding_SAME));
diff --git a/tensorflow/contrib/lite/toco/tooling_util.cc b/tensorflow/contrib/lite/toco/tooling_util.cc
index cf2cbeedc7..5a341294db 100644
--- a/tensorflow/contrib/lite/toco/tooling_util.cc
+++ b/tensorflow/contrib/lite/toco/tooling_util.cc
@@ -1405,20 +1405,7 @@ void ResolveModelFlags(const ModelFlags& model_flags, Model* model) {
}
input_minmax.min = (qmin - mean_value) / std_value;
input_minmax.max = (qmax - mean_value) / std_value;
- if (input_array.minmax) {
- if (input_array_proto.has_mean_value() ||
- input_array_proto.has_std_value()) {
- const double width = input_minmax.max - input_minmax.min;
- const double kMinMaxAllowedDiff = 1e-6 * width;
- CHECK(std::abs(input_minmax.min - input_array.minmax->min) <
- kMinMaxAllowedDiff &&
- std::abs(input_minmax.max - input_array.minmax->max) <
- kMinMaxAllowedDiff)
- << input_minmax.min << ", " << input_minmax.max
- << " != " << input_array.minmax->min << ", "
- << input_array.minmax->max;
- }
- } else {
+ if (!input_array.minmax) {
input_array.GetOrCreateMinMax() = input_minmax;
}
}
diff --git a/tensorflow/contrib/lookup/lookup_ops_test.py b/tensorflow/contrib/lookup/lookup_ops_test.py
index f681b7b132..5d4682ec9f 100644
--- a/tensorflow/contrib/lookup/lookup_ops_test.py
+++ b/tensorflow/contrib/lookup/lookup_ops_test.py
@@ -58,6 +58,12 @@ class HashTableOpTest(test.TestCase):
result = output.eval()
self.assertAllEqual([0, 1, -1], result)
+ exported_keys_tensor, exported_values_tensor = table.export()
+
+ self.assertItemsEqual([b"brain", b"salad", b"surgery"],
+ exported_keys_tensor.eval())
+ self.assertItemsEqual([0, 1, 2], exported_values_tensor.eval())
+
def testHashTableFindHighRank(self):
with self.test_session():
default_val = -1
diff --git a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
index f37a2593e2..c35e60a554 100644
--- a/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
+++ b/tensorflow/contrib/meta_graph_transform/meta_graph_transform.py
@@ -13,7 +13,10 @@
# limitations under the License.
# ==============================================================================
-"""Apply graph_transforms tool to MetaGraphDefs."""
+"""Apply graph_transforms tool to MetaGraphDefs.
+
+@@meta_graph_transform
+"""
from __future__ import absolute_import
from __future__ import division
diff --git a/tensorflow/contrib/metrics/BUILD b/tensorflow/contrib/metrics/BUILD
index 5ca42f41c1..e050f3c8d4 100644
--- a/tensorflow/contrib/metrics/BUILD
+++ b/tensorflow/contrib/metrics/BUILD
@@ -77,7 +77,7 @@ py_test(
py_test(
name = "metric_ops_test",
srcs = ["python/ops/metric_ops_test.py"],
- shard_count = 3,
+ shard_count = 8,
srcs_version = "PY2AND3",
tags = ["noasan"], # times out b/63678675
deps = [
diff --git a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc
index 8dca90a1e3..ed22ee667f 100644
--- a/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc
+++ b/tensorflow/contrib/mpi_collectives/kernels/mpi_ops.cc
@@ -73,7 +73,7 @@ limitations under the License.
*/
template <class T>
-using StatusOr = perftools::gputools::port::StatusOr<T>;
+using StatusOr = se::port::StatusOr<T>;
using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;
diff --git a/tensorflow/contrib/mpi_collectives/mpi_ops.cc b/tensorflow/contrib/mpi_collectives/mpi_ops.cc
new file mode 100644
index 0000000000..475297ca92
--- /dev/null
+++ b/tensorflow/contrib/mpi_collectives/mpi_ops.cc
@@ -0,0 +1,1236 @@
+/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifdef TENSORFLOW_USE_MPI
+
+#include <queue>
+#include <thread>
+#include <unordered_map>
+
+#include "tensorflow/core/framework/op.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/shape_inference.h"
+#include "tensorflow/core/framework/types.pb.h"
+#include "tensorflow/core/platform/mutex.h"
+
+#define EIGEN_USE_THREADS
+
+#if GOOGLE_CUDA
+#include <cuda_runtime.h>
+#include "tensorflow/stream_executor/stream.h"
+#endif
+
+#include "tensorflow/stream_executor/lib/statusor.h"
+
+#define OMPI_SKIP_MPICXX
+#include "third_party/mpi/mpi.h"
+#include "tensorflow/contrib/mpi_collectives/mpi_message.pb.h"
+#include "tensorflow/contrib/mpi_collectives/ring.h"
+
+/*
+ * MPI Allreduce and Allgather Ops for TensorFlow.
+ *
+ * TensorFlow natively provides inter-device communication through send and
+ * receive ops and inter-node communication through Distributed TensorFlow,
+ * based on the same send and receive abstractions. These end up being
+ * insufficient for synchronous data-parallel training on HPC clusters where
+ * Infiniband or other high-speed interconnects are available. This module
+ * implements MPI ops for allgather and allreduce, which do bandwidth-optimal
+ * gathers and reductions and can take advantage of hardware-optimized
+ * communication libraries through the MPI implementation.
+ *
+ * The primary logic of the allreduce and allgather are in RingAllgather() and
+ * RingAllreduce(). The background thread which facilitates MPI operations is
+ * run in BackgroundThreadLoop(). The provided MPI ops are:
+ * – MPIInit:
+ * Initialize MPI on a given device (CPU or GPU).
+ * Should only be run on a single device in every process.
+ * – MPISize:
+ * Get the number of MPI processes in the global communicator.
+ * – MPIRank:
+ * Get the rank of the current MPI process in the global communicator.
+ * – MPILocalRank:
+ * Get the local rank of the current MPI process within its node.
+ * – MPIAllreduce:
+ * Perform an allreduce on a Tensor, returning the sum
+ * across all MPI processes in the global communicator.
+ * – MPIAllgather:
+ * Perform an allgather on a Tensor, returning the concatenation of
+ * the tensor on the first dimension across all MPI processes in the
+ * global communicator.
+ *
+ */
+
+template <class T>
+using StatusOr = se::port::StatusOr<T>;
+
+using CPUDevice = Eigen::ThreadPoolDevice;
+using GPUDevice = Eigen::GpuDevice;
+
+namespace tensorflow {
+namespace contrib {
+namespace mpi {
+
+// Make sure template specializations are generated in the ring.cu.cc and the
+// ring.cc file, not in this file.
+extern template Status RingAllreduce<GPUDevice, int>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllreduce<GPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ Tensor*, Tensor*);
+extern template Status RingAllreduce<GPUDevice, float>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllgather<GPUDevice, int>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+extern template Status RingAllgather<GPUDevice, long long>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+extern template Status RingAllgather<GPUDevice, float>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+extern template Status RingAllreduce<CPUDevice, int>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllreduce<CPUDevice, long long>(OpKernelContext*,
+ const Tensor*,
+ Tensor*, Tensor*);
+extern template Status RingAllreduce<CPUDevice, float>(OpKernelContext*,
+ const Tensor*, Tensor*,
+ Tensor*);
+extern template Status RingAllgather<CPUDevice, int>(OpKernelContext*,
+ const Tensor*,
+ const std::vector<size_t>&,
+ Tensor*);
+extern template Status RingAllgather<CPUDevice, long long>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+extern template Status RingAllgather<CPUDevice, float>(
+ OpKernelContext*, const Tensor*, const std::vector<size_t>&, Tensor*);
+
+namespace {
+
+// Return true if the templated type is GPUDevice, otherwise false.
+template <typename T>
+bool IsGPUDevice();
+template <>
+bool IsGPUDevice<GPUDevice>() {
+ return true;
+};
+template <>
+bool IsGPUDevice<CPUDevice>() {
+ return false;
+};
+
+// A callback to call after the MPI communication completes. Since the
+// allreduce and allgather ops are asynchronous, this callback is what resumes
+// computation after the reduction is completed.
+typedef std::function<void(StatusOr<Tensor>)> CommunicationDoneCallback;
+
+struct CollectiveOpRecord {
+ // The rank performing this piece of the op
+ int rank;
+
+ // The name of the op/tensor to be reduced
+ std::string name;
+
+ // The op's kernel context
+ OpKernelContext* context;
+
+ // Data type of the op
+ DataType dtype;
+
+ // The input tensor
+ const Tensor* in_t;
+
+ // Allgather: Vector of per-rank first-dimension sizes
+ std::vector<size_t> sizes_vec;
+
+ // The temp tensor for intermediate results
+ Tensor temp_t;
+
+ // The output tensor
+ Tensor* out_t;
+
+ // Whether to run this op on the gpu
+ bool on_gpu;
+
+ // The callback to call after the op has completed
+ CommunicationDoneCallback callback;
+};
+
+// Table storing Tensors to be reduced, keyed by unique name.
+// This table contains everything necessary to do the reduction
+typedef std::unordered_map<std::string, CollectiveOpRecord> TensorTable;
+
+// Table for storing Tensor metadata on rank zero. This is used for error
+// checking and size calculations, as well as determining when a reduction is
+// ready to be done (when all nodes are ready to do it).
+typedef std::unordered_map<std::string, std::vector<MPIRequest> > MessageTable;
+
+// The global state required for the MPI ops.
+//
+// MPI is a library that stores a lot of global per-program state and often
+// requires running on a single thread. As a result, we have to have a single
+// background thread responsible for all MPI operations, and communicate with
+// that background thread through global state.
+struct MPIGlobalState {
+ // An atomic boolean which is set to true when MPI is initialized.
+ // This ensures that MPI_Init is never called twice.
+ std::atomic_flag initialized_flag = ATOMIC_FLAG_INIT;
+
+ // Condition variable to wait for initialization
+ condition_variable cv;
+
+ // Whether MPI_Init has been completed on the background thread.
+ bool initialization_done = false;
+
+ // Whether MPI_Init succeeded on the background thread.
+ Status init_status;
+
+ // A mutex that needs to be used whenever MPI operations touch
+ // shared structures.
+ mutex mu;
+
+ // Tensors waiting to be allreduced or allgathered.
+ TensorTable tensor_table;
+
+ // Queue of MPI requests waiting to be sent to the coordinator node.
+ std::queue<MPIRequest> message_queue;
+
+ // Background thread running MPI communication.
+ std::thread background_thread;
+
+ // Whether the background thread should shutdown.
+ bool shut_down = false;
+
+ // Only exists on the coordinator node (rank zero). Maintains a count of
+ // how many nodes are ready to allreduce every tensor (keyed by tensor
+ // name).
+ std::unique_ptr<MessageTable> message_table;
+
+ // The MPI rank, local rank, and size.
+ int rank = 0;
+ int local_rank = 0;
+ int size = 1;
+
+ // The device that MPI was initialized on. (-1 for no GPU)
+ int device = -1;
+
+ // The CUDA stream used for data transfers and within-allreduce operations.
+ // A naive implementation would use the TensorFlow StreamExecutor CUDA
+ // stream. However, the allreduce and allgather require doing memory copies
+ // and kernel executions (for accumulation of values on the GPU). However,
+ // the subsequent operations must wait for those operations to complete,
+ // otherwise MPI (which uses its own stream internally) will begin the data
+ // transfers before the CUDA calls are complete. In order to wait for those
+ // CUDA operations, if we were using the TensorFlow stream, we would have
+ // to synchronize that stream; however, other TensorFlow threads may be
+ // submitting more work to that stream, so synchronizing on it can cause
+ // the allreduce to be delayed, waiting for compute totally unrelated to it
+ // in other parts of the graph. Overlaying memory transfers and compute
+ // during backpropagation is crucial for good performance, so we cannot use
+ // the TensorFlow stream, and must use our own stream.
+#if GOOGLE_CUDA
+ cudaStream_t stream;
+ std::atomic_flag stream_created_flag = ATOMIC_FLAG_INIT;
+#endif
+
+ ~MPIGlobalState() {
+ // Make sure that the destructor of the background thread is safe to
+ // call. If a thread is still joinable (not detached or complete) its
+ // destructor cannot be called.
+ if (background_thread.joinable()) {
+ shut_down = true;
+ background_thread.join();
+ }
+ }
+};
+
+// All the MPI state that must be stored globally per-process.
+static MPIGlobalState mpi_global;
+
+// For clarify in argument lists.
+#define RANK_ZERO 0
+
+// A tag used for all coordinator messaging.
+#define TAG_NOTIFY 1
+
+// Store the MPIRequest for a name, and return whether the total count of
+// MPIRequests for that tensor is now equal to the MPI size (and thus we are
+// ready to reduce the tensor).
+bool IncrementTensorCount(std::unique_ptr<MessageTable>& message_table,
+ MPIRequest msg, int mpi_size) {
+ auto name = msg.tensor_name();
+ auto table_iter = message_table->find(name);
+ if (table_iter == message_table->end()) {
+ message_table->emplace(name, std::vector<MPIRequest>({msg}));
+ table_iter = message_table->find(name);
+ } else {
+ table_iter->second.push_back(msg);
+ }
+
+ int count = table_iter->second.size();
+ return count == mpi_size;
+}
+
+// Once a tensor is ready to be reduced, the coordinator sends an MPIResponse
+// instructing all ranks to start the reduction to all ranks. The MPIResponse
+// also contains error messages in case the submitted MPIRequests were not
+// valid (for example, contained mismatched shapes or types).
+//
+// Constructing the MPIResponse, thus, requires a whole lot of error checking.
+MPIResponse ConstructMPIResponse(std::unique_ptr<MessageTable>& message_table,
+ std::string name) {
+ bool error = false;
+ auto it = message_table->find(name);
+ assert(it != message_table->end());
+
+ std::vector<MPIRequest> requests = it->second;
+ assert(requests.size() > 0);
+
+ std::ostringstream error_message_stream;
+
+ // Check that all data types being reduced or gathered are identical
+ auto data_type = requests[0].tensor_type();
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ auto request_type = requests[i].tensor_type();
+ if (data_type != request_type) {
+ error = true;
+ error_message_stream << "Mismatched data types: One rank had type "
+ << DataType_Name(data_type)
+ << ", but another rank had type "
+ << DataType_Name(request_type) << ".";
+ break;
+ }
+ }
+
+ // Check that all requested operations are the same
+ auto message_type = requests[0].request_type();
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ if (error) {
+ break;
+ }
+
+ auto request_type = requests[i].request_type();
+ if (message_type != request_type) {
+ error = true;
+ error_message_stream << "Mismatched MPI operations: One rank did an "
+ << message_type << ", but another rank did an "
+ << request_type << ".";
+ break;
+ }
+ }
+
+ // If we are doing an allreduce, check that all tensor shapes
+ // are identical
+ if (message_type == MPIRequest::ALLREDUCE) {
+ TensorShape tensor_shape = requests[0].tensor_shape();
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ if (error) {
+ break;
+ }
+
+ TensorShape request_shape = requests[i].tensor_shape();
+ if (tensor_shape != request_shape) {
+ error = true;
+ error_message_stream << "Mismatched allreduce tensor shapes: "
+ << "One rank reduced a tensor of shape "
+ << tensor_shape.DebugString()
+ << ", but another rank sent a tensor of shape "
+ << request_shape.DebugString() << ".";
+ break;
+ }
+ }
+ }
+
+ // If we are doing an allgather, make sure all but the first dimension are
+ // the same. The first dimension may be different and the output tensor is
+ // the sum of the first dimension. Collect the sizes by rank.
+ if (message_type == MPIRequest::ALLGATHER) {
+ TensorShape tensor_shape = requests[0].tensor_shape();
+
+ if (tensor_shape.dims() == 0) {
+ error = true;
+ error_message_stream << "Rank zero tried to gather a rank-zero tensor.";
+ }
+
+ for (unsigned int i = 1; i < requests.size(); i++) {
+ if (error) {
+ break;
+ }
+
+ TensorShape request_shape = requests[i].tensor_shape();
+ if (tensor_shape.dims() != request_shape.dims()) {
+ error = true;
+ error_message_stream << "Mismatched allgather tensor shapes: "
+ << "One rank gathered a tensor of rank "
+ << tensor_shape.dims()
+ << ", but another rank sent a tensor of rank "
+ << request_shape.dims() << ".";
+ break;
+ }
+
+ for (unsigned int dim = 1; dim < tensor_shape.dims(); dim++) {
+ if (tensor_shape.dim_size(dim) != request_shape.dim_size(dim)) {
+ error = true;
+ error_message_stream
+ << "Mismatched allgather tensor shapes: "
+ << "One rank gathered a tensor with dimension " << dim
+ << " equal to " << tensor_shape.dim_size(dim)
+ << ", but another rank sent a tensor with dimension " << dim
+ << " equal to " << request_shape.dim_size(dim) << ".";
+ break;
+ }
+ }
+ }
+ }
+
+ MPIResponse response;
+ response.set_tensor_name(name);
+ if (error) {
+ std::string error_message = error_message_stream.str();
+ response.set_response_type(MPIResponse::ERROR);
+ response.set_error_message(error_message);
+ } else {
+ auto response_type = MPIResponse::ERROR;
+ if (message_type == MPIRequest::ALLREDUCE) {
+ response_type = MPIResponse::ALLREDUCE;
+ } else {
+ response_type = MPIResponse::ALLGATHER;
+ }
+ response.set_response_type(response_type);
+ }
+
+ // Clear all queued up requests for this name. They are now taken care of
+ // by the constructed MPI response.
+ message_table->erase(it);
+
+ return response;
+}
+
+// Process an MPIResponse by doing a reduction, a gather, or raising an error.
+void PerformCollectiveOp(TensorTable& tensor_table, MPIResponse response) {
+ OpKernelContext* context;
+ const Tensor* input_tensor;
+ std::vector<size_t> sizes_vec;
+ Tensor temp_tensor;
+ Tensor* output_tensor;
+ CommunicationDoneCallback callback;
+ bool on_gpu;
+ {
+ // Lock on the tensor table.
+ mutex_lock guard(mpi_global.mu);
+
+ // We should never fail at finding this key in the tensor table.
+ auto name = response.tensor_name();
+ auto iter = tensor_table.find(name);
+ assert(iter != tensor_table.end());
+
+ assert(response.response_type() == MPIResponse::ALLREDUCE ||
+ response.response_type() == MPIResponse::ALLGATHER ||
+ response.response_type() == MPIResponse::ERROR);
+
+ CollectiveOpRecord record = iter->second;
+ context = record.context;
+ input_tensor = record.in_t;
+ sizes_vec = record.sizes_vec;
+ temp_tensor = record.temp_t;
+ output_tensor = record.out_t;
+ on_gpu = record.on_gpu;
+ callback = record.callback;
+
+ // Clear the tensor table of this tensor and its callbacks; the rest of
+ // this function takes care of it.
+ tensor_table.erase(iter);
+ }
+
+ // Use CPUDevice instead of GPUDevice if no CUDA, to ensure we don't
+ // link to non-existent symbols.
+#if GOOGLE_CUDA
+#define GPU_DEVICE_IF_CUDA GPUDevice
+#else
+#define GPU_DEVICE_IF_CUDA CPUDevice
+#endif
+
+ Status status;
+ auto dtype = input_tensor->dtype();
+ if (response.response_type() == MPIResponse::ALLGATHER) {
+ if (dtype == DT_FLOAT) {
+ status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, float>(
+ context, input_tensor, sizes_vec, output_tensor)
+ : RingAllgather<CPUDevice, float>(
+ context, input_tensor, sizes_vec, output_tensor);
+ } else if (dtype == DT_INT32) {
+ status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, int>(
+ context, input_tensor, sizes_vec, output_tensor)
+ : RingAllgather<CPUDevice, int>(context, input_tensor,
+ sizes_vec, output_tensor);
+ } else if (dtype == DT_INT64) {
+ status = on_gpu ? RingAllgather<GPU_DEVICE_IF_CUDA, long long>(
+ context, input_tensor, sizes_vec, output_tensor)
+ : RingAllgather<CPUDevice, long long>(
+ context, input_tensor, sizes_vec, output_tensor);
+ } else {
+ status = errors::Unknown("Invalid tensor type for MPI allgather.");
+ }
+ } else if (response.response_type() == MPIResponse::ALLREDUCE) {
+ if (dtype == DT_FLOAT) {
+ status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, float>(
+ context, input_tensor, &temp_tensor, output_tensor)
+ : RingAllreduce<CPUDevice, float>(
+ context, input_tensor, &temp_tensor, output_tensor);
+ } else if (dtype == DT_INT32) {
+ status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, int>(
+ context, input_tensor, &temp_tensor, output_tensor)
+ : RingAllreduce<CPUDevice, int>(
+ context, input_tensor, &temp_tensor, output_tensor);
+ } else if (dtype == DT_INT64) {
+ status = on_gpu ? RingAllreduce<GPU_DEVICE_IF_CUDA, long long>(
+ context, input_tensor, &temp_tensor, output_tensor)
+ : RingAllreduce<CPUDevice, long long>(
+ context, input_tensor, &temp_tensor, output_tensor);
+ } else {
+ status = errors::Unknown("Invalid tensor type for MPI allreduce.");
+ }
+ } else if (response.response_type() == MPIResponse::ERROR) {
+ status = errors::FailedPrecondition(response.error_message());
+ }
+
+ if (status.ok()) {
+ callback(StatusOr<Tensor>(*output_tensor));
+ } else {
+ callback(StatusOr<Tensor>(status));
+ }
+}
+
+// The MPI background thread loop coordinates all the MPI processes and the
+// tensor reductions. The design of the communicator mechanism is limited by a
+// few considerations:
+//
+// 1. Some MPI implementations require all MPI calls to happen from a
+// single thread. Since TensorFlow may use several threads for graph
+// processing, this means we must have our own dedicated thread for
+// dealing with MPI.
+// 2. We want to gracefully handle errors, when MPI processes do not
+// properly agree upon what should happen (such as mismatched types or
+// shapes). To do so requires the MPI processes to know about the shapes
+// and types of the relevant tensors on the other processes.
+// 3. The MPI reductions and gathers should be able to happen in parallel
+// with other ongoing operations. Since MPI uses an internal
+// (inaccessible) GPU stream separate from the TF GPUDevice streams, we
+// cannot explicitly synchronize memcpys or kernels with it. As a result,
+// MPIAllreduce and MPIAllgather must be AsyncOpKernels to ensure proper
+// ordering of memcpys and kernels with respect to TF streams.
+// 4. NOTE: We cannot guarantee that all the MPI processes reduce their
+// tensors in the same order. Thus, there must be a way to ensure the
+// reduction memcpys and kernels occur for correct tensors across all
+// ranks at the same time. We choose to use a coordinator (rank ID 0) to
+// gather and trigger the reduction operations that are ready to execute.
+//
+// The coordinator currently follows a master-worker paradigm. Rank zero acts
+// as the master (the "coordinator"), whereas all other ranks are simply
+// workers. Each rank runs its own background thread which progresses in ticks.
+// In each tick, the following actions happen:
+//
+// a) The workers send any available MPIRequests to the coordinator. These
+// MPIRequests indicate what the worker would like to do (i.e. which
+// tensor they would like to gather or reduce, as well as their shape and
+// type). They repeat this for every tensor that they would like to
+// operate on after that tensor's collective op has executed ComputeAsync.
+//
+// b) The workers send an empty "DONE" message to the coordinator to
+// indicate that there are no more tensors they wish to operate on.
+//
+// c) The coordinator receives the MPIRequests from the workers, as well
+// as from its own TensorFlow ops, and stores them in a request table. The
+// coordinator continues to receive MPIRequest messages until it has
+// received MPI_SIZE number of empty "DONE" messages.
+//
+// d) The coordinator finds all tensors that are ready to be reduced,
+// gathered, or all operations that result in an error. For each of those,
+// it sends an MPIResponse to all the workers. When no more MPIResponses
+// are available, it sends a "DONE" response to the workers. If the
+// process is being shutdown, it instead sends a "SHUTDOWN" response.
+//
+// e) The workers listen for MPIResponse messages, processing each one by
+// doing the required reduce or gather, until they receive a "DONE"
+// response from the coordinator. At that point, the tick ends.
+// If instead of "DONE" they receive "SHUTDOWN", they exit their
+// background loop.
+// TODO: Use the global mpi_global state variable instead of a local one
+void BackgroundThreadLoop() {
+#if GOOGLE_CUDA
+ // Set the device, so that this thread uses the same GPU context as the
+ // calling thread.
+ // TODO: Ensure that this is operating correctly. The background thread
+ // needs to be able to control all GPUs that the rank has access to, and
+ // might be more than 1 GPU. Tensors could be resident in any of the
+ // GPUs, so the background thread's accumulate and copy kernels might need
+ // to correctly set the device and it might be necessary for the background
+ // thread to manage multiple streams.
+ cudaSetDevice(mpi_global.device);
+ cudaStreamCreate(&mpi_global.stream);
+#endif
+
+ // Initialize MPI. This must happen on the background thread, since not all
+ // MPI implementations support being called from multiple threads.
+ auto init_result = MPI_Init(NULL, NULL);
+ if (init_result != MPI_SUCCESS) {
+ mpi_global.init_status =
+ errors::Unknown("Could not initialize MPI; MPI_Init() failed.");
+ mpi_global.initialization_done = true;
+ mpi_global.cv.notify_all();
+ return;
+ } else {
+ mpi_global.init_status = Status::OK();
+ }
+
+ // Get MPI rank to determine if we are rank zero.
+ int rank;
+ MPI_Comm_rank(MPI_COMM_WORLD, &rank);
+ bool is_coordinator = rank == 0;
+
+ // Get MPI size to determine how many tensors to wait for before reducing.
+ int size;
+ MPI_Comm_size(MPI_COMM_WORLD, &size);
+
+ // Determine local rank by querying the local communicator.
+ MPI_Comm local_comm;
+ MPI_Comm_split_type(MPI_COMM_WORLD, MPI_COMM_TYPE_SHARED, 0, MPI_INFO_NULL,
+ &local_comm);
+ int local_rank;
+ MPI_Comm_rank(local_comm, &local_rank);
+
+ mpi_global.rank = rank;
+ mpi_global.local_rank = local_rank;
+ mpi_global.size = size;
+ mpi_global.initialization_done = true;
+
+ // Notify calling thread that initialization is complete
+ mpi_global.cv.notify_all();
+
+ // TODO: MOVE MESSAGE TABLE INITIALIZATION TO LIBRARY LOAD!
+ // Initialize the tensor count table. No tensors are available yet.
+ if (is_coordinator) {
+ mpi_global.message_table =
+ std::unique_ptr<MessageTable>(new MessageTable());
+ }
+
+ // The coordinator sends a SHUTDOWN message to trigger shutdown.
+ bool should_shut_down = false;
+ do {
+ // TODO: Eliminate the need for thread sleep by making all activity
+ // depend on other activity (e.g. condition or MPI waits).
+ std::this_thread::sleep_for(std::chrono::milliseconds(1));
+
+ // Copy the data structures from global state under this lock.
+ // However, don't keep the lock for the rest of the loop, so that
+ // enqueued stream callbacks can continue.
+ std::queue<MPIRequest> message_queue;
+ {
+ mutex_lock guard(mpi_global.mu);
+ while (!mpi_global.message_queue.empty()) {
+ MPIRequest message = mpi_global.message_queue.front();
+ mpi_global.message_queue.pop();
+ message_queue.push(message);
+ }
+ }
+
+ // Collect all tensors that are ready to be reduced. Record them in the
+ // tensor count table (rank zero) or send them to rank zero to be
+ // recorded (everyone else).
+ std::vector<std::string> ready_to_reduce;
+ while (!message_queue.empty()) {
+ // Pop the first available message message
+ MPIRequest message = message_queue.front();
+ message_queue.pop();
+
+ if (is_coordinator) {
+ bool reduce =
+ IncrementTensorCount(mpi_global.message_table, message, size);
+ if (reduce) {
+ ready_to_reduce.push_back(message.tensor_name());
+ }
+ } else {
+ std::string encoded_message;
+ message.SerializeToString(&encoded_message);
+ MPI_Send(encoded_message.c_str(), encoded_message.length() + 1,
+ MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
+ }
+ }
+
+ // Rank zero has put all its own tensors in the tensor count table.
+ // Now, it should count all the tensors that are coming from other
+ // ranks at this tick. It should keep getting tensors until it gets a
+ // DONE message from all the other ranks.
+ if (is_coordinator) {
+ // Count of DONE messages. Keep receiving messages until the number
+ // of messages is equal to the number of processes. Initialize to
+ // one since the coordinator is effectively done.
+ int completed_ranks = 1;
+ while (completed_ranks != size) {
+ MPI_Status status;
+ MPI_Probe(MPI_ANY_SOURCE, TAG_NOTIFY, MPI_COMM_WORLD, &status);
+
+ // Find number of characters in message (including zero byte).
+ int source_rank = status.MPI_SOURCE;
+ int msg_length;
+ MPI_Get_count(&status, MPI_BYTE, &msg_length);
+
+ // If the length is zero, this is a DONE message.
+ if (msg_length == 0) {
+ completed_ranks++;
+ MPI_Recv(NULL, 0, MPI_BYTE, source_rank, TAG_NOTIFY, MPI_COMM_WORLD,
+ &status);
+ continue;
+ }
+
+ // Get tensor name from MPI into an std::string.
+ char* buffer = new char[msg_length];
+ MPI_Recv(buffer, msg_length, MPI_BYTE, source_rank, TAG_NOTIFY,
+ MPI_COMM_WORLD, &status);
+ std::string received_data(buffer);
+ delete[] buffer;
+
+ MPIRequest received_message;
+ received_message.ParseFromString(received_data);
+ auto received_name = received_message.tensor_name();
+
+ bool reduce = IncrementTensorCount(mpi_global.message_table,
+ received_message, size);
+ if (reduce) {
+ ready_to_reduce.push_back(received_name);
+ }
+ }
+
+ // At this point, rank zero should have a fully updated tensor
+ // count table and should know all the tensors that need to be
+ // reduced or gathered, and everyone else should have sent all
+ // their information to rank zero. We can now do reductions and
+ // gathers; rank zero will choose which ones and in what order,
+ // and will notify the other ranks before doing each reduction.
+ for (int i = 0; i < ready_to_reduce.size(); i++) {
+ // Notify all nodes which tensor we'd like to reduce now
+ auto name = ready_to_reduce[i];
+ MPIResponse response =
+ ConstructMPIResponse(mpi_global.message_table, name);
+
+ std::string encoded_response;
+ response.SerializeToString(&encoded_response);
+ for (int r = 1; r < size; r++) {
+ MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
+ MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
+ }
+
+ // Perform the reduction. All nodes should end up performing
+ // the same reduction.
+ PerformCollectiveOp(mpi_global.tensor_table, response);
+ }
+
+ // Notify all nodes that we are done with the reductions for this
+ // tick.
+ MPIResponse done_response;
+ should_shut_down = mpi_global.shut_down;
+ done_response.set_response_type(
+ mpi_global.shut_down ? MPIResponse::SHUTDOWN : MPIResponse::DONE);
+ std::string encoded_response;
+ done_response.SerializeToString(&encoded_response);
+ for (int r = 1; r < size; r++) {
+ MPI_Send(encoded_response.c_str(), encoded_response.length() + 1,
+ MPI_BYTE, r, TAG_NOTIFY, MPI_COMM_WORLD);
+ }
+ } else {
+ // Notify the coordinator that this node is done sending messages.
+ // A DONE message is encoded as a zero-length message.
+ MPI_Send(NULL, 0, MPI_BYTE, RANK_ZERO, TAG_NOTIFY, MPI_COMM_WORLD);
+
+ // Receive names for tensors to reduce from rank zero. Once we
+ // receive a empty DONE message, stop waiting for more names.
+ while (true) {
+ MPI_Status status;
+ MPI_Probe(0, TAG_NOTIFY, MPI_COMM_WORLD, &status);
+
+ // Find number of characters in message (including zero byte).
+ int msg_length;
+ MPI_Get_count(&status, MPI_BYTE, &msg_length);
+
+ // Get tensor name from MPI into an std::string.
+ char* buffer = new char[msg_length];
+ MPI_Recv(buffer, msg_length, MPI_BYTE, 0, TAG_NOTIFY, MPI_COMM_WORLD,
+ &status);
+ std::string received_message(buffer);
+ delete[] buffer;
+
+ MPIResponse response;
+ response.ParseFromString(received_message);
+ if (response.response_type() == MPIResponse::DONE) {
+ // No more messages this tick
+ break;
+ } else if (response.response_type() == MPIResponse::SHUTDOWN) {
+ // No more messages this tick, and the background thread
+ // should shut down
+ should_shut_down = true;
+ break;
+ } else {
+ // Process the current message
+ PerformCollectiveOp(mpi_global.tensor_table, response);
+ }
+ }
+ }
+ } while (!should_shut_down);
+
+ MPI_Finalize();
+}
+
+// Initialize MPI and start the MPI background thread. Ensure that this is
+// only done once no matter how many times this function is called.
+Status InitializeMPIOnce(bool gpu) {
+ // Ensure MPI is only initialized once.
+ if (mpi_global.initialized_flag.test_and_set()) return mpi_global.init_status;
+
+ mpi_global.device = -1;
+#if GOOGLE_CUDA
+ if (gpu) {
+ cudaGetDevice(&mpi_global.device);
+ }
+#endif
+
+ // Start the MPI background thread, which assumes MPI is initialized
+ // TODO: Change this to a Tensorflow thread
+ mpi_global.background_thread = std::thread(BackgroundThreadLoop);
+
+ // Wait to ensure that the background thread has finished initializing MPI
+ mutex_lock guard(mpi_global.mu);
+ mpi_global.cv.wait(guard);
+ if (!mpi_global.initialization_done) {
+ mpi_global.init_status =
+ errors::Unknown("Failed to wait for MPI initialization.");
+ }
+
+ return mpi_global.init_status;
+}
+
+// Check that MPI is initialized.
+Status IsMPIInitialized() {
+ if (!mpi_global.initialization_done) {
+ return errors::FailedPrecondition(
+ "MPI has not been initialized; use tf.contrib.mpi.Session.");
+ }
+ return Status::OK();
+}
+
+// This function (called from the callback set up in MPIAll*Op::ComputeAsync)
+// only adds the op's record into the local op queue (to track the op's
+// progress), and sends a message to the coordinator indicating that this rank
+// is ready to begin. The MPI background thread will handle the MPI message.
+void EnqueueTensorCollective(CollectiveOpRecord record,
+ MPIRequest::RequestType rtype) {
+ const Tensor* input_tensor = record.in_t;
+ MPIRequest message;
+ message.set_request_rank(record.rank);
+ message.set_tensor_name(record.name);
+ message.set_tensor_type(record.dtype);
+ message.set_request_type(rtype);
+ input_tensor->shape().AsProto(message.mutable_tensor_shape());
+
+ mutex_lock guard(mpi_global.mu);
+ mpi_global.tensor_table.emplace(record.name, record);
+ mpi_global.message_queue.push(message);
+}
+
+} // namespace
+
+#if GOOGLE_CUDA
+cudaStream_t CudaStreamForMPI() { return mpi_global.stream; }
+#endif
+
+// Op to initialize MPI in the current process. The settings used in the
+// configuration are the same that must be used for all future MPI ops.
+template <typename Device>
+class MPIInitOp : public OpKernel {
+ public:
+ explicit MPIInitOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ bool on_gpu = IsGPUDevice<Device>();
+ OP_REQUIRES_OK(context, InitializeMPIOnce(on_gpu));
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_CPU),
+ MPIInitOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPIInit").Device(DEVICE_GPU),
+ MPIInitOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPIInit").Doc(R"doc(
+Initialize MPI for the current process.
+
+If this is run on a GPU, then that GPU must be used for all future MPI
+operations. If it is run on CPU, then all future MPI operations must also
+run on CPU.
+)doc");
+
+// Op to get the current MPI Size.
+template <typename Device>
+class MPISizeOp : public OpKernel {
+ public:
+ explicit MPISizeOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES_OK(context, IsMPIInitialized());
+
+ // Write integer to output tensor
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+
+ auto flat = output->flat<int>();
+ flat(0) = mpi_global.size;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_CPU),
+ MPISizeOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPISize").Device(DEVICE_GPU).HostMemory("size"),
+ MPISizeOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPISize")
+ .Output("size: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the number of running MPI processes.
+
+More precisely, returns the number of MPI processes in the group associated
+with the MPI_COMM_WORLD communicator.
+
+size: Size of the MPI group.
+)doc");
+
+// Op to get the current MPI Rank.
+template <typename Device>
+class MPIRankOp : public OpKernel {
+ public:
+ explicit MPIRankOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES_OK(context, IsMPIInitialized());
+
+ // Write integer to output tensor
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+
+ auto flat = output->flat<int>();
+ flat(0) = mpi_global.rank;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_CPU),
+ MPIRankOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPIRank").Device(DEVICE_GPU).HostMemory("rank"),
+ MPIRankOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPIRank")
+ .Output("rank: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the index of the current process in the MPI group.
+
+More precisely, returns the rank of the calling process in the MPI_COMM_WORLD
+communicator.
+
+rank: Rank of the calling process.
+)doc");
+
+// Op to get the current local MPI Rank.
+template <typename Device>
+class MPILocalRankOp : public OpKernel {
+ public:
+ explicit MPILocalRankOp(OpKernelConstruction* context) : OpKernel(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ OP_REQUIRES_OK(context, IsMPIInitialized());
+
+ // Write integer to output tensor
+ Tensor* output;
+ OP_REQUIRES_OK(context,
+ context->allocate_output(0, TensorShape({}), &output));
+
+ auto flat = output->flat<int>();
+ flat(0) = mpi_global.local_rank;
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPILocalRank").Device(DEVICE_CPU),
+ MPILocalRankOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(
+ Name("MPILocalRank").Device(DEVICE_GPU).HostMemory("rank"),
+ MPILocalRankOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPILocalRank")
+ .Output("rank: int32")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Returns the index of the current process in the node it is on.
+
+More precisely, returns the rank of the calling process in communicator that
+only spans the MPI processes running on that node.
+
+rank: Rank of the calling process on the node it is on.
+)doc");
+
+template <typename Device>
+class MPIAllreduceOp : public AsyncOpKernel {
+ public:
+ explicit MPIAllreduceOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+ // Although this op is handled asynchronously, the ComputeAsync call is
+ // very inexpensive. It only sets up a CollectiveOpRecord and places it
+ // in the table for the background thread to handle. Thus, we do not need
+ // a TF pool thread to perform the op.
+ bool IsExpensive() override { return false; }
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
+ const Tensor* input_tensor = &context->input(0);
+ Tensor* output_tensor;
+ OP_REQUIRES_OK_ASYNC(
+ context,
+ context->allocate_output(0, input_tensor->shape(), &output_tensor),
+ done);
+
+ // Record allocated on stack so op can fail without memory leak
+ CollectiveOpRecord record;
+ record.name = name();
+ record.context = context;
+ record.in_t = input_tensor;
+ record.out_t = output_tensor;
+ record.on_gpu = IsGPUDevice<Device>();
+ record.dtype = input_tensor->dtype();
+
+ const size_t temp_size =
+ (input_tensor->NumElements() + mpi_global.size - 1) / mpi_global.size;
+ TensorShape temp_shape;
+ temp_shape.AddDim(temp_size);
+ OP_REQUIRES_OK_ASYNC(context,
+ context->allocate_temp(input_tensor->dtype(),
+ temp_shape, &record.temp_t),
+ done);
+
+ auto allreduce_done_callback = [done, context](StatusOr<Tensor> status) {
+ context->SetStatus(status.status());
+ done();
+ };
+ record.callback = allreduce_done_callback;
+
+ auto allreduce_launch_callback = [record] {
+ EnqueueTensorCollective(record, MPIRequest::ALLREDUCE);
+ };
+
+ // If we are on a CPU, our device context will be null and we can't
+ // get a stream to enqueue this on. On a CPU this op is called when the
+ // data is already available, so we can just immediately do the
+ // allreduce; we don't have to wait for the data to get populated.
+#if GOOGLE_CUDA
+ auto device_context = context->op_device_context();
+ if (device_context == nullptr) {
+ allreduce_launch_callback();
+ } else {
+ auto stream = device_context->stream();
+ stream->ThenDoHostCallback(allreduce_launch_callback);
+ }
+#else
+ allreduce_launch_callback();
+#endif
+ }
+};
+
+REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_CPU),
+ MPIAllreduceOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(Name("MPIAllreduce").Device(DEVICE_GPU),
+ MPIAllreduceOp<GPUDevice>);
+#endif
+
+REGISTER_OP("MPIAllreduce")
+ .Attr("T: {int32, int64, float32}")
+ .Input("tensor: T")
+ .Output("sum: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ c->set_output(0, c->input(0));
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Perform an MPI Allreduce on a tensor. All other processes that do a reduction
+on a tensor with the same name must have the same dimension for that tensor.
+Tensors are reduced with other tensors that have the same node name for the
+allreduce.
+
+Arguments
+ tensor: A tensor to reduce.
+
+Output
+ sum: A tensor with the same shape as `tensor`, summed across all
+ MPI processes.
+)doc");
+
+template <typename Device>
+class MPIAllgatherOp : public AsyncOpKernel {
+ public:
+ explicit MPIAllgatherOp(OpKernelConstruction* context)
+ : AsyncOpKernel(context) {}
+
+ // Although this op is handled asynchronously, the ComputeAsync call is
+ // very inexpensive. It only sets up a CollectiveOpRecord and places it
+ // in the table for the background thread to handle. Thus, we do not need
+ // a TF pool thread to perform the op.
+ bool IsExpensive() override { return false; }
+
+ void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
+ OP_REQUIRES_OK_ASYNC(context, IsMPIInitialized(), done);
+ const Tensor* input_tensor = &context->input(0);
+ const Tensor* sizing_tensor = &context->input(1);
+
+ // Record allocated on stack so op can fail without memory leak
+ CollectiveOpRecord record;
+ record.name = name();
+ record.context = context;
+ record.in_t = input_tensor;
+ record.on_gpu = IsGPUDevice<Device>();
+
+ // Construct the output size from the sizing tensor
+ size_t output_first_dim = 0;
+ if (sizing_tensor->shape().dims() == 0) {
+ // 0-dim sizing_tensor implies that the op is just gathering
+ // a single element from each rank
+ output_first_dim = mpi_global.size;
+ for (int i = 0; i < mpi_global.size; i++) {
+ record.sizes_vec.push_back(1);
+ }
+ } else {
+ // Collect the total output tensor sizing from the sizing tensor
+ // NOTE: The sizing tensor is forced to be placed on the CPU by
+ // declaring the input as HostMemory, so it is valid to read it here.
+ const int64* sizing_array =
+ (const int64*)sizing_tensor->tensor_data().data();
+ for (int i = 0; i < mpi_global.size; i++) {
+ record.sizes_vec.push_back(sizing_array[i]);
+ output_first_dim += sizing_array[i];
+ }
+ }
+
+ TensorShape output_shape;
+ output_shape.AddDim(output_first_dim);
+ for (int i = 1; i < input_tensor->shape().dims(); i++) {
+ output_shape.AddDim(input_tensor->shape().dim_size(i));
+ }
+
+ Tensor* output_tensor;
+ OP_REQUIRES_OK_ASYNC(
+ context, context->allocate_output(0, output_shape, &output_tensor),
+ done);
+
+ record.out_t = output_tensor;
+ record.dtype = input_tensor->dtype();
+
+ auto allgather_done_callback = [done, context](StatusOr<Tensor> status) {
+ context->SetStatus(status.status());
+ done();
+ };
+ record.callback = allgather_done_callback;
+
+ auto allgather_launch_callback = [record] {
+ EnqueueTensorCollective(record, MPIRequest::ALLGATHER);
+ };
+
+ // If we are on a CPU, our device context will be null and we can't
+ // get a stream to enqueue this on. On a CPU this op is called when the
+ // data is already available, so we can just immediately do the
+ // allgather; we don't have to wait for the data to get populated.
+#if GOOGLE_CUDA
+ auto device_context = context->op_device_context();
+ if (device_context == nullptr) {
+ allgather_launch_callback();
+ } else {
+ auto stream = device_context->stream();
+ stream->ThenDoHostCallback(allgather_launch_callback);
+ }
+#else
+ allgather_launch_callback();
+#endif
+ }
+};
+
+REGISTER_OP("MPIAllgather")
+ .Attr("T: {int32, int64, float32}")
+ .Attr("S: {int64}")
+ .Input("tensor: T")
+ .Input("sizes: S")
+ .Output("gathered: T")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle output;
+ TF_RETURN_IF_ERROR(
+ c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &output));
+ c->set_output(0, output);
+ return Status::OK();
+ })
+ .Doc(R"doc(
+Perform an MPI Allgather on a tensor. All other processes that do a gather on a
+tensor with the same name must have the same rank for that tensor, and have the
+same dimension on all but the first dimension.
+
+Arguments
+ tensor: A tensor to gather.
+ sizes: A tensor containing the first-dimension sizes of tensors to be
+ gathered from other ranks
+
+Output
+ gathered: A tensor with the same shape as `tensor` except for the first
+ dimension, which is the sum of dimensions in `sizes`.
+)doc");
+
+REGISTER_KERNEL_BUILDER(
+ Name("MPIAllgather").Device(DEVICE_CPU).HostMemory("sizes"),
+ MPIAllgatherOp<CPUDevice>);
+#if GOOGLE_CUDA
+REGISTER_KERNEL_BUILDER(
+ Name("MPIAllgather").Device(DEVICE_GPU).HostMemory("sizes"),
+ MPIAllgatherOp<GPUDevice>);
+#endif
+
+} // namespace mpi
+} // namespace contrib
+} // namespace tensorflow
+
+#endif // TENSORFLOW_USE_MPI
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
index b9b482a698..b1cb89391c 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
@@ -24,7 +24,7 @@ limitations under the License.
namespace tensorflow {
-using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
+using se::cuda::ScopedActivateExecutorContext;
// Contains data for a single stream used for nccl communication; this includes
// a background thread that calls NcclManager::LoopKernelLaunches.
@@ -37,11 +37,11 @@ struct NcclManager::NcclStream {
cv.notify_all();
}
- perftools::gputools::StreamExecutor* executor = nullptr;
+ se::StreamExecutor* executor = nullptr;
// The stream on which to run the nccl collective.
// This is a different stream than the tensorflow compute stream.
- std::unique_ptr<perftools::gputools::Stream> stream;
+ std::unique_ptr<se::Stream> stream;
// See NcclManager::LoopKernelLaunches for information on these.
std::unique_ptr<Thread> thread;
@@ -95,9 +95,8 @@ ncclDataType_t ToNcclType(DataType t) {
// A participant in a Collective. See <Collective> below.
struct NcclManager::Participant {
Participant(const Tensor* in_t, Tensor* out_t, EventMgr* event_mgr,
- perftools::gputools::Stream* tensor_stream,
- perftools::gputools::StreamExecutor* executor, int gpu_device_id,
- NcclManager::DoneCallback done_callback)
+ se::Stream* tensor_stream, se::StreamExecutor* executor,
+ int gpu_device_id, NcclManager::DoneCallback done_callback)
: in_t(in_t),
out_t(out_t),
event_mgr(event_mgr),
@@ -121,11 +120,11 @@ struct NcclManager::Participant {
EventMgr* const event_mgr;
// Owned by the caller, who must keep it live until <done_callback> is called.
- perftools::gputools::Stream* const tensor_stream;
+ se::Stream* const tensor_stream;
// Matches the executor in CommunicatorMember::stream. Expected to be live for
// process lifetime.
- perftools::gputools::StreamExecutor* const executor = nullptr;
+ se::StreamExecutor* const executor = nullptr;
const int gpu_device_id;
@@ -245,7 +244,7 @@ NcclManager::Communicator* NcclManager::GetCommunicator(
if (nccl_stream == nullptr) {
nccl_stream = new NcclStream();
nccl_stream->executor = executor;
- nccl_stream->stream.reset(new perftools::gputools::Stream(executor));
+ nccl_stream->stream.reset(new se::Stream(executor));
nccl_stream->stream->Init();
streams.emplace_back(nccl_stream);
@@ -300,10 +299,10 @@ NcclManager::Communicator* NcclManager::GetCommunicator(
void NcclManager::AddToAllReduce(int num_devices, const string& key,
ncclRedOp_t reduction_op,
- perftools::gputools::StreamExecutor* executor,
+ se::StreamExecutor* executor,
int gpu_device_id, EventMgr* event_mgr,
- perftools::gputools::Stream* tensor_stream,
- const Tensor* in_t, Tensor* out_t,
+ se::Stream* tensor_stream, const Tensor* in_t,
+ Tensor* out_t,
const DoneCallback& done_callback) {
std::unique_ptr<Participant> participant(
new Participant(in_t, out_t, event_mgr, tensor_stream, executor,
@@ -312,11 +311,12 @@ void NcclManager::AddToAllReduce(int num_devices, const string& key,
kAllReduce, reduction_op);
}
-void NcclManager::AddBroadcastSend(
- int num_devices, const string& key,
- perftools::gputools::StreamExecutor* executor, int gpu_device_id,
- EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream,
- const Tensor* in_t, DoneCallback done_callback) {
+void NcclManager::AddBroadcastSend(int num_devices, const string& key,
+ se::StreamExecutor* executor,
+ int gpu_device_id, EventMgr* event_mgr,
+ se::Stream* tensor_stream,
+ const Tensor* in_t,
+ DoneCallback done_callback) {
std::unique_ptr<Participant> participant(
new Participant(in_t, nullptr /* out_t */, event_mgr, tensor_stream,
executor, gpu_device_id, std::move(done_callback)));
@@ -325,11 +325,11 @@ void NcclManager::AddBroadcastSend(
kBroadcast, ncclSum /* unused */);
}
-void NcclManager::AddBroadcastRecv(
- int num_devices, const string& key,
- perftools::gputools::StreamExecutor* executor, int gpu_device_id,
- EventMgr* event_mgr, perftools::gputools::Stream* tensor_stream,
- Tensor* out_t, DoneCallback done_callback) {
+void NcclManager::AddBroadcastRecv(int num_devices, const string& key,
+ se::StreamExecutor* executor,
+ int gpu_device_id, EventMgr* event_mgr,
+ se::Stream* tensor_stream, Tensor* out_t,
+ DoneCallback done_callback) {
std::unique_ptr<Participant> participant(
new Participant(nullptr /* in_t */, out_t, event_mgr, tensor_stream,
executor, gpu_device_id, std::move(done_callback)));
@@ -339,9 +339,8 @@ void NcclManager::AddBroadcastRecv(
void NcclManager::AddReduceSend(int num_devices, const string& key,
ncclRedOp_t reduction_op,
- perftools::gputools::StreamExecutor* executor,
- int gpu_device_id, EventMgr* event_mgr,
- perftools::gputools::Stream* tensor_stream,
+ se::StreamExecutor* executor, int gpu_device_id,
+ EventMgr* event_mgr, se::Stream* tensor_stream,
const Tensor* in_t,
DoneCallback done_callback) {
std::unique_ptr<Participant> participant(
@@ -353,9 +352,8 @@ void NcclManager::AddReduceSend(int num_devices, const string& key,
void NcclManager::AddReduceRecv(int num_devices, const string& key,
ncclRedOp_t reduction_op,
- perftools::gputools::StreamExecutor* executor,
- int gpu_device_id, EventMgr* event_mgr,
- perftools::gputools::Stream* tensor_stream,
+ se::StreamExecutor* executor, int gpu_device_id,
+ EventMgr* event_mgr, se::Stream* tensor_stream,
const Tensor* in_t, Tensor* out_t,
DoneCallback done_callback) {
std::unique_ptr<Participant> participant(
@@ -444,7 +442,7 @@ void NcclManager::RunCollective(const string& key, Collective* collective) {
}
void NcclManager::LoopKernelLaunches(NcclStream* nccl_stream) {
- perftools::gputools::Stream* comm_stream = nccl_stream->stream.get();
+ se::Stream* comm_stream = nccl_stream->stream.get();
ScopedActivateExecutorContext scoped_context(nccl_stream->executor);
const cudaStream_t* cu_stream = reinterpret_cast<const cudaStream_t*>(
comm_stream->implementation()->CudaStreamMemberHack());
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.h b/tensorflow/contrib/nccl/kernels/nccl_manager.h
index 6ff8cea84e..57a96c5d33 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager.h
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.h
@@ -55,41 +55,34 @@ class NcclManager {
// is also the stream that will use the produced data; <done_callback> is
// not called until the next kernel launched on <stream> would see the data.
void AddToAllReduce(int num_devices, const string& key,
- ncclRedOp_t reduction_op,
- perftools::gputools::StreamExecutor* executor,
+ ncclRedOp_t reduction_op, se::StreamExecutor* executor,
int gpu_device_id, EventMgr* event_mgr,
- perftools::gputools::Stream* tensor_stream,
- const Tensor* in_t, Tensor* out_t,
- const DoneCallback& done_callback);
+ se::Stream* tensor_stream, const Tensor* in_t,
+ Tensor* out_t, const DoneCallback& done_callback);
// AddBroadcastSend and AddBroadcastRecv combine to sent data from one sender
// to all receivers.
void AddBroadcastSend(int num_devices, const string& key,
- perftools::gputools::StreamExecutor* executor,
- int gpu_device_id, EventMgr* event_mgr,
- perftools::gputools::Stream* tensor_stream,
+ se::StreamExecutor* executor, int gpu_device_id,
+ EventMgr* event_mgr, se::Stream* tensor_stream,
const Tensor* in_t, DoneCallback done_callback);
void AddBroadcastRecv(int num_devices, const string& key,
- perftools::gputools::StreamExecutor* executor,
- int gpu_device_id, EventMgr* event_mgr,
- perftools::gputools::Stream* tensor_stream,
+ se::StreamExecutor* executor, int gpu_device_id,
+ EventMgr* event_mgr, se::Stream* tensor_stream,
Tensor* out_t, DoneCallback done_callback);
// AddReduceSend and AddReduceRecv combine to sent data from all senders
// to one receiver.
void AddReduceSend(int num_devices, const string& key,
- ncclRedOp_t reduction_op,
- perftools::gputools::StreamExecutor* executor,
+ ncclRedOp_t reduction_op, se::StreamExecutor* executor,
int gpu_device_id, EventMgr* event_mgr,
- perftools::gputools::Stream* tensor_stream,
- const Tensor* in_t, DoneCallback done_callback);
+ se::Stream* tensor_stream, const Tensor* in_t,
+ DoneCallback done_callback);
void AddReduceRecv(int num_devices, const string& key,
- ncclRedOp_t reduction_op,
- perftools::gputools::StreamExecutor* executor,
+ ncclRedOp_t reduction_op, se::StreamExecutor* executor,
int gpu_device_id, EventMgr* event_mgr,
- perftools::gputools::Stream* tensor_stream,
- const Tensor* in_t, Tensor* out_t,
- DoneCallback done_callback);
+ se::Stream* tensor_stream, const Tensor* in_t,
+ Tensor* out_t, DoneCallback done_callback);
private:
enum CollectiveType {
@@ -123,8 +116,7 @@ class NcclManager {
// Maps a device to the communication streams that make up its collective.
// This is used to share the stream across different communicators that
// include the same device.
- std::map<perftools::gputools::StreamExecutor*,
- std::vector<std::unique_ptr<NcclStream>>>
+ std::map<se::StreamExecutor*, std::vector<std::unique_ptr<NcclStream>>>
device_to_comm_streams_ GUARDED_BY(mu_);
std::vector<std::unique_ptr<Communicator>> communicators_;
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
index 06ca65e33a..4d8d922cb4 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
@@ -175,11 +175,9 @@ class NcclManagerTest : public ::testing::Test {
nullptr /* step_resource_manager */);
}
- static perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
- const Scalar* cuda_memory) {
- perftools::gputools::DeviceMemoryBase wrapped(
- const_cast<Scalar*>(cuda_memory));
- perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
+ static se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* cuda_memory) {
+ se::DeviceMemoryBase wrapped(const_cast<Scalar*>(cuda_memory));
+ se::DeviceMemory<Scalar> typed(wrapped);
return typed;
}
diff --git a/tensorflow/contrib/opt/BUILD b/tensorflow/contrib/opt/BUILD
index 612ecc3e63..13aa1d7e7a 100644
--- a/tensorflow/contrib/opt/BUILD
+++ b/tensorflow/contrib/opt/BUILD
@@ -25,6 +25,7 @@ py_library(
"python/training/multitask_optimizer_wrapper.py",
"python/training/nadam_optimizer.py",
"python/training/powersign.py",
+ "python/training/reg_adagrad_optimizer.py",
"python/training/sign_decay.py",
"python/training/variable_clipping_optimizer.py",
],
@@ -156,6 +157,25 @@ py_test(
)
py_test(
+ name = "reg_adagrad_optimizer_test",
+ srcs = ["python/training/reg_adagrad_optimizer_test.py"],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":opt_py",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:embedding_ops",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:resource_variable_ops",
+ "//tensorflow/python:variable_scope",
+ "//tensorflow/python:variables",
+ "//third_party/py/numpy",
+ ],
+)
+
+py_test(
name = "nadam_optimizer_test",
srcs = ["python/training/nadam_optimizer_test.py"],
srcs_version = "PY2AND3",
diff --git a/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer.py b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer.py
new file mode 100644
index 0000000000..d0e0405a2c
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer.py
@@ -0,0 +1,107 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""RegAdagrad for TensorFlow."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.ops import math_ops
+from tensorflow.python.training import adagrad
+from tensorflow.python.training import training_ops
+from tensorflow.python.util import tf_contextlib
+
+
+class RegAdagradOptimizer(adagrad.AdagradOptimizer):
+ """RegAdagrad: Adagrad with updates that optionally skip updating the slots.
+
+ This is meant to address the problem of additional regularization terms in the
+ loss function affecting learning rate decay and causing hyper-param
+ entanglement. Example usage:
+
+ loss = tf.nn.cross_entropy(x, labels)
+ reg_loss = reg_strength * tf.reduce_sum(x * x)
+ opt = tf.contrib.opt.RegAdagradOptimizer(learning_rate)
+ loss_update = opt.minimize(loss)
+ with opt.avoid_updating_slots():
+ reg_update = opt.minimize(reg_loss)
+ total_update = tf.group([loss_update, reg_update])
+
+ # ...
+
+ sess.run(total_update, ...)
+ """
+
+ def __init__(self,
+ learning_rate,
+ initial_accumulator_value=0.1,
+ use_locking=False,
+ name="RegAdagrad"):
+ super(RegAdagradOptimizer, self).__init__(
+ learning_rate,
+ initial_accumulator_value=initial_accumulator_value,
+ use_locking=use_locking,
+ name=name)
+ self._should_update_slots = True
+
+ @tf_contextlib.contextmanager
+ def avoid_updating_slots(self):
+ old = self._should_update_slots
+ self._should_update_slots = False
+ try:
+ yield
+ finally:
+ self._should_update_slots = old
+
+ def _apply_dense(self, grad, var):
+ acc = self.get_slot(var, "accumulator")
+ return training_ops.apply_adagrad(
+ var,
+ acc,
+ math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking,
+ update_slots=self._should_update_slots)
+
+ def _resource_apply_dense(self, grad, var, update_slots=True):
+ acc = self.get_slot(var, "accumulator")
+ return training_ops.resource_apply_adagrad(
+ var.handle,
+ acc.handle,
+ math_ops.cast(self._learning_rate_tensor, grad.dtype.base_dtype),
+ grad,
+ use_locking=self._use_locking,
+ update_slots=self._should_update_slots)
+
+ def _apply_sparse(self, grad, var, update_slots=True):
+ acc = self.get_slot(var, "accumulator")
+ return training_ops.sparse_apply_adagrad(
+ var,
+ acc,
+ math_ops.cast(self._learning_rate_tensor, var.dtype.base_dtype),
+ grad.values,
+ grad.indices,
+ use_locking=self._use_locking,
+ update_slots=self._should_update_slots)
+
+ def _resource_apply_sparse(self, grad, var, indices, update_slots=True):
+ acc = self.get_slot(var, "accumulator")
+ return training_ops.resource_sparse_apply_adagrad(
+ var.handle,
+ acc.handle,
+ math_ops.cast(self._learning_rate_tensor, grad.dtype),
+ grad,
+ indices,
+ use_locking=self._use_locking,
+ update_slots=self._should_update_slots)
diff --git a/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py
new file mode 100644
index 0000000000..ea56e1646a
--- /dev/null
+++ b/tensorflow/contrib/opt/python/training/reg_adagrad_optimizer_test.py
@@ -0,0 +1,343 @@
+# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for Regreg_adagrad_optimizer."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.contrib.opt.python.training import reg_adagrad_optimizer
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import embedding_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variable_scope
+from tensorflow.python.ops import variables
+from tensorflow.python.platform import test
+
+
+class RegAdagradOptimizerTest(test.TestCase):
+
+ def doTestBasic(self, use_locking=False, use_resource=False):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ if use_resource:
+ var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
+ var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
+ else:
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = reg_adagrad_optimizer.RegAdagradOptimizer(
+ 3.0, initial_accumulator_value=0.1, use_locking=use_locking)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testBasic(self):
+ self.doTestBasic(use_locking=False)
+
+ def testBasicResource(self):
+ self.doTestBasic(use_locking=False, use_resource=True)
+
+ def testBasicLocked(self):
+ self.doTestBasic(use_locking=True)
+
+ def testMinimizeSparseResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ var0 = resource_variable_ops.ResourceVariable(
+ [[1.0, 2.0], [3.0, 4.0]], dtype=dtype)
+ x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
+ pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
+ loss = pred * pred
+ sgd_op = reg_adagrad_optimizer.RegAdagradOptimizer(1.0).minimize(loss)
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllCloseAccordingToType([[1.0, 2.0], [3.0, 4.0]],
+ var0.eval())
+ # Run 1 step of sgd
+ sgd_op.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ [[0, 1], [3, 4]], var0.eval(), atol=0.01)
+
+ def testTensorLearningRate(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = reg_adagrad_optimizer.RegAdagradOptimizer(
+ constant_op.constant(3.0), initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Run 3 steps of adagrad
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testSparseBasic(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant([0.01], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
+ ada_opt = reg_adagrad_optimizer.RegAdagradOptimizer(
+ 3.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([[1.0], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0]], var1.eval())
+ # Run 3 step of sgd
+ for _ in range(3):
+ ada_update.run()
+ # Validate updated params
+ self.assertAllCloseAccordingToType(
+ np.array([[-1.6026098728179932], [2.0]]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([[3.0], [3.715679168701172]]), var1.eval())
+
+ def testSparseRepeatedIndices(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ repeated_index_update_var = variables.Variable(
+ [[1.0], [2.0]], dtype=dtype)
+ aggregated_update_var = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ grad_repeated_index = ops.IndexedSlices(
+ constant_op.constant([0.1, 0.1], shape=[2, 1], dtype=dtype),
+ constant_op.constant([1, 1]), constant_op.constant([2, 1]))
+ grad_aggregated = ops.IndexedSlices(
+ constant_op.constant([0.2], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
+ repeated_update = reg_adagrad_optimizer.RegAdagradOptimizer(
+ 3.0).apply_gradients([(grad_repeated_index,
+ repeated_index_update_var)])
+ aggregated_update = reg_adagrad_optimizer.RegAdagradOptimizer(
+ 3.0).apply_gradients([(grad_aggregated, aggregated_update_var)])
+ variables.global_variables_initializer().run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+ for _ in range(3):
+ repeated_update.run()
+ aggregated_update.run()
+ self.assertAllClose(aggregated_update_var.eval(),
+ repeated_index_update_var.eval())
+
+ def testSparseRepeatedIndicesResourceVariable(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ var_repeated = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype)
+ loss_repeated = math_ops.reduce_sum(
+ embedding_ops.embedding_lookup(var_repeated, [0, 0]))
+ var_aggregated = resource_variable_ops.ResourceVariable(
+ [1.0, 2.0], dtype=dtype)
+ loss_aggregated = 2 * math_ops.reduce_sum(
+ embedding_ops.embedding_lookup(var_aggregated, [0]))
+ update_op_repeated = reg_adagrad_optimizer.RegAdagradOptimizer(
+ 2.0).minimize(loss_repeated)
+ update_op_aggregated = reg_adagrad_optimizer.RegAdagradOptimizer(
+ 2.0).minimize(loss_aggregated)
+ variables.global_variables_initializer().run()
+ self.assertAllCloseAccordingToType(var_repeated.eval(),
+ var_aggregated.eval())
+ for _ in range(3):
+ update_op_repeated.run()
+ update_op_aggregated.run()
+ self.assertAllCloseAccordingToType(var_repeated.eval(),
+ var_aggregated.eval())
+
+ def testSparseStability(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ shape = [1, 6]
+ var0 = variables.Variable(
+ [[
+ 0.00872496, -0.106952, 0.110467, 0.226505, -0.0147257,
+ -0.0105945
+ ]],
+ dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant(
+ [[
+ -5.91278e-05, 5.31673e-05, -2.5779e-06, 4.29153e-05,
+ -8.4877e-05, -9.48906e-05
+ ]],
+ shape=shape,
+ dtype=dtype), constant_op.constant([0]),
+ constant_op.constant(shape))
+ ada_opt = reg_adagrad_optimizer.RegAdagradOptimizer(
+ 1.0, initial_accumulator_value=0.1)
+ ada_update = ada_opt.apply_gradients(zip([grads0], [var0]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ init = variables.global_variables_initializer()
+ for _ in range(100):
+ init.run()
+ ada_update.run()
+ self.assertAllCloseAccordingToType(
+ np.array([[0.1, 0.1, 0.1, 0.1, 0.1, 0.1]]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([[
+ 0.00891194, -0.10712013, 0.11047515, 0.22636929, -0.0144573,
+ -0.01029443
+ ]]), var0.eval())
+
+ def testSharing(self):
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = reg_adagrad_optimizer.RegAdagradOptimizer(3.0)
+ # Apply the optimizer twice. Both applications will use
+ # the same accums.
+ ada_update1 = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ ada_update2 = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = ada_opt.get_slot(var1, "accumulator")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values.
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Mix the first and the second adagrad for 3 steps.
+ ada_update1.run()
+ ada_update2.run()
+ ada_update1.run()
+ # Validate updated params (the same as with only 1 RegAdagrad).
+ self.assertAllCloseAccordingToType(
+ np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([2.715679168701172, 3.715679168701172]), var1.eval())
+
+ def testDynamicShapeVariable_Ok(self):
+ with self.test_session():
+ v = variable_scope.get_variable(
+ "v", initializer=constant_op.constant(1.), validate_shape=False)
+ self.assertFalse(v.shape.is_fully_defined())
+ # Creating optimizer should cause no exception.
+ reg_adagrad_optimizer.RegAdagradOptimizer(
+ 3.0, initial_accumulator_value=0.1)
+
+ def testSkipUpdatingSlots(self):
+ iav = 0.130005 # A value that works with float16
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ var0 = variables.Variable([1.0, 2.0], dtype=dtype)
+ var1 = variables.Variable([3.0, 4.0], dtype=dtype)
+ grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
+ grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
+ ada_opt = reg_adagrad_optimizer.RegAdagradOptimizer(
+ 3.0, initial_accumulator_value=iav)
+ # Apply the optimizer twice. Both applications will use
+ # the same accums.
+ with ada_opt.avoid_updating_slots():
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ self.assertEqual(["accumulator"], ada_opt.get_slot_names())
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = ada_opt.get_slot(var1, "accumulator")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+ variables.global_variables_initializer().run()
+
+ # Fetch params to validate initial values.
+ self.assertAllClose([1.0, 2.0], var0.eval())
+ self.assertAllClose([3.0, 4.0], var1.eval())
+ # Mix the first and the second adagrad for 3 steps.
+ for _ in range(3):
+ ada_update.run()
+ # Validate that ada_opt's slots are not updated.
+ self.assertAllCloseAccordingToType(np.array([iav, iav]), slot0.eval())
+ self.assertAllCloseAccordingToType(np.array([iav, iav]), slot1.eval())
+
+ def testSparseSkipUpdatingSlots(self):
+ iav = 0.130005 # A value that works with float16
+ for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
+ with self.test_session():
+ var0 = variables.Variable([[1.0], [2.0]], dtype=dtype)
+ var1 = variables.Variable([[3.0], [4.0]], dtype=dtype)
+ grads0 = ops.IndexedSlices(
+ constant_op.constant([0.1], shape=[1, 1], dtype=dtype),
+ constant_op.constant([0]), constant_op.constant([2, 1]))
+ grads1 = ops.IndexedSlices(
+ constant_op.constant([0.01], shape=[1, 1], dtype=dtype),
+ constant_op.constant([1]), constant_op.constant([2, 1]))
+ ada_opt = reg_adagrad_optimizer.RegAdagradOptimizer(
+ 3.0, initial_accumulator_value=iav)
+ with ada_opt.avoid_updating_slots():
+ ada_update = ada_opt.apply_gradients(
+ zip([grads0, grads1], [var0, var1]))
+ slot0 = ada_opt.get_slot(var0, "accumulator")
+ self.assertEquals(slot0.get_shape(), var0.get_shape())
+ slot1 = ada_opt.get_slot(var1, "accumulator")
+ self.assertEquals(slot1.get_shape(), var1.get_shape())
+
+ variables.global_variables_initializer().run()
+ # Fetch params to validate initial values
+ self.assertAllClose([[1.0], [2.0]], var0.eval())
+ self.assertAllClose([[3.0], [4.0]], var1.eval())
+ # Run 3 step of sgd
+ for _ in range(3):
+ ada_update.run()
+ # Validate that ada_opt's slots are not updated.
+ self.assertAllCloseAccordingToType(
+ np.array([[iav], [iav]]), slot0.eval())
+ self.assertAllCloseAccordingToType(
+ np.array([[iav], [iav]]), slot1.eval())
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
index 8ac9b58145..9e2858d00f 100644
--- a/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
+++ b/tensorflow/contrib/optimizer_v2/checkpointable_utils_test.py
@@ -702,8 +702,7 @@ class CheckpointCompatibilityTests(test.TestCase):
with save_graph.as_default(), self.test_session(
graph=save_graph) as session:
root = self._initialized_model()
- object_saver = checkpointable_utils.CheckpointableSaver(root)
- save_path = object_saver.save(
+ save_path = root.save(
session=session, file_prefix=checkpoint_prefix)
with context.eager_mode():
root = self._initialized_model()
@@ -716,8 +715,7 @@ class CheckpointCompatibilityTests(test.TestCase):
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
with context.eager_mode():
root = self._initialized_model()
- object_saver = checkpointable_utils.CheckpointableSaver(root)
- save_path = object_saver.save(file_prefix=checkpoint_prefix)
+ save_path = root.save(file_prefix=checkpoint_prefix)
with context.graph_mode():
save_graph = ops.Graph()
with save_graph.as_default(), self.test_session(
diff --git a/tensorflow/contrib/optimizer_v2/optimizer_v2.py b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
index ce15db6f1e..46bfbb729f 100644
--- a/tensorflow/contrib/optimizer_v2/optimizer_v2.py
+++ b/tensorflow/contrib/optimizer_v2/optimizer_v2.py
@@ -125,19 +125,6 @@ class _DenseResourceVariableProcessor(_OptimizableVariable):
return update_op
-class _StreamingModelPortProcessor(_OptimizableVariable):
- """Processor for streaming ModelPorts."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g, *args):
- return g
-
-
class _TensorProcessor(_OptimizableVariable):
"""Processor for ordinary Tensors.
@@ -167,8 +154,6 @@ def _get_processor(v):
return _DenseResourceVariableProcessor(v)
if isinstance(v, variables.Variable):
return _RefVariableProcessor(v)
- if v.op.type == "SubmodelPort":
- return _StreamingModelPortProcessor(v)
if isinstance(v, ops.Tensor):
return _TensorProcessor(v)
raise NotImplementedError("Trying to optimize unsupported type ", v)
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms.py b/tensorflow/contrib/quantize/python/fold_batch_norms.py
index aa0ef64308..6f41722748 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms.py
@@ -501,8 +501,27 @@ def _GetBatchNormParams(graph, context, has_scaling):
bn_decay_var_tensor = None
split_context = context.split('/')
- base_context = split_context[-1]
-
+ # Matching variable names is brittle and relies on scoping
+ # conventions. Fused batch norm folding is more robust. Support for unfused
+ # batch norms will be deprecated as we move forward. Fused batch norms allow
+ # for faster training and should be used whenever possible.
+ # context contains part of the names of the tensors we are interested in:
+ # For MobilenetV1, the context has repetitions:
+ # MobilenetV1/MobilenetV1/Conv2d_3_depthwise
+ # when the moving_mean tensor has the name:
+ # MobilenetV1/Conv2d_3_depthwise/BatchNorm/moving_mean/read
+ # To pick the correct variable name, it is necessary to ignore the repeating
+ # header.
+
+ # For MobilenetV2, this problem does not exist:
+ # The context is: MobilenetV2/expanded_conv_3/depthwise
+ # and the names of the tensors start with a single MobilenetV2
+ # The moving mean for example, has the name:
+ # MobilenetV2/expanded_conv_3/depthwise/BatchNorm/moving_mean/read
+ # We ignore the first string (MobilenetV1 or MobilenetV2)
+ # in the context to match correctly in both cases
+
+ base_context = '/'.join(split_context[1:])
oplist = graph.get_operations()
op_suffix_mean = base_context + '/BatchNorm/moments/Squeeze'
op_suffix_variance = base_context + '/BatchNorm/moments/Squeeze_1'
@@ -520,7 +539,6 @@ def _GetBatchNormParams(graph, context, has_scaling):
op_suffix_gamma = base_context + '/BatchNorm/gamma'
op_suffix_moving_variance = base_context + '/BatchNorm/moving_variance/read'
op_suffix_moving_mean = base_context + '/BatchNorm/moving_mean/read'
-
# Parse through list of ops to find relevant ops
for op in oplist:
if op.name.endswith(op_suffix_mean):
diff --git a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
index af31467476..64e8142e7c 100644
--- a/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
+++ b/tensorflow/contrib/quantize/python/fold_batch_norms_test.py
@@ -134,6 +134,85 @@ class FoldBatchNormsTest(test_util.TensorFlowTestCase):
def testFoldConv2d(self):
self._RunTestOverParameters(self._TestFoldConv2d)
+ def testMultipleLayerConv2d(self,
+ relu=nn_ops.relu,
+ relu_op_name='Relu',
+ has_scaling=True,
+ fused_batch_norm=False,
+ freeze_batch_norm_delay=None):
+ """Tests folding cases for a network with multiple layers.
+
+ Args:
+ relu: Callable that returns an Operation, a factory method for the Relu*.
+ relu_op_name: String, name of the Relu* operation.
+ has_scaling: Bool, when true the batch norm has scaling.
+ fused_batch_norm: Bool, when true the batch norm is fused.
+ freeze_batch_norm_delay: None or the number of steps after which training
+ switches to using frozen mean and variance
+ """
+ g = ops.Graph()
+ with g.as_default():
+ batch_size, height, width = 5, 128, 128
+ inputs = array_ops.zeros((batch_size, height, width, 3))
+ out_depth = 3
+ stride = 1
+ activation_fn = relu
+ scope = 'network/expanded_conv_1/conv'
+ layer1 = conv2d(
+ inputs,
+ out_depth, [5, 5],
+ stride=stride,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=activation_fn,
+ normalizer_fn=batch_norm,
+ normalizer_params=self._BatchNormParams(
+ scale=has_scaling, fused=fused_batch_norm),
+ scope=scope)
+ # Add another layer
+ scope = 'network/expanded_conv_2/conv'
+
+ _ = conv2d(
+ layer1,
+ 2 * out_depth, [5, 5],
+ stride=stride,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=activation_fn,
+ normalizer_fn=batch_norm,
+ normalizer_params=self._BatchNormParams(
+ scale=has_scaling, fused=fused_batch_norm),
+ scope=scope)
+
+ fold_batch_norms.FoldBatchNorms(
+ g, is_training=True, freeze_batch_norm_delay=freeze_batch_norm_delay)
+ folded_mul = g.get_operation_by_name(scope + '/mul_fold')
+ self.assertEqual(folded_mul.type, 'Mul')
+ self._AssertInputOpsAre(folded_mul, [
+ scope + '/correction_mult',
+ self._BatchNormMultiplierName(scope, has_scaling, fused_batch_norm)
+ ])
+ self._AssertOutputGoesToOps(folded_mul, g, [scope + '/Conv2D_Fold'])
+
+ folded_conv = g.get_operation_by_name(scope + '/Conv2D_Fold')
+ self.assertEqual(folded_conv.type, 'Conv2D')
+ # Remove :0 at end of name for tensor prior to comparison
+ self._AssertInputOpsAre(folded_conv,
+ [scope + '/mul_fold', layer1.name[:-2]])
+ self._AssertOutputGoesToOps(folded_conv, g, [scope + '/post_conv_mul'])
+
+ folded_add = g.get_operation_by_name(scope + '/add_fold')
+ self.assertEqual(folded_add.type, 'Add')
+ self._AssertInputOpsAre(folded_add, [
+ scope + '/correction_add',
+ self._BathNormBiasName(scope, fused_batch_norm)
+ ])
+ output_op_names = [scope + '/' + relu_op_name]
+ self._AssertOutputGoesToOps(folded_add, g, output_op_names)
+
+ for op in g.get_operations():
+ self.assertFalse('//' in op.name, 'Double slash in op %s' % op.name)
+
def _TestFoldConv2dUnknownShape(self, relu, relu_op_name, with_bypass,
has_scaling, fused_batch_norm,
freeze_batch_norm_delay):
diff --git a/tensorflow/contrib/quantize/python/quantize.py b/tensorflow/contrib/quantize/python/quantize.py
index d2d0426d23..efc1a94b3c 100644
--- a/tensorflow/contrib/quantize/python/quantize.py
+++ b/tensorflow/contrib/quantize/python/quantize.py
@@ -133,19 +133,27 @@ def Quantize(graph,
bits=activation_bits,
producer_scope=scope,
consumer_scope=scope)
- _InsertQuantOp(
- add_context,
- 'add_quant',
- layer_match.bypass_op,
- input_to_ops_map.ConsumerOperations(layer_match.bypass_op),
- is_training,
- moving_avg=True,
- ema_decay=ema_decay,
- quant_delay=quant_delay,
- vars_collection=vars_collection,
- bits=activation_bits,
- producer_scope=scope,
- consumer_scope=scope)
+ # Make sure the op following this isn't an activation. In which case, we
+ # shouldn't quantize it, since the activation will be Fused into the
+ # Add at inference time.
+ consumers = input_to_ops_map.ConsumerOperations(layer_match.bypass_op)
+ if any([consumer.type in _ACTIVATION_TYPES for consumer in consumers]):
+ logging.info('Skipping %s, because its followed by an activation.',
+ layer_match.bypass_op.name)
+ else:
+ _InsertQuantOp(
+ add_context,
+ 'add_quant',
+ layer_match.bypass_op,
+ input_to_ops_map.ConsumerOperations(layer_match.bypass_op),
+ is_training,
+ moving_avg=True,
+ ema_decay=ema_decay,
+ quant_delay=quant_delay,
+ vars_collection=vars_collection,
+ bits=activation_bits,
+ producer_scope=scope,
+ consumer_scope=scope)
# Quantize bypass ops that occur after the activation.
if layer_match.post_activation_bypass_op is not None:
@@ -153,19 +161,27 @@ def Quantize(graph,
r'^(.*)/([^/]+)', layer_match.post_activation_bypass_op.name).group(1)
# If `scope` is given, only quantize it if the producer is in the right
# scope.
- _InsertQuantOp(
- post_activation_bypass_context,
- 'post_activation_bypass_quant',
- layer_match.post_activation_bypass_op,
- input_to_ops_map.ConsumerOperations(
- layer_match.post_activation_bypass_op),
- is_training,
- moving_avg=True,
- ema_decay=ema_decay,
- quant_delay=quant_delay,
- vars_collection=vars_collection,
- bits=activation_bits,
- producer_scope=scope)
+ # Make sure the op following this isn't an activation. In which case, we
+ # shouldn't quantize it, since the activation will be Fused into the
+ # Add at inference time.
+ consumers = input_to_ops_map.ConsumerOperations(
+ layer_match.post_activation_bypass_op)
+ if any([consumer.type in _ACTIVATION_TYPES for consumer in consumers]):
+ logging.info('Skipping %s, because its followed by an activation.',
+ layer_match.post_activation_bypass_op.name)
+ else:
+ _InsertQuantOp(
+ post_activation_bypass_context,
+ 'post_activation_bypass_quant',
+ layer_match.post_activation_bypass_op,
+ consumers,
+ is_training,
+ moving_avg=True,
+ ema_decay=ema_decay,
+ quant_delay=quant_delay,
+ vars_collection=vars_collection,
+ bits=activation_bits,
+ producer_scope=scope)
def _FindLayersToQuantize(graph):
diff --git a/tensorflow/contrib/quantize/python/quantize_graph_test.py b/tensorflow/contrib/quantize/python/quantize_graph_test.py
index caf8ff28d5..54faf582f1 100644
--- a/tensorflow/contrib/quantize/python/quantize_graph_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_graph_test.py
@@ -113,20 +113,6 @@ class QuantizeGraphTest(test_util.TensorFlowTestCase):
# Ensure that variables were added.
self.assertTrue(len(orig_variable_names) < len(q_variables))
- def testWithPreActivationBypass(self):
- self._RunTestOverAllRewrites(self._TestWithPreActivationBypass)
-
- def _TestWithPreActivationBypass(self, rewrite_fn):
- # Tests that the default graph is correctly used when no args are provided
- # to rewrite_fn.
- with ops.Graph().as_default() as g:
- self._ConvLayer(pre_activation_bypass=True, scope='scope1')
- rewrite_fn()
-
- op_names = [op.name for op in g.get_operations()]
- self.assertTrue(
- any('scope1/add_quant/' in name for name in op_names))
-
def testWithPostActivationBypass(self):
self._RunTestOverAllRewrites(self._TestWithPostActivationBypass)
diff --git a/tensorflow/contrib/quantize/python/quantize_test.py b/tensorflow/contrib/quantize/python/quantize_test.py
index d37c83d683..5e479f3946 100644
--- a/tensorflow/contrib/quantize/python/quantize_test.py
+++ b/tensorflow/contrib/quantize/python/quantize_test.py
@@ -82,9 +82,22 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
quantization_node_name = 'FakeQuantWithMinMaxVars'
- add_quant = graph.get_operation_by_name('test/add_quant/' +
- quantization_node_name)
- self.assertEqual(add_quant.type, quantization_node_name)
+ conv_quant = graph.get_operation_by_name('test/test/conv_quant/' +
+ quantization_node_name)
+ self.assertEqual(conv_quant.type, quantization_node_name)
+
+ # Scan through all FakeQuant operations, ensuring that the activation
+ # isn't in the consumers of the operation. Since activations are folded
+ # the preceding operation during inference, the FakeQuant operation after
+ # the activation is all that is needed.
+ for op in graph.get_operations():
+ if op.type == quantization_node_name:
+ quant_op = graph.get_operation_by_name(op.name)
+ consumers = []
+ for output in quant_op.outputs:
+ consumers.extend(output.consumers())
+
+ self.assertNotIn('test/identity', [c.name for c in consumers])
def testInsertQuantOpForAddAfterSeparableConv2d(self):
self._RunTestOverParameters(
@@ -109,9 +122,20 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
quantization_node_name = 'FakeQuantWithMinMaxVars'
- add_quant = graph.get_operation_by_name('test/add_quant/' +
- quantization_node_name)
- self.assertEqual(add_quant.type, quantization_node_name)
+ conv_quant = graph.get_operation_by_name('test/test/conv_quant/' +
+ quantization_node_name)
+ self.assertEqual(conv_quant.type, quantization_node_name)
+
+ for op in graph.get_operations():
+ if op.type == quantization_node_name:
+ quant_op = graph.get_operation_by_name(op.name)
+ # Scan through all FakeQuant operations, ensuring that the activation
+ # identity op isn't in the consumers of the operation.
+ consumers = []
+ for output in quant_op.outputs:
+ consumers.extend(output.consumers())
+
+ self.assertNotIn('test/identity', [c.name for c in consumers])
def testFinalLayerQuantized(self):
self._RunTestOverParameters(self._TestFinalLayerQuantized)
@@ -153,12 +177,21 @@ class QuantizeTest(test_util.TensorFlowTestCase):
activation_fn=array_ops.identity,
scope='test/test')
bypass_tensor = math_ops.add(conv, input2, name='test/add')
- _ = array_ops.identity(bypass_tensor, name='test/output')
+ # The output of the post_activation bypass will be another layer.
+ _ = conv2d(
+ bypass_tensor,
+ 32, [5, 5],
+ stride=2,
+ padding='SAME',
+ weights_initializer=self._WeightInit(0.09),
+ activation_fn=array_ops.identity,
+ scope='test/unused')
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
- # Ensure that the bypass node is preceded and followed by
- # FakeQuantWithMinMaxVars operations.
+ # Ensure that the bypass node is preceded by and followed by a
+ # FakeQuantWithMinMaxVar operation, since the output of the Add isn't an
+ # activation.
self.assertTrue('FakeQuantWithMinMaxVars' in
[c.type for c in bypass_tensor.consumers()])
self.assertTrue('FakeQuantWithMinMaxVars' in
@@ -198,9 +231,9 @@ class QuantizeTest(test_util.TensorFlowTestCase):
quantize.Quantize(graph, is_training, weight_bits=8, activation_bits=8)
- # Ensure that the bypass node is preceded and followed by
- # FakeQuantWithMinMaxVars operations.
- self.assertTrue('FakeQuantWithMinMaxVars' in
+ # Ensure that the bypass node is preceded by a FakeQuantWithMinMaxVar
+ # operation, and NOT followed by one.
+ self.assertTrue('FakeQuantWithMinMaxVars' not in
[c.type for c in bypass_tensor.consumers()])
self.assertTrue('FakeQuantWithMinMaxVars' in
[i.op.type for i in bypass_tensor.op.inputs])
diff --git a/tensorflow/contrib/rnn/kernels/blas_gemm.cc b/tensorflow/contrib/rnn/kernels/blas_gemm.cc
index 03006dab32..45d22b739b 100644
--- a/tensorflow/contrib/rnn/kernels/blas_gemm.cc
+++ b/tensorflow/contrib/rnn/kernels/blas_gemm.cc
@@ -26,9 +26,9 @@ namespace tensorflow {
#if GOOGLE_CUDA
namespace {
template <typename T>
-perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
- perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
- perftools::gputools::DeviceMemory<T> typed(wrapped);
+se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
+ se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
+ se::DeviceMemory<T> typed(wrapped);
return typed;
}
} // namespace
@@ -41,9 +41,8 @@ void TensorCuBlasGemm<T>::operator()(OpKernelContext* ctx, bool transa,
T alpha, const T* a, int lda, const T* b,
int ldb, T beta, T* c, int ldc) {
#if GOOGLE_CUDA
- perftools::gputools::blas::Transpose trans[] = {
- perftools::gputools::blas::Transpose::kNoTranspose,
- perftools::gputools::blas::Transpose::kTranspose};
+ se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
+ se::blas::Transpose::kTranspose};
auto a_ptr = AsDeviceMemory(a);
auto b_ptr = AsDeviceMemory(b);
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/BUILD b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
index f3e6731213..2311c15a68 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/rpc/python/kernel_tests/BUILD
@@ -28,7 +28,6 @@ py_library(
py_library(
name = "rpc_op_test_base",
srcs = ["rpc_op_test_base.py"],
- tags = ["notsan"],
deps = [
":test_example_proto_py",
"//tensorflow/contrib/proto",
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
index e2e0dbc7a2..3fc6bfbb4d 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test.py
@@ -35,6 +35,7 @@ class RpcOpTest(test.TestCase, rpc_op_test_base.RpcOpTestBase):
_protocol = 'grpc'
invalid_method_string = 'Method not found'
+ connect_failed_string = 'Connect Failed'
def __init__(self, methodName='runTest'): # pylint: disable=invalid-name
super(RpcOpTest, self).__init__(methodName)
diff --git a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
index 89f3ee1a1c..27273d16b1 100644
--- a/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
+++ b/tensorflow/contrib/rpc/python/kernel_tests/rpc_op_test_base.py
@@ -93,40 +93,39 @@ class RpcOpTestBase(object):
response_values = sess.run(response_tensors)
self.assertAllEqual(response_values.shape, [0])
- def testInvalidAddresses(self):
- with self.test_session() as sess:
- with self.assertRaisesOpError(self.invalid_method_string):
- sess.run(
- self.rpc(
- method='/InvalidService.IncrementTestShapes',
- address=self._address,
- request=''))
+ def testInvalidMethod(self):
+ for method in [
+ '/InvalidService.IncrementTestShapes',
+ self.get_method_name('InvalidMethodName')
+ ]:
+ with self.test_session() as sess:
+ with self.assertRaisesOpError(self.invalid_method_string):
+ sess.run(self.rpc(method=method, address=self._address, request=''))
- with self.assertRaisesOpError(self.invalid_method_string):
- sess.run(
- self.rpc(
- method=self.get_method_name('InvalidMethodName'),
- address=self._address,
- request=''))
+ _, status_code_value, status_message_value = sess.run(
+ self.try_rpc(method=method, address=self._address, request=''))
+ self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
+ self.assertTrue(
+ self.invalid_method_string in status_message_value.decode('ascii'))
- # This also covers the case of address=''
- # and address='localhost:293874293874'
+ def testInvalidAddress(self):
+ # This covers the case of address='' and address='localhost:293874293874'
+ address = 'unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@'
+ with self.test_session() as sess:
with self.assertRaises(errors.UnavailableError):
sess.run(
self.rpc(
method=self.get_method_name('IncrementTestShapes'),
- address='unix:/tmp/this_unix_socket_doesnt_exist_97820348!!@',
+ address=address,
request=''))
-
- # Test invalid method with the TryRpc op
_, status_code_value, status_message_value = sess.run(
self.try_rpc(
- method=self.get_method_name('InvalidMethodName'),
- address=self._address,
+ method=self.get_method_name('IncrementTestShapes'),
+ address=address,
request=''))
- self.assertEqual(errors.UNIMPLEMENTED, status_code_value)
+ self.assertEqual(errors.UNAVAILABLE, status_code_value)
self.assertTrue(
- self.invalid_method_string in status_message_value.decode('ascii'))
+ self.connect_failed_string in status_message_value.decode('ascii'))
def testAlwaysFailingMethod(self):
with self.test_session() as sess:
@@ -138,6 +137,18 @@ class RpcOpTestBase(object):
with self.assertRaisesOpError(I_WARNED_YOU):
sess.run(response_tensors)
+ response_tensors, status_code, status_message = self.try_rpc(
+ method=self.get_method_name('AlwaysFailWithInvalidArgument'),
+ address=self._address,
+ request='')
+ self.assertEqual(response_tensors.shape, ())
+ self.assertEqual(status_code.shape, ())
+ self.assertEqual(status_message.shape, ())
+ status_code_value, status_message_value = sess.run((status_code,
+ status_message))
+ self.assertEqual(errors.INVALID_ARGUMENT, status_code_value)
+ self.assertTrue(I_WARNED_YOU in status_message_value.decode('ascii'))
+
def testSometimesFailingMethodWithManyRequests(self):
with self.test_session() as sess:
# Fail hard by default.
@@ -197,8 +208,7 @@ class RpcOpTestBase(object):
address=self._address,
request=request_tensors) for _ in range(10)
]
- # Launch parallel 10 calls to the RpcOp, each containing
- # 20 rpc requests.
+ # Launch parallel 10 calls to the RpcOp, each containing 20 rpc requests.
many_response_values = sess.run(many_response_tensors)
self.assertEqual(10, len(many_response_values))
for response_values in many_response_values:
diff --git a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
index b32371b642..53ba7badca 100644
--- a/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
+++ b/tensorflow/contrib/tensorrt/kernels/trt_engine_op.cc
@@ -25,7 +25,6 @@ limitations under the License.
namespace tensorflow {
static ::tensorflow::tensorrt::Logger logger;
-namespace gpu = ::perftools::gputools;
using IRuntime = nvinfer1::IRuntime;
using Dims = nvinfer1::Dims;
diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly.py b/tensorflow/contrib/timeseries/examples/known_anomaly.py
index e77628ddd3..71621abc71 100644
--- a/tensorflow/contrib/timeseries/examples/known_anomaly.py
+++ b/tensorflow/contrib/timeseries/examples/known_anomaly.py
@@ -41,17 +41,8 @@ _MODULE_PATH = path.dirname(__file__)
_DATA_FILE = path.join(_MODULE_PATH, "data/changepoints.csv")
-def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300):
- """Training, evaluating, and predicting on a series with changepoints."""
-
- # Indicate the format of our exogenous feature, in this case a string
- # representing a boolean value.
- string_feature = tf.feature_column.categorical_column_with_vocabulary_list(
- key="is_changepoint", vocabulary_list=["no", "yes"])
- # Specify the way this feature is presented to the model, here using a one-hot
- # encoding.
- one_hot_feature = tf.feature_column.indicator_column(
- categorical_column=string_feature)
+def state_space_esitmator(exogenous_feature_columns):
+ """Constructs a StructuralEnsembleRegressor."""
def _exogenous_update_condition(times, features):
del times # unused
@@ -62,14 +53,48 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300):
# no changepoint.
return tf.equal(tf.squeeze(features["is_changepoint"], axis=-1), "yes")
- estimator = tf.contrib.timeseries.StructuralEnsembleRegressor(
- periodicities=12,
- # Extract a smooth period by constraining the number of latent values
- # being cycled between.
- cycle_num_latent_values=3,
- num_features=1,
- exogenous_feature_columns=[one_hot_feature],
- exogenous_update_condition=_exogenous_update_condition)
+ return (
+ tf.contrib.timeseries.StructuralEnsembleRegressor(
+ periodicities=12,
+ # Extract a smooth period by constraining the number of latent values
+ # being cycled between.
+ cycle_num_latent_values=3,
+ num_features=1,
+ exogenous_feature_columns=exogenous_feature_columns,
+ exogenous_update_condition=_exogenous_update_condition),
+ # Use truncated backpropagation with a window size of 64, batching
+ # together 4 of these windows (random offsets) per training step. Training
+ # with exogenous features often requires somewhat larger windows.
+ 4, 64)
+
+
+def autoregressive_esitmator(exogenous_feature_columns):
+ input_window_size = 8
+ output_window_size = 2
+ return (
+ tf.contrib.timeseries.ARRegressor(
+ periodicities=12,
+ num_features=1,
+ input_window_size=input_window_size,
+ output_window_size=output_window_size,
+ exogenous_feature_columns=exogenous_feature_columns),
+ 64, input_window_size + output_window_size)
+
+
+def train_and_evaluate_exogenous(
+ estimator_fn, csv_file_name=_DATA_FILE, train_steps=300):
+ """Training, evaluating, and predicting on a series with changepoints."""
+ # Indicate the format of our exogenous feature, in this case a string
+ # representing a boolean value.
+ string_feature = tf.feature_column.categorical_column_with_vocabulary_list(
+ key="is_changepoint", vocabulary_list=["no", "yes"])
+ # Specify the way this feature is presented to the model, here using a one-hot
+ # encoding.
+ one_hot_feature = tf.feature_column.indicator_column(
+ categorical_column=string_feature)
+
+ estimator, batch_size, window_size = estimator_fn(
+ exogenous_feature_columns=[one_hot_feature])
reader = tf.contrib.timeseries.CSVReader(
csv_file_name,
# Indicate the format of our CSV file. First we have two standard columns,
@@ -85,10 +110,7 @@ def train_and_evaluate_exogenous(csv_file_name=_DATA_FILE, train_steps=300):
# This CSV has a header line; here we just ignore it.
skip_header_lines=1)
train_input_fn = tf.contrib.timeseries.RandomWindowInputFn(
- # Use truncated backpropagation with a window size of 64, batching
- # together 4 of these windows (random offsets) per training step. Training
- # with exogenous features often requires somewhat larger windows.
- reader, batch_size=4, window_size=64)
+ reader, batch_size=batch_size, window_size=window_size)
estimator.train(input_fn=train_input_fn, steps=train_steps)
evaluation_input_fn = tf.contrib.timeseries.WholeDatasetInputFn(reader)
evaluation = estimator.evaluate(input_fn=evaluation_input_fn, steps=1)
@@ -145,7 +167,12 @@ def main(unused_argv):
if not HAS_MATPLOTLIB:
raise ImportError(
"Please install matplotlib to generate a plot from this example.")
- make_plot("Ignoring a known anomaly", *train_and_evaluate_exogenous())
+ make_plot("Ignoring a known anomaly (state space)",
+ *train_and_evaluate_exogenous(
+ estimator_fn=state_space_esitmator))
+ make_plot("Ignoring a known anomaly (autoregressive)",
+ *train_and_evaluate_exogenous(
+ estimator_fn=autoregressive_esitmator, train_steps=3000))
pyplot.show()
diff --git a/tensorflow/contrib/timeseries/examples/known_anomaly_test.py b/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
index c3e307cad8..8c64f2e186 100644
--- a/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
+++ b/tensorflow/contrib/timeseries/examples/known_anomaly_test.py
@@ -23,12 +23,24 @@ from tensorflow.contrib.timeseries.examples import known_anomaly
from tensorflow.python.platform import test
-class KnownAnaomalyExampleTest(test.TestCase):
+class KnownAnomalyExampleTest(test.TestCase):
- def test_shapes_and_variance_structural(self):
+ def test_shapes_and_variance_structural_ar(self):
(times, observed, all_times, mean, upper_limit, lower_limit,
anomaly_locations) = known_anomaly.train_and_evaluate_exogenous(
- train_steps=50)
+ train_steps=1, estimator_fn=known_anomaly.autoregressive_esitmator)
+ self.assertAllEqual(
+ anomaly_locations,
+ [25, 50, 75, 100, 125, 150, 175, 249])
+ self.assertAllEqual(all_times.shape, mean.shape)
+ self.assertAllEqual(all_times.shape, upper_limit.shape)
+ self.assertAllEqual(all_times.shape, lower_limit.shape)
+ self.assertAllEqual(times.shape, observed.shape)
+
+ def test_shapes_and_variance_structural_ssm(self):
+ (times, observed, all_times, mean, upper_limit, lower_limit,
+ anomaly_locations) = known_anomaly.train_and_evaluate_exogenous(
+ train_steps=50, estimator_fn=known_anomaly.state_space_esitmator)
self.assertAllEqual(
anomaly_locations,
[25, 50, 75, 100, 125, 150, 175, 249])
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
index 4f6527a546..558d9480b4 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model.py
@@ -60,7 +60,8 @@ class ARModel(model.TimeSeriesModel):
num_features,
num_time_buckets=10,
loss=NORMAL_LIKELIHOOD_LOSS,
- hidden_layer_sizes=None):
+ hidden_layer_sizes=None,
+ exogenous_feature_columns=None):
"""Constructs an auto-regressive model.
Args:
@@ -81,6 +82,11 @@ class ARModel(model.TimeSeriesModel):
observations and predictions, while the training loss is computed on
normalized data (if input statistics are available).
hidden_layer_sizes: list of sizes of hidden layers.
+ exogenous_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to exogenous
+ features which provide extra information to the model but are not part
+ of the series to be predicted. Passed to
+ `tf.feature_column.input_layer`.
"""
self.input_window_size = input_window_size
self.output_window_size = output_window_size
@@ -90,7 +96,12 @@ class ARModel(model.TimeSeriesModel):
self.window_size = self.input_window_size + self.output_window_size
self.loss = loss
super(ARModel, self).__init__(
- num_features=num_features)
+ num_features=num_features,
+ exogenous_feature_columns=exogenous_feature_columns)
+ if exogenous_feature_columns is not None:
+ self.exogenous_size = self._get_exogenous_embedding_shape()[-1]
+ else:
+ self.exogenous_size = 0
assert num_time_buckets > 0
self._buckets = int(num_time_buckets)
if periodicities is None or not periodicities:
@@ -110,7 +121,10 @@ class ARModel(model.TimeSeriesModel):
# that the serving input_receiver_fn gets placeholder shapes correct.
return (array_ops.zeros([self.input_window_size], dtype=dtypes.int64),
array_ops.zeros(
- [self.input_window_size, self.num_features], dtype=self.dtype))
+ [self.input_window_size, self.num_features], dtype=self.dtype),
+ array_ops.zeros(
+ [self.input_window_size, self.exogenous_size],
+ dtype=self.dtype))
# TODO(allenl,agarwal): Support sampling for AR.
def random_model_parameters(self, seed=None):
@@ -163,7 +177,7 @@ class ARModel(model.TimeSeriesModel):
activations.append((activation, activation_size))
return activations
- def prediction_ops(self, times, values):
+ def prediction_ops(self, times, values, exogenous_regressors):
"""Compute model predictions given input data.
Args:
@@ -173,6 +187,8 @@ class ARModel(model.TimeSeriesModel):
prediction times.
values: A [batch size, self.input_window_size, self.num_features] Tensor
with input features.
+ exogenous_regressors: A [batch size, self.window_size,
+ self.exogenous_size] Tensor with exogenous features.
Returns:
Tuple (predicted_mean, predicted_covariance), where each element is a
Tensor with shape [batch size, self.output_window_size,
@@ -183,25 +199,33 @@ class ARModel(model.TimeSeriesModel):
if self.input_window_size:
values.get_shape().assert_is_compatible_with(
[None, self.input_window_size, self.num_features])
+ if exogenous_regressors is not None:
+ exogenous_regressors.get_shape().assert_is_compatible_with(
+ [None, self.window_size, self.exogenous_size])
# Create input features.
+ activation_components = []
if self._periods:
_, time_features = self._compute_time_features(times)
activation_size = self.window_size * self._buckets * len(self._periods)
- activation = array_ops.reshape(time_features, [-1, activation_size])
+ activation_components.append(
+ array_ops.reshape(time_features, [-1, activation_size]))
else:
activation_size = 0
- activation = None
-
if self.input_window_size:
inp = array_ops.slice(values, [0, 0, 0], [-1, self.input_window_size, -1])
inp_size = self.input_window_size * self.num_features
inp = array_ops.reshape(inp, [-1, inp_size])
- if activation is not None:
- activation = array_ops.concat([inp, activation], 1)
- else:
- activation = inp
+ activation_components.append(inp)
activation_size += inp_size
+ if self.exogenous_size:
+ exogenous_size = self.window_size * self.exogenous_size
+ activation_size += exogenous_size
+ exogenous_flattened = array_ops.reshape(
+ exogenous_regressors, [-1, exogenous_size])
+ activation_components.append(exogenous_flattened)
assert activation_size
+ assert activation_components
+ activation = array_ops.concat(activation_components, axis=1)
activations.append((activation, activation_size))
# Create hidden layers.
activations += self._create_hidden_stack(activation, activation_size)
@@ -228,6 +252,19 @@ class ARModel(model.TimeSeriesModel):
math_ops.reduce_prod(array_ops.shape(targets)), loss_op.dtype)
return loss_op
+ def _process_exogenous_features(self, times, features):
+ embedded = super(ARModel, self)._process_exogenous_features(
+ times=times, features=features)
+ if embedded is None:
+ assert self.exogenous_size == 0
+ # No embeddings. Return a zero-size [batch, times, 0] array so we don't
+ # have to special case it downstream.
+ return array_ops.zeros(
+ array_ops.concat([array_ops.shape(times), constant_op.constant([0])],
+ axis=0))
+ else:
+ return embedded
+
# TODO(allenl, agarwal): Consider better ways of warm-starting predictions.
def predict(self, features):
"""Computes predictions multiple steps into the future.
@@ -243,6 +280,7 @@ class ARModel(model.TimeSeriesModel):
segment of the time series before `TIMES`. This data is used
to start of the autoregressive computation. This should have data for
at least self.input_window_size timesteps.
+ And any exogenous features, with shapes prefixed by shape of `TIMES`.
Returns:
A dictionary with keys, "mean", "covariance". The
values are Tensors of shape [batch_size, predict window size,
@@ -250,25 +288,39 @@ class ARModel(model.TimeSeriesModel):
"""
predict_times = math_ops.cast(
ops.convert_to_tensor(features[PredictionFeatures.TIMES]), dtypes.int32)
+ exogenous_regressors = self._process_exogenous_features(
+ times=predict_times,
+ features={key: value for key, value in features.items()
+ if key not in [TrainEvalFeatures.TIMES,
+ TrainEvalFeatures.VALUES,
+ PredictionFeatures.STATE_TUPLE]})
+ with ops.control_dependencies(
+ [check_ops.assert_equal(array_ops.shape(predict_times)[1],
+ array_ops.shape(exogenous_regressors)[1])]):
+ exogenous_regressors = array_ops.identity(exogenous_regressors)
batch_size = array_ops.shape(predict_times)[0]
num_predict_values = array_ops.shape(predict_times)[1]
prediction_iterations = ((num_predict_values + self.output_window_size - 1)
// self.output_window_size)
- # Pad predict_times so as to have exact multiple of self.output_window_size
- # values per example.
+ # Pad predict_times and exogenous regressors so as to have exact multiple of
+ # self.output_window_size values per example.
padding_size = (prediction_iterations * self.output_window_size -
num_predict_values)
- padding = array_ops.zeros([batch_size, padding_size], predict_times.dtype)
- predict_times = control_flow_ops.cond(
- padding_size > 0, lambda: array_ops.concat([predict_times, padding], 1),
- lambda: predict_times)
+ predict_times = array_ops.pad(
+ predict_times, [[0, 0], [0, padding_size]])
+ exogenous_regressors = array_ops.pad(
+ exogenous_regressors, [[0, 0], [0, padding_size], [0, 0]])
state = features[PredictionFeatures.STATE_TUPLE]
- (state_times, state_values) = state
+ (state_times, state_values, state_exogenous_regressors) = state
state_times = math_ops.cast(
ops.convert_to_tensor(state_times), dtypes.int32)
state_values = ops.convert_to_tensor(state_values, dtype=self.dtype)
+ state_exogenous_regressors = ops.convert_to_tensor(
+ state_exogenous_regressors, dtype=self.dtype)
initial_input_times = predict_times[:, :self.output_window_size]
+ initial_input_exogenous_regressors = (
+ exogenous_regressors[:, :self.output_window_size, :])
if self.input_window_size > 0:
initial_input_times = array_ops.concat(
[state_times[:, -self.input_window_size:], initial_input_times], 1)
@@ -279,6 +331,11 @@ class ARModel(model.TimeSeriesModel):
check_ops.assert_equal(values_size, times_size)
]):
initial_input_values = state_values[:, -self.input_window_size:, :]
+ initial_input_exogenous_regressors = array_ops.concat(
+ [state_exogenous_regressors[:, -self.input_window_size:, :],
+ initial_input_exogenous_regressors[
+ :, :self.output_window_size, :]],
+ axis=1)
else:
initial_input_values = 0
@@ -288,9 +345,10 @@ class ARModel(model.TimeSeriesModel):
return math_ops.less(iteration_number, prediction_iterations)
def _while_body(iteration_number, input_times, input_values,
- mean_ta, covariance_ta):
+ input_exogenous_regressors, mean_ta, covariance_ta):
"""Predict self.output_window_size values."""
- prediction_ops = self.prediction_ops(input_times, input_values)
+ prediction_ops = self.prediction_ops(
+ input_times, input_values, input_exogenous_regressors)
predicted_mean = prediction_ops["mean"]
predicted_covariance = prediction_ops["covariance"]
offset = self.output_window_size * gen_math_ops.minimum(
@@ -299,20 +357,33 @@ class ARModel(model.TimeSeriesModel):
if self.output_window_size < self.input_window_size:
new_input_values = array_ops.concat(
[input_values[:, self.output_window_size:, :], predicted_mean], 1)
+ new_input_exogenous_regressors = array_ops.concat(
+ [input_exogenous_regressors[:, -self.input_window_size:, :],
+ exogenous_regressors[
+ :, offset:offset + self.output_window_size, :]],
+ axis=1)
new_input_times = array_ops.concat([
- input_times[:, self.output_window_size:],
+ input_times[:, -self.input_window_size:],
predict_times[:, offset:offset + self.output_window_size]
], 1)
else:
new_input_values = predicted_mean[:, -self.input_window_size:, :]
+ new_input_exogenous_regressors = exogenous_regressors[
+ :,
+ offset - self.input_window_size:offset + self.output_window_size,
+ :]
new_input_times = predict_times[
:,
offset - self.input_window_size:offset + self.output_window_size]
else:
new_input_values = input_values
+ new_input_exogenous_regressors = exogenous_regressors[
+ :, offset:offset + self.output_window_size, :]
new_input_times = predict_times[:,
offset:offset + self.output_window_size]
new_input_times.set_shape(initial_input_times.get_shape())
+ new_input_exogenous_regressors.set_shape(
+ initial_input_exogenous_regressors.get_shape())
new_mean_ta = mean_ta.write(iteration_number, predicted_mean)
if isinstance(covariance_ta, tensor_array_ops.TensorArray):
new_covariance_ta = covariance_ta.write(iteration_number,
@@ -322,6 +393,7 @@ class ARModel(model.TimeSeriesModel):
return (iteration_number + 1,
new_input_times,
new_input_values,
+ new_input_exogenous_regressors,
new_mean_ta,
new_covariance_ta)
@@ -332,9 +404,13 @@ class ARModel(model.TimeSeriesModel):
if self.loss != ARModel.SQUARED_LOSS else 0.)
mean_ta_init = tensor_array_ops.TensorArray(
dtype=self.dtype, size=prediction_iterations)
- _, _, _, mean_ta, covariance_ta = control_flow_ops.while_loop(
+ _, _, _, _, mean_ta, covariance_ta = control_flow_ops.while_loop(
_while_condition, _while_body, [
- 0, initial_input_times, initial_input_values, mean_ta_init,
+ 0,
+ initial_input_times,
+ initial_input_values,
+ initial_input_exogenous_regressors,
+ mean_ta_init,
covariance_ta_init
])
@@ -366,11 +442,11 @@ class ARModel(model.TimeSeriesModel):
return {"mean": predicted_mean,
"covariance": predicted_covariance}
- def _process_window(self, features, mode):
+ def _process_window(self, features, mode, exogenous_regressors):
"""Compute model outputs on a single window of data."""
- # TODO(agarwal): Use exogenous features
times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtypes.int64)
values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)
+ exogenous_regressors = math_ops.cast(exogenous_regressors, dtype=self.dtype)
original_values = values
# Extra shape checking for the window size (above that in
@@ -395,7 +471,8 @@ class ARModel(model.TimeSeriesModel):
input_values = values[:, :self.input_window_size, :]
else:
input_values = None
- prediction_ops = self.prediction_ops(times, input_values)
+ prediction_ops = self.prediction_ops(
+ times, input_values, exogenous_regressors)
prediction = prediction_ops["mean"]
covariance = prediction_ops["covariance"]
targets = array_ops.slice(values, [0, self.input_window_size, 0],
@@ -419,7 +496,8 @@ class ARModel(model.TimeSeriesModel):
return model.ModelOutputs(
loss=loss,
end_state=(times[:, -self.input_window_size:],
- values[:, -self.input_window_size:, :]),
+ values[:, -self.input_window_size:, :],
+ exogenous_regressors[:, -self.input_window_size:, :]),
predictions={"mean": prediction, "covariance": covariance,
"observed": original_values[:, -self.output_window_size:]},
prediction_times=times[:, -self.output_window_size:])
@@ -454,17 +532,24 @@ class ARModel(model.TimeSeriesModel):
"""
features = {feature_name: ops.convert_to_tensor(feature_value)
for feature_name, feature_value in features.items()}
+ times = features[TrainEvalFeatures.TIMES]
+ exogenous_regressors = self._process_exogenous_features(
+ times=times,
+ features={key: value for key, value in features.items()
+ if key not in [TrainEvalFeatures.TIMES,
+ TrainEvalFeatures.VALUES,
+ PredictionFeatures.STATE_TUPLE]})
if mode == estimator_lib.ModeKeys.TRAIN:
# For training, we require the window size to be self.window_size as
# iterating sequentially on larger windows could introduce a bias.
- return self._process_window(features, mode=mode)
+ return self._process_window(
+ features, mode=mode, exogenous_regressors=exogenous_regressors)
elif mode == estimator_lib.ModeKeys.EVAL:
# For evaluation, we allow the user to pass in a larger window, in which
# case we try to cover as much of the window as possible without
# overlap. Quantitative evaluation is more efficient/correct with fixed
# windows matching self.window_size (as with training), but this looping
# allows easy plotting of "in-sample" predictions.
- times = features[TrainEvalFeatures.TIMES]
times.get_shape().assert_has_rank(2)
static_window_size = times.get_shape()[1].value
if (static_window_size is not None
@@ -500,7 +585,9 @@ class ARModel(model.TimeSeriesModel):
feature_name:
feature_value[:, base_offset:base_offset + self.window_size]
for feature_name, feature_value in features.items()},
- mode=mode)
+ mode=mode,
+ exogenous_regressors=exogenous_regressors[
+ :, base_offset:base_offset + self.window_size])
# This code needs to be updated if new predictions are added in
# self._process_window
assert len(model_outputs.predictions) == 3
@@ -525,7 +612,9 @@ class ARModel(model.TimeSeriesModel):
batch_size = array_ops.shape(times)[0]
prediction_shape = [batch_size, self.output_window_size * num_iterations,
self.num_features]
- previous_state_times, previous_state_values = state
+ (previous_state_times,
+ previous_state_values,
+ previous_state_exogenous_regressors) = state
# Make sure returned state always has windows of self.input_window_size,
# even if we were passed fewer than self.input_window_size points this
# time.
@@ -540,14 +629,24 @@ class ARModel(model.TimeSeriesModel):
self._scale_data(values)], axis=1)[:, -self.input_window_size:, :]
new_state_values.set_shape((None, self.input_window_size,
self.num_features))
+ new_exogenous_regressors = array_ops.concat(
+ [previous_state_exogenous_regressors,
+ exogenous_regressors], axis=1)[:, -self.input_window_size:, :]
+ new_exogenous_regressors.set_shape(
+ (None,
+ self.input_window_size,
+ self.exogenous_size))
else:
# There is no state to keep, and the strided slices above do not handle
# input_window_size=0.
new_state_times = previous_state_times
new_state_values = previous_state_values
+ new_exogenous_regressors = previous_state_exogenous_regressors
return model.ModelOutputs(
loss=math_ops.reduce_mean(loss_ta.stack(), axis=0),
- end_state=(new_state_times, new_state_values),
+ end_state=(new_state_times,
+ new_state_values,
+ new_exogenous_regressors),
predictions={
"mean": array_ops.reshape(
array_ops.transpose(mean_ta.stack(), [1, 0, 2, 3]),
@@ -604,7 +703,8 @@ class AnomalyMixtureARModel(ARModel):
num_features,
anomaly_distribution=GAUSSIAN_ANOMALY,
num_time_buckets=10,
- hidden_layer_sizes=None):
+ hidden_layer_sizes=None,
+ exogenous_feature_columns=None):
assert (anomaly_prior_probability < 1.0 and
anomaly_prior_probability > 0.0)
self._anomaly_prior_probability = anomaly_prior_probability
@@ -619,7 +719,8 @@ class AnomalyMixtureARModel(ARModel):
input_window_size=input_window_size,
output_window_size=output_window_size,
loss=ARModel.NORMAL_LIKELIHOOD_LOSS,
- hidden_layer_sizes=hidden_layer_sizes)
+ hidden_layer_sizes=hidden_layer_sizes,
+ exogenous_feature_columns=exogenous_feature_columns)
def _create_anomaly_ops(self, times, values, prediction_ops_dict):
anomaly_log_param = variable_scope.get_variable(
@@ -631,9 +732,9 @@ class AnomalyMixtureARModel(ARModel):
# distribution.
prediction_ops_dict["anomaly_params"] = gen_math_ops.exp(anomaly_log_param)
- def prediction_ops(self, times, values):
+ def prediction_ops(self, times, values, exogenous_regressors):
prediction_ops_dict = super(AnomalyMixtureARModel, self).prediction_ops(
- times, values)
+ times, values, exogenous_regressors)
self._create_anomaly_ops(times, values, prediction_ops_dict)
return prediction_ops_dict
diff --git a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
index 1e1ca4e77f..d078ac8d46 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/ar_model_test.py
@@ -155,12 +155,15 @@ class ARModelTest(test.TestCase):
state_times = np.expand_dims(train_data_times[:input_window_size], 0)
state_values = np.expand_dims(
train_data_values[:input_window_size, :], 0)
+ state_exogenous = state_times[:, :, None][:, :, :0]
def prediction_input_fn():
return ({
PredictionFeatures.TIMES: training.limit_epochs(
predict_times, num_epochs=1),
- PredictionFeatures.STATE_TUPLE: (state_times, state_values)
+ PredictionFeatures.STATE_TUPLE: (state_times,
+ state_values,
+ state_exogenous)
}, {})
(predictions,) = tuple(estimator.predict(input_fn=prediction_input_fn))
predicted_mean = predictions["mean"][:, 0]
@@ -246,7 +249,8 @@ class ARModelTest(test.TestCase):
with session.Session():
predicted_values = model.predict({
PredictionFeatures.TIMES: [[4, 6, 10]],
- PredictionFeatures.STATE_TUPLE: ([[1, 2]], [[[1.], [2.]]])
+ PredictionFeatures.STATE_TUPLE: (
+ [[1, 2]], [[[1.], [2.]]], [[[], []]])
})
variables.global_variables_initializer().run()
self.assertAllEqual(predicted_values["mean"].eval().shape,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators.py b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
index 886e1846e2..f4608ca2d1 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators.py
@@ -190,7 +190,7 @@ class ARRegressor(TimeSeriesRegressor):
def __init__(
self, periodicities, input_window_size, output_window_size,
- num_features, num_time_buckets=10,
+ num_features, exogenous_feature_columns=None, num_time_buckets=10,
loss=ar_model.ARModel.NORMAL_LIKELIHOOD_LOSS, hidden_layer_sizes=None,
anomaly_prior_probability=None, anomaly_distribution=None,
optimizer=None, model_dir=None, config=None):
@@ -205,7 +205,12 @@ class ARRegressor(TimeSeriesRegressor):
output_window_size: Number of future time steps to predict. Note that
setting it to > 1 empirically seems to give a better fit.
num_features: The dimensionality of the time series (one for univariate,
- more than one for multivariate).
+ more than one for multivariate).
+ exogenous_feature_columns: A list of `tf.feature_column`s (for example
+ `tf.feature_column.embedding_column`) corresponding to exogenous
+ features which provide extra information to the model but are not part
+ of the series to be predicted. Passed to
+ `tf.feature_column.input_layer`.
num_time_buckets: Number of buckets into which to divide (time %
periodicity) for generating time based features.
loss: Loss function to use for training. Currently supported values are
@@ -241,6 +246,7 @@ class ARRegressor(TimeSeriesRegressor):
anomaly_distribution = ar_model.AnomalyMixtureARModel.GAUSSIAN_ANOMALY
model = ar_model.ARModel(
periodicities=periodicities, num_features=num_features,
+ exogenous_feature_columns=exogenous_feature_columns,
num_time_buckets=num_time_buckets,
input_window_size=input_window_size,
output_window_size=output_window_size, loss=loss,
@@ -255,6 +261,7 @@ class ARRegressor(TimeSeriesRegressor):
input_window_size=input_window_size,
output_window_size=output_window_size,
num_features=num_features,
+ exogenous_feature_columns=exogenous_feature_columns,
num_time_buckets=num_time_buckets,
hidden_layer_sizes=hidden_layer_sizes,
anomaly_prior_probability=anomaly_prior_probability,
diff --git a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
index 9f161c1695..eebee053f8 100644
--- a/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
+++ b/tensorflow/contrib/timeseries/python/timeseries/estimators_test.py
@@ -29,6 +29,7 @@ from tensorflow.contrib.timeseries.python.timeseries import saved_model_utils
from tensorflow.python.client import session
from tensorflow.python.estimator import estimator_lib
+from tensorflow.python.feature_column import feature_column
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
@@ -48,12 +49,17 @@ class TimeSeriesRegressorTest(test.TestCase):
def _fit_restore_fit_test_template(self, estimator_fn, dtype):
"""Tests restoring previously fit models."""
model_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
- first_estimator = estimator_fn(model_dir)
+ exogenous_feature_columns = (
+ feature_column.numeric_column("exogenous"),
+ )
+ first_estimator = estimator_fn(model_dir, exogenous_feature_columns)
times = numpy.arange(20, dtype=numpy.int64)
values = numpy.arange(20, dtype=dtype.as_numpy_dtype)
+ exogenous = numpy.arange(20, dtype=dtype.as_numpy_dtype)
features = {
feature_keys.TrainEvalFeatures.TIMES: times,
- feature_keys.TrainEvalFeatures.VALUES: values
+ feature_keys.TrainEvalFeatures.VALUES: values,
+ "exogenous": exogenous
}
train_input_fn = input_pipeline.RandomWindowInputFn(
input_pipeline.NumpyReader(features), shuffle_seed=2, num_threads=1,
@@ -68,14 +74,19 @@ class TimeSeriesRegressorTest(test.TestCase):
first_loss_after_fit = first_estimator.evaluate(
input_fn=eval_input_fn, steps=1)["loss"]
self.assertLess(first_loss_after_fit, first_loss_before_fit)
- second_estimator = estimator_fn(model_dir)
+ second_estimator = estimator_fn(model_dir, exogenous_feature_columns)
second_estimator.train(input_fn=train_input_fn, steps=2)
whole_dataset_input_fn = input_pipeline.WholeDatasetInputFn(
input_pipeline.NumpyReader(features))
whole_dataset_evaluation = second_estimator.evaluate(
input_fn=whole_dataset_input_fn, steps=1)
+ exogenous_values_ten_steps = {
+ "exogenous": numpy.arange(
+ 10, dtype=dtype.as_numpy_dtype)[None, :, None]
+ }
predict_input_fn = input_pipeline.predict_continuation_input_fn(
evaluation=whole_dataset_evaluation,
+ exogenous_features=exogenous_values_ten_steps,
steps=10)
# Also tests that limit_epochs in predict_continuation_input_fn prevents
# infinite iteration
@@ -92,6 +103,7 @@ class TimeSeriesRegressorTest(test.TestCase):
saved_prediction = saved_model_utils.predict_continuation(
continue_from=whole_dataset_evaluation,
steps=10,
+ exogenous_features=exogenous_values_ten_steps,
signatures=signatures,
session=sess)
# Saved model predictions should be the same as Estimator predictions
@@ -104,7 +116,8 @@ class TimeSeriesRegressorTest(test.TestCase):
continue_from=whole_dataset_evaluation,
features={
feature_keys.FilteringFeatures.TIMES: times[None, -1] + 2,
- feature_keys.FilteringFeatures.VALUES: values[None, -1] + 2.
+ feature_keys.FilteringFeatures.VALUES: values[None, -1] + 2.,
+ "exogenous": values[None, -1, None] + 12.
},
signatures=signatures,
session=sess)
@@ -112,6 +125,10 @@ class TimeSeriesRegressorTest(test.TestCase):
second_saved_prediction = saved_model_utils.predict_continuation(
continue_from=first_filtering,
steps=1,
+ exogenous_features={
+ "exogenous": numpy.arange(
+ 1, dtype=dtype.as_numpy_dtype)[None, :, None]
+ },
signatures=signatures,
session=sess)
self.assertEqual(
@@ -122,7 +139,8 @@ class TimeSeriesRegressorTest(test.TestCase):
continue_from=first_filtering,
features={
feature_keys.FilteringFeatures.TIMES: times[-1] + 3,
- feature_keys.FilteringFeatures.VALUES: values[-1] + 3.
+ feature_keys.FilteringFeatures.VALUES: values[-1] + 3.,
+ "exogenous": values[-1, None] + 13.
},
signatures=signatures,
session=sess)
@@ -131,7 +149,8 @@ class TimeSeriesRegressorTest(test.TestCase):
six.assertCountEqual(
self,
[feature_keys.FilteringFeatures.TIMES,
- feature_keys.FilteringFeatures.VALUES],
+ feature_keys.FilteringFeatures.VALUES,
+ "exogenous"],
signatures.signature_def[
feature_keys.SavedModelLabels.COLD_START_FILTER].inputs.keys())
batch_numpy_times = numpy.tile(
@@ -142,7 +161,8 @@ class TimeSeriesRegressorTest(test.TestCase):
session=sess,
features={
feature_keys.FilteringFeatures.TIMES: batch_numpy_times,
- feature_keys.FilteringFeatures.VALUES: batch_numpy_values
+ feature_keys.FilteringFeatures.VALUES: batch_numpy_values,
+ "exogenous": 10. + batch_numpy_values
}
)
predict_times = numpy.tile(
@@ -150,26 +170,32 @@ class TimeSeriesRegressorTest(test.TestCase):
predictions = saved_model_utils.predict_continuation(
continue_from=state,
times=predict_times,
+ exogenous_features={
+ "exogenous": numpy.tile(numpy.arange(
+ 15, dtype=dtype.as_numpy_dtype), (10,))[None, :, None]
+ },
signatures=signatures,
session=sess)
self.assertAllEqual([10, 15, 1], predictions["mean"].shape)
def test_fit_restore_fit_ar_regressor(self):
- def _estimator_fn(model_dir):
+ def _estimator_fn(model_dir, exogenous_feature_columns):
return estimators.ARRegressor(
periodicities=10, input_window_size=10, output_window_size=6,
num_features=1, model_dir=model_dir, config=_SeedRunConfig(),
# This test is flaky with normal likelihood loss (could add more
# training iterations instead).
- loss=ar_model.ARModel.SQUARED_LOSS)
+ loss=ar_model.ARModel.SQUARED_LOSS,
+ exogenous_feature_columns=exogenous_feature_columns)
self._fit_restore_fit_test_template(_estimator_fn, dtype=dtypes.float32)
def test_fit_restore_fit_structural_ensemble_regressor(self):
dtype = dtypes.float32
- def _estimator_fn(model_dir):
+ def _estimator_fn(model_dir, exogenous_feature_columns):
return estimators.StructuralEnsembleRegressor(
num_features=1, periodicities=10, model_dir=model_dir, dtype=dtype,
- config=_SeedRunConfig())
+ config=_SeedRunConfig(),
+ exogenous_feature_columns=exogenous_feature_columns)
self._fit_restore_fit_test_template(_estimator_fn, dtype=dtype)
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index 9646d15486..eac210418b 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -162,6 +162,7 @@ py_library(
"python/tpu/__init__.py",
"python/tpu/bfloat16.py",
"python/tpu/device_assignment.py",
+ "python/tpu/keras_support.py",
"python/tpu/topology.py",
"python/tpu/tpu.py",
"python/tpu/tpu_feed.py",
diff --git a/tensorflow/contrib/tpu/python/tpu/keras_support.py b/tensorflow/contrib/tpu/python/tpu/keras_support.py
new file mode 100644
index 0000000000..e86ca0a1d8
--- /dev/null
+++ b/tensorflow/contrib/tpu/python/tpu/keras_support.py
@@ -0,0 +1,391 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""*Experimental* support for running Keras models on the TPU.
+
+To use, wrap your model with the `keras_support.tpu_model` function.
+
+Example usage:
+
+```
+# Must activate before building TPU models
+keras_support.setup_tpu_session(master_address)
+
+image = tf.keras.layers.Input(shape=(28, 28, 3), name='image')
+c1 = tf.keras.layers.Conv2D(filters=16, kernel_size=(3, 3))( image)
+flattened = tf.keras.layers.Flatten()(c1)
+logits = tf.keras.layers.Dense(10, activation='softmax')(flattened)
+model = tf.keras.Model(inputs=[image], outputs=[logits])
+model = keras_support.tpu_model(model)
+
+# Only TF optimizers are currently supported.
+model.compile(optimizer=tf.train.AdamOptimizer(), ...)
+
+# `images` and `labels` should be Numpy arrays. Support for tensor input
+# (e.g. datasets) is planned.
+model.fit(images, labels)
+
+# Invoke before shutting down
+keras_support.shutdown_tpu_session()
+```
+"""
+
+# pylint: disable=protected-access
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import collections
+import re
+
+from tensorflow.contrib.framework.python.framework import experimental
+from tensorflow.contrib.tpu.python.ops import tpu_ops
+from tensorflow.contrib.tpu.python.tpu import tpu
+from tensorflow.core.protobuf import config_pb2
+from tensorflow.python.client import session as tf_session
+from tensorflow.python.estimator import model_fn as model_fn_lib
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_spec
+from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import layers
+from tensorflow.python.keras._impl.keras import models
+from tensorflow.python.keras._impl.keras import optimizers as keras_optimizers
+from tensorflow.python.keras._impl.keras.layers import embeddings
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import tf_logging as logging
+from tensorflow.python.training import training_util
+
+
+class TPUEmbedding(embeddings.Embedding):
+ """TPU compatible embedding layer.
+
+ The default Keras layer is not TPU compatible. This layer is a drop-in
+ replacement: it has the same behavior and will work on CPU and GPU devices.
+ """
+
+ def __init__(self, *args, **kw):
+ super(TPUEmbedding, self).__init__(*args, **kw)
+
+ def build(self, input_shape):
+ if input_shape[0] is None:
+ raise ValueError(
+ 'TPUEmbeddings must have a fixed input_length or input shape.')
+ return super(TPUEmbedding, self).build(input_shape)
+
+ def call(self, inputs):
+ if K.dtype(inputs) != 'int32':
+ inputs = math_ops.cast(inputs, 'int32')
+
+ inputs = array_ops.one_hot(inputs, self.input_dim)
+ return math_ops.tensordot(inputs, self.embeddings, 1)
+
+
+class CompiledTPUOp(
+ collections.namedtuple(
+ 'CompiledTPUOp',
+ ['tpu_execute_op', 'infeed_tensors', 'infeed_op', 'outfeed_op'])):
+ pass
+
+
+def _valid_name(tensor_name):
+ """Return a valid tensor name (strips '/', ':', etc)."""
+ return re.sub('[^a-zA-Z0-9_-]+', '', tensor_name)
+
+
+class TPUFunction(object):
+ """K.function compatible interface for invoking a TPU compiled function.
+
+ Recompilation is triggered on-demand for each set of new inputs shapes: the
+ results are cached for future execution. We expect most computations will
+ be dominated by a standard batch-size, followed by a straggler batch for
+ the end of training or evaluation.
+
+ All `inputs` and `outputs` will be loaded via the infeed and outfeed queues
+ instead of being injected as `feed_dict` items or fetches.
+ """
+
+ def __init__(self, model, execution_mode):
+ self.model = model
+ self.execution_mode = execution_mode
+ self._compilation_cache = {}
+
+ def _specialize_model(self, input_specs):
+ """Specialize `self.model` (a Keras model) for the given input shapes."""
+ # Re-create our input and output layers inside our subgraph. They will be
+ # attached to the true computation when we clone our model in `tpu_fn`.
+ K.set_learning_phase(self.execution_mode == model_fn_lib.ModeKeys.TRAIN)
+
+ # functools.partial and callable objects are not supported by tpu.rewrite
+ def _model_fn():
+ """Compute fit/eval/predict for the TPU."""
+ is_training = self.execution_mode == model_fn_lib.ModeKeys.TRAIN
+ is_test = self.execution_mode == model_fn_lib.ModeKeys.EVAL
+ is_predict = self.execution_mode == model_fn_lib.ModeKeys.PREDICT
+
+ # During train/eval, we infeed our features as well as labels.
+ if is_training or is_test:
+ infeed_layers = self.model._input_layers + self.model._output_layers
+ else:
+ infeed_layers = self.model._input_layers
+
+ # Generate our infeed operation to read features & labels.
+ infeed_tensors = tpu_ops.infeed_dequeue_tuple(
+ dtypes=[spec.dtype for spec in input_specs],
+ shapes=[spec.shape for spec in input_specs],
+ name='infeed-%s' % self.execution_mode)
+
+ assert len(infeed_tensors) == len(infeed_layers), (
+ 'Infeed inputs did not match model: %s vs %s', (infeed_layers,
+ infeed_tensors))
+
+ tpu_targets = []
+ tpu_inputs = []
+
+ # Sort infeed outputs into inputs and labels for calling our Keras model.
+ for tensor, layer in zip(infeed_tensors, infeed_layers):
+ if layer in self.model._input_layers:
+ tpu_inputs.append(layers.Input(name=layer.name, tensor=tensor))
+ if layer in self.model._output_layers:
+ tpu_targets.append(tensor)
+
+ optimizer = self.model.optimizer
+ optimizer.iterations = training_util.get_or_create_global_step()
+
+ # Call our model with our infeed inputs (re-using the weights).
+ model_outputs = self.model(tpu_inputs)
+ child_model = models.Model(inputs=tpu_inputs, outputs=model_outputs)
+ if is_training or is_test:
+ child_model.compile(
+ optimizer=self.model.optimizer,
+ loss=self.model.loss,
+ loss_weights=self.model.loss_weights,
+ metrics=self.model.metrics,
+ weighted_metrics=self.model.weighted_metrics,
+ target_tensors=tpu_targets,
+ )
+
+ # Compute our outfeed depending on the execution mode
+ if is_training:
+ child_model._make_train_function()
+ self._outfeed_spec = [
+ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
+ for tensor in child_model.train_function.outputs
+ ]
+ return [
+ child_model.train_function.updates_op,
+ tpu_ops.outfeed_enqueue_tuple(
+ child_model.train_function.outputs, name='oufeed-enqueue-train')
+ ]
+ elif is_test:
+ child_model._make_test_function()
+ self._outfeed_spec = [
+ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
+ for tensor in child_model.test_function.outputs
+ ]
+ return [
+ tpu_ops.outfeed_enqueue_tuple(
+ child_model.test_function.outputs, name='outfeed-enqueue-test')
+ ]
+ elif is_predict:
+ child_model._make_predict_function()
+ self._outfeed_spec = [
+ tensor_spec.TensorSpec(tensor.shape, tensor.dtype, tensor.name)
+ for tensor in child_model.predict_function.outputs
+ ]
+ return [
+ tpu_ops.outfeed_enqueue_tuple(
+ child_model.predict_function.outputs,
+ name='outfeed-enqueue-predict',
+ )
+ ]
+ else:
+ assert False, 'Unexpected execution mode: %s' % self.execution_mode
+
+ # Capture outfeed metadata computed during the rewrite.
+ self._outfeed_spec = None
+
+ tpu_execute_op = tpu.rewrite(_model_fn)
+
+ K._initialize_variables(K.get_session()) # pylint-disable: protected-access
+
+ # Generate CPU side operations to enqueue features/labels and dequeue
+ # outputs from the model call.
+ with ops.device('/device:TPU:0'):
+ infeed_tensors = []
+ for spec in input_specs:
+ infeed_tensors.append(
+ array_ops.placeholder(
+ dtype=spec.dtype,
+ shape=spec.shape,
+ name='infeed-enqueue-%s' % spec.name))
+
+ infeed_op = tpu_ops.infeed_enqueue_tuple(
+ infeed_tensors, [spec.shape for spec in input_specs],
+ name='infeed-enqueue-%s' % self.execution_mode)
+
+ outfeed_op = tpu_ops.outfeed_dequeue_tuple(
+ dtypes=[spec.dtype for spec in self._outfeed_spec],
+ shapes=[spec.shape for spec in self._outfeed_spec],
+ name='outfeed-dequeue-%s' % self.execution_mode)
+
+ return CompiledTPUOp(tpu_execute_op, infeed_tensors, infeed_op, outfeed_op)
+
+ def __call__(self, inputs):
+ assert isinstance(inputs, list)
+
+ # Strip sample weight from inputs
+ if (self.execution_mode == model_fn_lib.ModeKeys.TRAIN or
+ self.execution_mode == model_fn_lib.ModeKeys.EVAL):
+ input_tensors = self.model._feed_inputs + self.model._feed_targets
+ inputs = inputs[:len(input_tensors)]
+ else:
+ input_tensors = self.model._feed_inputs
+
+ # Compute an input specification (used to generate infeed enqueue and
+ # dequeue operations). We use the shape from our input array and the
+ # dtype from our model. A user may pass in a float64 for a float32
+ # input: for model compatibility we still must generate a float32 infeed.
+ input_specs = []
+ for tensor, ary in zip(input_tensors, inputs):
+ input_specs.append(
+ tensor_spec.TensorSpec(ary.shape, tensor.dtype,
+ _valid_name(tensor.name)))
+
+ # XLA requires every operation in the graph has a fixed shape. To
+ # handle varying batch sizes we recompile a new sub-graph for each
+ # unique input shape.
+ shape_key = tuple([tuple(spec.shape.as_list()) for spec in input_specs])
+
+ if shape_key not in self._compilation_cache:
+ logging.info('New input shapes; (re-)compiling: mode=%s, %s',
+ self.execution_mode, input_specs)
+ self._compilation_cache[shape_key] = self._specialize_model(input_specs)
+
+ compiled_model = self._compilation_cache[shape_key]
+
+ infeed_dict = {}
+ for tensor, value in zip(compiled_model.infeed_tensors, inputs):
+ infeed_dict[tensor] = value
+
+ session = K.get_session()
+ _, _, outfeed_outputs = session.run([
+ compiled_model.infeed_op, compiled_model.tpu_execute_op,
+ compiled_model.outfeed_op
+ ], infeed_dict)
+
+ return outfeed_outputs
+
+
+@experimental
+def setup_tpu_session(master):
+ """Initializes and returns a Keras/TF session connected the TPU `master`."""
+ session = tf_session.Session(
+ target=master, config=config_pb2.ConfigProto(isolate_session_state=True))
+ K.set_session(session)
+ K.get_session().run(tpu.initialize_system())
+ K.manual_variable_initialization(True)
+ return session
+
+
+@experimental
+def shutdown_tpu_session(session=None):
+ """Shutdown the TPU attached to session.
+
+ This should be called to cleanly shut down the TPU system before the client
+ exits.
+
+ Args:
+ session: Session to shutdown, or None to use the default session.
+
+ Returns:
+
+ """
+ if session is None:
+ session = K.get_session()
+
+ session.run(tpu.shutdown_system())
+
+
+class KerasTPUModel(models.Model):
+ """TPU compatible Keras model wrapper."""
+
+ def __init__(self, inputs, outputs, name=None):
+ super(models.Model, self).__init__(
+ inputs=inputs,
+ outputs=outputs,
+ name=name,
+ )
+ self.predict_function = None
+ self.test_function = None
+ self.train_function = None
+
+ def compile(self,
+ optimizer,
+ loss=None,
+ metrics=None,
+ loss_weights=None,
+ sample_weight_mode=None,
+ weighted_metrics=None,
+ target_tensors=None,
+ **kwargs):
+ if sample_weight_mode:
+ raise ValueError('sample_weight_mode not supported for TPU execution.')
+ if weighted_metrics:
+ raise ValueError('weighted_metrics not supported for TPU execution.')
+ if target_tensors:
+ raise ValueError('target_tensors is not supported for TPU execution.')
+
+ super(KerasTPUModel, self).compile(optimizer, loss, metrics, loss_weights,
+ sample_weight_mode, weighted_metrics,
+ target_tensors, **kwargs)
+
+ # Keras optimizers are not compatible with TPU rewrite
+ if not isinstance(self.optimizer, keras_optimizers.TFOptimizer):
+ raise ValueError(
+ 'Optimizer must be a TFOptimizer, got: %s' % self.optimizer)
+
+ def train_on_batch(self, x, y, sample_weight=None, class_weight=None):
+ return super(KerasTPUModel, self).train_on_batch(x, y, sample_weight,
+ class_weight)
+
+ def _make_train_function(self):
+ if not self.train_function:
+ self.train_function = TPUFunction(self, model_fn_lib.ModeKeys.TRAIN)
+
+ return self.train_function
+
+ def _make_test_function(self):
+ if not self.test_function:
+ self.test_function = TPUFunction(self, model_fn_lib.ModeKeys.EVAL)
+ return self.test_function
+
+ def _make_predict_function(self):
+ if not self.predict_function:
+ self.predict_function = TPUFunction(self, model_fn_lib.ModeKeys.PREDICT)
+ return self.predict_function
+
+ def cpu_model(self):
+ return models.Model(
+ inputs=self.inputs,
+ outputs=self.outputs,
+ name=self.name,
+ )
+
+
+@experimental
+def tpu_model(model):
+ return KerasTPUModel(
+ inputs=model.inputs, outputs=model.outputs, name=model.name)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu.py b/tensorflow/contrib/tpu/python/tpu/tpu.py
index a1690dadff..7b8786304c 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu.py
@@ -173,36 +173,18 @@ class TPUReplicateContext(control_flow_ops.XLAControlFlowContext):
# gradients, and put the gradient of X in cluster
# 'root_cluster.gradient_uid'.
#
- # When the gradient code adds multiple Ops, it asks them to
- # be colocated either with the original Op X, or with one of
- # the preceding Ops that was added to the gradient. In other
- # words, we want to detect the case where we are colocating
- # with an Op that is in cluster root_cluster.gradient_uid
- # and put the new Op in that same cluster if the
- # gradient_uid is the same (the case that we are in the same
- # invocation of gradients, and just adding new Ops to the
- # cluster); and in a different cluster if the gradient_uids
- # are different (the case that we are in a new invocation of
- # gradients, taking the gradient of a previously-computed
- # gradient).
+ # When taking a gradient of a gradient, some ops will be
+ # colocated with Op in the forward pass (e.g., cluster
+ # root_cluster) and some in the backward pass (e.g., cluster
+ # root_cluster.initial_gradient_uid). We need all of the
+ # grad-of-grad ops to be in the same cluster to avoid cyclic
+ # dependencies between clusters. We adopt a heuristic that
+ # puts any op clustered with root_cluster.<xxx> in
+ # root_cluster.gradient_uid, even if xxx was
+ # initial_gradient_uid.
self._in_gradient_colocation = op
parts = outside_attr.split(".")
- if len(parts) > 1:
- uid = parts[-1]
- if uid == gradient_uid:
- # Keep using the same cluster
- cluster = outside_attr
- else:
- # We're taking the gradient of a gradient so make a new
- # cluster attr, adding a new '.uid' on the end to
- # preserve the invariant that the gradient_uid is the
- # suffix after the last '.' in the attr.
- cluster = outside_attr + "." + gradient_uid
- else:
- # We're taking the gradient of an Op in the forward pass, so
- # make a new cluster combining the Op's cluster and the
- # gradient id.
- cluster = outside_attr + "." + gradient_uid
+ cluster = parts[0] + "." + gradient_uid
self._EnterOutsideCompilationScope(cluster=cluster)
except ValueError:
# The attr was not present: do nothing.
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index 24ea732877..acca47e9a3 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -145,6 +145,7 @@ load(
"if_static",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
+load("@io_bazel_rules_closure//closure:defs.bzl", "closure_proto_library")
load(
"//third_party/mkl:build_defs.bzl",
"if_mkl",
@@ -161,7 +162,7 @@ exports_files(["ops/ops.pbtxt"])
# Note that some protos are in neither additional_core_proto_srcs nor this
# filegroup; e.g. ones with individual proto_library targets.
# LINT.IfChange
-CORE_PROTO_SRCS = [
+COMMON_PROTO_SRCS = [
"example/example.proto",
"example/feature.proto",
"framework/allocation_description.proto",
@@ -189,7 +190,6 @@ CORE_PROTO_SRCS = [
"framework/types.proto",
"framework/variable.proto",
"framework/versions.proto",
- "lib/core/error_codes.proto",
"protobuf/config.proto",
"protobuf/cluster.proto",
"protobuf/debug.proto",
@@ -202,8 +202,14 @@ CORE_PROTO_SRCS = [
"util/memmapped_file_system.proto",
"util/saved_tensor_slice.proto",
]
+
+ERROR_CODES_PROTO_SRCS = [
+ "lib/core/error_codes.proto",
+]
# LINT.ThenChange(//tensorflow/core/android_proto_config.asciipb)
+CORE_PROTO_SRCS = COMMON_PROTO_SRCS + ERROR_CODES_PROTO_SRCS
+
# Protos which are not needed on mobile builds, but should be included in
# protos_all.
#
@@ -224,12 +230,16 @@ ADDITIONAL_CORE_PROTO_SRCS = [
tf_proto_library(
name = "protos_all",
- srcs = CORE_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS,
+ srcs = [],
cc_api_version = 2,
default_header = True,
j2objc_api_version = 1,
java_api_version = 2,
js_api_version = 2,
+ protodeps = [
+ ":protos_all_proto",
+ ":error_codes_proto",
+ ],
visibility = ["//visibility:public"],
)
@@ -257,12 +267,6 @@ proto_library(
visibility = ["//visibility:public"],
)
-closure_proto_library(
- name = "example_protos_closure",
- visibility = ["//visibility:public"],
- deps = [":example_protos"],
-)
-
exports_files([
"framework/types.proto",
])
@@ -287,7 +291,7 @@ PLATFORM_BASE_HDRS = [
"platform/logging.h",
"platform/macros.h",
"platform/types.h",
- "platform/cpu_info.h",
+ "platform/byte_order.h",
]
PLATFORM_OTHER_HDRS = [
@@ -295,6 +299,7 @@ PLATFORM_OTHER_HDRS = [
"platform/stacktrace.h",
"platform/stacktrace_handler.h",
"platform/context.h",
+ "platform/cpu_info.h",
"platform/cpu_feature_guard.h",
"platform/dynamic_annotations.h",
"platform/env.h",
@@ -323,7 +328,6 @@ cc_library(
srcs = glob([
"platform/*/integral_types.h",
"platform/*/logging.h",
- "platform/*/cpu_info.h",
]),
hdrs = PLATFORM_BASE_HDRS,
deps = [
@@ -542,6 +546,7 @@ tf_cuda_library(
"framework/device_base.h",
"framework/function.h",
"framework/graph_def_util.h",
+ "framework/graph_to_functiondef.h",
"framework/kernel_def_builder.h",
"framework/log_memory.h",
"framework/lookup_interface.h",
@@ -674,6 +679,7 @@ cc_library(
"framework/tensor_types.h",
"framework/type_traits.h",
"lib/bfloat16/bfloat16.h",
+ "platform/byte_order.h",
"platform/default/dynamic_annotations.h",
"platform/default/integral_types.h",
"platform/default/logging.h",
@@ -995,6 +1001,7 @@ cc_library(
"//tensorflow/core/kernels:nn",
"//tensorflow/core/kernels:parameterized_truncated_normal_op",
"//tensorflow/core/kernels:parsing",
+ "//tensorflow/core/kernels:partitioned_function_ops",
"//tensorflow/core/kernels:random_ops",
"//tensorflow/core/kernels:random_poisson_op",
"//tensorflow/core/kernels:remote_fused_graph_ops",
@@ -1139,7 +1146,8 @@ filegroup(
filegroup(
name = "mobile_srcs_no_runtime",
srcs = [
- ":proto_text_srcs_all",
+ ":protos_all_proto_text_srcs",
+ ":error_codes_proto_text_srcs",
"//tensorflow/core/platform/default/build_config:android_srcs",
] + glob(
[
@@ -1624,6 +1632,18 @@ tf_proto_library_cc(
],
)
+tf_proto_library_cc(
+ name = "eager_service_proto",
+ srcs = ["protobuf/eager_service.proto"],
+ has_services = 1,
+ cc_api_version = 2,
+ cc_stubby_versions = ["2"],
+ protodeps = tf_additional_all_protos(),
+ visibility = [
+ "//tensorflow:internal",
+ ],
+)
+
LIB_INTERNAL_PRIVATE_HEADERS = ["framework/resource_handle.h"] + glob(
[
"lib/**/*.h",
@@ -1919,6 +1939,7 @@ cc_library(
"lib/core/casts.h",
"lib/core/stringpiece.h",
"lib/png/png_io.h",
+ "platform/byte_order.h",
"platform/cpu_info.h",
"platform/default/integral_types.h",
"platform/default/logging.h",
@@ -1934,15 +1955,58 @@ cc_library(
],
)
-proto_text_hdrs_and_srcs = tf_generate_proto_text_sources(
- name = "proto_text_srcs_all",
- srcs = CORE_PROTO_SRCS,
+tf_proto_library(
+ name = "error_codes_proto",
+ srcs = ERROR_CODES_PROTO_SRCS,
+ cc_api_version = 2,
+ default_header = True,
+ j2objc_api_version = 1,
+ java_api_version = 2,
+ js_api_version = 2,
+)
+
+tf_generate_proto_text_sources(
+ name = "error_codes_proto_text",
+ srcs = ERROR_CODES_PROTO_SRCS,
+ protodeps = [],
srcs_relative_dir = "tensorflow/core/",
+ deps = [
+ ":error_codes_proto_cc",
+ ":lib_internal",
+ ],
+)
+
+tf_proto_library(
+ name = "protos_all_proto",
+ srcs = COMMON_PROTO_SRCS + ADDITIONAL_CORE_PROTO_SRCS,
+ cc_api_version = 2,
+ default_header = True,
+ j2objc_api_version = 1,
+ java_api_version = 2,
+ js_api_version = 2,
+ protodeps = [
+ ":error_codes_proto",
+ ],
+)
+
+tf_generate_proto_text_sources(
+ name = "protos_all_proto_text",
+ srcs = COMMON_PROTO_SRCS,
+ protodeps = ERROR_CODES_PROTO_SRCS,
+ srcs_relative_dir = "tensorflow/core/",
+ deps = [
+ ":error_codes_proto_text",
+ ":lib_internal",
+ ":protos_all_proto_cc",
+ ],
)
cc_library(
name = "proto_text",
- hdrs = proto_text_hdrs_and_srcs.hdrs,
+ hdrs = [
+ ":error_codes_proto_text_hdrs",
+ ":protos_all_proto_text_hdrs",
+ ],
deps = [
":lib",
":lib_internal",
@@ -2087,7 +2151,7 @@ tf_cuda_library(
"util/memmapped_file_system.cc",
"util/memmapped_file_system_writer.cc",
],
- }) + proto_text_hdrs_and_srcs.srcs + tf_additional_framework_srcs(),
+ }) + tf_additional_framework_srcs(),
hdrs = FRAMEWORK_INTERNAL_PUBLIC_HEADERS,
copts = tf_copts(),
linkopts = select({
@@ -2101,7 +2165,8 @@ tf_cuda_library(
deps = [
":lib",
":lib_internal",
- ":proto_text",
+ ":protos_all_proto_text",
+ ":error_codes_proto_text",
":protos_all_cc",
":version_lib",
"//tensorflow/core/platform/default/build_config:platformlib",
@@ -2513,6 +2578,19 @@ tf_cuda_library(
cc_library(
name = "gpu_id",
+ hdrs = [
+ "common_runtime/gpu/gpu_id.h",
+ "common_runtime/gpu/gpu_id_manager.h",
+ ],
+ deps = [
+ ":lib",
+ ] + if_static([
+ ":gpu_id_impl",
+ ]),
+)
+
+cc_library(
+ name = "gpu_id_impl",
srcs = ["common_runtime/gpu/gpu_id_manager.cc"],
hdrs = [
"common_runtime/gpu/gpu_id.h",
@@ -2562,7 +2640,7 @@ tf_cuda_library(
":core_cpu_lib",
":framework",
":framework_internal",
- ":gpu_id",
+ ":gpu_id_impl",
":gpu_init_impl",
":gpu_lib",
":graph",
@@ -2998,6 +3076,7 @@ tf_cc_tests(
"framework/common_shape_fns_test.cc",
"framework/function_test.cc",
"framework/graph_def_util_test.cc",
+ "framework/graph_to_functiondef_test.cc",
"framework/kernel_def_builder_test.cc",
"framework/memory_types_test.cc",
"framework/node_def_builder_test.cc",
@@ -3076,6 +3155,8 @@ tf_cc_tests(
":testlib",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/cc:function_ops",
+ "//tensorflow/cc:ops",
"//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops",
"//tensorflow/cc:while_loop",
@@ -4080,3 +4161,9 @@ alias(
actual = ":mobile_srcs",
visibility = ["//visibility:public"],
)
+
+closure_proto_library(
+ name = "example_protos_closure",
+ visibility = ["//visibility:public"],
+ deps = [":example_protos"],
+)
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt
index daeb5fe9a2..461b498662 100644
--- a/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNN.pbtxt
@@ -7,30 +7,30 @@ buffer.
rnn_mode: Indicates the type of the RNN model.
input_mode: Indicate whether there is a linear projection between the input and
- The actual computation before the first layer. 'skip_input' is only allowed
+ the actual computation before the first layer. 'skip_input' is only allowed
when input_size == num_units; 'auto_select' implies 'skip_input' when
input_size == num_units; otherwise, it implies 'linear_input'.
-direction: Indicates whether a bidirectional model will be used.
- dir = (direction == bidirectional) ? 2 : 1
-dropout: dropout probability. When set to 0., dropout is disabled.
-seed: the 1st part of a seed to initialize dropout.
-seed2: the 2nd part of a seed to initialize dropout.
-input: a 3-D tensor with the shape of [seq_length, batch_size, input_size].
-input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size,
+direction: Indicates whether a bidirectional model will be used. Should be
+ "unidirectional" or "bidirectional".
+dropout: Dropout probability. When set to 0., dropout is disabled.
+seed: The 1st part of a seed to initialize dropout.
+seed2: The 2nd part of a seed to initialize dropout.
+input: A 3-D tensor with the shape of [seq_length, batch_size, input_size].
+input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size,
num_units].
input_c: For LSTM, a 3-D tensor with the shape of
[num_layer * dir, batch, num_units]. For other models, it is ignored.
-params: a 1-D tensor that contains the weights and biases in an opaque layout.
+params: A 1-D tensor that contains the weights and biases in an opaque layout.
The size must be created through CudnnRNNParamsSize, and initialized
separately. Note that they might not be compatible across different
generations. So it is a good idea to save and restore
-output: a 3-D tensor with the shape of [seq_length, batch_size,
+output: A 3-D tensor with the shape of [seq_length, batch_size,
dir * num_units].
-output_h: the same shape has input_h.
-output_c: the same shape as input_c for LSTM. An empty tensor for other models.
+output_h: The same shape has input_h.
+output_c: The same shape as input_c for LSTM. An empty tensor for other models.
is_training: Indicates whether this operation is used for inferenece or
training.
-reserve_space: an opaque tensor that can be used in backprop calculation. It
+reserve_space: An opaque tensor that can be used in backprop calculation. It
is only produced if is_training is false.
END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt
index 075ec52648..7cd5ae637b 100644
--- a/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackprop.pbtxt
@@ -6,27 +6,27 @@ Compute the backprop of both data and weights in a RNN.
rnn_mode: Indicates the type of the RNN model.
input_mode: Indicate whether there is a linear projection between the input and
- The actual computation before the first layer. 'skip_input' is only allowed
+ the actual computation before the first layer. 'skip_input' is only allowed
when input_size == num_units; 'auto_select' implies 'skip_input' when
input_size == num_units; otherwise, it implies 'linear_input'.
-direction: Indicates whether a bidirectional model will be used.
- dir = (direction == bidirectional) ? 2 : 1
-dropout: dropout probability. When set to 0., dropout is disabled.
-seed: the 1st part of a seed to initialize dropout.
-seed2: the 2nd part of a seed to initialize dropout.
-input: a 3-D tensor with the shape of [seq_length, batch_size, input_size].
-input_h: a 3-D tensor with the shape of [num_layer * dir, batch_size,
+direction: Indicates whether a bidirectional model will be used. Should be
+ "unidirectional" or "bidirectional".
+dropout: Dropout probability. When set to 0., dropout is disabled.
+seed: The 1st part of a seed to initialize dropout.
+seed2: The 2nd part of a seed to initialize dropout.
+input: A 3-D tensor with the shape of [seq_length, batch_size, input_size].
+input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size,
num_units].
input_c: For LSTM, a 3-D tensor with the shape of
[num_layer * dir, batch, num_units]. For other models, it is ignored.
-params: a 1-D tensor that contains the weights and biases in an opaque layout.
+params: A 1-D tensor that contains the weights and biases in an opaque layout.
The size must be created through CudnnRNNParamsSize, and initialized
separately. Note that they might not be compatible across different
generations. So it is a good idea to save and restore
-output: a 3-D tensor with the shape of [seq_length, batch_size,
+output: A 3-D tensor with the shape of [seq_length, batch_size,
dir * num_units].
-output_h: the same shape has input_h.
-output_c: the same shape as input_c for LSTM. An empty tensor for other models.
+output_h: The same shape has input_h.
+output_c: The same shape as input_c for LSTM. An empty tensor for other models.
output_backprop: A 3-D tensor with the same shape as output in the forward pass.
output_h_backprop: A 3-D tensor with the same shape as output_h in the forward
pass.
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV2.pbtxt
new file mode 100644
index 0000000000..03aa9cc250
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNBackpropV2.pbtxt
@@ -0,0 +1,49 @@
+op {
+ graph_op_name: "CudnnRNNBackpropV2"
+ visibility: HIDDEN
+ summary: "Backprop step of CudnnRNN."
+ description: <<END
+Compute the backprop of both data and weights in a RNN. Takes an extra
+ "host_reserved" inupt than CudnnRNNBackprop, which is used to determine RNN
+ cudnnRNNAlgo_t and cudnnMathType_t.
+
+rnn_mode: Indicates the type of the RNN model.
+input_mode: Indicates whether there is a linear projection between the input and
+ the actual computation before the first layer. 'skip_input' is only allowed
+ when input_size == num_units; 'auto_select' implies 'skip_input' when
+ input_size == num_units; otherwise, it implies 'linear_input'.
+direction: Indicates whether a bidirectional model will be used. Should be
+ "unidirectional" or "bidirectional".
+dropout: Dropout probability. When set to 0., dropout is disabled.
+seed: The 1st part of a seed to initialize dropout.
+seed2: The 2nd part of a seed to initialize dropout.
+input: A 3-D tensor with the shape of [seq_length, batch_size, input_size].
+input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size,
+ num_units].
+input_c: For LSTM, a 3-D tensor with the shape of
+ [num_layer * dir, batch, num_units]. For other models, it is ignored.
+params: A 1-D tensor that contains the weights and biases in an opaque layout.
+ The size must be created through CudnnRNNParamsSize, and initialized
+ separately. Note that they might not be compatible across different
+ generations. So it is a good idea to save and restore
+output: A 3-D tensor with the shape of [seq_length, batch_size,
+ dir * num_units].
+output_h: The same shape has input_h.
+output_c: The same shape as input_c for LSTM. An empty tensor for other models.
+output_backprop: A 3-D tensor with the same shape as output in the forward pass.
+output_h_backprop: A 3-D tensor with the same shape as output_h in the forward
+ pass.
+output_c_backprop: A 3-D tensor with the same shape as output_c in the forward
+ pass.
+reserve_space: The same reserve_space produced in the forward operation.
+host_reserved: The same host_reserved produced in the forward operation.
+input_backprop: The backprop to input in the forward pass. Has the same shape
+ as input.
+input_h_backprop: The backprop to input_h in the forward pass. Has the same
+ shape as input_h.
+input_c_backprop: The backprop to input_c in the forward pass. Has the same
+ shape as input_c.
+params_backprop: The backprop to the params buffer in the forward pass. Has the
+ same shape as params.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_CudnnRNNV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_CudnnRNNV2.pbtxt
new file mode 100644
index 0000000000..c8a39de68c
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_CudnnRNNV2.pbtxt
@@ -0,0 +1,40 @@
+op {
+ graph_op_name: "CudnnRNNV2"
+ visibility: HIDDEN
+ summary: "A RNN backed by cuDNN."
+ description: <<END
+Computes the RNN from the input and initial states, with respect to the params
+buffer. Produces one extra output "host_reserved" than CudnnRNN.
+
+rnn_mode: Indicates the type of the RNN model.
+input_mode: Indicates whether there is a linear projection between the input and
+ the actual computation before the first layer. 'skip_input' is only allowed
+ when input_size == num_units; 'auto_select' implies 'skip_input' when
+ input_size == num_units; otherwise, it implies 'linear_input'.
+direction: Indicates whether a bidirectional model will be used. Should be
+ "unidirectional" or "bidirectional".
+dropout: Dropout probability. When set to 0., dropout is disabled.
+seed: The 1st part of a seed to initialize dropout.
+seed2: The 2nd part of a seed to initialize dropout.
+input: A 3-D tensor with the shape of [seq_length, batch_size, input_size].
+input_h: A 3-D tensor with the shape of [num_layer * dir, batch_size,
+ num_units].
+input_c: For LSTM, a 3-D tensor with the shape of
+ [num_layer * dir, batch, num_units]. For other models, it is ignored.
+params: A 1-D tensor that contains the weights and biases in an opaque layout.
+ The size must be created through CudnnRNNParamsSize, and initialized
+ separately. Note that they might not be compatible across different
+ generations. So it is a good idea to save and restore
+output: A 3-D tensor with the shape of [seq_length, batch_size,
+ dir * num_units].
+output_h: The same shape has input_h.
+output_c: The same shape as input_c for LSTM. An empty tensor for other models.
+is_training: Indicates whether this operation is used for inferenece or
+ training.
+reserve_space: An opaque tensor that can be used in backprop calculation. It
+ is only produced if is_training is true.
+host_reserved: An opaque tensor that can be used in backprop calculation. It is
+ only produced if is_training is true. It is output on host memory rather than
+ device memory.
+END
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_PartitionedCall.pbtxt b/tensorflow/core/api_def/base_api/api_def_PartitionedCall.pbtxt
new file mode 100644
index 0000000000..caf8172a52
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_PartitionedCall.pbtxt
@@ -0,0 +1,23 @@
+op {
+ graph_op_name: "PartitionedCall"
+ in_arg {
+ name: "args"
+ description: "A list of input tensors."
+ }
+ out_arg {
+ name: "output"
+ description: "A list of return values."
+ }
+ attr { name: "Tin" description: "A list of input types." }
+ attr { name: "Tout" description: "A list of output types." }
+ attr {
+ name: "f"
+ description: <<END
+ A function that takes 'args', a list of tensors, and returns 'output',
+ another list of tensors. Input and output types are specified by 'Tin'
+ and 'Tout'. The function body of f will be placed and partitioned across
+ devices, setting this op apart from the regular Call op.
+END
+ }
+ summary: "returns `f(inputs)`, where `f`'s body is placed and partitioned."
+}
diff --git a/tensorflow/core/api_def/python_api/api_def_PartitionedCall.pbtxt b/tensorflow/core/api_def/python_api/api_def_PartitionedCall.pbtxt
new file mode 100644
index 0000000000..c443acd2e9
--- /dev/null
+++ b/tensorflow/core/api_def/python_api/api_def_PartitionedCall.pbtxt
@@ -0,0 +1 @@
+op { graph_op_name: "PartitionedCall" visibility: HIDDEN }
diff --git a/tensorflow/core/common_runtime/direct_session.cc b/tensorflow/core/common_runtime/direct_session.cc
index 0479061daf..0afbd02e86 100644
--- a/tensorflow/core/common_runtime/direct_session.cc
+++ b/tensorflow/core/common_runtime/direct_session.cc
@@ -54,7 +54,7 @@ limitations under the License.
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/device_tracer.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mutex.h"
diff --git a/tensorflow/core/common_runtime/function.cc b/tensorflow/core/common_runtime/function.cc
index d310520ebd..a6f637b488 100644
--- a/tensorflow/core/common_runtime/function.cc
+++ b/tensorflow/core/common_runtime/function.cc
@@ -209,6 +209,7 @@ class FunctionLibraryRuntimeImpl : public FunctionLibraryRuntime {
// The instantiated and transformed function is encoded as a Graph
// object, and an executor is created for the graph.
struct Item : public core::RefCounted {
+ bool invalidated = false;
const Graph* graph = nullptr; // Owned by exec.
const FunctionLibraryDefinition* overlay_lib = nullptr; // Not owned.
FunctionBody* func_graph = nullptr;
@@ -284,15 +285,7 @@ FunctionLibraryRuntimeImpl::FunctionLibraryRuntimeImpl(
}
FunctionLibraryRuntimeImpl::~FunctionLibraryRuntimeImpl() {
- // The most common patterns of FLR usage don't require the caller to
- // explicitly release handles. As a result, we try to unref each item until
- // it's erased.
- for (auto item : items_) {
- if (item.second) {
- while (!item.second->Unref()) {
- }
- }
- }
+ for (auto p : items_) p.second->Unref();
}
// An asynchronous op kernel which executes an instantiated function
@@ -497,24 +490,30 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
options_copy.target = device_name_;
const string key = Canonicalize(function_name, attrs, options_copy);
+ Handle found_handle = kInvalidHandle;
{
mutex_lock l(mu_);
- *handle = parent_->GetHandle(key);
- if (*handle != kInvalidHandle) {
+ found_handle = parent_->GetHandle(key);
+ if (found_handle != kInvalidHandle) {
FunctionLibraryRuntime::LocalHandle handle_on_device =
- parent_->GetHandleOnDevice(device_name_, *handle);
+ parent_->GetHandleOnDevice(device_name_, found_handle);
if (handle_on_device == kInvalidLocalHandle) {
return errors::Internal("LocalHandle not found for handle ", *handle,
".");
}
- auto item_handle = items_.find(handle_on_device);
- if (item_handle == items_.end()) {
+ auto iter = items_.find(handle_on_device);
+ if (iter == items_.end()) {
return errors::Internal("LocalHandle ", handle_on_device,
- " for handle ", *handle,
+ " for handle ", found_handle,
" not found in items.");
}
- item_handle->second->Ref();
- return Status::OK();
+ Item* item = iter->second;
+ if (!item->invalidated) {
+ *handle = found_handle;
+ return Status::OK();
+ }
+ // *item is invalidated. Fall through and instantiate the given
+ // function_name/attrs/option again.
}
}
@@ -546,10 +545,10 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
{
mutex_lock l(mu_);
- *handle = parent_->GetHandle(key);
- if (*handle != kInvalidHandle) {
+ Handle found_handle_again = parent_->GetHandle(key);
+ if (found_handle_again != found_handle) {
delete fbody;
- items_[parent_->GetHandleOnDevice(device_name_, *handle)]->Ref();
+ *handle = found_handle_again;
} else {
*handle = parent_->AddHandle(key, device_name_, next_handle_);
Item* item = new Item;
@@ -566,16 +565,12 @@ Status FunctionLibraryRuntimeImpl::ReleaseHandle(Handle handle) {
if (!parent_->IsInstantiatedOnDevice(device_name_, handle)) {
return parent_->ReleaseHandle(handle);
}
-
LocalHandle h = parent_->GetHandleOnDevice(device_name_, handle);
CHECK_NE(h, kInvalidLocalHandle);
mutex_lock l(mu_);
CHECK_EQ(1, items_.count(h));
Item* item = items_[h];
- if (item->Unref()) {
- items_.erase(h);
- TF_RETURN_IF_ERROR(parent_->RemoveHandle(handle));
- }
+ item->invalidated = true; // Reinstantiate later.
return Status::OK();
}
@@ -736,6 +731,7 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
// computation is done and stored in *rets, we send the return values back
// to the source_device (caller) so that the ProcFLR can receive them later.
std::vector<Tensor>* remote_args = new std::vector<Tensor>;
+ item->Ref();
ProcessFunctionLibraryRuntime::ReceiveTensorsAsync(
source_device, target_device, "arg_", src_incarnation, args.size(),
device_context, {}, rendezvous, remote_args,
@@ -747,6 +743,7 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
s = frame->SetArgs(*remote_args);
}
if (!s.ok()) {
+ item->Unref();
delete frame;
delete remote_args;
delete exec_args;
@@ -757,6 +754,7 @@ void FunctionLibraryRuntimeImpl::RunRemote(const Options& opts, Handle handle,
*exec_args, [item, frame, rets, done, source_device, target_device,
target_incarnation, rendezvous, device_context,
remote_args, exec_args](const Status& status) {
+ core::ScopedUnref unref(item);
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
@@ -842,11 +840,13 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
return;
}
+ item->Ref();
item->exec->RunAsync(
// Executor args
*exec_args,
// Done callback.
[item, frame, rets, done, exec_args](const Status& status) {
+ core::ScopedUnref unref(item);
Status s = status;
if (s.ok()) {
s = frame->ConsumeRetvals(rets);
@@ -906,6 +906,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
exec_args->runner = *run_opts.runner;
exec_args->call_frame = frame;
+ item->Ref();
item->exec->RunAsync(
// Executor args
*exec_args,
@@ -914,6 +915,7 @@ void FunctionLibraryRuntimeImpl::Run(const Options& opts, Handle handle,
[item, frame, exec_args](DoneCallback done,
// Start unbound arguments.
const Status& status) {
+ core::ScopedUnref unref(item);
delete exec_args;
done(status);
},
diff --git a/tensorflow/core/common_runtime/function_test.cc b/tensorflow/core/common_runtime/function_test.cc
index 61b2f0e60f..373fc64007 100644
--- a/tensorflow/core/common_runtime/function_test.cc
+++ b/tensorflow/core/common_runtime/function_test.cc
@@ -231,19 +231,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return status;
}
FunctionLibraryRuntime::Options opts;
- status = Run(flr, handle, opts, args, rets, add_runner);
- if (!status.ok()) return status;
-
- // Release the handle and try running again. It should not succeed.
- status = flr->ReleaseHandle(handle);
- if (!status.ok()) return status;
-
- Status status2 = Run(flr, handle, opts, args, std::move(rets));
- EXPECT_TRUE(errors::IsInvalidArgument(status2));
- EXPECT_TRUE(
- str_util::StrContains(status2.error_message(), "remote execution."));
-
- return status;
+ TF_RETURN_IF_ERROR(Run(flr, handle, opts, args, rets, add_runner));
+ return flr->ReleaseHandle(handle);
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
@@ -304,16 +293,8 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
*rets[i] = retvals[i];
}
- // Release the handle and try running again. It should not succeed.
- status = flr->ReleaseHandle(handle);
- if (!status.ok()) return status;
-
- Status status2 = Run(flr, handle, opts, args, std::move(rets));
- EXPECT_TRUE(errors::IsInvalidArgument(status2));
- EXPECT_TRUE(
- str_util::StrContains(status2.error_message(), "remote execution."));
-
- return status;
+ // Release the handle.
+ return flr->ReleaseHandle(handle);
}
std::unique_ptr<Graph> GetFuncBody(FunctionLibraryRuntime* flr,
diff --git a/tensorflow/core/common_runtime/function_threadpool_test.cc b/tensorflow/core/common_runtime/function_threadpool_test.cc
index 2d09e83d01..98dac38a8c 100644
--- a/tensorflow/core/common_runtime/function_threadpool_test.cc
+++ b/tensorflow/core/common_runtime/function_threadpool_test.cc
@@ -144,19 +144,7 @@ class FunctionLibraryRuntimeTest : public ::testing::Test {
return status;
}
FunctionLibraryRuntime::Options opts;
- status = Run(flr, handle, opts, args, rets, add_runner);
- if (!status.ok()) return status;
-
- // Release the handle and try running again. It should not succeed.
- status = flr->ReleaseHandle(handle);
- if (!status.ok()) return status;
-
- Status status2 = Run(flr, handle, opts, args, std::move(rets));
- EXPECT_TRUE(errors::IsInvalidArgument(status2));
- EXPECT_TRUE(
- str_util::StrContains(status2.error_message(), "remote execution."));
-
- return status;
+ return Run(flr, handle, opts, args, std::move(rets), add_runner);
}
Status Run(FunctionLibraryRuntime* flr, FunctionLibraryRuntime::Handle handle,
diff --git a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
index c2c0b020c7..ad142e9982 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.h
@@ -29,8 +29,6 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/config.pb.h"
-namespace gpu = ::perftools::gputools;
-
namespace tensorflow {
// A GPU memory allocator that implements a 'best-fit with coalescing'
@@ -52,7 +50,7 @@ class GPUBFCAllocator : public BFCAllocator {
class GPUMemAllocator : public SubAllocator {
public:
// Note: stream_exec cannot be null.
- explicit GPUMemAllocator(perftools::gputools::StreamExecutor* stream_exec)
+ explicit GPUMemAllocator(se::StreamExecutor* stream_exec)
: stream_exec_(stream_exec) {
CHECK(stream_exec_ != nullptr);
}
@@ -68,13 +66,13 @@ class GPUMemAllocator : public SubAllocator {
void Free(void* ptr, size_t num_bytes) override {
if (ptr != nullptr) {
- gpu::DeviceMemoryBase gpu_ptr(ptr);
+ se::DeviceMemoryBase gpu_ptr(ptr);
stream_exec_->Deallocate(&gpu_ptr);
}
}
private:
- perftools::gputools::StreamExecutor* stream_exec_; // not owned, non-null
+ se::StreamExecutor* stream_exec_; // not owned, non-null
TF_DISALLOW_COPY_AND_ASSIGN(GPUMemAllocator);
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
index 08961fc105..934a57a5fb 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.cc
@@ -38,7 +38,7 @@ GPUcudaMallocAllocator::~GPUcudaMallocAllocator() { delete base_allocator_; }
void* GPUcudaMallocAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
#ifdef GOOGLE_CUDA
// allocate with cudaMalloc
- gpu::cuda::ScopedActivateExecutorContext scoped_activation{stream_exec_};
+ se::cuda::ScopedActivateExecutorContext scoped_activation{stream_exec_};
CUdeviceptr rv = 0;
CUresult res = cuMemAlloc(&rv, num_bytes);
if (res != CUDA_SUCCESS) {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
index 208697361d..5043fac797 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_cudamalloc_allocator.h
@@ -44,7 +44,7 @@ class GPUcudaMallocAllocator : public VisitableAllocator {
private:
VisitableAllocator* base_allocator_ = nullptr; // owned
- perftools::gputools::StreamExecutor* stream_exec_; // Not owned.
+ se::StreamExecutor* stream_exec_; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(GPUcudaMallocAllocator);
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
index b0ca7e3109..e4c834b30d 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.cc
@@ -40,9 +40,8 @@ int64* NewMask(int64 word) {
int64* before_mask = NewMask(0xabababababababab);
int64* after_mask = NewMask(0xcdcdcdcdcdcdcdcd);
-bool CheckMask(perftools::gputools::StreamExecutor* exec, void* ptr,
- int64* mask) {
- gpu::DeviceMemory<int64> gpu_ptr{gpu::DeviceMemoryBase{ptr, MASK_BYTES}};
+bool CheckMask(se::StreamExecutor* exec, void* ptr, int64* mask) {
+ se::DeviceMemory<int64> gpu_ptr{se::DeviceMemoryBase{ptr, MASK_BYTES}};
int64 tmp[MASK_WORDS];
if (!exec->SynchronousMemcpy(&tmp, gpu_ptr, MASK_BYTES)) {
@@ -62,9 +61,8 @@ bool CheckMask(perftools::gputools::StreamExecutor* exec, void* ptr,
return ok;
}
-void InitMask(perftools::gputools::StreamExecutor* exec, void* ptr,
- int64* mask) {
- gpu::DeviceMemory<int64> gpu_ptr{gpu::DeviceMemoryBase{ptr, MASK_BYTES}};
+void InitMask(se::StreamExecutor* exec, void* ptr, int64* mask) {
+ se::DeviceMemory<int64> gpu_ptr{se::DeviceMemoryBase{ptr, MASK_BYTES}};
if (!exec->SynchronousMemcpy(&gpu_ptr, mask, MASK_BYTES)) {
LOG(FATAL) << "Could not copy debug mask";
}
@@ -176,8 +174,8 @@ void* GPUNanResetAllocator::AllocateRaw(size_t alignment, size_t num_bytes) {
size_t req_size = base_allocator_->RequestedSize(allocated_ptr);
std::vector<float> nans((req_size + sizeof(float) - 1) / sizeof(float),
std::nanf(""));
- gpu::DeviceMemory<float> nan_ptr{
- gpu::DeviceMemoryBase{static_cast<float*>(allocated_ptr), req_size}};
+ se::DeviceMemory<float> nan_ptr{
+ se::DeviceMemoryBase{static_cast<float*>(allocated_ptr), req_size}};
if (!stream_exec_->SynchronousMemcpy(&nan_ptr, &nans[0], req_size)) {
LOG(ERROR) << "Could not initialize to NaNs";
@@ -191,8 +189,8 @@ void GPUNanResetAllocator::DeallocateRaw(void* ptr) {
size_t req_size = base_allocator_->RequestedSize(ptr);
std::vector<float> nans((req_size + sizeof(float) - 1) / sizeof(float),
std::nanf(""));
- gpu::DeviceMemory<float> nan_ptr{
- gpu::DeviceMemoryBase{static_cast<float*>(ptr), req_size}};
+ se::DeviceMemory<float> nan_ptr{
+ se::DeviceMemoryBase{static_cast<float*>(ptr), req_size}};
if (!stream_exec_->SynchronousMemcpy(&nan_ptr, &nans[0], req_size)) {
LOG(ERROR) << "Could not initialize to NaNs";
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
index adce3a8436..c49ec2a566 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator.h
@@ -55,7 +55,7 @@ class GPUDebugAllocator : public VisitableAllocator {
private:
VisitableAllocator* base_allocator_ = nullptr; // owned
- perftools::gputools::StreamExecutor* stream_exec_; // Not owned.
+ se::StreamExecutor* stream_exec_; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(GPUDebugAllocator);
};
@@ -81,7 +81,7 @@ class GPUNanResetAllocator : public VisitableAllocator {
private:
VisitableAllocator* base_allocator_ = nullptr; // owned
- perftools::gputools::StreamExecutor* stream_exec_; // Not owned.
+ se::StreamExecutor* stream_exec_; // Not owned.
TF_DISALLOW_COPY_AND_ASSIGN(GPUNanResetAllocator);
};
diff --git a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
index d34f0cb3c2..236a0afa0b 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_debug_allocator_test.cc
@@ -43,7 +43,7 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_None) {
std::vector<int64> cpu_array(s);
memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
int64* gpu_array = a.Allocate<int64>(cpu_array.size());
- gpu::DeviceMemory<int64> gpu_array_ptr{gpu::DeviceMemoryBase{gpu_array}};
+ se::DeviceMemory<int64> gpu_array_ptr{se::DeviceMemoryBase{gpu_array}};
ASSERT_TRUE(stream_exec->SynchronousMemcpy(&gpu_array_ptr, &cpu_array[0],
s * sizeof(int64)));
EXPECT_TRUE(a.CheckHeader(gpu_array));
@@ -68,13 +68,13 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Header) {
memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
int64* gpu_array = a.Allocate<int64>(cpu_array.size());
- gpu::DeviceMemory<int64> gpu_array_ptr{
- gpu::DeviceMemoryBase{gpu_array}};
+ se::DeviceMemory<int64> gpu_array_ptr{
+ se::DeviceMemoryBase{gpu_array}};
ASSERT_TRUE(stream_exec->SynchronousMemcpy(
&gpu_array_ptr, &cpu_array[0], cpu_array.size() * sizeof(int64)));
- gpu::DeviceMemory<int64> gpu_hdr_ptr{
- gpu::DeviceMemoryBase{gpu_array - 1}};
+ se::DeviceMemory<int64> gpu_hdr_ptr{
+ se::DeviceMemoryBase{gpu_array - 1}};
// Clobber first word of the header.
float pi = 3.1417;
ASSERT_TRUE(
@@ -101,14 +101,14 @@ TEST(GPUDebugAllocatorTest, OverwriteDetection_Footer) {
memset(&cpu_array[0], 0, cpu_array.size() * sizeof(int64));
int64* gpu_array = a.Allocate<int64>(cpu_array.size());
- gpu::DeviceMemory<int64> gpu_array_ptr{
- gpu::DeviceMemoryBase{gpu_array}};
+ se::DeviceMemory<int64> gpu_array_ptr{
+ se::DeviceMemoryBase{gpu_array}};
ASSERT_TRUE(stream_exec->SynchronousMemcpy(
&gpu_array_ptr, &cpu_array[0], cpu_array.size() * sizeof(int64)));
// Clobber word of the footer.
- gpu::DeviceMemory<int64> gpu_ftr_ptr{
- gpu::DeviceMemoryBase{gpu_array + s}};
+ se::DeviceMemory<int64> gpu_ftr_ptr{
+ se::DeviceMemoryBase{gpu_array + s}};
float pi = 3.1417;
ASSERT_TRUE(
stream_exec->SynchronousMemcpy(&gpu_ftr_ptr, &pi, sizeof(float)));
@@ -131,7 +131,7 @@ TEST(GPUDebugAllocatorTest, ResetToNan) {
// Allocate 1024 floats
float* gpu_array = a.Allocate<float>(cpu_array.size());
- gpu::DeviceMemory<float> gpu_array_ptr{gpu::DeviceMemoryBase{gpu_array}};
+ se::DeviceMemory<float> gpu_array_ptr{se::DeviceMemoryBase{gpu_array}};
ASSERT_TRUE(stream_exec->SynchronousMemcpy(&cpu_array[0], gpu_array_ptr,
cpu_array.size() * sizeof(float)));
for (float f : cpu_array) {
@@ -174,7 +174,7 @@ TEST(GPUDebugAllocatorTest, ResetToNanWithHeaderFooter) {
// Allocate 1024 floats
float* gpu_array = a.Allocate<float>(cpu_array.size());
- gpu::DeviceMemory<float> gpu_array_ptr{gpu::DeviceMemoryBase{gpu_array}};
+ se::DeviceMemory<float> gpu_array_ptr{se::DeviceMemoryBase{gpu_array}};
ASSERT_TRUE(stream_exec->SynchronousMemcpy(&cpu_array[0], gpu_array_ptr,
cpu_array.size() * sizeof(float)));
for (float f : cpu_array) {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.cc b/tensorflow/core/common_runtime/gpu/gpu_device.cc
index 0b9e8f9cc2..1fa33991f7 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.cc
@@ -200,27 +200,27 @@ class BaseGPUDevice::StreamGroupFactory {
// This function is thread safe.
BaseGPUDevice::StreamGroup* GetOrCreate(TfGpuId tf_gpu_id,
int stream_group_within_gpu,
- gpu::StreamExecutor* executor) {
+ se::StreamExecutor* executor) {
mutex_lock guard(lock_);
StreamGroup* group =
&streams_[key_type(tf_gpu_id.value(), stream_group_within_gpu)];
if (!group->compute) {
- group->compute = new gpu::Stream(executor);
+ group->compute = new se::Stream(executor);
group->compute->Init();
VLOG(2) << "Created stream[" << stream_group_within_gpu
<< "] = " << group->compute;
- group->host_to_device = new gpu::Stream(executor);
+ group->host_to_device = new se::Stream(executor);
group->host_to_device->Init();
VLOG(2) << "Created host_to_device_stream[" << stream_group_within_gpu
<< "] = " << group->host_to_device;
- group->device_to_host = new gpu::Stream(executor);
+ group->device_to_host = new se::Stream(executor);
group->device_to_host->Init();
VLOG(2) << "Created device_to_host_stream[" << stream_group_within_gpu
<< "] = " << group->device_to_host;
- group->device_to_device = new gpu::Stream(executor);
+ group->device_to_device = new se::Stream(executor);
group->device_to_device->Init();
VLOG(2) << "Created device_to_device_stream[" << stream_group_within_gpu
<< "] = " << group->device_to_host;
@@ -297,9 +297,8 @@ Status BaseGPUDevice::Init(const SessionOptions& options) {
}
scratch_.push_back(static_cast<char*>(scratch_buffer));
- perftools::gputools::DeviceMemory<char> mem(
- perftools::gputools::DeviceMemoryBase(scratch_buffer,
- scratch_buffer_size));
+ se::DeviceMemory<char> mem(
+ se::DeviceMemoryBase(scratch_buffer, scratch_buffer_size));
bool ok = executor_->SynchronousMemZero(
&mem, Eigen::kCudaScratchSize + sizeof(unsigned int));
@@ -441,7 +440,7 @@ void BaseGPUDevice::ComputeHelper(OpKernel* op_kernel,
gpu_device_context =
static_cast<GPUDeviceContext*>(context->op_device_context());
}
- gpu::Stream* stream = gpu_device_context->stream();
+ se::Stream* stream = gpu_device_context->stream();
const auto stream_id = gpu_device_context->stream_id();
const bool vlog_1 = VLOG_IS_ON(1);
@@ -485,7 +484,7 @@ void BaseGPUDevice::ComputeHelper(OpKernel* op_kernel,
if (idc->stream() != stream) stream->ThenWaitFor(idc->stream());
}
}
- gpu::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()};
+ se::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()};
op_kernel->Compute(context);
if (context->status().ok()) {
if (sync_every_op_) {
@@ -504,7 +503,7 @@ void BaseGPUDevice::ConsumeListOfAccessedTensors(
if (device_context != nullptr) {
gpu_device_context = static_cast<GPUDeviceContext*>(device_context);
}
- gpu::Stream* stream = gpu_device_context->stream();
+ se::Stream* stream = gpu_device_context->stream();
em_->ThenDeleteTensors(stream, tensor_refs);
}
@@ -520,7 +519,7 @@ void BaseGPUDevice::ComputeAsync(AsyncOpKernel* op_kernel,
gpu_device_context =
static_cast<GPUDeviceContext*>(context->op_device_context());
}
- gpu::Stream* stream = gpu_device_context->stream();
+ se::Stream* stream = gpu_device_context->stream();
const auto stream_id = gpu_device_context->stream_id();
VLOG(1) << "GpuDevice::ComputeAsync " << op_kernel->name() << " op "
@@ -532,7 +531,7 @@ void BaseGPUDevice::ComputeAsync(AsyncOpKernel* op_kernel,
// false value. Measurements show that its overhead is negligible.
port::Tracing::TraceMe activity(op_kernel->name(), op_kernel->type_string(),
op_kernel->IsExpensive());
- gpu::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()};
+ se::cuda::ScopedActivateExecutorContext scoped_activation{stream->parent()};
op_kernel->ComputeAsync(context, done);
}
@@ -666,7 +665,7 @@ class ConcretePerOpGpuDevice : public PerOpGpuDevice {
Status ParseVisibleDeviceList(const string& visible_device_list,
std::vector<CudaGpuId>* visible_gpu_order) {
visible_gpu_order->clear();
- gpu::Platform* gpu_manager = GPUMachineManager();
+ se::Platform* gpu_manager = GPUMachineManager();
// If the user wants to remap the visible to virtual GPU mapping,
// check for that here.
@@ -785,7 +784,7 @@ Status SingleVirtualDeviceMemoryLimit(const GPUOptions& gpu_options,
int64* memory_limit) {
int64 total_memory = 0;
int64 available_memory = 0;
- gpu::StreamExecutor* se =
+ se::StreamExecutor* se =
GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
if (!se->DeviceMemoryUsage(&available_memory, &total_memory)) {
return errors::Unknown("Failed to query available memory for GPU ",
@@ -859,7 +858,7 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
const string& name_prefix,
std::vector<Device*>* devices) {
TF_RETURN_IF_ERROR(ValidateGPUMachineManager());
- gpu::Platform* gpu_manager = GPUMachineManager();
+ se::Platform* gpu_manager = GPUMachineManager();
if (gpu_manager == nullptr) {
return Status::OK();
}
@@ -998,7 +997,7 @@ Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
}
static string GetShortDeviceDescription(CudaGpuId cuda_gpu_id,
- const gpu::DeviceDescription& desc) {
+ const se::DeviceDescription& desc) {
int cc_major;
int cc_minor;
if (!desc.cuda_compute_capability(&cc_major, &cc_minor)) {
@@ -1026,9 +1025,9 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
CudaGpuId cuda_gpu_id = GpuIdManager::TfToCudaGpuId(tf_gpu_id);
int numa_node = dev_locality.numa_node();
- gpu::StreamExecutor* se =
+ se::StreamExecutor* se =
GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
- const gpu::DeviceDescription& desc = se->GetDeviceDescription();
+ const se::DeviceDescription& desc = se->GetDeviceDescription();
ProcessState* process_state = ProcessState::singleton();
Allocator* gpu_allocator = process_state->GetGPUAllocator(
options.config.gpu_options(), tf_gpu_id, memory_limit);
@@ -1061,15 +1060,15 @@ Status BaseGPUDeviceFactory::CreateGPUDevice(const SessionOptions& options,
namespace {
std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>>
-GetPeerAccessMap(gpu::Platform* platform,
+GetPeerAccessMap(se::Platform* platform,
const std::vector<CudaGpuId>& visible_gpu_order) {
std::unique_ptr<std::map<std::pair<CudaGpuId, CudaGpuId>, bool>> map(
new std::map<std::pair<CudaGpuId, CudaGpuId>, bool>);
for (CudaGpuId cuda_gpu_i : visible_gpu_order) {
for (CudaGpuId cuda_gpu_j : visible_gpu_order) {
- gpu::StreamExecutor* from =
+ se::StreamExecutor* from =
GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie();
- gpu::StreamExecutor* to =
+ se::StreamExecutor* to =
GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie();
(*map)[{cuda_gpu_i, cuda_gpu_j}] = from->CanEnablePeerAccessTo(to);
}
@@ -1081,7 +1080,7 @@ GetPeerAccessMap(gpu::Platform* platform,
} // namespace
Status BaseGPUDeviceFactory::GetInterconnectMaps(
- const std::vector<CudaGpuId>& visible_gpu_order, gpu::Platform* gpu_manager,
+ const std::vector<CudaGpuId>& visible_gpu_order, se::Platform* gpu_manager,
std::vector<InterconnectMap>* maps) {
// The default interconnect map is obtained from the StreamExecutor.
auto access_map = GetPeerAccessMap(gpu_manager, visible_gpu_order);
@@ -1112,9 +1111,9 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
// Get GPU bus_id from its reported NUMA affinity. Because GPUs are
// virtualized in some environments, we can't just use the GPU id.
// NUMA locales are indexed from 0, buses are indexed from 1.
- gpu::StreamExecutor* se =
+ se::StreamExecutor* se =
GpuIdUtil::ExecutorForCudaGpuId(cuda_gpu_id).ValueOrDie();
- const gpu::DeviceDescription& desc = se->GetDeviceDescription();
+ const se::DeviceDescription& desc = se->GetDeviceDescription();
int numa_node = desc.numa_node();
if (numa_node < 0) {
// For some reason the StreamExecutor couldn't get the NUMA
@@ -1170,7 +1169,7 @@ Status BaseGPUDeviceFactory::GetDeviceLocalities(
}
static int GetDefaultMinGPUMultiprocessorCount(
- gpu::Platform* gpu_manager,
+ se::Platform* gpu_manager,
const std::vector<CudaGpuId>& visible_gpu_order) {
static const int kDefaultMinGPUMultiprocessorCount = 8;
@@ -1183,8 +1182,8 @@ static int GetDefaultMinGPUMultiprocessorCount(
continue;
}
- gpu::StreamExecutor* se = exec_status.ValueOrDie();
- const gpu::DeviceDescription& desc = se->GetDeviceDescription();
+ se::StreamExecutor* se = exec_status.ValueOrDie();
+ const se::DeviceDescription& desc = se->GetDeviceDescription();
max_count = std::max(max_count, desc.core_count());
}
@@ -1196,7 +1195,7 @@ static int GetDefaultMinGPUMultiprocessorCount(
}
static int GetMinGPUMultiprocessorCount(
- gpu::Platform* gpu_manager,
+ se::Platform* gpu_manager,
const std::vector<CudaGpuId>& visible_gpu_order) {
const char* tf_min_gpu_core_count = getenv("TF_MIN_GPU_MULTIPROCESSOR_COUNT");
@@ -1274,7 +1273,7 @@ std::vector<CudaVersion> GetSupportedCudaComputeCapabilities() {
return cuda_caps;
}
-Status EnablePeerAccess(gpu::Platform* platform,
+Status EnablePeerAccess(se::Platform* platform,
const std::vector<CudaGpuId>& visible_gpu_order) {
int possible_peer_count = 0;
int enabled_peer_count = 0;
@@ -1283,9 +1282,9 @@ Status EnablePeerAccess(gpu::Platform* platform,
for (int j = 0; j < visible_gpu_order.size(); ++j) {
const CudaGpuId cuda_gpu_j = visible_gpu_order[j];
// We have already validated that ExecutorForDevice() calls return OK.
- gpu::StreamExecutor* from =
+ se::StreamExecutor* from =
GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_i).ValueOrDie();
- gpu::StreamExecutor* to =
+ se::StreamExecutor* to =
GpuIdUtil::ExecutorForCudaGpuId(platform, cuda_gpu_j).ValueOrDie();
if (from->CanEnablePeerAccessTo(to)) {
@@ -1319,7 +1318,7 @@ Status EnablePeerAccess(gpu::Platform* platform,
Status BaseGPUDeviceFactory::GetValidDeviceIds(
const std::vector<CudaGpuId>& visible_gpu_order,
std::vector<CudaGpuId>* ids) {
- gpu::Platform* gpu_manager = GPUMachineManager();
+ se::Platform* gpu_manager = GPUMachineManager();
bool new_gpu_found = false;
for (int i = 0; i < visible_gpu_order.size(); ++i) {
const CudaGpuId cuda_gpu_id = visible_gpu_order[i];
@@ -1334,7 +1333,7 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
auto executor = GpuIdUtil::ExecutorForCudaGpuId(gpu_manager, cuda_gpu_id);
if (!executor.ok()) {
- return StreamExecutorUtil::ConvertStatus(executor.status());
+ return executor.status();
}
auto stream_exec = executor.ValueOrDie();
@@ -1389,8 +1388,8 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
<< exec_status.status().ToString();
continue;
}
- gpu::StreamExecutor* se = exec_status.ValueOrDie();
- const gpu::DeviceDescription& desc = se->GetDeviceDescription();
+ se::StreamExecutor* se = exec_status.ValueOrDie();
+ const se::DeviceDescription& desc = se->GetDeviceDescription();
CudaVersion device_capability;
if (!desc.cuda_compute_capability(&device_capability.major_part,
&device_capability.minor_part)) {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_device.h b/tensorflow/core/common_runtime/gpu/gpu_device.h
index cc5c3881dd..b754ffd2db 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_device.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_device.h
@@ -94,7 +94,7 @@ class BaseGPUDevice : public LocalDevice {
// The executor that provides control for the device; e.g., for CUDA this
// corresponds to the cuda context.
- gpu::StreamExecutor* executor() const { return executor_; }
+ se::StreamExecutor* executor() const { return executor_; }
Allocator* GetScopedAllocator(AllocatorAttributes attr,
int64 step_id) override;
@@ -107,15 +107,15 @@ class BaseGPUDevice : public LocalDevice {
Allocator* gpu_allocator_; // not owned
Allocator* cpu_allocator_; // not owned
- gpu::StreamExecutor* executor_; // not owned
+ se::StreamExecutor* executor_; // not owned
std::unique_ptr<ScopedAllocatorMgr> scoped_allocator_mgr_;
private:
struct StreamGroup {
- gpu::Stream* compute = nullptr;
- gpu::Stream* host_to_device = nullptr;
- gpu::Stream* device_to_host = nullptr;
- gpu::Stream* device_to_device = nullptr;
+ se::Stream* compute = nullptr;
+ se::Stream* host_to_device = nullptr;
+ se::Stream* device_to_host = nullptr;
+ se::Stream* device_to_device = nullptr;
};
class StreamGroupFactory;
@@ -168,7 +168,7 @@ class BaseGPUDeviceFactory : public DeviceFactory {
// pathways between GPUs.
virtual Status GetInterconnectMaps(
const std::vector<CudaGpuId>& visible_gpu_order,
- gpu::Platform* gpu_manager, std::vector<InterconnectMap>* maps);
+ se::Platform* gpu_manager, std::vector<InterconnectMap>* maps);
struct TfGpuIdHash {
std::size_t operator()(const TfGpuId& id) const noexcept {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
index af6a59a85d..4898448476 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.cc
@@ -18,11 +18,9 @@ limitations under the License.
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/protobuf/config.pb.h"
-namespace gpu = ::perftools::gputools;
-
namespace tensorflow {
-EventMgr::EventMgr(gpu::StreamExecutor* se, const GPUOptions& gpu_options)
+EventMgr::EventMgr(se::StreamExecutor* se, const GPUOptions& gpu_options)
: exec_(se),
deferred_bytes_threshold_(gpu_options.deferred_deletion_bytes()
? gpu_options.deferred_deletion_bytes()
@@ -94,7 +92,7 @@ void EventMgr::StopPollingLoop() {
}
}
-void EventMgr::ThenDeleteTensors(perftools::gputools::Stream* stream,
+void EventMgr::ThenDeleteTensors(se::Stream* stream,
const TensorReferenceVector& tensors) {
mutex_lock l(mu_);
// TODO(jeff): We currently keep one accumulated_tensors_ object.
@@ -152,16 +150,16 @@ void EventMgr::PollLoop() {
polling_stopped_->Notify();
}
-void EventMgr::QueueInUse(gpu::Stream* stream, InUse iu) {
+void EventMgr::QueueInUse(se::Stream* stream, InUse iu) {
VLOG(2) << "QueueInUse free_events_ " << free_events_.size()
<< " used_events_ " << used_events_.size();
// Events are created on demand, and repeatedly reused. There is no
// limit placed here on the number of allocated Events.
if (free_events_.empty()) {
- free_events_.push_back(new gpu::Event(exec_));
+ free_events_.push_back(new se::Event(exec_));
free_events_.back()->Init();
}
- gpu::Event* e = free_events_.back();
+ se::Event* e = free_events_.back();
free_events_.pop_back();
stream->ThenRecordEvent(e);
iu.event = e;
@@ -199,18 +197,18 @@ void EventMgr::PollEvents(bool is_dedicated_poller,
// the first non-complete record that is still pending.
for (auto& iu : used_events_) {
if (iu.event == nullptr) continue;
- gpu::Event::Status s = iu.event->PollForStatus();
+ se::Event::Status s = iu.event->PollForStatus();
switch (s) {
- case gpu::Event::Status::kUnknown:
- case gpu::Event::Status::kError:
+ case se::Event::Status::kUnknown:
+ case se::Event::Status::kError:
// We don't expect to see these. Someday maybe propagate
// a Status error, but for now fail hard.
LOG(FATAL) << "Unexpected Event status: " << static_cast<int>(s);
break;
- case gpu::Event::Status::kPending:
+ case se::Event::Status::kPending:
if (!is_dedicated_poller) return; // quit processing queue
break;
- case gpu::Event::Status::kComplete:
+ case se::Event::Status::kComplete:
// Make a copy of the InUse record so we can free it after releasing
// the lock
to_free->push_back(iu);
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
index fd5f50ca4e..b26f88a201 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr.h
@@ -44,14 +44,13 @@ class GPUOptions;
// Events are recorded.
class EventMgr {
public:
- EventMgr(perftools::gputools::StreamExecutor* se,
- const GPUOptions& gpu_options);
+ EventMgr(se::StreamExecutor* se, const GPUOptions& gpu_options);
~EventMgr();
// Releases the references on the elements of "tensors" as soon as
// all events currently enqueued on "stream" have completed.
- void ThenDeleteTensors(perftools::gputools::Stream* stream,
+ void ThenDeleteTensors(se::Stream* stream,
const TensorReferenceVector& tensors);
struct BufRec {
@@ -65,8 +64,7 @@ class EventMgr {
// Takes ownership of *bufrec.buf and calls bufrec.alloc->DeallocateRaw()
// on it as soon as all events currently enqueued on *stream have completed.
- inline void ThenDeleteBuffer(perftools::gputools::Stream* stream,
- BufRec bufrec) {
+ inline void ThenDeleteBuffer(se::Stream* stream, BufRec bufrec) {
ToFreeVector to_free;
{
mutex_lock l(mu_);
@@ -76,8 +74,7 @@ class EventMgr {
FreeMemory(to_free);
}
- inline void ThenExecute(perftools::gputools::Stream* stream,
- std::function<void()> func) {
+ inline void ThenExecute(se::Stream* stream, std::function<void()> func) {
ToFreeVector to_free;
{
mutex_lock l(mu_);
@@ -89,7 +86,7 @@ class EventMgr {
private:
friend class TEST_EventMgrHelper;
- perftools::gputools::StreamExecutor* const exec_;
+ se::StreamExecutor* const exec_;
const int64 deferred_bytes_threshold_;
const int32 polling_active_delay_usecs_;
mutex mu_;
@@ -98,7 +95,7 @@ class EventMgr {
void FlushAccumulatedTensors() EXCLUSIVE_LOCKS_REQUIRED(mu_);
struct InUse {
- perftools::gputools::Event* event;
+ se::Event* event;
TensorReferenceVector* mem;
BufRec bufrec;
std::function<void()> func;
@@ -130,22 +127,21 @@ class EventMgr {
// Stream-enqueue an unused Event and save with it a collection of
// Tensors and/or a BufRec to be deleted only after the Event
// records.
- void QueueInUse(perftools::gputools::Stream* stream, InUse in_use)
+ void QueueInUse(se::Stream* stream, InUse in_use)
EXCLUSIVE_LOCKS_REQUIRED(mu_);
- void QueueTensors(perftools::gputools::Stream* stream,
- TensorReferenceVector* tensors)
+ void QueueTensors(se::Stream* stream, TensorReferenceVector* tensors)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
QueueInUse(stream, {nullptr, tensors, BufRec(), nullptr});
}
- void QueueBuffer(perftools::gputools::Stream* stream, BufRec bufrec)
+ void QueueBuffer(se::Stream* stream, BufRec bufrec)
EXCLUSIVE_LOCKS_REQUIRED(mu_) {
QueueInUse(stream, {nullptr, nullptr, bufrec, nullptr});
}
- void QueueFunc(perftools::gputools::Stream* stream,
- std::function<void()> func) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ void QueueFunc(se::Stream* stream, std::function<void()> func)
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
QueueInUse(stream, {nullptr, nullptr, BufRec(), std::move(func)});
}
@@ -166,10 +162,10 @@ class EventMgr {
void StopPollingLoop();
// A stack of unused events
- std::vector<perftools::gputools::Event*> free_events_ GUARDED_BY(mu_);
+ std::vector<se::Event*> free_events_ GUARDED_BY(mu_);
// Buffered list of tensors waiting to have an event queued for deletion
- perftools::gputools::Stream* accumulated_stream_ GUARDED_BY(mu_);
+ se::Stream* accumulated_stream_ GUARDED_BY(mu_);
TensorReferenceVector* accumulated_tensors_ GUARDED_BY(mu_);
// Sum of the TotalBytes() of the tensors in "accumulated_tensors_"
int64 accumulated_tensor_bytes_ GUARDED_BY(mu_);
diff --git a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
index 3ad0b0eb85..1d4ad957b9 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_event_mgr_test.cc
@@ -23,8 +23,6 @@ limitations under the License.
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/protobuf/config.pb.h"
-namespace gpu = ::perftools::gputools;
-
namespace tensorflow {
class TEST_EventMgrHelper {
@@ -47,8 +45,7 @@ class TEST_EventMgrHelper {
return em_->free_events_.size();
}
- void QueueTensors(perftools::gputools::Stream* stream,
- TensorReferenceVector* tensors) {
+ void QueueTensors(se::Stream* stream, TensorReferenceVector* tensors) {
mutex_lock l(em_->mu_);
em_->QueueTensors(stream, tensors);
}
@@ -121,7 +118,7 @@ TEST(EventMgr, DelayedPolling) {
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, th.queue_size());
TensorReferenceVector* v = nullptr;
- std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
CHECK(stream.get());
stream->Init();
for (int i = 0; i < 5; ++i) {
@@ -153,7 +150,7 @@ TEST(EventMgr, FlushLargeTensorImmediately) {
EventMgr em(stream_exec, GPUOptions());
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, live_tensor_bytes);
- std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
CHECK(stream.get());
stream->Init();
for (int i = 0; i < 5; ++i) {
@@ -170,7 +167,7 @@ TEST(EventMgr, ManySmallTensorsFlushedImmediately) {
EventMgr em(stream_exec, GPUOptions());
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, live_tensor_bytes);
- std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
CHECK(stream.get());
stream->Init();
for (int i = 0; i < 5; ++i) {
@@ -189,8 +186,8 @@ TEST(EventMgr, StreamSwitchingFlushesImmediately) {
EventMgr em(stream_exec, GPUOptions());
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, live_tensor_bytes);
- std::unique_ptr<gpu::Stream> stream1(new gpu::Stream(stream_exec));
- std::unique_ptr<gpu::Stream> stream2(new gpu::Stream(stream_exec));
+ std::unique_ptr<se::Stream> stream1(new se::Stream(stream_exec));
+ std::unique_ptr<se::Stream> stream2(new se::Stream(stream_exec));
stream1->Init();
stream2->Init();
TensorReferenceVector v1;
@@ -211,7 +208,7 @@ TEST(EventMgr, ManySmallTensorsSeparateCallsFlushed) {
EventMgr em(stream_exec, GPUOptions());
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, live_tensor_bytes);
- std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
CHECK(stream.get());
stream->Init();
for (int i = 0; i < 5; ++i) {
@@ -234,7 +231,7 @@ TEST(EventMgr, NonEmptyShutdown) {
TEST_EventMgrHelper th(&em);
EXPECT_EQ(0, th.queue_size());
EXPECT_EQ(0, th.free_size());
- std::unique_ptr<gpu::Stream> stream(new gpu::Stream(stream_exec));
+ std::unique_ptr<se::Stream> stream(new se::Stream(stream_exec));
CHECK(stream.get());
stream->Init();
for (int i = 0; i < 5; ++i) {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
index 5c503d1261..42bf074e63 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_id_utils.h
@@ -24,24 +24,20 @@ limitations under the License.
namespace tensorflow {
-// TODO(b/77980417): Remove this and use the regular tensorflow::se alias once
-// that's available.
-namespace gpu = ::stream_executor;
-
// Utility methods for translation between Tensorflow GPU ids and CUDA GPU ids.
class GpuIdUtil {
public:
// Convenient methods for getting the associated executor given a TfGpuId or
// CudaGpuId.
- static gpu::port::StatusOr<gpu::StreamExecutor*> ExecutorForCudaGpuId(
- gpu::Platform* gpu_manager, CudaGpuId cuda_gpu_id) {
+ static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId(
+ se::Platform* gpu_manager, CudaGpuId cuda_gpu_id) {
return gpu_manager->ExecutorForDevice(cuda_gpu_id.value());
}
- static gpu::port::StatusOr<gpu::StreamExecutor*> ExecutorForCudaGpuId(
+ static se::port::StatusOr<se::StreamExecutor*> ExecutorForCudaGpuId(
CudaGpuId cuda_gpu_id) {
return ExecutorForCudaGpuId(GPUMachineManager(), cuda_gpu_id);
}
- static gpu::port::StatusOr<gpu::StreamExecutor*> ExecutorForTfGpuId(
+ static se::port::StatusOr<se::StreamExecutor*> ExecutorForTfGpuId(
TfGpuId tf_gpu_id) {
return ExecutorForCudaGpuId(GpuIdManager::TfToCudaGpuId(tf_gpu_id));
}
diff --git a/tensorflow/core/common_runtime/gpu/gpu_init.cc b/tensorflow/core/common_runtime/gpu/gpu_init.cc
index aa23e3cc61..e0ec93a98e 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_init.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_init.cc
@@ -26,21 +26,14 @@ limitations under the License.
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/stream_executor_util.h"
-namespace gpu = ::perftools::gputools;
-
namespace tensorflow {
Status ValidateGPUMachineManager() {
- auto result = gpu::MultiPlatformManager::PlatformWithName("CUDA");
- if (!result.ok()) {
- return StreamExecutorUtil::ConvertStatus(result.status());
- }
-
- return Status::OK();
+ return se::MultiPlatformManager::PlatformWithName("CUDA").status();
}
-gpu::Platform* GPUMachineManager() {
- auto result = gpu::MultiPlatformManager::PlatformWithName("CUDA");
+se::Platform* GPUMachineManager() {
+ auto result = se::MultiPlatformManager::PlatformWithName("CUDA");
if (!result.ok()) {
LOG(FATAL) << "Could not find Platform with name CUDA";
return nullptr;
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.cc b/tensorflow/core/common_runtime/gpu/gpu_util.cc
index 5214ceaae5..7ba853fa51 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_util.cc
+++ b/tensorflow/core/common_runtime/gpu/gpu_util.cc
@@ -55,19 +55,15 @@ limitations under the License.
const tensorflow::int64 FLAGS_brain_gpu_util_debug_string_maxlen = 128;
extern bool FLAGS_brain_gpu_record_mem_types;
-using perftools::gputools::DeviceMemoryBase;
-using perftools::gputools::Stream;
-
namespace tensorflow {
-// TODO(b/77980417): Remove this and use the regular tensorflow::se alias once
-// that's available.
-namespace gpu = ::stream_executor;
+using se::DeviceMemoryBase;
+using se::Stream;
Status PrepareCopy(Device* device, const DeviceContext* ctx, const Tensor& src,
const Tensor* dst,
const DeviceBase::GpuDeviceInfo** dev_info,
- gpu::Stream** stream) {
+ se::Stream** stream) {
if (device == nullptr) {
return errors::Internal("Unexpected null device.");
}
@@ -122,7 +118,7 @@ void GPUUtil::SetProtoFromGPU(const Tensor& tensor, Device* dev,
StatusCallback done) {
VLOG(1) << "SetProtoFromGPU device_context " << device_context;
const DeviceBase::GpuDeviceInfo* dev_info = nullptr;
- gpu::Stream* send_stream = nullptr;
+ se::Stream* send_stream = nullptr;
Status s = PrepareCopy(dev, device_context, tensor, nullptr, &dev_info,
&send_stream);
if (!s.ok()) {
@@ -197,7 +193,7 @@ void GPUUtil::DeviceToDeviceCopy(DeviceContext* send_dev_context,
const Tensor* input, Tensor* output,
StatusCallback done) {
const DeviceBase::GpuDeviceInfo* dev_info = nullptr;
- gpu::Stream* send_stream = nullptr;
+ se::Stream* send_stream = nullptr;
Status s = PrepareCopy(src, send_dev_context, *input, output, &dev_info,
&send_stream);
if (!s.ok()) {
@@ -264,7 +260,7 @@ void GPUUtil::CopyGPUTensorToCPU(Device* gpu_device,
StatusCallback done) {
VLOG(1) << "CopyGPUTensorToCPU";
const DeviceBase::GpuDeviceInfo* dev_info = nullptr;
- gpu::Stream* send_stream = nullptr;
+ se::Stream* send_stream = nullptr;
Status s = PrepareCopy(gpu_device, device_context, *gpu_tensor, cpu_tensor,
&dev_info, &send_stream);
if (!s.ok()) {
@@ -309,7 +305,7 @@ void GPUUtil::CopyCPUTensorToGPU(const Tensor* cpu_tensor,
StatusCallback done) {
VLOG(1) << "CopyCPUTensorToGPU";
const DeviceBase::GpuDeviceInfo* dev_info = nullptr;
- gpu::Stream* recv_stream = nullptr;
+ se::Stream* recv_stream = nullptr;
Status s = PrepareCopy(gpu_device, device_context, *cpu_tensor, gpu_tensor,
&dev_info, &recv_stream);
if (!s.ok()) {
@@ -432,7 +428,7 @@ void GPUUtil::CopyGPUTensorToSameGPU(Device* gpu_device,
StatusCallback done) {
VLOG(1) << "CopyGPUTensorToSameGPU";
const DeviceBase::GpuDeviceInfo* dev_info = nullptr;
- gpu::Stream* send_stream = nullptr;
+ se::Stream* send_stream = nullptr;
Status s = PrepareCopy(gpu_device, device_context, *src_gpu_tensor,
dst_gpu_tensor, &dev_info, &send_stream);
if (!s.ok()) {
diff --git a/tensorflow/core/common_runtime/gpu/gpu_util.h b/tensorflow/core/common_runtime/gpu/gpu_util.h
index 337dc89895..237b0044da 100644
--- a/tensorflow/core/common_runtime/gpu/gpu_util.h
+++ b/tensorflow/core/common_runtime/gpu/gpu_util.h
@@ -27,10 +27,6 @@ namespace tensorflow {
class RecvTensorResponse;
class TensorProto;
-// TODO(b/77980417): Remove this and use the regular tensorflow::se alias once
-// that's available.
-namespace gpu = ::stream_executor;
-
class GPUUtil {
public:
// "tensor" is GPU-local. "dev" is the hosting GPU.
@@ -74,10 +70,9 @@ class GPUUtil {
// NOTE: will be removed soon, see StreamExecutorUtil::AsDeviceMemory
// instead.
template <typename T>
- static perftools::gputools::DeviceMemory<T> AsDeviceMemory(const Tensor& t) {
+ static se::DeviceMemory<T> AsDeviceMemory(const Tensor& t) {
T* ptr = reinterpret_cast<T*>(const_cast<void*>(DMAHelper::base(&t)));
- return perftools::gputools::DeviceMemory<T>(
- perftools::gputools::DeviceMemoryBase(ptr, t.TotalBytes()));
+ return se::DeviceMemory<T>(se::DeviceMemoryBase(ptr, t.TotalBytes()));
}
// Computes a checksum over the contents of "tensor", which is allocated
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator.h b/tensorflow/core/common_runtime/gpu/pool_allocator.h
index 91ce830df8..310158aba1 100644
--- a/tensorflow/core/common_runtime/gpu/pool_allocator.h
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator.h
@@ -181,7 +181,7 @@ class BasicCPUAllocator : public SubAllocator {
class CUDAHostAllocator : public SubAllocator {
public:
// Note: stream_exec cannot be null.
- explicit CUDAHostAllocator(perftools::gputools::StreamExecutor* stream_exec)
+ explicit CUDAHostAllocator(se::StreamExecutor* stream_exec)
: stream_exec_(stream_exec) {
CHECK(stream_exec_ != nullptr);
}
@@ -206,7 +206,7 @@ class CUDAHostAllocator : public SubAllocator {
}
private:
- perftools::gputools::StreamExecutor* stream_exec_; // not owned, non-null
+ se::StreamExecutor* stream_exec_; // not owned, non-null
TF_DISALLOW_COPY_AND_ASSIGN(CUDAHostAllocator);
};
diff --git a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
index 85555955e3..a4c8d5fe86 100644
--- a/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
+++ b/tensorflow/core/common_runtime/gpu/pool_allocator_test.cc
@@ -20,18 +20,16 @@ limitations under the License.
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/test.h"
-namespace gpu = ::perftools::gputools;
-
namespace tensorflow {
namespace {
TEST(PoolAllocatorTest, ZeroSizeBuffers) {
- gpu::Platform* platform =
- gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ se::Platform* platform =
+ se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
PoolAllocator pool(
2 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
- platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
.ValueOrDie()),
new NoopRounder, "pool");
@@ -44,12 +42,12 @@ TEST(PoolAllocatorTest, ZeroSizeBuffers) {
}
TEST(PoolAllocatorTest, ZeroSizePool) {
- gpu::Platform* platform =
- gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ se::Platform* platform =
+ se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
PoolAllocator pool(
0 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
- platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
.ValueOrDie()),
new NoopRounder, "pool");
@@ -77,12 +75,12 @@ TEST(PoolAllocatorTest, ZeroSizePool) {
}
TEST(PoolAllocatorTest, Alignment) {
- gpu::Platform* platform =
- gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ se::Platform* platform =
+ se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
PoolAllocator pool(
0 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
- platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
.ValueOrDie()),
new NoopRounder, "pool");
for (int i = 0; i < 16; ++i) {
@@ -123,12 +121,12 @@ TEST(PoolAllocatorTest, AutoResize) {
}
TEST(PoolAllocatorTest, CudaHostAllocator) {
- gpu::Platform* platform =
- gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ se::Platform* platform =
+ se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
PoolAllocator pool(
2 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
- platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
.ValueOrDie()),
new NoopRounder, "pool");
@@ -200,12 +198,12 @@ TEST(PoolAllocatorTest, Pow2Rounder) {
}
TEST(PoolAllocatorTest, Name) {
- gpu::Platform* platform =
- gpu::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
+ se::Platform* platform =
+ se::MultiPlatformManager::PlatformWithName("cuda").ValueOrDie();
PoolAllocator pool(
2 /*pool_size_limit*/, false /*auto_resize*/,
new CUDAHostAllocator(
- platform->GetExecutor(gpu::StreamExecutorConfig(/*ordinal=*/0))
+ platform->GetExecutor(se::StreamExecutorConfig(/*ordinal=*/0))
.ValueOrDie()),
new NoopRounder, "pool");
EXPECT_EQ("pool", pool.Name());
diff --git a/tensorflow/core/common_runtime/gpu/process_state.cc b/tensorflow/core/common_runtime/gpu/process_state.cc
index 866a03d046..5ed01278c1 100644
--- a/tensorflow/core/common_runtime/gpu/process_state.cc
+++ b/tensorflow/core/common_runtime/gpu/process_state.cc
@@ -146,7 +146,7 @@ Allocator* ProcessState::GetGPUAllocator(const GPUOptions& options,
// If there are any pending AllocVisitors for this bus, add
// them now.
- gpu::StreamExecutor* se =
+ se::StreamExecutor* se =
GpuIdUtil::ExecutorForTfGpuId(tf_gpu_id).ValueOrDie();
int bus_id = se->GetDeviceDescription().numa_node();
if (bus_id >= 0 && bus_id < static_cast<int64>(gpu_visitors_.size())) {
@@ -257,7 +257,7 @@ Allocator* ProcessState::GetCUDAHostAllocator(int numa_node) {
// better source of information about which executor to use. For
// example, process_state could maybe save the first stream executor
// it knows is valid.
- gpu::StreamExecutor* se = nullptr;
+ se::StreamExecutor* se = nullptr;
for (int i = 0; i < static_cast<int>(gpu_allocators_.size()); ++i) {
if (gpu_allocators_[i] != nullptr) {
se = GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
@@ -305,7 +305,7 @@ void ProcessState::AddGPUAllocVisitor(int bus_id, AllocVisitor visitor) {
#if GOOGLE_CUDA
mutex_lock lock(mu_);
for (int i = 0; i < static_cast<int64>(gpu_allocators_.size()); ++i) {
- gpu::StreamExecutor* se =
+ se::StreamExecutor* se =
GpuIdUtil::ExecutorForTfGpuId(TfGpuId(i)).ValueOrDie();
if (gpu_allocators_[i] &&
(se->GetDeviceDescription().numa_node() + 1) == bus_id) {
diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h
index 38a18cd087..c92c5d1af3 100644
--- a/tensorflow/core/common_runtime/gpu_device_context.h
+++ b/tensorflow/core/common_runtime/gpu_device_context.h
@@ -25,16 +25,13 @@ class Stream;
namespace tensorflow {
-// TODO(b/77980417): Replace stream_executor:: with se:: once our namespace
-// migration is complete and the alias is available.
-
class GPUDeviceContext : public DeviceContext {
public:
// Does not take ownership of streams.
- GPUDeviceContext(int stream_id, stream_executor::Stream* stream,
- stream_executor::Stream* host_to_device_stream,
- stream_executor::Stream* device_to_host_stream,
- stream_executor::Stream* device_to_device_stream)
+ GPUDeviceContext(int stream_id, se::Stream* stream,
+ se::Stream* host_to_device_stream,
+ se::Stream* device_to_host_stream,
+ se::Stream* device_to_device_stream)
: stream_id_(stream_id),
stream_(stream),
host_to_device_stream_(host_to_device_stream),
@@ -43,14 +40,10 @@ class GPUDeviceContext : public DeviceContext {
~GPUDeviceContext() override {}
- stream_executor::Stream* stream() const override { return stream_; }
- stream_executor::Stream* host_to_device_stream() const {
- return host_to_device_stream_;
- }
- stream_executor::Stream* device_to_host_stream() const {
- return device_to_host_stream_;
- }
- stream_executor::Stream* device_to_device_stream() const {
+ se::Stream* stream() const override { return stream_; }
+ se::Stream* host_to_device_stream() const { return host_to_device_stream_; }
+ se::Stream* device_to_host_stream() const { return device_to_host_stream_; }
+ se::Stream* device_to_device_stream() const {
return device_to_device_stream_;
}
int stream_id() const { return stream_id_; }
@@ -63,20 +56,20 @@ class GPUDeviceContext : public DeviceContext {
Device* device, Tensor* cpu_tensor,
StatusCallback done) override;
- void MaintainLifetimeOnStream(
- const Tensor* t, perftools::gputools::Stream* stream) const override {}
+ void MaintainLifetimeOnStream(const Tensor* t,
+ se::Stream* stream) const override {}
private:
int stream_id_;
// The default primary stream to use for this context.
// All the memory belongs to this stream.
- stream_executor::Stream* stream_;
+ se::Stream* stream_;
// The stream to use for copy data from host into GPU.
- stream_executor::Stream* host_to_device_stream_;
+ se::Stream* host_to_device_stream_;
// The stream to use for copy data from GPU to host.
- stream_executor::Stream* device_to_host_stream_;
+ se::Stream* device_to_host_stream_;
// The stream to use for copy data between GPU.
- stream_executor::Stream* device_to_device_stream_;
+ se::Stream* device_to_device_stream_;
};
} // namespace tensorflow
diff --git a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
index 64d8849475..7de1b80e2d 100644
--- a/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
+++ b/tensorflow/core/common_runtime/kernel_benchmark_testlib.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/test_benchmark.h"
diff --git a/tensorflow/core/common_runtime/local_device.cc b/tensorflow/core/common_runtime/local_device.cc
index ca7f1614f1..873182371e 100644
--- a/tensorflow/core/common_runtime/local_device.cc
+++ b/tensorflow/core/common_runtime/local_device.cc
@@ -19,6 +19,7 @@ limitations under the License.
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/common_runtime/eigen_thread_pool.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_feature_guard.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.cc b/tensorflow/core/common_runtime/process_function_library_runtime.cc
index d05f146f21..e61ed8c479 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.cc
@@ -181,12 +181,7 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
const string& function_key, const string& device_name,
FunctionLibraryRuntime::LocalHandle local_handle) {
mutex_lock l(mu_);
- FunctionLibraryRuntime::Handle h =
- gtl::FindWithDefault(table_, function_key, kInvalidHandle);
- if (h != kInvalidHandle) {
- if (function_data_.count(h) != 0) return h;
- }
- h = next_handle_;
+ auto h = next_handle_;
FunctionData* fd = new FunctionData(device_name, local_handle);
function_data_[h] = std::unique_ptr<FunctionData>(fd);
table_[function_key] = h;
@@ -197,12 +192,7 @@ FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::AddHandle(
FunctionLibraryRuntime::Handle ProcessFunctionLibraryRuntime::GetHandle(
const string& function_key) const {
mutex_lock l(mu_);
- FunctionLibraryRuntime::Handle h =
- gtl::FindWithDefault(table_, function_key, kInvalidHandle);
- if (h != kInvalidHandle) {
- if (function_data_.count(h) == 0) return kInvalidHandle;
- }
- return h;
+ return gtl::FindWithDefault(table_, function_key, kInvalidHandle);
}
bool ProcessFunctionLibraryRuntime::IsInstantiatedOnDevice(
@@ -272,13 +262,6 @@ Status ProcessFunctionLibraryRuntime::Instantiate(
return Status::OK();
}
-Status ProcessFunctionLibraryRuntime::RemoveHandle(
- FunctionLibraryRuntime::Handle handle) {
- mutex_lock l(mu_);
- function_data_.erase(handle);
- return Status::OK();
-}
-
Status ProcessFunctionLibraryRuntime::ReleaseHandle(
FunctionLibraryRuntime::Handle handle) {
FunctionLibraryRuntime* flr = nullptr;
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime.h b/tensorflow/core/common_runtime/process_function_library_runtime.h
index c7b8259f78..05e5770899 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime.h
+++ b/tensorflow/core/common_runtime/process_function_library_runtime.h
@@ -134,9 +134,6 @@ class ProcessFunctionLibraryRuntime {
// of the device where the function is registered.
string GetDeviceName(FunctionLibraryRuntime::Handle handle);
- // Removes handle from the state owned by this object.
- Status RemoveHandle(FunctionLibraryRuntime::Handle handle);
-
Status Clone(Env* env, int graph_def_version,
const OptimizerOptions& optimizer_options,
CustomKernelCreator custom_kernel_creator,
diff --git a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
index 4fbf2abc67..cc10e77ad2 100644
--- a/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
+++ b/tensorflow/core/common_runtime/process_function_library_runtime_test.cc
@@ -119,12 +119,13 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
EXPECT_GE(call_count, 1); // Test runner is used.
- // Release the handle and then try running the function. It shouldn't
- // succeed.
+ // Release the handle and then try running the function. It
+ // should still succeed.
status = proc_flr_->ReleaseHandle(handle);
if (!status.ok()) {
return status;
}
+
Notification done2;
proc_flr_->Run(opts, handle, args, &out,
[&status, &done2](const Status& s) {
@@ -132,10 +133,7 @@ class ProcessFunctionLibraryRuntimeTest : public ::testing::Test {
done2.Notify();
});
done2.WaitForNotification();
- EXPECT_TRUE(errors::IsNotFound(status));
- EXPECT_TRUE(str_util::StrContains(status.error_message(), "not found."));
-
- return Status::OK();
+ return status;
}
std::vector<Device*> devices_;
diff --git a/tensorflow/core/common_runtime/process_util.cc b/tensorflow/core/common_runtime/process_util.cc
index 22fd940d82..f8f3a1ecd7 100644
--- a/tensorflow/core/common_runtime/process_util.cc
+++ b/tensorflow/core/common_runtime/process_util.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include <string.h>
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tracing.h"
diff --git a/tensorflow/core/distributed_runtime/master_session.cc b/tensorflow/core/distributed_runtime/master_session.cc
index e3022f38a2..83afc5b1a4 100644
--- a/tensorflow/core/distributed_runtime/master_session.cc
+++ b/tensorflow/core/distributed_runtime/master_session.cc
@@ -89,6 +89,10 @@ class MasterSession::ReffedClientGraph : public core::RefCounted {
~ReffedClientGraph() override {
if (should_deregister_) {
DeregisterPartitions();
+ } else {
+ for (Part& part : partitions_) {
+ worker_cache_->ReleaseWorker(part.name, part.worker);
+ }
}
}
@@ -1174,14 +1178,8 @@ Status MasterSession::Create(GraphDef* graph_def,
TF_RETURN_IF_ERROR(GraphExecutionState::MakeForBaseGraph(
graph_def, execution_options, &execution_state_));
}
- // TODO(b/36574172): Remove these conditions when ClusterSpec
- // propagation is supported in all servers.
- if (options.cluster_def != nullptr ||
- session_opts_.config.isolate_session_state()) {
- should_delete_worker_sessions_ = true;
- return CreateWorkerSessions(options);
- }
- return Status::OK();
+ should_delete_worker_sessions_ = true;
+ return CreateWorkerSessions(options);
}
Status MasterSession::CreateWorkerSessions(
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
index d004abd1c1..cde6b785dc 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.cc
@@ -30,7 +30,7 @@ limitations under the License.
namespace tensorflow {
-namespace {
+namespace internal {
class GrpcCall {
public:
explicit GrpcCall(CallContainer<GrpcCall>* container, int index, bool try_rpc,
@@ -57,9 +57,10 @@ class GrpcCall {
container_->Done(s, index_);
}
+ CallOptions* call_opts() { return &call_opts_; }
+ int index() { return index_; }
const string& request() const { return *request_msg_; }
string* response() const { return response_msg_; }
- CallOptions* call_opts() { return &call_opts_; }
private:
CallContainer<GrpcCall>* const container_;
@@ -72,7 +73,9 @@ class GrpcCall {
string* status_message_;
};
-} // namespace
+} // namespace internal
+
+using internal::GrpcCall;
GrpcRPCFactory::GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
int64 timeout_in_ms)
@@ -110,28 +113,6 @@ void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
Tensor* response_t, Tensor* status_code_t,
Tensor* status_message_t,
AsyncOpKernel::DoneCallback done) {
- auto address = address_t.flat<string>();
- auto method = method_t.flat<string>();
- auto request = request_t.flat<string>();
-
- // Stubs are maintained by the GrpcRPCFactory class and will be
- // deleted when the class is destroyed.
- ::grpc::GenericStub* singleton_stub = nullptr;
- if (address.size() == 1) {
- singleton_stub = GetOrCreateStubForAddress(address(0));
- }
- auto get_stub = [&address, this,
- singleton_stub](int64 ix) -> ::grpc::GenericStub* {
- return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
- : singleton_stub;
- };
- auto get_method_ptr = [&method](int64 ix) -> const string* {
- return (method.size() > 1) ? &(method(ix)) : &(method(0));
- };
- auto get_request_ptr = [&request](int64 ix) -> const string* {
- return (request.size() > 1) ? &(request(ix)) : &(request(0));
- };
-
if (try_rpc) {
// In this case status_code will never be set in the response,
// so we just set it to OK.
@@ -140,49 +121,22 @@ void GrpcRPCFactory::Call(OpKernelContext* ctx, int64 num_elements,
static_cast<int>(errors::Code::OK));
}
- CancellationManager* cm = ctx->cancellation_manager();
- CancellationToken cancellation_token = cm->get_cancellation_token();
-
- // This object will delete itself when done.
- auto* container =
- new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
- std::move(done), cancellation_token);
-
- auto response = response_t->flat<string>();
- int32* status_code_ptr = nullptr;
- string* status_message_ptr = nullptr;
- if (try_rpc) {
- status_code_ptr = status_code_t->flat<int32>().data();
- status_message_ptr = status_message_t->flat<string>().data();
- }
- for (int i = 0; i < num_elements; ++i) {
- container->calls()->emplace_back(
- container, i, try_rpc, get_request_ptr(i), &response(i),
- (try_rpc) ? &status_code_ptr[i] : nullptr,
- (try_rpc) ? &status_message_ptr[i] : nullptr);
- }
+ CallContainer<GrpcCall>::CreateCallFn create_call_fn =
+ [this, &request_t, &try_rpc, response_t, status_code_t, status_message_t](
+ CallContainer<GrpcCall>* container, int index) {
+ CreateCall(request_t, try_rpc, index, container, response_t,
+ status_code_t, status_message_t);
+ };
- int i = 0;
- for (GrpcCall& call : *(container->calls())) {
- // This object will delete itself when done.
- new RPCState<string>(get_stub(i), &completion_queue_, *get_method_ptr(i),
- call.request(), call.response(),
- /*done=*/[&call](const Status& s) { call.Done(s); },
- call.call_opts(), fail_fast_, timeout_in_ms_);
- ++i;
- }
+ CallContainer<GrpcCall>::StartCallFn start_call_fn =
+ [this, &address_t, &method_t](GrpcCall* call) {
+ StartCall(address_t, method_t, call);
+ };
- // Need to register this callback after all the RPCs are in
- // flight; otherwise we may try to cancel an RPC *before* it
- // launches, which is a no-op, and then fall into a deadlock.
- bool is_cancelled = !cm->RegisterCallback(
- cancellation_token, [container]() { container->StartCancel(); });
-
- if (is_cancelled) {
- ctx->SetStatus(errors::Cancelled("Operation has been cancelled."));
- // container's reference counter will take care of calling done().
- container->StartCancel();
- }
+ // This object will delete itself when done.
+ new CallContainer<GrpcCall>(ctx, num_elements, fail_fast_, try_rpc,
+ std::move(done), std::move(create_call_fn),
+ std::move(start_call_fn));
}
::grpc::GenericStub* GrpcRPCFactory::GetOrCreateStubForAddress(
@@ -210,4 +164,53 @@ GrpcRPCFactory::ChannelPtr GrpcRPCFactory::CreateChannelForAddress(
/*target=*/address, ::grpc::InsecureChannelCredentials(), args);
}
+void GrpcRPCFactory::CreateCall(const Tensor& request_t, const bool try_rpc,
+ int index, CallContainer<GrpcCall>* container,
+ Tensor* response_t, Tensor* status_code_t,
+ Tensor* status_message_t) {
+ auto request = request_t.flat<string>();
+ auto get_request_ptr = [&request](int64 ix) -> const string* {
+ return (request.size() > 1) ? &(request(ix)) : &(request(0));
+ };
+ auto response = response_t->flat<string>();
+ int32* status_code_ptr = nullptr;
+ string* status_message_ptr = nullptr;
+ if (try_rpc) {
+ status_code_ptr = status_code_t->flat<int32>().data();
+ status_message_ptr = status_message_t->flat<string>().data();
+ }
+ container->RegisterCall(container, index, try_rpc, get_request_ptr(index),
+ &response(index),
+ (try_rpc) ? &status_code_ptr[index] : nullptr,
+ (try_rpc) ? &status_message_ptr[index] : nullptr);
+}
+
+void GrpcRPCFactory::StartCall(const Tensor& address_t, const Tensor& method_t,
+ GrpcCall* call) {
+ auto address = address_t.flat<string>();
+ auto method = method_t.flat<string>();
+ // Stubs are maintained by the GrpcRPCFactory class and will be
+ // deleted when the class is destroyed.
+ ::grpc::GenericStub* singleton_stub = nullptr;
+ if (address.size() == 1) {
+ singleton_stub = GetOrCreateStubForAddress(address(0));
+ }
+ auto get_stub = [&address, this,
+ singleton_stub](int64 ix) -> ::grpc::GenericStub* {
+ return (address.size() > 1) ? GetOrCreateStubForAddress(address(ix))
+ : singleton_stub;
+ };
+ auto get_method_ptr = [&method](int64 ix) -> const string* {
+ return (method.size() > 1) ? &(method(ix)) : &(method(0));
+ };
+
+ int index = call->index();
+ // This object will delete itself when done.
+ new RPCState<string>(get_stub(index), &completion_queue_,
+ *get_method_ptr(index), call->request(),
+ call->response(),
+ /*done=*/[call](const Status& s) { call->Done(s); },
+ call->call_opts(), fail_fast_, timeout_in_ms_);
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h
index 34ec235aaf..29394c84b5 100644
--- a/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h
+++ b/tensorflow/core/distributed_runtime/rpc/grpc_rpc_factory.h
@@ -20,10 +20,16 @@ limitations under the License.
#include "tensorflow/core/distributed_runtime/rpc/grpc_util.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/util/rpc/call_container.h"
#include "tensorflow/core/util/rpc/rpc_factory.h"
namespace tensorflow {
+// Forward declaration of GrpcCall.
+namespace internal {
+class GrpcCall;
+} // namespace internal
+
class GrpcRPCFactory : public RPCFactory {
public:
explicit GrpcRPCFactory(OpKernelConstruction* ctx, bool fail_fast,
@@ -42,6 +48,18 @@ class GrpcRPCFactory : public RPCFactory {
virtual ChannelPtr CreateChannelForAddress(const string& address);
private:
+ // Creates a call and registers it with given `container`. The `index` is used
+ // to index into the tensor arguments.
+ void CreateCall(const Tensor& request_t, const bool try_rpc, int index,
+ CallContainer<internal::GrpcCall>* container,
+ Tensor* response_t, Tensor* status_code_t,
+ Tensor* status_message_t);
+
+ // Asynchronously invokes the given `call`. The call completion is handled
+ // by the call container the call was previously registered with.
+ void StartCall(const Tensor& address_t, const Tensor& method_t,
+ internal::GrpcCall* call);
+
::grpc::GenericStub* GetOrCreateStubForAddress(const string& address);
bool fail_fast_;
diff --git a/tensorflow/core/framework/bfloat16.h b/tensorflow/core/framework/bfloat16.h
index 968c18bdd2..2f79d0fa70 100644
--- a/tensorflow/core/framework/bfloat16.h
+++ b/tensorflow/core/framework/bfloat16.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_FRAMEWORK_BFLOAT16_H_
#include "tensorflow/core/framework/numeric_types.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/types.h"
#if defined(PLATFORM_WINDOWS)
diff --git a/tensorflow/core/framework/collective.h b/tensorflow/core/framework/collective.h
index e440424aac..f6fe12e7ef 100644
--- a/tensorflow/core/framework/collective.h
+++ b/tensorflow/core/framework/collective.h
@@ -80,7 +80,7 @@ struct CollInstanceParams {
// Task name prefix of corresponding device name.
std::vector<string> task_names;
// True if every task has the same number of devices.
- bool same_num_devices_per_task;
+ bool same_num_devices_per_task = false;
CollImplDetails impl_details;
string ToString() const;
CollInstanceParams& operator=(const struct CollInstanceParams& other);
@@ -99,9 +99,9 @@ struct CollectiveParams {
CollInstanceParams instance;
CollTaskParams task;
- string name; // node name used only for log or error messages
- int default_rank; // index of this op within device_names
- bool is_source; // broadcast only
+ string name; // node name used only for log or error messages
+ int default_rank; // index of this op within device_names
+ bool is_source = false; // broadcast only
// Rank of this device in each subdivision permutation.
std::vector<int> subdiv_rank;
std::unique_ptr<OpKernel> merge_op; // reduction only
diff --git a/tensorflow/compiler/jit/graph_to_functiondef.cc b/tensorflow/core/framework/graph_to_functiondef.cc
index 8f5e11dfa4..4ffa503379 100644
--- a/tensorflow/compiler/jit/graph_to_functiondef.cc
+++ b/tensorflow/core/framework/graph_to_functiondef.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/jit/graph_to_functiondef.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include <unordered_map>
#include <unordered_set>
@@ -111,7 +111,7 @@ string NodeNameMapping::Renormalize(const string& name) const {
} // anonymous namespace
// Graph to FunctionDef conversion. This code is closely modeled on the Python
-// code in third_party/tensorflow/python/framework/function.py.
+// code in tensorflow/python/framework/function.py.
Status GraphToFunctionDef(const Graph& graph, const string& name,
FunctionDef* fdef) {
diff --git a/tensorflow/compiler/jit/graph_to_functiondef.h b/tensorflow/core/framework/graph_to_functiondef.h
index 3e1ae7bbbe..cb0e2b2fbd 100644
--- a/tensorflow/compiler/jit/graph_to_functiondef.h
+++ b/tensorflow/core/framework/graph_to_functiondef.h
@@ -13,8 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#ifndef TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_
-#define TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_
+#ifndef TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_
+#define TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/graph/graph.h"
@@ -23,11 +23,10 @@ limitations under the License.
namespace tensorflow {
// Converts 'graph' to a FunctionDef 'fdef', with name 'name'.
-// Closely modeled on the Python code in
-// third_party/tensorflow/python/framework/function.py
+// Closely modeled on the Python code in tensorflow/python/framework/function.py
Status GraphToFunctionDef(const Graph& graph, const string& name,
FunctionDef* fdef);
} // namespace tensorflow
-#endif // TENSORFLOW_COMPILER_JIT_GRAPH_TO_FUNCTIONDEF_H_
+#endif // TENSORFLOW_CORE_FRAMEWORK_GRAPH_TO_FUNCTIONDEF_H_
diff --git a/tensorflow/compiler/jit/graph_to_functiondef_test.cc b/tensorflow/core/framework/graph_to_functiondef_test.cc
index 676db7c4dd..587e2c07ac 100644
--- a/tensorflow/compiler/jit/graph_to_functiondef_test.cc
+++ b/tensorflow/core/framework/graph_to_functiondef_test.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-#include "tensorflow/compiler/jit/graph_to_functiondef.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
#include "tensorflow/cc/framework/ops.h"
#include "tensorflow/cc/ops/function_ops.h"
diff --git a/tensorflow/core/framework/remote_fused_graph_execute_info.proto b/tensorflow/core/framework/remote_fused_graph_execute_info.proto
index 389a08ac2f..946da40d0e 100644
--- a/tensorflow/core/framework/remote_fused_graph_execute_info.proto
+++ b/tensorflow/core/framework/remote_fused_graph_execute_info.proto
@@ -14,14 +14,6 @@ import "tensorflow/core/framework/types.proto";
// not valid across executions, but can be serialized back and forth from within
// a single run.
message RemoteFusedGraphExecuteInfo {
- enum NodeType {
- UNUSED = 0;
- GRAPH_INPUT = 1;
- GRAPH_OUTPUT = 2;
- FUSED_NODE = 3;
- BORDER_INPUT = 4;
- BORDER_OUTPUT = 5;
- }
message TensorShapeTypeProto {
DataType dtype = 1;
diff --git a/tensorflow/core/framework/resource_var.h b/tensorflow/core/framework/resource_var.h
new file mode 100644
index 0000000000..872b8f8b30
--- /dev/null
+++ b/tensorflow/core/framework/resource_var.h
@@ -0,0 +1,58 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_FRAMEWORK_RESOURCE_VAR_H_
+#define TENSORFLOW_CORE_FRAMEWORK_RESOURCE_VAR_H_
+
+#include "tensorflow/core/framework/resource_mgr.h"
+
+namespace tensorflow {
+
+// Resource stored by variables in the resource manager
+// (new, resource-style version).
+class Var : public ResourceBase {
+ public:
+ explicit Var(DataType dtype) : tensor_(dtype) {}
+ // Not copyable or movable.
+ Var(const Var&) = delete;
+ Var& operator=(const Var&) = delete;
+
+ // TODO(ebrevdo): Use LockSet instead of exposing mu.
+ mutex* mu() { return &mu_; }
+ Tensor* tensor() { return &tensor_; }
+
+ string DebugString() override {
+ return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
+ tensor_.shape().DebugString());
+ }
+
+ // Only used in the resource variable path. In resource variables,
+ // tensor.IsInitialized() can be true (i.e. have memory allocated to it) while
+ // there is not a good value there due to a race condition, and it's possible
+ // to stumble upon this during variable.initialized_value(). So it's best to
+ // just store directly whether the variable is initialized.
+ bool is_initialized = false; // GUARDED_BY(mu_) but annotalysis doesn't like
+ // it.
+
+ private:
+ mutex mu_;
+ Tensor tensor_;
+
+ ~Var() override {}
+};
+
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_CORE_FRAMEWORK_RESOURCE_VAR_H_
diff --git a/tensorflow/core/grappler/clusters/utils.cc b/tensorflow/core/grappler/clusters/utils.cc
index 50d6e6468f..a7519725a5 100644
--- a/tensorflow/core/grappler/clusters/utils.cc
+++ b/tensorflow/core/grappler/clusters/utils.cc
@@ -32,6 +32,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/mem.h"
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index ddbf7f3697..35f11eac29 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -42,6 +42,8 @@ cc_library(
deps = [
":utils",
"//tensorflow/core/grappler/utils:topological_sort",
+ "//tensorflow/core/grappler:graph_view",
+ "//tensorflow/core/grappler:op_types",
"//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index a0125ce342..313f63149d 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -19,10 +19,13 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
#include "tensorflow/core/framework/common_shape_fns.h"
+#include "tensorflow/core/framework/node_def_util.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/costs/utils.h"
+#include "tensorflow/core/grappler/graph_view.h"
+#include "tensorflow/core/grappler/op_types.h"
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/strings/str_util.h"
@@ -155,7 +158,7 @@ struct Processor<DimensionHandle> {
template <typename Handle>
class DisjointSet {
public:
- DisjointSet(const Processor<Handle>& processor) : processor_(processor) {}
+ DisjointSet() {}
~DisjointSet() {
for (auto rep : nodes_) {
delete rep.second;
@@ -253,16 +256,16 @@ typename DisjointSet<Handle>::Rep* DisjointSet<Handle>::Find(Handle value) {
return root;
}
-bool IsQueue(const Node& node) {
- return str_util::EndsWith(node.type_string(), "QueueV2");
+bool IsQueue(const NodeDef& node) {
+ return str_util::EndsWith(node.op(), "QueueV2");
}
// Returns true if the node is an Enter op AND its input is a Queue.
-bool IsEnterWithQueue(const Node& node) {
- if (node.IsEnter()) {
- const Node* in_node;
- TF_CHECK_OK(node.input_node(0, &in_node));
- return IsQueue(*in_node);
+bool IsEnterWithQueue(const NodeDef& node, const GraphView& graph) {
+ if (IsEnter(node)) {
+ GraphView::InputPort input(&node, 0);
+ GraphView::OutputPort fanin = graph.GetRegularFanin(input);
+ return IsQueue(*fanin.node);
}
return false;
}
@@ -279,8 +282,9 @@ bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
return false;
}
+// This really should be done in an external debugging tool
void VerboseLogUnknownDimensionSources(
- const Graph& graph,
+ const GraphDef& graph,
const std::map<string, std::vector<OpInfo::TensorProperties>>&
input_properties_map,
const std::map<string, std::vector<OpInfo::TensorProperties>>&
@@ -295,17 +299,13 @@ void VerboseLogUnknownDimensionSources(
// do not have any unknown dimensions in their inputs, but
// we have some unknown dimensions in their outputs.
std::map<string, int> op_to_count;
- for (const Node* const node : graph.nodes()) {
- if (node->num_outputs() == 0) {
- continue;
- }
-
- const auto& input_properties = input_properties_map.at(node->name());
- const auto& output_properties = output_properties_map.at(node->name());
+ for (const NodeDef& node : graph.node()) {
+ const auto& input_properties = input_properties_map.at(node.name());
+ const auto& output_properties = output_properties_map.at(node.name());
bool has_unknown_inputs = false;
- for (int i = 0; i < node->num_inputs(); ++i) {
- if (HasAnyUnknownDimensions(input_properties[i].shape())) {
+ for (const auto& input_prop : input_properties) {
+ if (HasAnyUnknownDimensions(input_prop.shape())) {
has_unknown_inputs = true;
break;
}
@@ -315,26 +315,24 @@ void VerboseLogUnknownDimensionSources(
continue;
}
- for (int i = 0; i < node->num_outputs(); ++i) {
- if (HasAnyUnknownDimensions(output_properties[i].shape())) {
+ for (const auto& output_prop : output_properties) {
+ if (HasAnyUnknownDimensions(output_prop.shape())) {
string inputs = "input_shapes=[";
- for (int i = 0; i < node->num_inputs(); ++i) {
- inputs +=
- PartialTensorShape::DebugString(input_properties[i].shape());
+ for (const auto& input_prop : input_properties) {
+ inputs += PartialTensorShape::DebugString(input_prop.shape());
}
inputs += "]";
string outputs = "output_shapes=[";
- for (int i = 0; i < node->num_outputs(); ++i) {
- outputs +=
- PartialTensorShape::DebugString(output_properties[i].shape());
+ for (const auto& output_prop : output_properties) {
+ outputs += PartialTensorShape::DebugString(output_prop.shape());
}
outputs += "]";
- VLOG(2) << "Node: " << node->name() << ", Op: " << node->def().op()
- << ", " << inputs << ", " << outputs;
+ VLOG(2) << "Node: " << node.name() << ", Op: " << node.op() << ", "
+ << inputs << ", " << outputs;
- op_to_count[node->def().op()]++;
+ op_to_count[node.op()]++;
// don't log again for this node
break;
@@ -357,13 +355,13 @@ void VerboseLogUnknownDimensionSources(
// information is refined.
class TopoQueue {
public:
- explicit TopoQueue(const std::unordered_map<const Node*, int>& topo_order)
+ explicit TopoQueue(const std::unordered_map<const NodeDef*, int>& topo_order)
: queue_(CompareNodes(topo_order)) {}
- void push(const Node* n) { queue_.insert(n); }
- const Node* pop() {
+ void push(const NodeDef* n) { queue_.insert(n); }
+ const NodeDef* pop() {
CHECK(!empty());
auto it = queue_.begin();
- const Node* n = *it;
+ const NodeDef* n = *it;
queue_.erase(it);
return n;
}
@@ -376,16 +374,16 @@ class TopoQueue {
// use their id to ensure they're sorted topologically.
struct CompareNodes {
explicit CompareNodes(
- const std::unordered_map<const Node*, int>& topo_ordering)
+ const std::unordered_map<const NodeDef*, int>& topo_ordering)
: topo_order(topo_ordering) {}
- bool operator()(const Node* lhs, const Node* rhs) const {
+ bool operator()(const NodeDef* lhs, const NodeDef* rhs) const {
return topo_order.at(lhs) < topo_order.at(rhs);
}
private:
- const std::unordered_map<const Node*, int>& topo_order;
+ const std::unordered_map<const NodeDef*, int>& topo_order;
};
- std::set<const Node*, CompareNodes> queue_;
+ std::set<const NodeDef*, CompareNodes> queue_;
};
// Merge and relax symbolic shapes.
@@ -396,22 +394,41 @@ class TopoQueue {
class SymbolicShapeRefiner {
public:
explicit SymbolicShapeRefiner(
- const GraphDef& graph,
+ const GraphView& graph,
const std::unordered_map<string, std::unordered_set<int>>& fed_ports)
- : function_library_(OpRegistry::Global(), graph.library()),
+ : graph_(graph),
+ function_library_(OpRegistry::Global(), graph.GetGraph()->library()),
fed_ports_(fed_ports) {
- graph_def_version_ = graph.versions().producer();
- node_to_context_.reserve(graph.node_size());
+ graph_def_version_ = graph.GetGraph()->versions().producer();
+ node_to_context_.reserve(graph.GetGraph()->node_size());
+ }
+
+ const GraphView& graph() const { return graph_; }
+
+ struct NodeContext {
+ const OpRegistrationData* op_data;
+ DataTypeVector input_types;
+ DataTypeVector output_types;
+ std::unique_ptr<InferenceContext> inference_context;
+ std::vector<ShapeHandle> output_tensors_as_shapes;
+ };
+
+ NodeContext* GetNodeContext(const NodeDef* node) {
+ auto it = node_to_context_.find(node);
+ if (it == node_to_context_.end()) {
+ return nullptr;
+ }
+ return &it->second;
}
- InferenceContext* GetContext(const Node* node) {
+ InferenceContext* GetContext(const NodeDef* node) {
auto it = node_to_context_.find(node);
if (it == node_to_context_.end()) {
return nullptr;
}
return it->second.inference_context.get();
}
- Status UpdateNode(const Node* node, bool relax, bool* refined) {
+ Status UpdateNode(const NodeDef* node, bool relax, bool* refined) {
NodeContext* node_context = GetNodeContext(node);
if (node_context == nullptr) {
TF_RETURN_IF_ERROR(AddNode(node));
@@ -421,82 +438,84 @@ class SymbolicShapeRefiner {
// Check if the shapes of the nodes in the fan-in of this node have changed,
// and if they have, update the node input shapes.
InferenceContext* inference_context = node_context->inference_context.get();
- std::vector<Tensor> const_values(node->num_inputs());
- std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
- std::vector<ShapeHandle> input_tensors_as_shapes(node->num_inputs());
-
- for (const Edge* e : node->in_edges()) {
- if (e->IsControlEdge()) continue;
-
- int dst_input = e->dst_input();
- int src_output = e->src_output();
-
- Node* input = e->src();
- NodeContext* c = GetNodeContext(input);
- if (c == nullptr) {
- return errors::FailedPrecondition(
- "Input ", dst_input, " ('", input->name(), "') for '", node->name(),
- "' was not previously added to ShapeRefiner.");
- }
+ std::vector<Tensor> const_values(inference_context->num_inputs());
+ std::vector<const Tensor*> input_tensors(inference_context->num_inputs(),
+ nullptr);
+ std::vector<ShapeHandle> input_tensors_as_shapes(
+ inference_context->num_inputs());
+
+ for (int dst_input = 0; dst_input < inference_context->num_inputs();
+ ++dst_input) {
+ GraphView::InputPort port(node, dst_input);
+ for (const GraphView::OutputPort fanin : graph_.GetFanin(port)) {
+ int src_output = fanin.port_id;
+ const NodeDef* input = fanin.node;
+ NodeContext* c = GetNodeContext(input);
+ if (c == nullptr) {
+ return errors::FailedPrecondition(
+ "Input ", dst_input, " ('", input->name(), "') for '",
+ node->name(), "' was not previously added to ShapeRefiner.");
+ }
- if (input->IsConstant()) {
- // Convert constant value into tensors.
- if (const_values[dst_input].FromProto(
- input->def().attr().at("value").tensor())) {
- input_tensors[dst_input] = &const_values[dst_input];
- // Integer tensors of rank one can also be interpreted as a shape
- // provided all their values are >= -1.
- if (const_values[dst_input].dims() == 1 &&
- (const_values[dst_input].dtype() == DT_INT32 ||
- const_values[dst_input].dtype() == DT_INT64)) {
- ShapeHandle tensor_shape = inference_context->Vector(
- const_values[dst_input].NumElements());
- ShapeHandle shp;
- if (inference_context
- ->MakeShapeFromTensor(input_tensors[dst_input],
- tensor_shape, &shp)
- .ok()) {
- input_tensors_as_shapes[dst_input] = shp;
+ if (IsConstant(*input)) {
+ // Convert constant value into tensors.
+ if (const_values[dst_input].FromProto(
+ input->attr().at("value").tensor())) {
+ input_tensors[dst_input] = &const_values[dst_input];
+ // Integer tensors of rank one can also be interpreted as a shape
+ // provided all their values are >= -1.
+ if (const_values[dst_input].dims() == 1 &&
+ (const_values[dst_input].dtype() == DT_INT32 ||
+ const_values[dst_input].dtype() == DT_INT64)) {
+ ShapeHandle tensor_shape = inference_context->Vector(
+ const_values[dst_input].NumElements());
+ ShapeHandle shp;
+ if (inference_context
+ ->MakeShapeFromTensor(input_tensors[dst_input],
+ tensor_shape, &shp)
+ .ok()) {
+ input_tensors_as_shapes[dst_input] = shp;
+ }
}
}
}
- }
- if (c->output_tensors_as_shapes.size() > src_output) {
- input_tensors_as_shapes[dst_input] =
- c->output_tensors_as_shapes[src_output];
- }
-
- DCHECK_GE(dst_input, 0);
- if (!*refined && !inference_context->input(dst_input).SameHandle(
- c->inference_context->output(src_output))) {
- *refined = true;
- }
- inference_context->SetInput(dst_input,
- c->inference_context->output(src_output));
-
- if (!*refined &&
- inference_context->requested_input_tensor_as_partial_shape(
- dst_input)) {
- // The input value may have changed. Since we have no way to know if
- // that's indeed the case, err on the safe side.
- *refined = true;
- }
-
- // Also propagate handle shape and dtype of edges which are carrying
- // resource handles.
- if (e->src()->output_type(src_output) == DT_RESOURCE) {
- auto* outputs =
- c->inference_context->output_handle_shapes_and_types(src_output);
- if (!outputs) continue;
- auto* inputs =
- inference_context->input_handle_shapes_and_types(dst_input);
+ if (c->output_tensors_as_shapes.size() > src_output) {
+ input_tensors_as_shapes[dst_input] =
+ c->output_tensors_as_shapes[src_output];
+ }
- if (!inputs || !EquivalentShapesAndTypes(*outputs, *inputs)) {
+ DCHECK_GE(dst_input, 0);
+ if (!*refined && !inference_context->input(dst_input).SameHandle(
+ c->inference_context->output(src_output))) {
+ *refined = true;
+ }
+ inference_context->SetInput(dst_input,
+ c->inference_context->output(src_output));
+
+ if (!*refined &&
+ inference_context->requested_input_tensor_as_partial_shape(
+ dst_input)) {
+ // The input value may have changed. Since we have no way to know if
+ // that's indeed the case, err on the safe side.
*refined = true;
}
- inference_context->set_input_handle_shapes_and_types(dst_input,
- *outputs);
+
+ // Also propagate handle shape and dtype of edges which are carrying
+ // resource handles.
+ if (node_context->input_types[dst_input] == DT_RESOURCE) {
+ auto* outputs =
+ c->inference_context->output_handle_shapes_and_types(src_output);
+ if (!outputs) continue;
+ auto* inputs =
+ inference_context->input_handle_shapes_and_types(dst_input);
+
+ if (!inputs || !EquivalentShapesAndTypes(*outputs, *inputs)) {
+ *refined = true;
+ }
+ inference_context->set_input_handle_shapes_and_types(dst_input,
+ *outputs);
+ }
}
}
@@ -510,10 +529,10 @@ class SymbolicShapeRefiner {
input_tensors_as_shapes);
// Update the shapes of the outputs.
- return InferShapes(node, node_context);
+ return InferShapes(*node, node_context);
}
- Status SetUnknownShape(const Node* node, int output_port) {
+ Status SetUnknownShape(const NodeDef* node, int output_port) {
shape_inference::ShapeHandle shape =
GetUnknownOutputShape(node, output_port);
InferenceContext* ctx = GetContext(node);
@@ -525,7 +544,7 @@ class SymbolicShapeRefiner {
}
struct ShapeId {
- const Node* node;
+ const NodeDef* node;
int port_id;
bool operator==(const ShapeId& other) const {
return node == other.node && port_id == other.port_id;
@@ -533,12 +552,12 @@ class SymbolicShapeRefiner {
};
struct HashShapeId {
std::size_t operator()(const ShapeId& shp) const {
- return std::hash<const Node*>{}(shp.node) + shp.port_id;
+ return std::hash<const NodeDef*>{}(shp.node) + shp.port_id;
}
};
struct DimId {
- const Node* node;
+ const NodeDef* node;
int port_id;
int dim_index;
bool operator==(const DimId& other) const {
@@ -549,13 +568,14 @@ class SymbolicShapeRefiner {
struct HashDimId {
std::size_t operator()(const DimId& dim) const {
- return std::hash<const Node*>{}(dim.node) + dim.port_id + dim.dim_index;
+ return std::hash<const NodeDef*>{}(dim.node) + dim.port_id +
+ dim.dim_index;
}
};
// Compute the shape of the tensors outputed by node 'node' at output port
// 'port_index' as the intersection of shape1 and shape2.
- ShapeHandle OutputAsIntersection(const Node* node, int port_index,
+ ShapeHandle OutputAsIntersection(const NodeDef* node, int port_index,
ShapeHandle shape1, ShapeHandle shape2) {
if (shape1.SameHandle(shape2)) {
return shape1;
@@ -600,7 +620,7 @@ class SymbolicShapeRefiner {
// Compute the shape of the tensors outputed by node 'node' at output port
// 'port_index' as the union of shape1 and shape2.
- ShapeHandle OutputAsUnion(const Node* node, int port_index,
+ ShapeHandle OutputAsUnion(const NodeDef* node, int port_index,
ShapeHandle shape1, ShapeHandle shape2) {
if (shape1.SameHandle(shape2)) {
return shape1;
@@ -670,20 +690,24 @@ class SymbolicShapeRefiner {
return true;
}
- Status AddNode(const Node* node) {
+ Status AddNode(const NodeDef* node) {
+ NodeContext& node_ctx = node_to_context_[node];
+ TF_RETURN_IF_ERROR(function_library_.LookUp(node->op(), &node_ctx.op_data));
+
+ TF_RETURN_IF_ERROR(InOutTypesForNode(*node, node_ctx.op_data->op_def,
+ &node_ctx.input_types,
+ &node_ctx.output_types));
+
// Create the inference context for this node.
- std::vector<ShapeHandle> input_shapes(node->num_inputs());
+ const int num_inputs = node_ctx.input_types.size();
+ std::vector<ShapeHandle> input_shapes(num_inputs);
std::vector<std::unique_ptr<std::vector<ShapeAndType>>>
- input_handle_shapes_and_types(node->num_inputs());
- std::vector<const Tensor*> input_tensors(node->num_inputs(), nullptr);
+ input_handle_shapes_and_types(num_inputs);
+ std::vector<const Tensor*> input_tensors(num_inputs, nullptr);
std::vector<ShapeHandle> input_tensors_as_shapes;
- NodeContext& node_ctx = node_to_context_[node];
- TF_RETURN_IF_ERROR(
- function_library_.LookUp(node->type_string(), &node_ctx.op_data));
-
node_ctx.inference_context.reset(new InferenceContext(
- graph_def_version_, &node->def(), node->op_def(), input_shapes,
+ graph_def_version_, node, node_ctx.op_data->op_def, input_shapes,
input_tensors, input_tensors_as_shapes,
std::move(input_handle_shapes_and_types)));
const Status s = node_ctx.inference_context->construction_status();
@@ -696,7 +720,7 @@ class SymbolicShapeRefiner {
private:
// Return the one ShapeHandle used to denote a fully unknown shape for a node
// output.
- ShapeHandle GetUnknownOutputShape(const Node* node, int index) {
+ ShapeHandle GetUnknownOutputShape(const NodeDef* node, int index) {
ShapeId id{node, index};
auto it = unknown_shapes_.find(id);
if (it != unknown_shapes_.end()) {
@@ -709,7 +733,8 @@ class SymbolicShapeRefiner {
}
// Return the one ShapeHandle used to denote a fully unknown dimension for a
// node output.
- DimensionHandle GetUnknownOutputDim(const Node* node, int index, int dim_id) {
+ DimensionHandle GetUnknownOutputDim(const NodeDef* node, int index,
+ int dim_id) {
DimId id{node, index, dim_id};
auto it = unknown_dims_.find(id);
if (it != unknown_dims_.end()) {
@@ -721,31 +746,25 @@ class SymbolicShapeRefiner {
return dim;
}
- struct NodeContext {
- const OpRegistrationData* op_data;
- std::unique_ptr<InferenceContext> inference_context;
- std::vector<ShapeHandle> output_tensors_as_shapes;
- };
-
- Status InferShapes(const Node* node, NodeContext* c) {
+ Status InferShapes(const NodeDef& node, NodeContext* c) {
InferenceContext* ic = c->inference_context.get();
- auto it = fed_ports_.find(node->name());
+ auto it = fed_ports_.find(node.name());
const bool is_fed = it != fed_ports_.end();
// Propagate shape tensors unless the node is fed.
// TODO(bsteiner) We should still propagate the shapes to the ports that
// aren't fed in the case of a ShapeN node.
if (!is_fed) {
- if (node->type_string() == "Shape") {
+ if (IsShape(node)) {
c->output_tensors_as_shapes.resize(1);
c->output_tensors_as_shapes[0] = c->inference_context->input(0);
- } else if (node->type_string() == "ShapeN") {
+ } else if (IsShapeN(node)) {
c->output_tensors_as_shapes.resize(c->inference_context->num_inputs());
for (int i = 0; i < c->inference_context->num_inputs(); ++i) {
c->output_tensors_as_shapes[i] = c->inference_context->input(i);
}
- } else if (node->type_string() == "ConcatV2") {
+ } else if (node.op() == "ConcatV2") {
bool valid = true;
ShapeHandle result;
for (int i = 0; i < ic->num_inputs() - 1; ++i) {
@@ -763,7 +782,7 @@ class SymbolicShapeRefiner {
c->output_tensors_as_shapes.resize(1);
c->output_tensors_as_shapes[0] = result;
}
- } else if (node->type_string() == "Slice") {
+ } else if (IsSlice(node)) {
ShapeHandle input = ic->input_tensors_as_shapes()[0];
bool valid = ic->RankKnown(input);
const Tensor* slice_offset = ic->input_tensor(1);
@@ -800,22 +819,16 @@ class SymbolicShapeRefiner {
// It is possible to feed node output ports with tensors of any shape: as
// a result, the shape of a fed port is completely unknown.
for (const int output_port : it->second) {
- status.Update(SetUnknownShape(node, output_port));
+ status.Update(SetUnknownShape(&node, output_port));
}
}
return status;
}
- NodeContext* GetNodeContext(const Node* node) {
- auto it = node_to_context_.find(node);
- if (it == node_to_context_.end()) {
- return nullptr;
- }
- return &it->second;
- }
-
+ private:
+ const GraphView& graph_;
int graph_def_version_;
- std::unordered_map<const Node*, NodeContext> node_to_context_;
+ std::unordered_map<const NodeDef*, NodeContext> node_to_context_;
std::unordered_map<ShapeId, ShapeHandle, HashShapeId> unknown_shapes_;
std::unordered_map<DimId, DimensionHandle, HashDimId> unknown_dims_;
FunctionLibraryDefinition function_library_;
@@ -827,7 +840,7 @@ class SymbolicShapeRefiner {
// dims, and consolidate the information globally.
class SymbolicShapeManager {
public:
- SymbolicShapeManager() : shapes_(shape_processor_), dims_(dim_processor_) {}
+ SymbolicShapeManager() {}
Status Merge(ShapeHandle s1, ShapeHandle s2) {
if (!s1.IsSet() || !s2.IsSet()) {
@@ -867,14 +880,12 @@ class SymbolicShapeManager {
}
private:
- Processor<ShapeHandle> shape_processor_;
DisjointSet<shape_inference::ShapeHandle> shapes_;
- Processor<DimensionHandle> dim_processor_;
DisjointSet<shape_inference::DimensionHandle> dims_;
};
Status GraphProperties::MergeEnqueueShapesAndTypes(
- SymbolicShapeRefiner* shape_refiner, const Node* qnode,
+ SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
const std::vector<ShapeAndType>& shapes_and_types,
std::vector<ShapeAndType>* queue_shapes_and_types) {
if (shapes_and_types.size() != queue_shapes_and_types->size()) {
@@ -897,7 +908,7 @@ Status GraphProperties::MergeEnqueueShapesAndTypes(
}
Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
- SymbolicShapeRefiner* shape_refiner, const Node* qnode,
+ SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
const std::vector<ShapeAndType>& shapes_and_types,
std::vector<ShapeAndType>* queue_shapes_and_types) {
if (shapes_and_types.size() != queue_shapes_and_types->size()) {
@@ -925,7 +936,7 @@ Status GraphProperties::RelaxEnqueueShapesAndMergeTypes(
// inputs are UnknownShapes. So we need to ignore the input from NextIteration
// nodes to propagate any known shape from the Merge node.
Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
- const Node* node, bool relax,
+ const NodeDef* node, bool relax,
bool* new_shapes) const {
InferenceContext* c = shape_refiner->GetContext(node);
if (!c) {
@@ -942,25 +953,24 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
ShapeHandle out;
bool out_initialized = false;
- for (const Edge* e : node->in_edges()) {
- if (e->IsControlEdge()) {
- continue;
- }
+ for (const GraphView::Edge fanin :
+ shape_refiner->graph().GetFaninEdges(*node, false)) {
// Skip back edges during the initial propagation phase. This is equivalent
// to assuming that all the inputs to the merge nodes are fed by the same
// shape, and will be corrected as needed in the relaxation phase.
- if (!relax && e->src()->IsNextIteration()) {
+ if (!relax && IsNextIteration(*fanin.src.node)) {
continue;
}
- InferenceContext* in = shape_refiner->GetContext(e->src());
+ InferenceContext* in = shape_refiner->GetContext(fanin.src.node);
if (!relax && !in) {
// Handling a loop for the first time, the back edge won't have any shape
// info.
continue;
}
- ShapeHandle input = in->output(e->src_output());
- c->SetInput(e->dst_input(), input);
+ ShapeHandle input = in->output(fanin.src.port_id);
+ CHECK_EQ(fanin.tgt.node, node);
+ c->SetInput(fanin.tgt.port_id, input);
if (!out_initialized) {
out_initialized = true;
out = input;
@@ -984,7 +994,7 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
// Manually propagate the input shape for Enter nodes and update any Merge node
// outputs.
Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
- const Node* node, bool relax,
+ const NodeDef* node, bool relax,
bool* new_shapes) {
auto enter_ctx = shape_refiner->GetContext(node);
if (!enter_ctx) {
@@ -992,33 +1002,27 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
enter_ctx = shape_refiner->GetContext(node);
}
- for (const Edge* e : node->in_edges()) {
- if (e->IsControlEdge()) {
- continue;
- }
- InferenceContext* in = shape_refiner->GetContext(e->src());
- ShapeHandle input = in->output(e->src_output());
- if (!enter_ctx->output(0).SameHandle(input)) {
- if (relax) {
- enter_ctx->RelaxInput(0, input);
- } else {
- enter_ctx->MergeInput(0, input);
- }
- enter_ctx->set_output(0, input);
- *new_shapes = true;
- }
+ GraphView::InputPort inp(node, 0);
+ GraphView::OutputPort fanin = shape_refiner->graph().GetRegularFanin(inp);
+
+ InferenceContext* in = shape_refiner->GetContext(fanin.node);
+ ShapeHandle input = in->output(fanin.port_id);
+ if (!enter_ctx->output(0).SameHandle(input)) {
+ enter_ctx->SetInput(0, input);
+ enter_ctx->set_output(0, input);
+ *new_shapes = true;
}
return Status::OK();
}
-Status GraphProperties::UpdateShapes(
- SymbolicShapeRefiner* shape_refiner, bool relax,
- const Node* n, bool* new_shapes) const {
- if (n->IsEnter()) {
+Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner,
+ bool relax, const NodeDef* n,
+ bool* new_shapes) const {
+ if (IsEnter(*n)) {
// The Enter shape function always forwards an UnknownShape, so do the right
// thing here.
TF_RETURN_IF_ERROR(UpdateEnter(shape_refiner, n, relax, new_shapes));
- } else if (n->IsMerge()) {
+ } else if (IsMerge(*n)) {
// Properly handle merge nodes.
TF_RETURN_IF_ERROR(UpdateMergeNode(shape_refiner, n, relax, new_shapes));
} else {
@@ -1028,7 +1032,7 @@ Status GraphProperties::UpdateShapes(
if (updated) {
// We want to avoid propagating through loops on the merge pass because
// the shapes are not guaranteed to converge.
- if (relax || !n->IsNextIteration()) {
+ if (relax || !IsNextIteration(*n)) {
*new_shapes = true;
}
}
@@ -1039,8 +1043,8 @@ Status GraphProperties::UpdateShapes(
// Propagates the shapes in the transitive fan-out of <new_shapes>.
Status GraphProperties::PropagateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
- const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
- resources,
+ const std::unordered_map<const NodeDef*,
+ std::unordered_set<const NodeDef*>>& resources,
int num_loops) const {
// Limit the number of iterations to prevent infinite loops in the presence of
// incorrect shape functions. The algoritm should converge in at most
@@ -1062,15 +1066,13 @@ Status GraphProperties::PropagateShapes(
int64 num_loop_iterations = 0;
while (!new_shapes->empty() &&
num_loop_iterations++ < max_loop_iterations) {
- const Node* n = new_shapes->pop();
+ const NodeDef* n = new_shapes->pop();
bool updated = false;
TF_RETURN_IF_ERROR(UpdateShapes(shape_refiner, relax, n, &updated));
if (updated) {
- for (const Edge* e : n->out_edges()) {
- if (!e->IsControlEdge()) {
- const Node* fanout = e->dst();
- new_shapes->push(fanout);
- }
+ for (const GraphView::InputPort fanout :
+ shape_refiner->graph().GetFanouts(*n, false)) {
+ new_shapes->push(fanout.node);
}
}
}
@@ -1080,7 +1082,7 @@ Status GraphProperties::PropagateShapes(
// fanout of the queues, we need to manually propagate the shapes from
// enqueue node to the corresponding queue.
TF_RETURN_IF_ERROR(UpdateResource(resource.first, resource.second,
- shape_refiner, relax, new_shapes));
+ shape_refiner, new_shapes));
}
} while (!new_shapes->empty() &&
num_resource_iterations++ < max_resource_iterations);
@@ -1093,10 +1095,11 @@ Status GraphProperties::PropagateShapes(
}
Status GraphProperties::UpdateResource(
- const Node* qnode, const std::unordered_set<const Node*>& queue_inputs,
- SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes) {
+ const NodeDef* qnode,
+ const std::unordered_set<const NodeDef*>& queue_inputs,
+ SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes) {
// Proceed only if qnode is a queue or an Enter with queue input.
- if (!IsQueue(*qnode) && !IsEnterWithQueue(*qnode)) {
+ if (!IsQueue(*qnode) && !IsEnterWithQueue(*qnode, shape_refiner->graph())) {
return Status::OK();
}
auto qctx = shape_refiner->GetContext(qnode);
@@ -1108,31 +1111,24 @@ Status GraphProperties::UpdateResource(
// Merge all inputs into the enqueue node, regardless of which phase we
// are in.
std::vector<ShapeAndType> queue_shapes_and_types;
- if (queue_handle_data) {
- queue_shapes_and_types = *queue_handle_data;
- }
for (const auto& node : queue_inputs) {
- auto ctx = shape_refiner->GetContext(node);
+ auto ctx = shape_refiner->GetNodeContext(node);
if (!ctx) {
continue;
}
// TODO(bsteiner): handle EnqueueMany as well.
- if (node->type_string().find("Enqueue") != std::string::npos &&
- node->type_string().find("EnqueueMany") == std::string::npos) {
+ if (node->op().find("Enqueue") != std::string::npos &&
+ node->op().find("EnqueueMany") == std::string::npos) {
std::vector<ShapeAndType> shapes_and_types;
- for (int i = 1; i < ctx->num_inputs(); ++i) {
- shapes_and_types.push_back({ctx->input(i), node->input_type(i)});
+ for (int i = 1; i < ctx->input_types.size(); ++i) {
+ shapes_and_types.push_back(
+ {ctx->inference_context->input(i), ctx->input_types[i]});
}
if (queue_shapes_and_types.empty()) {
queue_shapes_and_types = shapes_and_types;
} else {
- if (relax) {
- TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
- shape_refiner, qnode, shapes_and_types, &queue_shapes_and_types));
- } else {
- TF_RETURN_IF_ERROR(MergeEnqueueShapesAndTypes(
- shape_refiner, qnode, shapes_and_types, &queue_shapes_and_types));
- }
+ TF_RETURN_IF_ERROR(RelaxEnqueueShapesAndMergeTypes(
+ shape_refiner, qnode, shapes_and_types, &queue_shapes_and_types));
}
}
}
@@ -1142,11 +1138,9 @@ Status GraphProperties::UpdateResource(
queue_shapes_and_types)) {
qctx->set_output_handle_shapes_and_types(0, queue_shapes_and_types);
- for (const Edge* e : qnode->out_edges()) {
- if (!e->IsControlEdge()) {
- const Node* fanout = e->dst();
- new_shapes->push(fanout);
- }
+ for (const GraphView::InputPort fanout :
+ shape_refiner->graph().GetFanouts(*qnode, false)) {
+ new_shapes->push(fanout.node);
}
}
@@ -1156,18 +1150,6 @@ Status GraphProperties::UpdateResource(
Status GraphProperties::InferStatically(bool assume_valid_feeds) {
FunctionLibraryDefinition function_library(OpRegistry::Global(),
item_.graph.library());
- Graph graph(function_library);
- graph_ = &graph;
- ImportGraphDefOptions options;
- // Graph optimization happens at the late stage of graph execution,
- // when colocation constraints are already validated previously and
- // the device placement of nodes has also completed, so there
- // is no need to validate colocation constraints again.
- options.validate_colocation_constraints = false;
- options.validate_shape = false;
- Status s = ImportGraphDef(options, item_.graph, &graph, nullptr);
- TF_RETURN_IF_ERROR(s);
-
std::unordered_map<string, std::unordered_set<int>> fed_ports;
if (!assume_valid_feeds) {
for (const auto& feed : item_.feed) {
@@ -1180,46 +1162,45 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
std::unordered_map<const NodeDef*, int> topo_order;
TF_RETURN_IF_ERROR(ComputeTopologicalOrder(item_.graph, &topo_order));
- std::unordered_map<string, int> order_by_name;
- for (const auto topo : topo_order) {
- order_by_name[topo.first->name()] = topo.second;
- }
+ GraphView graph_view(&item_.graph);
- // List the resources and the nodes using them. Also collect the Enter and
- // Merge nodes.
- std::unordered_map<const Node*, int> graph_topo_order;
- std::unordered_map<const Node*, std::unordered_set<const Node*>> resources;
- std::unordered_set<const Node*> merge_nodes;
- std::unordered_set<const Node*> fed_nodes;
- std::unordered_set<const Node*> primary_inputs;
+ // List the resources and the nodes using them. Also collect the Merge nodes,
+ // fed nodes, and primary inputs.
+ std::unordered_map<const NodeDef*, std::unordered_set<const NodeDef*>>
+ resources;
+ std::unordered_set<const NodeDef*> merge_nodes;
+ std::unordered_set<const NodeDef*> fed_nodes;
+ std::unordered_set<const NodeDef*> primary_inputs;
int num_loops = 0;
- for (const Node* const node : graph.nodes()) {
- auto it = order_by_name.find(node->name());
- if (it == order_by_name.end()) {
- continue;
- }
- graph_topo_order[node] = it->second;
-
- for (int i = 0; i < node->num_inputs(); ++i) {
- if (node->input_type(i) == DataType::DT_RESOURCE) {
- const Node* resource;
- TF_CHECK_OK(node->input_node(i, &resource));
- resources[resource].insert(node);
- }
- }
- if (node->num_inputs() == 0) {
- primary_inputs.insert(node);
- } else if (node->IsMerge()) {
- merge_nodes.insert(node);
- } else if (node->IsNextIteration()) {
+ for (const NodeDef& node : item_.graph.node()) {
+ if (NumNonControlInputs(node) == 0) {
+ primary_inputs.insert(&node);
+ } else if (IsMerge(node)) {
+ merge_nodes.insert(&node);
+ } else if (IsNextIteration(node)) {
++num_loops;
+ } else {
+ const OpRegistrationData* op_data;
+ TF_RETURN_IF_ERROR(function_library.LookUp(node.op(), &op_data));
+ DataTypeVector input_types;
+ DataTypeVector output_types;
+ TF_RETURN_IF_ERROR(InOutTypesForNode(node, op_data->op_def, &input_types,
+ &output_types));
+ for (int i = 0; i < input_types.size(); ++i) {
+ if (input_types[i] == DataType::DT_RESOURCE) {
+ GraphView::InputPort input(&node, i);
+ const GraphView::OutputPort resource =
+ graph_view.GetRegularFanin(input);
+ resources[resource.node].insert(&node);
+ }
+ }
}
- if (fed_ports.find(node->name()) != fed_ports.end()) {
- fed_nodes.insert(node);
+ if (fed_ports.find(node.name()) != fed_ports.end()) {
+ fed_nodes.insert(&node);
}
}
- SymbolicShapeRefiner refiner(item_.graph, fed_ports);
+ SymbolicShapeRefiner refiner(graph_view, fed_ports);
// We propagate shapes through the graph in two phases. In the first phase, we
// exclusively merge shapes but we do not propagate shapes through the
@@ -1227,19 +1208,19 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// we exclusively relax shapes and propagate shapes through loops until
// reaching fixed point.
for (int relax = 0; relax < 2; relax++) {
- TopoQueue new_shapes(graph_topo_order);
+ TopoQueue new_shapes(topo_order);
// Seed the propagation of shapes through merge nodes.
if (relax) {
- for (const Node* node : merge_nodes) {
+ for (const NodeDef* node : merge_nodes) {
new_shapes.push(node);
}
}
// Also seed the propagation of shapes in the fanout of primary inputs.
- for (const Node* node : primary_inputs) {
+ for (const NodeDef* node : primary_inputs) {
new_shapes.push(node);
}
// Also seed the propagation of shapes in the fanout of fed nodes.
- for (const Node* node : fed_nodes) {
+ for (const NodeDef* node : fed_nodes) {
new_shapes.push(node);
}
// Propagate shapes normally.
@@ -1250,14 +1231,14 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
// Track shapes globally across the graph.
SymbolicShapeManager shape_manager;
bool found_error = false;
- for (const Node* const node : graph.nodes()) {
- auto node_ctx = refiner.GetContext(node);
+ for (const NodeDef& node : item_.graph.node()) {
+ auto node_ctx = refiner.GetContext(&node);
if (!node_ctx) {
continue;
}
// Skip any information that comes from fed nodes.
- if (fed_ports.find(node->name()) != fed_ports.end()) {
- VLOG(2) << "Skipping feed node shape: " << node->name();
+ if (fed_ports.find(node.name()) != fed_ports.end()) {
+ VLOG(2) << "Skipping feed node shape: " << node.name();
continue;
}
for (const auto& merged_shapes : node_ctx->MergedShapes()) {
@@ -1281,61 +1262,56 @@ Status GraphProperties::InferStatically(bool assume_valid_feeds) {
}
}
- for (const Node* const node : graph.nodes()) {
- VLOG(3) << "Filling in graph properties for node: " << node->name();
- auto ctx = refiner.GetContext(node);
+ for (const NodeDef& node : item_.graph.node()) {
+ VLOG(3) << "Filling in graph properties for node: " << node.name();
+ auto ctx = refiner.GetNodeContext(&node);
if (!ctx) {
continue;
}
// Fill input properties.
{
- CHECK_EQ(ctx->num_inputs(), node->num_inputs());
- auto& input_properties = input_properties_[node->name()];
+ // CHECK_EQ(ctx->num_inputs(), node.num_inputs());
+ auto& input_properties = input_properties_[node.name()];
// Should always be empty, node names in graph are supposed to be unique.
CHECK_EQ(input_properties.size(), 0);
- input_properties.resize(ctx->num_inputs());
- for (int i = 0; i < ctx->num_inputs(); ++i) {
- shape_manager.AsTensorProperties(ctx->input(i), node->input_type(i),
+ input_properties.resize(ctx->inference_context->num_inputs());
+ GraphView::InputPort input(&node, -1);
+ for (int i = 0; i < ctx->inference_context->num_inputs(); ++i) {
+ shape_manager.AsTensorProperties(ctx->inference_context->input(i),
+ ctx->input_types[i],
&input_properties[i]);
- }
- for (const auto& edge : node->in_edges()) {
- if (edge->IsControlEdge()) {
- continue;
- }
- if (!edge->src()->IsConstant()) {
- continue;
- }
- const int input_id = edge->dst_input();
- if (input_id >= input_properties.size()) {
+ input.port_id = i;
+ GraphView::OutputPort fanin = graph_view.GetRegularFanin(input);
+ if (!IsConstant(*fanin.node)) {
continue;
}
- const NodeDef& node = edge->src()->def();
- const TensorProto& raw_val = node.attr().at("value").tensor();
- *input_properties[input_id].mutable_value() = raw_val;
+ const TensorProto& raw_val = fanin.node->attr().at("value").tensor();
+ *input_properties[i].mutable_value() = raw_val;
}
}
// Fill output properties.
{
- CHECK_EQ(ctx->num_outputs(), node->num_outputs());
- auto& output_properties = output_properties_[node->name()];
+ // CHECK_EQ(ctx->num_outputs(), node->num_outputs());
+ auto& output_properties = output_properties_[node.name()];
// Should always be empty, node names in graph are supposed to be unique.
CHECK_EQ(output_properties.size(), 0);
- output_properties.resize(ctx->num_outputs());
- for (int i = 0; i < ctx->num_outputs(); ++i) {
- shape_manager.AsTensorProperties(ctx->output(i), node->output_type(i),
+ output_properties.resize(ctx->inference_context->num_outputs());
+ for (int i = 0; i < ctx->inference_context->num_outputs(); ++i) {
+ shape_manager.AsTensorProperties(ctx->inference_context->output(i),
+ ctx->output_types[i],
&output_properties[i]);
}
}
}
// Help trace the unknown dimensions to their origins.
- VerboseLogUnknownDimensionSources(graph, input_properties_,
+ VerboseLogUnknownDimensionSources(item_.graph, input_properties_,
output_properties_);
return Status::OK();
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index 4c3f3f5f53..485324c466 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/core/grappler/grappler_item.h"
namespace tensorflow {
-class Graph;
namespace grappler {
@@ -79,40 +78,41 @@ class GraphProperties {
// Merges shapes <shapes_and_types>, determined from an EnqueueV2 node, into
// <*queue_shapes_and_types>.
static Status MergeEnqueueShapesAndTypes(
- SymbolicShapeRefiner* shape_refiner, const Node* qnode,
+ SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
const std::vector<shape_inference::ShapeAndType>& shapes_and_types,
std::vector<shape_inference::ShapeAndType>* queue_shapes_and_types);
// Relaxes shapes <shapes_and_types>, determined from an EnqueueV2 node, into
// <*queue_shapes_and_types>.
static Status RelaxEnqueueShapesAndMergeTypes(
- SymbolicShapeRefiner* shape_refiner, const Node* qnode,
+ SymbolicShapeRefiner* shape_refiner, const NodeDef* qnode,
const std::vector<shape_inference::ShapeAndType>& shapes_and_types,
std::vector<shape_inference::ShapeAndType>* queue_shapes_and_types);
// Update the shapes for qnode. If output shapes of qnode have changed,
// enqueue its fanout in 'new_shapes'.
static Status UpdateResource(
- const Node* qnode, const std::unordered_set<const Node*>& queue_inputs,
- SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes);
+ const NodeDef* qnode,
+ const std::unordered_set<const NodeDef*>& queue_inputs,
+ SymbolicShapeRefiner* shape_refiner, TopoQueue* new_shapes);
// Update the output shapes of a Merge node, and enqueue its fanout in
// new_shapes if needed.
- Status UpdateMergeNode(SymbolicShapeRefiner* shape_refiner, const Node* node,
- bool relax, bool* new_shapes) const;
+ Status UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
+ const NodeDef* node, bool relax,
+ bool* new_shapes) const;
// Process the Enter node, and enqueue its fanout in new_shapes if needed.
static Status UpdateEnter(SymbolicShapeRefiner* shape_refiner,
- const Node* node, bool relax, bool* new_shapes);
+ const NodeDef* node, bool relax, bool* new_shapes);
// Update the shapes for node 'n'. If output shapes for n have changed,
// enqueue its fanout in 'new_shapes'.
- Status UpdateShapes(
- SymbolicShapeRefiner* shape_refiner, bool relax,
- const Node* n, bool* new_shapes) const;
+ Status UpdateShapes(SymbolicShapeRefiner* shape_refiner, bool relax,
+ const NodeDef* n, bool* new_shapes) const;
// Propagate the shapes for the nodes enqueued in new_shapes and their
// transitive fanout until a fixed point is reached.
Status PropagateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
- const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
- resources,
+ const std::unordered_map<const NodeDef*,
+ std::unordered_set<const NodeDef*>>& resources,
int num_loops) const;
// Data members
@@ -120,8 +120,6 @@ class GraphProperties {
std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_;
std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_;
const std::vector<OpInfo::TensorProperties> missing_properties_;
-
- Graph* graph_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index 3de697bd37..afe334dfa2 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
@@ -955,6 +956,11 @@ TEST_F(GraphPropertiesTest, Performance) {
string filename = io::JoinPath(testing::TensorFlowSrcRoot(), kTestDataPath,
"large_graph.pbtxt.html");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
+ TF_CHECK_OK(AddDefaultAttrsToGraphDef(
+ &item.graph,
+ FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library()), 0,
+ true));
+
GraphProperties properties(item);
TF_CHECK_OK(properties.InferStatically(false));
}
diff --git a/tensorflow/core/grappler/costs/utils.cc b/tensorflow/core/grappler/costs/utils.cc
index f318e3911c..be54d98534 100644
--- a/tensorflow/core/grappler/costs/utils.cc
+++ b/tensorflow/core/grappler/costs/utils.cc
@@ -44,7 +44,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/bits.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
-#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index 0e5c654acf..7f68272950 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -850,14 +850,16 @@ Costs VirtualScheduler::Summary() const {
VLOG(1) << "Expected max per-op streaming buffers: "
<< graph_costs_.max_per_op_streaming;
- VLOG(1) << "Per-op execution time:";
+ VLOG(1) << "Per-op execution time / compute time / memory time:";
for (const auto& op_cost_pair : op_to_cost_) {
const auto& op = op_cost_pair.first;
const auto& cost = op_cost_pair.second.execution_time.count();
+ const auto& compute_cost = op_cost_pair.second.compute_time.count();
+ const auto& memory_cost = op_cost_pair.second.memory_time.count();
const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
if (cost) { // Skip printing out zero-cost ops.
VLOG(1) << " + " << op << " : " << (is_op_cost_accurate ? "" : "~")
- << cost;
+ << cost << " / " << compute_cost << " / " << memory_cost;
}
}
@@ -898,7 +900,8 @@ Costs VirtualScheduler::Summary() const {
<< ", at the end: "
<< strings::HumanReadableNumBytes(state.memory_usage);
- VLOG(1) << "Per-op execution time (and memory usage at peak memory usage):";
+ VLOG(1) << "Per-op execution time compute time / memory time "
+ "(and memory usage at peak memory usage):";
// Profile non-persistent op memory usage.
for (const auto& node_port : state.mem_usage_snapshot_at_peak) {
@@ -912,6 +915,8 @@ Costs VirtualScheduler::Summary() const {
for (const auto& op_cost_pair : state.op_to_cost) {
const auto& op = op_cost_pair.first;
const auto& cost = op_cost_pair.second.execution_time.count();
+ const auto& compute_cost = op_cost_pair.second.compute_time.count();
+ const auto& memory_cost = op_cost_pair.second.memory_time.count();
total_compute_time_ns += op_cost_pair.second.execution_time;
const bool is_op_cost_accurate = !op_cost_pair.second.inaccurate;
if (!is_op_cost_accurate) {
@@ -930,8 +935,9 @@ Costs VirtualScheduler::Summary() const {
if (cost || mem_usage_percent > 1.0) {
// Print out only non-zero cost ops or ops with > 1% memory usage.
VLOG(1) << " + " << op << " : " << (is_op_cost_accurate ? "" : "~")
- << cost << " (" << strings::HumanReadableNumBytes(op_mem_usage)
- << " [" << mem_usage_percent << "%] "
+ << cost << " / " << compute_cost << " / " << memory_cost << " ("
+ << strings::HumanReadableNumBytes(op_mem_usage) << " ["
+ << mem_usage_percent << "%] "
<< (persisent_ops.count(op) > 0 ? ": persistent op)" : ")");
}
}
diff --git a/tensorflow/core/grappler/devices.cc b/tensorflow/core/grappler/devices.cc
index b318ac22d4..3268697671 100644
--- a/tensorflow/core/grappler/devices.cc
+++ b/tensorflow/core/grappler/devices.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include <memory>
#include "tensorflow/core/grappler/devices.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#if GOOGLE_CUDA
@@ -30,15 +31,14 @@ int GetNumAvailableGPUs() {
int num_eligible_gpus = 0;
#if GOOGLE_CUDA
if (ValidateGPUMachineManager().ok()) {
- perftools::gputools::Platform* gpu_manager = GPUMachineManager();
+ se::Platform* gpu_manager = GPUMachineManager();
if (gpu_manager != nullptr) {
int num_gpus = gpu_manager->VisibleDeviceCount();
for (int i = 0; i < num_gpus; i++) {
auto exec_status = gpu_manager->ExecutorForDevice(i);
if (exec_status.ok()) {
- perftools::gputools::StreamExecutor* se = exec_status.ValueOrDie();
- const perftools::gputools::DeviceDescription& desc =
- se->GetDeviceDescription();
+ se::StreamExecutor* se = exec_status.ValueOrDie();
+ const se::DeviceDescription& desc = se->GetDeviceDescription();
int min_gpu_core_count = 8;
if (desc.core_count() >= min_gpu_core_count) {
num_eligible_gpus++;
@@ -56,10 +56,9 @@ int GetNumAvailableGPUs() {
int64 AvailableGPUMemory(int gpu_id) {
#if GOOGLE_CUDA
// Look up the device, to see its attributes.
- perftools::gputools::Platform* gpu_platform = GPUMachineManager();
+ se::Platform* gpu_platform = GPUMachineManager();
CHECK_LT(gpu_id, gpu_platform->VisibleDeviceCount());
- perftools::gputools::StreamExecutor* se =
- gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie();
+ se::StreamExecutor* se = gpu_platform->ExecutorForDevice(gpu_id).ValueOrDie();
int64 total_memory, available_memory;
CHECK(se->DeviceMemoryUsage(&available_memory, &total_memory));
diff --git a/tensorflow/core/grappler/graph_view.cc b/tensorflow/core/grappler/graph_view.cc
index 0d3f94854b..3e448216f9 100644
--- a/tensorflow/core/grappler/graph_view.cc
+++ b/tensorflow/core/grappler/graph_view.cc
@@ -173,5 +173,54 @@ int GraphView::NumFanins(const NodeDef& node,
return count;
}
+std::unordered_set<GraphView::Edge, GraphView::HashEdge>
+GraphView::GetFanoutEdges(const NodeDef& node,
+ bool include_controlled_edges) const {
+ std::unordered_set<Edge, HashEdge> result;
+ OutputPort port;
+ port.node = const_cast<NodeDef*>(&node);
+ const int first_port_id = include_controlled_edges ? -1 : 0;
+ auto it = num_regular_outputs_.find(&node);
+ const int last_port_id = (it != num_regular_outputs_.end()) ? it->second : -1;
+
+ for (int i = first_port_id; i <= last_port_id; ++i) {
+ port.port_id = i;
+ auto it = fanouts_.find(port);
+ if (it != fanouts_.end()) {
+ Edge fanout;
+ fanout.src.node = const_cast<NodeDef*>(&node);
+ fanout.src.port_id = i;
+ for (auto itr = it->second.begin(); itr != it->second.end(); ++itr) {
+ fanout.tgt = *itr;
+ result.insert(fanout);
+ }
+ }
+ }
+ return result;
+}
+
+std::unordered_set<GraphView::Edge, GraphView::HashEdge>
+GraphView::GetFaninEdges(const NodeDef& node,
+ bool include_controlling_edges) const {
+ std::unordered_set<Edge, HashEdge> result;
+ for (int i = 0; i < node.input_size(); ++i) {
+ Edge fanin;
+ fanin.tgt.node = const_cast<NodeDef*>(&node);
+ fanin.tgt.port_id = i;
+ string fanin_name = ParseNodeName(node.input(i), &fanin.src.port_id);
+ if (fanin.src.port_id < 0) {
+ if (!include_controlling_edges) {
+ break;
+ }
+ }
+ auto it = nodes_.find(fanin_name);
+ if (it != nodes_.end()) {
+ fanin.src.node = it->second;
+ result.insert(fanin);
+ }
+ }
+ return result;
+}
+
} // end namespace grappler
} // end namespace tensorflow
diff --git a/tensorflow/core/grappler/graph_view.h b/tensorflow/core/grappler/graph_view.h
index 173ce9c09c..c3baad0987 100644
--- a/tensorflow/core/grappler/graph_view.h
+++ b/tensorflow/core/grappler/graph_view.h
@@ -29,6 +29,8 @@ namespace grappler {
class GraphView {
public:
struct Port {
+ Port() : node(nullptr), port_id(-1) {}
+ Port(NodeDef* n, int port) : node(n), port_id(port) {}
NodeDef* node = nullptr;
int port_id = -1;
@@ -36,8 +38,16 @@ class GraphView {
return node == other.node && port_id == other.port_id;
}
};
- struct InputPort : public Port {};
- struct OutputPort : public Port {};
+ struct InputPort : public Port {
+ InputPort() = default;
+ InputPort(NodeDef* n, int port_id) : Port(n, port_id) {}
+ InputPort(const NodeDef* n, int port_id)
+ : Port(const_cast<NodeDef*>(n), port_id) {}
+ };
+ struct OutputPort : public Port {
+ OutputPort() = default;
+ OutputPort(NodeDef* n, int port_id) : Port(n, port_id) {}
+ };
struct HashPort {
std::size_t operator()(const Port& port) const {
@@ -45,6 +55,20 @@ class GraphView {
}
};
+ struct Edge {
+ OutputPort src;
+ InputPort tgt;
+
+ bool operator==(const Edge& other) const {
+ return src == other.src && tgt == other.tgt;
+ }
+ };
+ struct HashEdge {
+ std::size_t operator()(const Edge& edge) const {
+ return HashPort()(edge.src) + HashPort()(edge.tgt);
+ }
+ };
+
explicit GraphView(GraphDef* graph);
GraphDef* GetGraph() const { return graph_; }
NodeDef* GetNode(const string& node_name) const;
@@ -63,6 +87,7 @@ class GraphView {
const OutputPort& port) const;
std::unordered_set<OutputPort, HashPort> GetFanin(
const InputPort& port) const;
+
// Special case: regular (i.e. non-control) input ports can only have one
// fanin.
const OutputPort GetRegularFanin(const InputPort& port) const;
@@ -79,6 +104,13 @@ class GraphView {
// controlling nodes iff include_controlling_nodes is true.
int NumFanins(const NodeDef& node, bool include_controlling_nodes) const;
+ // Get all the edge in the immediate fanout (resp fanin) of a node. Include
+ // the control edges iff include_controlling_edges is true.
+ std::unordered_set<Edge, HashEdge> GetFanoutEdges(
+ const NodeDef& node, bool include_controlled_edges) const;
+ std::unordered_set<Edge, HashEdge> GetFaninEdges(
+ const NodeDef& node, bool include_controlling_edges) const;
+
private:
GraphDef* graph_;
std::unordered_map<string, NodeDef*> nodes_;
diff --git a/tensorflow/core/grappler/op_types.cc b/tensorflow/core/grappler/op_types.cc
index 9c45aed62f..f595cf6456 100644
--- a/tensorflow/core/grappler/op_types.cc
+++ b/tensorflow/core/grappler/op_types.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/str_util.h"
+#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace grappler {
@@ -451,43 +452,101 @@ OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
OPDEF_PROPERTY_HELPER(Commutative, commutative)
bool IsInvolution(const NodeDef& node) {
- const std::unordered_set<string> involution_ops{
- "Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"};
- return involution_ops.count(node.op()) > 0;
+ static const std::unordered_set<string>* involution_ops =
+ CHECK_NOTNULL((new std::unordered_set<string>{
+ "Conj", "Reciprocal", "Invert", "Neg", "LogicalNot"}));
+ return involution_ops->count(node.op()) > 0;
}
bool IsValueAndOrderPreserving(const NodeDef& node) {
if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
return true;
}
- const std::unordered_set<string> value_and_order_preserving_ops{
- "CheckNumerics",
- "DebugGradientIdentity",
- "DeepCopy"
- "Enter",
- "Exit",
- "ExpandDims",
- "Identity",
- "IdentityN",
- "PreventGradient",
- "Print",
- "Reshape",
- "Snapshot",
- "Squeeze",
- "StopGradient",
- };
- return value_and_order_preserving_ops.count(node.op()) > 0;
+ static const std::unordered_set<string>* value_and_order_preserving_ops =
+ CHECK_NOTNULL((new const std::unordered_set<string>{
+ "CheckNumerics",
+ "DebugGradientIdentity",
+ "DeepCopy"
+ "Enter",
+ "Exit",
+ "ExpandDims",
+ "Identity",
+ "IdentityN",
+ "PreventGradient",
+ "Print",
+ "Reshape",
+ "Snapshot",
+ "Squeeze",
+ "StopGradient",
+ }));
+ return value_and_order_preserving_ops->count(node.op()) > 0;
}
bool IsValuePreserving(const NodeDef& node) {
- const std::unordered_set<string> value_preserving_ops{
- "InvertPermutation",
- "Reverse",
- "Roll",
- "Transpose",
- };
+ static const std::unordered_set<string>* value_preserving_ops =
+ CHECK_NOTNULL((new std::unordered_set<string>{
+ "InvertPermutation",
+ "Reverse",
+ "Roll",
+ "Transpose",
+ }));
return IsValueAndOrderPreserving(node) ||
- value_preserving_ops.count(node.op()) > 0;
+ value_preserving_ops->count(node.op()) > 0;
+}
+
+bool IsUnaryElementWise(const NodeDef& node) {
+ static const std::unordered_set<string>* element_wise_ops =
+ CHECK_NOTNULL((new std::unordered_set<string>{
+ "Abs",
+ "Acos",
+ "Acosh",
+ "Asin",
+ "Asinh",
+ "Atan",
+ "Atan2",
+ "Atanh",
+ "Ceil",
+ "ComplexAbs",
+ "Conj",
+ "Cos",
+ "Cosh",
+ "Digamma",
+ "Elu"
+ "Erf",
+ "Erfc",
+ "Exp",
+ "Expm1",
+ "Floor",
+ "Inv",
+ "Invert",
+ "Isinf",
+ "Isnan",
+ "Isfinite",
+ "Lgamma",
+ "Log",
+ "Log1p",
+ "LogicalNot",
+ "Neg",
+ "Reciprocal",
+ "Relu",
+ "Relu6",
+ "Rint",
+ "Round",
+ "Selu",
+ "Rsqrt",
+ "Sigmoid",
+ "Sign",
+ "Sin",
+ "SinH",
+ "Softplus",
+ "Softsign",
+ "Sqrt",
+ "Square",
+ "Tan"
+ "Tanh",
+ }));
+ return element_wise_ops->count(node.op()) > 0 ||
+ (!IsIdentityN(node) && IsValueAndOrderPreserving(node));
}
bool HasOpDef(const NodeDef& node) {
diff --git a/tensorflow/core/grappler/op_types.h b/tensorflow/core/grappler/op_types.h
index 79fd05e187..7f5da19d90 100644
--- a/tensorflow/core/grappler/op_types.h
+++ b/tensorflow/core/grappler/op_types.h
@@ -177,6 +177,8 @@ bool IsValueAndOrderPreserving(const NodeDef& node);
// function returns true if the op commutes with all element-wise operations.
bool IsValuePreserving(const NodeDef& node);
+bool IsUnaryElementWise(const NodeDef& node);
+
// Returns true if we can find an opdef corresponding to the op of the node.
bool HasOpDef(const NodeDef& node);
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 42c3580d40..ad2db685fc 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -243,6 +243,7 @@ cc_library(
deps = [
":graph_optimizer",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
],
)
@@ -517,13 +518,11 @@ cc_library(
":loop_optimizer",
":memory_optimizer",
":model_pruner",
- "//tensorflow/core:core_cpu_base",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler/utils:colocation",
- "//tensorflow/core/grappler/utils:functions",
"//tensorflow/core/grappler/utils:topological_sort",
],
)
@@ -540,11 +539,9 @@ tf_cuda_cc_test(
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
- "//tensorflow/core:testlib",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
- "//tensorflow/core/grappler/utils:grappler_test",
],
)
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index ed199c1ac8..c0bd0bda95 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -208,8 +208,7 @@ bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
graph_properties.GetOutputProperties(reshape.name());
const std::vector<OpInfo::TensorProperties>& input_props =
graph_properties.GetOutputProperties(input.name());
- if (reshape_props.empty() || input_props.empty() ||
- input_props.size() <= output_pos) {
+ if (reshape_props.empty() || input_props.size() <= output_pos) {
return false;
}
@@ -1340,6 +1339,182 @@ class RemoveNegationStage : public ArithmeticOptimizerStage {
}
};
+// This optimization hoists the common prefix of unary ops of the inputs to
+// concat out of the concat.
+// For example: Concat([Exp(Sin(x)), Exp(Sin(y)), Exp(Sin(z))]) ->
+// Exp(Sin(Concat([x, y, z]))).
+// TODO(rmlarsen): Support casting. We would have to change the type attribute
+// on the concat node.
+class HoistCWiseUnaryFromConcatStage : public ArithmeticOptimizerStage {
+ public:
+ explicit HoistCWiseUnaryFromConcatStage(
+ const GraphOptimizerContext& ctx,
+ const ArithmeticOptimizerContext& ctx_ext)
+ : ArithmeticOptimizerStage("", ctx, ctx_ext) {}
+
+ ~HoistCWiseUnaryFromConcatStage() override = default;
+
+ bool IsSupported(const NodeDef* node) const override {
+ if (!IsConcat(*node)) return false;
+ const int n = node->attr().at("N").i();
+ return n > 1;
+ }
+
+ Status TrySimplify(NodeDef* concat_node,
+ string* simplified_node_name) override {
+ int prefix_length;
+ std::set<string> ctrl_inputs;
+ TF_RETURN_IF_ERROR(
+ FindCommonUnaryOpPrefix(*concat_node, &prefix_length, &ctrl_inputs));
+ if (prefix_length > 0) {
+ TF_RETURN_IF_ERROR(
+ HoistUnaryOpPrefix(prefix_length, &ctrl_inputs, concat_node));
+ AddToOptimizationQueue(concat_node);
+ }
+ return Status::OK();
+ }
+
+ private:
+ void RemoveControlInputs(std::set<string>* removed_ctrl_inputs,
+ NodeDef* node) const {
+ const int num_inputs = node->input_size();
+ for (int idx = num_inputs - 1; idx >= 0; --idx) {
+ const string& input = node->input(idx);
+ if (IsControlInput(input)) {
+ removed_ctrl_inputs->insert(input);
+ ctx().node_map->RemoveOutput(NodeName(input), node->name());
+ node->mutable_input()->RemoveLast();
+ } else {
+ break;
+ }
+ }
+ }
+
+ void AddControlInputs(std::set<string>* new_ctrl_inputs,
+ NodeDef* node) const {
+ for (int idx = node->input_size() - 1; idx >= 0; --idx) {
+ const string& existing_input = node->input(idx);
+ if (IsControlInput(existing_input)) {
+ new_ctrl_inputs->erase(existing_input);
+ } else {
+ break;
+ }
+ }
+ for (const string& new_input : *new_ctrl_inputs) {
+ ctx().node_map->AddOutput(NodeName(new_input), node->name());
+ node->add_input(new_input);
+ }
+ }
+
+ // Returns the length of the common unary prefix chain of ops that can be
+ // hoisted out of concat.
+ Status FindCommonUnaryOpPrefix(const NodeDef& concat_node, int* prefix_length,
+ std::set<string>* ctrl_inputs) const {
+ *prefix_length = 0;
+ const int n = concat_node.attr().at("N").i();
+ // Follow the chains backwards from each concat input as long as all the
+ // following conditions hold:
+ // 1. The ops in all chains are the same.
+ // 2. The op is a unary elemenwise op.
+ // 3. The op output has only a single consumer.
+ std::vector<NodeDef*> tail(n, nullptr);
+ const int start = concat_node.op() == "Concat" ? 1 : 0;
+ const int end = start + n;
+ // Set up tail pointers to point to the immediate inputs to Concat.
+ for (int i = start; i < end; ++i) {
+ if (IsControlInput(concat_node.input(i))) {
+ return errors::FailedPrecondition("Got control input ",
+ concat_node.input(i),
+ " where normal input was expected.");
+ }
+ TF_RETURN_IF_ERROR(GetInputNode(concat_node.input(i), &tail[i - start]));
+ }
+
+ bool stop = false;
+ ctrl_inputs->clear();
+ while (!stop) {
+ const NodeDef* tail0 = tail[0];
+ if (!IsUnaryElementWise(*tail0)) break;
+ for (int chain = 0; chain < n; ++chain) {
+ // TODO(rmlarsen): Allow and hoist outgoing control edges.
+ if (tail[chain]->op() != tail0->op() ||
+ ctx().node_map->GetOutputs(tail[chain]->name()).size() > 1) {
+ stop = true;
+ break;
+ }
+ }
+ if (stop) break;
+ // We found one more op that can be hoisted.
+ ++(*prefix_length);
+ for (int chain = 0; chain < n; ++chain) {
+ RemoveControlInputs(ctrl_inputs, tail[chain]);
+ }
+ // Advance tail pointers to the next level.
+ for (int chain = 0; chain < n; ++chain) {
+ if (tail[chain]->input_size() == 0 ||
+ IsControlInput(tail[chain]->input(0))) {
+ stop = true;
+ break;
+ } else {
+ NodeDef* new_tail = nullptr;
+ TF_RETURN_IF_ERROR(GetInputNode(tail[chain]->input(0), &new_tail));
+ tail[chain] = new_tail;
+ }
+ }
+ }
+ return Status::OK();
+ }
+
+ Status HoistUnaryOpPrefix(const int prefix_length,
+ std::set<string>* ctrl_inputs,
+ NodeDef* concat_node) {
+ const int n = concat_node->attr().at("N").i();
+ const int start = concat_node->op() == "Concat" ? 1 : 0;
+ const int end = start + n;
+ const std::set<NodeDef*> consumers =
+ ctx().node_map->GetOutputs(concat_node->name());
+ AddControlInputs(ctrl_inputs, concat_node);
+ for (int chain = 0; chain < (end - start); ++chain) {
+ NodeDef* tail = nullptr;
+ const string concat_input = concat_node->input(chain + start);
+ for (int distance = 0; distance < prefix_length; ++distance) {
+ if (distance == 0) {
+ TF_RETURN_IF_ERROR(GetInputNode(concat_input, &tail));
+ } else {
+ TF_RETURN_IF_ERROR(GetInputNode(tail->input(0), &tail));
+ }
+ }
+
+ // Hook the node following tail directly into the concat node.
+ const string tail_input = tail->input(0);
+ concat_node->set_input(chain + start, tail_input);
+ ctx().node_map->UpdateInput(concat_node->name(), concat_input,
+ tail_input);
+
+ if (chain == 0) {
+ // Reuse nodes in the first chain to process output of concat.
+ tail->set_input(0, concat_node->name());
+ ctx().node_map->UpdateInput(tail->name(), tail_input,
+ concat_node->name());
+
+ // Update the consumers of concat to consume the end of the chain
+ // instead.
+ for (NodeDef* consumer : consumers) {
+ for (int idx = 0; idx < consumer->input_size(); ++idx) {
+ if (consumer->input(idx) == concat_node->name()) {
+ consumer->set_input(idx, concat_input);
+ ctx().node_map->UpdateInput(consumer->name(), concat_node->name(),
+ concat_input);
+ }
+ }
+ AddToOptimizationQueue(consumer);
+ }
+ }
+ }
+ return Status::OK();
+ }
+};
+
} // namespace
class UniqueNodes {
@@ -1995,6 +2170,8 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps(bool can_use_shapes) {
pipeline.AddStage<RemoveRedundantCastStage>(ctx, ctx_ext);
if (options_.remove_negation)
pipeline.AddStage<RemoveNegationStage>(ctx, ctx_ext);
+ if (options_.hoist_unary_out_of_concat)
+ pipeline.AddStage<HoistCWiseUnaryFromConcatStage>(ctx, ctx_ext);
VLOG(1) << "Run " << pipeline.NumStages() << " arithmetic optimizer stages: "
<< str_util::Join(pipeline.StageNames(), ", ");
@@ -2062,17 +2239,18 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
nodes_to_preserve_ = item.NodesToPreserve();
fetch_nodes_known_ = !item.fetch.empty();
*optimized_graph = item.graph;
- optimized_graph_ = optimized_graph;
+ GrapplerItem optimized_item(item, optimized_graph);
+ optimized_graph_ = &optimized_item.graph;
node_map_.reset(new NodeMap(optimized_graph_));
- DedupComputations();
+ if (options_.dedup_computations) {
+ DedupComputations();
+ }
// Perform topological sort on the graph in order to help AddOpsRewrite to
// optimize larger subgraphs starting from the roots with more inputs.
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph_));
- GrapplerItem optimized_item(item, optimized_graph);
- optimized_graph_ = &optimized_item.graph;
graph_properties_.reset(new GraphProperties(optimized_item));
const Status status = graph_properties_->InferStatically(false);
const bool can_use_shapes = status.ok();
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
index c0fe8839ca..375f13acc1 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.h
@@ -56,24 +56,24 @@ class ArithmeticOptimizer : public GraphOptimizer {
struct ArithmeticOptimizerOptions {
// TODO(ezhulenev): flag do disable TrySimplifyAndReplaceUses in tests.
// Remove when all optimizers will be migrated to separate stages.
+ bool dedup_computations = true;
bool enable_try_simplify_and_replace = true;
- bool combine_add_to_addn = false;
+ bool combine_add_to_addn = true;
bool hoist_common_factor_out_of_aggregation = true;
- bool minimize_broadcasts = false;
+ bool minimize_broadcasts = true;
bool remove_identity_transpose = true;
bool remove_redundant_bitcast = true;
bool remove_redundant_cast = true;
bool remove_negation = true;
+ bool hoist_unary_out_of_concat = false;
// Choose which arithmetic optimizer stages will be enabled for a given
// optimization level by default.
static ArithmeticOptimizerOptions Default(
RewriterConfig::Toggle opt_level) {
ArithmeticOptimizerOptions options;
- // TODO(ezhulenev): enable by default after 1.8 release cut
if (opt_level == RewriterConfig::AGGRESSIVE) {
- options.combine_add_to_addn = true;
- options.minimize_broadcasts = true;
+ options.hoist_unary_out_of_concat = true;
}
return options;
}
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index cb1f2ea732..df10dbdf48 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -98,6 +98,7 @@ class ArithmeticOptimizerTest : public GrapplerTest {
// should explicitly enable required optimization for tests isolation
void DisableAllStages(ArithmeticOptimizer* optimizer) {
ArithmeticOptimizer::ArithmeticOptimizerOptions options;
+ options.dedup_computations = false;
options.enable_try_simplify_and_replace = false;
options.combine_add_to_addn = false;
options.hoist_common_factor_out_of_aggregation = false;
@@ -147,6 +148,10 @@ class ArithmeticOptimizerTest : public GrapplerTest {
DisableAllStages(optimizer);
optimizer->options_.remove_negation = true;
}
+ void EnableOnlyHoistCWiseUnaryFromConcat(ArithmeticOptimizer* optimizer) {
+ DisableAllStages(optimizer);
+ optimizer->options_.hoist_unary_out_of_concat = true;
+ }
};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -2086,5 +2091,102 @@ TEST_F(ArithmeticOptimizerTest, MinimizeBroadcasts_BuildTreeUp) {
EXPECT_EQ("mul1", mul3_node->input(1));
}
+TEST_F(ArithmeticOptimizerTest, HoistCWiseUnaryFromConcat) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output a = ops::Variable(s.WithOpName("a"), {32}, DT_FLOAT);
+ Output b = ops::Variable(s.WithOpName("b"), {32}, DT_FLOAT);
+ Output c = ops::Variable(s.WithOpName("c"), {32}, DT_FLOAT);
+ Output axis = ops::Const(s.WithOpName("axis"), 0, {});
+ Output ctrl1 = ops::Const(s.WithOpName("ctrl1"), 1, {});
+ Output ctrl2 = ops::Const(s.WithOpName("ctrl2"), 2, {});
+ Output ctrl3 = ops::Const(s.WithOpName("ctrl3"), 3, {});
+ // Test case with chains of length 1.
+ Output sin_a =
+ ops::Sin(s.WithOpName("sin_a").WithControlDependencies(ctrl3), a);
+ Output exp_a =
+ ops::Exp(s.WithOpName("exp_a").WithControlDependencies(ctrl1), sin_a);
+ Output exp_b = ops::Exp(s.WithOpName("exp_b"), b);
+ Output exp_c =
+ ops::Exp(s.WithOpName("exp_c").WithControlDependencies(ctrl2), c);
+ Output concat =
+ ops::Concat(s.WithOpName("concat"), {exp_a, exp_b, exp_c}, axis);
+ Output id = ops::Identity(s.WithOpName("id"), concat);
+
+ // Test case with chains of length 2.
+ Output exp_a2 =
+ ops::Exp(s.WithOpName("exp_a2").WithControlDependencies(ctrl1), sin_a);
+ Output exp_b2 = ops::Exp(s.WithOpName("exp_b2"), b);
+ Output exp_c2 =
+ ops::Exp(s.WithOpName("exp_c2").WithControlDependencies(ctrl2), c);
+ Output cos_exp_a2 = ops::Cos(
+ s.WithOpName("cos_exp_a2").WithControlDependencies(ctrl1), exp_a2);
+ Output cos_exp_b2 = ops::Cos(
+ s.WithOpName("cos_exp_b2").WithControlDependencies(ctrl3), exp_b2);
+ Output cos_exp_c2 = ops::Cos(s.WithOpName("cos_exp_c2"), exp_c2);
+ Output concat2 = ops::Concat(s.WithOpName("concat2"),
+ {cos_exp_a2, cos_exp_b2, cos_exp_c2}, axis);
+ Output id2 = ops::Identity(s.WithOpName("id2"), concat2);
+ GrapplerItem item;
+ item.fetch = {"id", "id2"};
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ GraphDef output;
+ ArithmeticOptimizer optimizer;
+ EnableOnlyHoistCWiseUnaryFromConcat(&optimizer);
+
+ OptimizeAndPrune(&optimizer, &item, &output);
+ int found = 0;
+ for (const NodeDef& node : output.node()) {
+ if (node.name() == "concat") {
+ EXPECT_EQ(6, node.input_size());
+ EXPECT_EQ("sin_a", node.input(0));
+ EXPECT_EQ("b", node.input(1));
+ EXPECT_EQ("c", node.input(2));
+ EXPECT_EQ("axis", node.input(3));
+ EXPECT_EQ("^ctrl1", node.input(4));
+ EXPECT_EQ("^ctrl2", node.input(5));
+ found++;
+ }
+ if (node.name() == "exp_a") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("concat", node.input(0));
+ found++;
+ }
+ if (node.name() == "id") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("exp_a", node.input(0));
+ found++;
+ }
+
+ if (node.name() == "concat2") {
+ EXPECT_EQ(7, node.input_size());
+ EXPECT_EQ("sin_a", node.input(0));
+ EXPECT_EQ("b", node.input(1));
+ EXPECT_EQ("c", node.input(2));
+ EXPECT_EQ("axis", node.input(3));
+ EXPECT_EQ("^ctrl1", node.input(4));
+ EXPECT_EQ("^ctrl2", node.input(5));
+ EXPECT_EQ("^ctrl3", node.input(6));
+ found++;
+ }
+ if (node.name() == "exp_a2") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("concat2", node.input(0));
+ found++;
+ }
+ if (node.name() == "cos_exp_a2") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("exp_a2", node.input(0));
+ found++;
+ }
+ if (node.name() == "id2") {
+ EXPECT_EQ(1, node.input_size());
+ EXPECT_EQ("cos_exp_a2", node.input(0));
+ found++;
+ }
+ }
+ EXPECT_EQ(7, found);
+}
+
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index e29aaa25fe..45bb188e8d 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -36,6 +36,7 @@ limitations under the License.
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/setround.h"
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index 1acce05909..25693c5c60 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -520,6 +520,25 @@ TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_UnknownOutputShape) {
EXPECT_EQ("Mul", node.op()) << node.name();
}
}
+
+ const std::vector<string> fetch = {"mul_0", "mul_4", "mul_8"};
+ auto x_known_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto x_partially_unknown_t =
+ GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
+ auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 7}));
+ auto expected_tensors =
+ EvaluateNodes(item.graph, fetch,
+ {{"x_known", x_known_t},
+ {"x_partially_unknown", x_partially_unknown_t},
+ {"x_unknown", x_unknown_t}});
+ EXPECT_EQ(fetch.size(), expected_tensors.size());
+ auto tensors = EvaluateNodes(output, fetch,
+ {{"x_known", x_known_t},
+ {"x_partially_unknown", x_partially_unknown_t},
+ {"x_unknown", x_unknown_t}});
+ EXPECT_EQ(fetch.size(), tensors.size());
+ for (int i = 0; i < tensors.size(); i++)
+ test::ExpectTensorNear<float>(expected_tensors[i], tensors[i], 1e-5);
}
TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) {
@@ -572,6 +591,20 @@ TEST_F(ConstantFoldingTest, NeutralElement_PartialShape_KnownOutputShape) {
EXPECT_TRUE(IsControlInput(node.input(1)));
}
}
+ const std::vector<string> fetch = {"addn1"};
+ auto x_partially_unknown_t =
+ GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto x_unknown_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto expected_tensors =
+ EvaluateNodes(item.graph, fetch,
+ {{"x_partially_unknown", x_partially_unknown_t},
+ {"x_unknown", x_unknown_t}});
+ EXPECT_EQ(1, expected_tensors.size());
+ auto tensors = EvaluateNodes(output, fetch,
+ {{"x_partially_unknown", x_partially_unknown_t},
+ {"x_unknown", x_unknown_t}});
+ EXPECT_EQ(1, tensors.size());
+ test::ExpectTensorNear<float>(expected_tensors[0], tensors[0], 1e-5);
}
TEST_F(ConstantFoldingTest, CreateConstNodes) {
@@ -1064,6 +1097,20 @@ TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN) {
}
}
EXPECT_EQ(9, found);
+
+ auto v1_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({3, 4}));
+ auto v2_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({5, 6}));
+ auto v3_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({4, 6}));
+ const std::vector<string> fetch_nodes = {"i1a", "i1b", "i2a", "i2b",
+ "i2c", "i3a", "i3b"};
+ auto tensors_expected = EvaluateNodes(
+ item.graph, fetch_nodes, {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
+ EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
+ auto tensors = EvaluateNodes(output, fetch_nodes,
+ {{"v1", v1_t}, {"v2", v2_t}, {"v3", v3_t}});
+ EXPECT_EQ(fetch_nodes.size(), tensors.size());
+ for (int i = 0; i < fetch_nodes.size(); i++)
+ test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
}
TEST_F(ConstantFoldingTest, ShapeMaterializationShapeN_MultipleOutputs) {
@@ -1930,6 +1977,14 @@ TEST_F(ConstantFoldingTest, Packing) {
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ const std::vector<string> fetch_nodes = {"i1", "i2"};
+ auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes);
+ EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
+ auto tensors = EvaluateNodes(output, fetch_nodes);
+ EXPECT_EQ(fetch_nodes.size(), tensors.size());
+ for (int i = 0; i < fetch_nodes.size(); i++)
+ test::ExpectTensorNear<float>(tensors_expected[i], tensors[i], 1e-5);
+
// Make sure that the representation of the folded constant is space
// efficient: in particular, the whole message should be smaller than 8k
// (the size needed to naively encode 1000 floats folded twice).
@@ -1965,6 +2020,13 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) {
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
+ std::vector<string> fetch_nodes = {"o1", "o2", "p1", "p2"};
+ auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1, 5}));
+ auto g_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({1}));
+ auto tensors_expected =
+ EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}, {"g", g_t}});
+ EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
+
// Run a second time to make sure the optimization is idempotent.
item.graph.Swap(&output);
status = optimizer.Optimize(nullptr, item, &output);
@@ -2005,6 +2067,11 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs) {
}
}
EXPECT_EQ(6, found);
+
+ auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}, {"g", g_t}});
+ EXPECT_EQ(fetch_nodes.size(), tensors.size());
+ for (int i = 0; i < fetch_nodes.size(); i++)
+ test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
}
TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
@@ -2024,6 +2091,11 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
GrapplerItem item;
TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ std::vector<string> fetch_nodes = {"o1", "o2"};
+ auto a_t = GenerateRandomTensor<DT_FLOAT>(TensorShape({2, 2}));
+ auto tensors_expected = EvaluateNodes(item.graph, fetch_nodes, {{"a", a_t}});
+ EXPECT_EQ(fetch_nodes.size(), tensors_expected.size());
+
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
@@ -2078,6 +2150,10 @@ TEST_F(ConstantFoldingTest, MaterializeBroadcastGradientArgs_InfiniteLoop) {
}
}
EXPECT_EQ(7, found);
+ auto tensors = EvaluateNodes(output, fetch_nodes, {{"a", a_t}});
+ EXPECT_EQ(fetch_nodes.size(), tensors.size());
+ for (int i = 0; i < fetch_nodes.size(); i++)
+ test::ExpectTensorEqual<int>(tensors_expected[i], tensors[i]);
}
TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
@@ -2452,7 +2528,6 @@ TEST_F(ConstantFoldingTest, PartialFolding_IdentityN) {
ConstantFolding optimizer(nullptr /* cpu_device */);
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
- LOG(INFO) << output.DebugString();
TF_EXPECT_OK(status);
EXPECT_EQ(8, output.node_size());
for (const auto& node : output.node()) {
@@ -2539,6 +2614,8 @@ TEST_F(ConstantFoldingTest, TrivialPack) {
EXPECT_EQ(tensors_expected[0].shape(), tensors[0].shape());
}
+// The test does not evalute the optimized and original graphs to check if their
+// outputs are the same. See b/78233179.
TEST_F(ConstantFoldingTest, Enter) {
GrapplerItem item;
AttrValue frame_name;
@@ -2555,7 +2632,7 @@ TEST_F(ConstantFoldingTest, Enter) {
value_tensor.AsProtoTensorContent(value.mutable_tensor());
GraphDef& graph = item.graph;
- AddNode("x", "Placeholder", {}, {{"T", type}}, &graph);
+ AddNode("x", "Placeholder", {}, {{"dtype", type}}, &graph);
AddNode("c1", "Const", {"^x"}, {{"value", value}, {"dtype", type}}, &graph);
AddNode("enter1", "Enter", {"x"},
{{"T", type},
diff --git a/tensorflow/core/grappler/optimizers/function_optimizer.cc b/tensorflow/core/grappler/optimizers/function_optimizer.cc
index 950933b933..47e7dc0a96 100644
--- a/tensorflow/core/grappler/optimizers/function_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/function_optimizer.cc
@@ -180,7 +180,7 @@ FunctionDefLibrary TrimFunctionLibrary(const FunctionLibraryDefinition& flib,
const string& func_name = func->signature().name();
keep_funcs.insert(func_name);
- // Find all the functions that called from the function body.
+ // Find all the functions called from the function body.
const auto& func_body = func->node_def();
std::for_each(func_body.begin(), func_body.end(), add_node_to_func_queue);
@@ -541,7 +541,7 @@ Status InlineSymbolicGradient(const NodeDef& node, SymbolicGradientEnv* env,
Status FunctionOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
- VLOG(2) << "Optimize function library: id=" << item.id;
+ VLOG(1) << "Optimize Grappler item: id=" << item.id;
// Nothing to do here.
if (item.graph.library().function_size() == 0) {
diff --git a/tensorflow/core/grappler/optimizers/loop_optimizer.cc b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
index fff06dd2ac..f7994221bb 100644
--- a/tensorflow/core/grappler/optimizers/loop_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/loop_optimizer.cc
@@ -320,42 +320,50 @@ Status LoopInvariantNodeMotionOptimizer::RevertInvariantNodes() {
return Status::OK();
}
-Status LoopInvariantNodeMotionOptimizer::FindInvariantNodes(NodeDef* node) {
- auto consumers = node_map_->GetOutputs(node->name());
- invariant_nodes_.insert(std::make_pair(node, consumers.size()));
- for (auto* consumer : consumers) {
- if (invariant_nodes_.count(consumer) || ModifiesFrameInfo(*consumer)) {
- continue;
- }
- bool is_invariant = true;
- for (const auto& input : consumer->input()) {
- if (!IsControlInput(input)) {
- const string name = NodeName(input);
- auto* producer = node_map_->GetNode(name);
- if (!invariant_nodes_.count(producer)) {
- if (IsConstant(*producer)) {
- invariant_nodes_.insert(
- std::make_pair(producer, node_map_->GetOutputs(name).size()));
- } else {
- is_invariant = false;
- break;
- }
- }
+Status LoopInvariantNodeMotionOptimizer::FindInvariantNodes(
+ NodeDef* start_node) {
+ std::vector<NodeDef*> stack;
+ stack.reserve(32);
+ stack.push_back(start_node);
+ while (!stack.empty()) {
+ NodeDef* node = stack.back();
+ stack.pop_back();
+ auto consumers = node_map_->GetOutputs(node->name());
+ invariant_nodes_.emplace(node, consumers.size());
+ for (auto* consumer : consumers) {
+ if (invariant_nodes_.count(consumer) || ModifiesFrameInfo(*consumer)) {
+ continue;
}
- }
- if (is_invariant) {
- std::set<NodeDef*> producers;
+ bool is_invariant = true;
for (const auto& input : consumer->input()) {
- auto* producer = node_map_->GetNode(input);
- producers.insert(producer);
+ if (!IsControlInput(input)) {
+ const string name = NodeName(input);
+ auto* producer = node_map_->GetNode(name);
+ if (!invariant_nodes_.count(producer)) {
+ if (IsConstant(*producer)) {
+ invariant_nodes_.insert(
+ std::make_pair(producer, node_map_->GetOutputs(name).size()));
+ } else {
+ is_invariant = false;
+ break;
+ }
+ }
+ }
}
- for (auto* producer : producers) {
- auto iter = invariant_nodes_.find(producer);
- if (iter != invariant_nodes_.end()) {
- --iter->second;
+ if (is_invariant) {
+ std::set<NodeDef*> producers;
+ for (const auto& input : consumer->input()) {
+ auto* producer = node_map_->GetNode(input);
+ producers.insert(producer);
+ }
+ for (auto* producer : producers) {
+ auto iter = invariant_nodes_.find(producer);
+ if (iter != invariant_nodes_.end()) {
+ --iter->second;
+ }
}
+ stack.push_back(consumer);
}
- TF_RETURN_IF_ERROR(FindInvariantNodes(consumer));
}
}
return Status::OK();
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.cc b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
index cdc4698c34..c98eef1a6a 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.cc
@@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
-#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/versions.pb.h"
#include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
@@ -30,7 +29,6 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
#include "tensorflow/core/grappler/utils/colocation.h"
-#include "tensorflow/core/grappler/utils/functions.h"
#include "tensorflow/core/grappler/utils/topological_sort.h"
#include "tensorflow/core/lib/core/status.h"
@@ -63,15 +61,15 @@ int NumIterations(const RewriterConfig& cfg) {
}
// Check if optimizer is allowed to run only once.
-int IsRunOnceOptimizer(const string& name) { return name == "layout"; }
+bool IsRunOnceOptimizer(const string& name) { return name == "layout"; }
} // namespace
-std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
- const string& optimizer) const {
#define MK_OPT(NAME, VALUE) \
if (optimizer == NAME) return std::unique_ptr<GraphOptimizer>(VALUE)
+std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
+ const string& optimizer) const {
MK_OPT("pruning", new ModelPruner());
MK_OPT("function", new FunctionOptimizer(cfg_.function_optimization()));
MK_OPT("constfold", new ConstantFolding(cpu_device_));
@@ -84,9 +82,10 @@ std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
MK_OPT("debug_stripper", new DebugStripper());
return std::unique_ptr<GraphOptimizer>();
-#undef MK_OPT
}
+#undef MK_OPT
+
Status MetaOptimizer::InitializeOptimizers(
std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
if (!cfg_.disable_model_pruning()) {
@@ -161,14 +160,18 @@ Status MetaOptimizer::InitializeOptimizersByName(
Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
- VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id;
-
std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
- bool register_by_name = !cfg_.optimizers().empty();
- TF_RETURN_IF_ERROR(register_by_name ? InitializeOptimizersByName(&optimizers)
- : InitializeOptimizers(&optimizers));
+ if (cfg_.optimizers().empty()) {
+ TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers));
+ } else {
+ TF_RETURN_IF_ERROR(InitializeOptimizersByName(&optimizers));
+ }
+
+ VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id
+ << " num_optimizers=" << optimizers.size();
if (optimizers.empty()) {
+ VLOG(3) << "Skip graph optimization, no optimizers registered";
*optimized_graph = item.graph;
return Status::OK();
}
@@ -178,6 +181,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
GrapplerItem optimized_item = item;
optimized_graph->Swap(&optimized_item.graph);
+ bool is_optimized = false;
GraphOptimizationResult optimization_result(item.id);
for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
@@ -201,7 +205,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
optimized_graph->Swap(&optimized_item.graph);
result = status.ToString();
} else {
- optimization_result.is_optimized = true;
+ is_optimized = true;
float duration_ms = (end_us - start_us) / 1000.0f;
result = strings::StrCat(
PrintSizesBeforeAfter(optimized_item.graph, *optimized_graph),
@@ -217,7 +221,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
// Record graph optimization result.
optimization_results_.push_back(optimization_result);
- if (optimization_result.is_optimized) {
+ if (is_optimized) {
TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
ReassignColocation(optimized_graph);
// Make sure that the optimizers preserved the graph version.
@@ -231,71 +235,7 @@ Status MetaOptimizer::OptimizeGraph(Cluster* cluster, const GrapplerItem& item,
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* optimized_graph) {
optimization_results_.clear();
-
- // 1. Optimize main graph
TF_RETURN_IF_ERROR(OptimizeGraph(cluster, item, optimized_graph));
-
- // 2. Optimize function library
- FunctionLibraryDefinition flib(OpRegistry::Global(),
- optimized_graph->library());
-
- // Optimize each function only once.
- std::unordered_set<string> optimized_funcs;
- bool optimize_function_library = true;
-
- // TODO(ezhulenev): turn it on after fixing ranklab: tune_tf_test.
- cfg_.set_constant_folding(RewriterConfig::OFF);
- cfg_.set_arithmetic_optimization(RewriterConfig::OFF);
-
- while (optimize_function_library) {
- optimize_function_library = false;
-
- for (const FunctionDef& func : optimized_graph->library().function()) {
- const string& func_name = func.signature().name();
-
- // Skip already optimized functions.
- if (optimized_funcs.find(func_name) != optimized_funcs.end()) continue;
-
- // Skip parametrized functions (function type or body is defined only at
- // function call time by caller node attributes).
- if (IsParametrized(func)) continue;
-
- VLOG(3) << "Optimize function: function=" << func_name;
-
- // Function optimization might specialize nested function calls, so we
- // have to reset the flag and do at least one more pass over the library.
- optimize_function_library = true;
- optimized_funcs.insert(func_name);
-
- // Make a GrapplerItem from a FunctionDef.
- GrapplerFunctionItem func_item;
- TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(func, flib, &func_item));
-
- // Optimize function body graph.
- GraphDef optimized_func_graph;
- TF_RETURN_IF_ERROR(
- OptimizeGraph(cluster, func_item, &optimized_func_graph));
-
- // Function body optimization might have created new specialized
- // functions, add them to the library.
- TF_RETURN_IF_ERROR(flib.AddLibrary(optimized_func_graph.library()));
-
- // Convert optimized graph back to FunctionDef.
- FunctionDef optimized_func;
- func_item.SwapFunctionBody(std::move(optimized_func_graph));
- TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
-
- // Replace optimized function with a new FunctionDef.
- TF_RETURN_IF_ERROR(flib.RemoveFunction(func_name));
- TF_RETURN_IF_ERROR(flib.AddFunctionDef(optimized_func));
- }
-
- // If optimized at least one function, update the graph library.
- if (optimize_function_library) {
- *optimized_graph->mutable_library() = flib.ToProto();
- }
- }
-
return Status::OK();
}
@@ -303,8 +243,7 @@ void MetaOptimizer::PrintResult() {
for (const GraphOptimizationResult& graph_result : optimization_results_) {
LOG(INFO) << "Optimization results for grappler item: " << graph_result.id;
for (const OptimizerResult& result : graph_result.results) {
- LOG(INFO) << "Return status of optimizer " << result.optimizer_name
- << ": " << result.result;
+ LOG(INFO) << " " << result.optimizer_name << ": " << result.result;
}
}
}
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer.h b/tensorflow/core/grappler/optimizers/meta_optimizer.h
index 7cf9a40c2d..b8d4666248 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer.h
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer.h
@@ -69,7 +69,6 @@ class MetaOptimizer : public GraphOptimizer {
struct GraphOptimizationResult {
explicit GraphOptimizationResult(const string& id) : id(id) {}
string id;
- bool is_optimized = false;
std::vector<OptimizerResult> results;
};
diff --git a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
index 887a988af9..9fcf07651b 100644
--- a/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/meta_optimizer_test.cc
@@ -16,14 +16,11 @@ limitations under the License.
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
#include "tensorflow/cc/ops/standard_ops.h"
-#include "tensorflow/core/framework/function_testlib.h"
-#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/inputs/trivial_test_graph_input_yielder.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer.h"
#include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
#include "tensorflow/core/grappler/utils.h"
-#include "tensorflow/core/grappler/utils/grappler_test.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
@@ -31,8 +28,6 @@ namespace tensorflow {
namespace grappler {
namespace {
-constexpr char kDevice[] = "/device:CPU:0";
-
class TestOptimizer : public CustomGraphOptimizer {
public:
static void SetOptimized(const bool flag_value) { optimized_ = flag_value; }
@@ -64,9 +59,7 @@ bool TestOptimizer::optimized_;
REGISTER_GRAPH_OPTIMIZER(TestOptimizer);
-class MetaOptimizerTest : public GrapplerTest {};
-
-TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
+TEST(MetaOptimizerTest, RunsCustomOptimizer) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
@@ -82,7 +75,7 @@ TEST_F(MetaOptimizerTest, RunsCustomOptimizer) {
EXPECT_TRUE(TestOptimizer::IsOptimized());
}
-TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
+TEST(MetaOptimizerTest, RunOptimizersTwice) {
TrivialTestGraphInputYielder fake_input(4, 1, 10, false, {"CPU:0"});
GrapplerItem item;
CHECK(fake_input.NextItem(&item));
@@ -96,167 +89,6 @@ TEST_F(MetaOptimizerTest, RunOptimizersTwice) {
TF_EXPECT_OK(status);
}
-TEST_F(MetaOptimizerTest, OptimizeFunctionLibrary) {
- using test::function::NDef;
-
- // Enable ony function optimization.
- RewriterConfig rewriter_config;
- rewriter_config.set_meta_optimizer_iterations(RewriterConfig::TWO);
- rewriter_config.set_function_optimization(RewriterConfig::ON);
- rewriter_config.add_optimizers("function");
-
- MetaOptimizer optimizer(nullptr, rewriter_config);
-
- // Define function library:
- //
- // MyMul(x, y) = x * y
- // *MySquare(x) = MyMul(x, x)
- // *MyQuadratic(x) = MySquare(MySquare(x))
- //
- // * - marked as noinline
-
- FunctionDef mul_func = FunctionDefHelper::Create(
- "MyMul", {"x:T", "y:T"}, {"z:T"}, {"T: {float, double}"},
- {{{"mul"}, "Mul", {"x", "y"}, {{"T", "$T"}}}},
- /* Mapping between function returns and function node outputs. */
- {{"z", "mul:z:0"}});
-
- FunctionDef square_func = FunctionDefHelper::Create(
- "MySquare", {"x:T"}, {"z:T"}, {"T: {float, double}"},
- {{{"my_mul"}, "MyMul", {"x", "x"}, {{"T", "$T"}}}},
- /* Mapping between function returns and function node outputs. */
- {{"z", "my_mul:z:0"}});
- (*square_func.mutable_attr())["_noinline"].set_b(true);
-
- FunctionDef quadratic_func = FunctionDefHelper::Create(
- "MyQuadratic", {"x:T"}, {"z:T"}, {"T: {float, double}"},
- {{{"square"}, "MySquare", {"x"}, {{"T", "$T"}}},
- {{"quadratic"}, "MySquare", {"square:z"}, {{"T", "$T"}}}},
- /* Mapping between function returns and function node outputs. */
- {{"z", "quadratic:z:0"}});
- (*quadratic_func.mutable_attr())["_noinline"].set_b(true);
-
- // Tensorflow graph:
- //
- // a = tf.Placeholder(tf.float);
- // b = tf.Placeholder(tf.int32);
- //
- // square = MySquare(a); // a^2
- // quadratic = MyQuadratic(b); // b^4
- GrapplerItem item;
- item.graph = test::function::GDef(
- {NDef("a", "Placeholder", {}, {{"dtype", DT_FLOAT}}, kDevice),
- NDef("b", "Placeholder", {}, {{"dtype", DT_INT32}}, kDevice),
- // Calls into function library
- NDef("square", "MySquare", {"a"}, {{"T", DT_FLOAT}}, kDevice),
- NDef("quadratic", "MyQuadratic", {"b"}, {{"T", DT_INT32}}, kDevice),
- // Forward outputs
- NDef("out_s", "Identity", {"square:0"}, {{"T", DT_FLOAT}}, kDevice),
- NDef("out_q", "Identity", {"quadratic:0"}, {{"T", DT_INT32}}, kDevice)},
- // FunctionLib
- {mul_func, square_func, quadratic_func});
-
- GraphDef output;
- TF_EXPECT_OK(optimizer.Optimize(nullptr, item, &output));
-
- FunctionLibraryDefinition optimized_flib(OpRegistry::Global(),
- output.library());
-
- // Specialized and optimized functions should be added to the graph.
- EXPECT_EQ(6, optimized_flib.num_functions());
-
- // MyQuadratic should be specialized once:
- // 0. 'quadratic' node in the main graph
- const string optimized_0 = "MyQuadratic_specialized_for_quadratic";
-
- // MySquare should be specialized and optimized for 3 instantiations:
- // 1. 'square' node in the main graph
- // 2. 'square' node in the MyQuadratic specialization
- // 3. 'quadratic' node in the MyQuadratic specialization
-
- const string optimized_1 = "MySquare_specialized_for_square";
- const string optimized_2 = "MySquare_specialized_for_square_1";
- const string optimized_3 = "MySquare_specialized_for_quadratic";
-
- const FunctionDef* optimized_func_0 = optimized_flib.Find(optimized_0);
- const FunctionDef* optimized_func_1 = optimized_flib.Find(optimized_1);
- const FunctionDef* optimized_func_2 = optimized_flib.Find(optimized_2);
- const FunctionDef* optimized_func_3 = optimized_flib.Find(optimized_3);
-
- ASSERT_NE(optimized_func_0, nullptr);
- ASSERT_NE(optimized_func_1, nullptr);
- ASSERT_NE(optimized_func_2, nullptr);
- ASSERT_NE(optimized_func_3, nullptr);
-
- // Graph should call optimized function.
- int count = 0;
- for (const NodeDef& node : output.node()) {
- if (node.name() == "square" && count++) {
- EXPECT_EQ("MySquare_specialized_for_square", node.op());
- } else if (node.name() == "quadratic" && count++) {
- EXPECT_EQ("MyQuadratic_specialized_for_quadratic", node.op());
- }
- }
- EXPECT_EQ(2, count);
-
- // Specialized MySquare should call specialized functions.
- count = 0;
- for (const NodeDef& node : optimized_func_0->node_def()) {
- if (node.name() == "square" && count++) {
- EXPECT_EQ(optimized_2, node.op());
- } else if (node.name() == "quadratic" && count++) {
- EXPECT_EQ(optimized_3, node.op());
- }
- }
- EXPECT_EQ(2, count);
-
- const std::vector<const FunctionDef*> optimized_funcs = {
- optimized_func_1, optimized_func_1, optimized_func_3};
-
- // MyMul should be inlined into all optimized versions of MySquare.
- for (const FunctionDef* optimized_func : optimized_funcs) {
- count = 0;
- for (const NodeDef& node : optimized_func->node_def()) {
- if (node.name() == "my_mul/inlined_inputs" && count++) {
- EXPECT_EQ("IdentityN", node.op());
- EXPECT_EQ(2, node.input_size());
- EXPECT_EQ("x:0", node.input(0));
- EXPECT_EQ("x:0", node.input(1));
- } else if (node.name() == "my_mul/x" && count++) {
- EXPECT_EQ("Identity", node.op());
- EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("my_mul/inlined_inputs:output:0", node.input(0));
- } else if (node.name() == "my_mul/y" && count++) {
- EXPECT_EQ("Identity", node.op());
- EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("my_mul/inlined_inputs:output:1", node.input(0));
- } else if (node.name() == "my_mul/mul" && count++) {
- EXPECT_EQ("Mul", node.op());
- EXPECT_EQ(2, node.input_size());
- EXPECT_EQ("my_mul/x:output:0", node.input(0));
- EXPECT_EQ("my_mul/y:output:0", node.input(1));
- } else if (node.name() == "my_mul" && count++) {
- EXPECT_EQ("IdentityN", node.op());
- EXPECT_EQ(1, node.input_size());
- EXPECT_EQ("my_mul/mul:z:0", node.input(0));
- }
- EXPECT_TRUE(node.device().empty());
- }
- EXPECT_EQ(5, count);
- }
-
- item.fetch = {"out_s", "out_q"};
- item.feed.emplace_back("a", test::AsScalar<float>(2.0f));
- item.feed.emplace_back("b", test::AsScalar<int>(4));
- auto tensors_expected = EvaluateFetchNodes(item);
-
- GrapplerItem optimized(item, std::move(output));
- auto tensors = EvaluateFetchNodes(optimized);
-
- test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
- test::ExpectTensorEqual<int>(tensors_expected[1], tensors[1]);
-}
-
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 835b8bbb47..6355f13654 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -943,6 +943,7 @@ tf_kernel_library(
srcs = ["cudnn_rnn_ops.cc"],
visibility = ["//visibility:public"],
deps = [
+ ":gpu_util_hdrs",
"//tensorflow/core:cudnn_rnn_ops_op_lib",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -2038,6 +2039,17 @@ tf_kernel_library(
],
)
+tf_kernel_library(
+ name = "partitioned_function_ops",
+ prefix = "partitioned_function_ops",
+ deps = [
+ "//tensorflow/core:core_cpu_internal",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:functional_ops_op_lib",
+ "//tensorflow/core:lib",
+ ],
+)
+
cc_library(
name = "image",
deps = [
@@ -5153,6 +5165,7 @@ filegroup(
"decode_proto_op.cc",
"encode_proto_op.cc",
"rpc_op.cc",
+ "partitioned_function_ops.cc",
# Excluded due to experimental status:
"debug_ops.*",
"scatter_nd_op*",
@@ -5939,6 +5952,7 @@ tf_cc_test(
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:protos_all_cc",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/core/kernels/avgpooling_op.cc b/tensorflow/core/kernels/avgpooling_op.cc
index c581d1451f..ba38e1a188 100644
--- a/tensorflow/core/kernels/avgpooling_op.cc
+++ b/tensorflow/core/kernels/avgpooling_op.cc
@@ -156,10 +156,10 @@ class AvgPoolingOp<GPUDevice, T> : public UnaryOp<T> {
TensorShape output_shape = params.forward_output_shape();
if (data_format_ == FORMAT_NCHW) {
- DnnPoolingOp<T>::Compute(
- context, perftools::gputools::dnn::PoolingMode::kAverage, ksize_,
- stride_, padding_, data_format_, tensor_in, output_shape,
- /*propagate_nans=*/false);
+ DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kAverage, ksize_,
+ stride_, padding_, data_format_, tensor_in,
+ output_shape,
+ /*propagate_nans=*/false);
} else {
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
@@ -417,10 +417,10 @@ class AvgPoolingGradOp<GPUDevice, T> : public OpKernel {
output_shape.AddDim(shape_vec(i));
}
- DnnPoolingGradOp<T>::Compute(
- context, perftools::gputools::dnn::PoolingMode::kAverage, ksize_,
- stride_, padding_, data_format_, nullptr, nullptr, out_backprop,
- output_shape, /*propagate_nans=*/false);
+ DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kAverage,
+ ksize_, stride_, padding_, data_format_,
+ nullptr, nullptr, out_backprop, output_shape,
+ /*propagate_nans=*/false);
}
private:
@@ -547,10 +547,10 @@ class AvgPoolingGradOpCustomGPUKernel : public OpKernel {
output->flat<T>().data(), // bottom_diff
context->eigen_gpu_device()); // d
} else {
- DnnPoolingGradOp<T>::Compute(
- context, perftools::gputools::dnn::PoolingMode::kAverage, ksize_,
- stride_, padding_, data_format_, nullptr, nullptr, out_backprop,
- output_shape, /*propagate_nans=*/false);
+ DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kAverage,
+ ksize_, stride_, padding_, data_format_,
+ nullptr, nullptr, out_backprop, output_shape,
+ /*propagate_nans=*/false);
}
}
diff --git a/tensorflow/core/kernels/batch_matmul_op_impl.h b/tensorflow/core/kernels/batch_matmul_op_impl.h
index 43e716c542..a1c03f9918 100644
--- a/tensorflow/core/kernels/batch_matmul_op_impl.h
+++ b/tensorflow/core/kernels/batch_matmul_op_impl.h
@@ -245,35 +245,35 @@ struct LaunchBatchMatMul<CPUDevice, Scalar> {
namespace {
template <typename T>
-perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
- perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
- perftools::gputools::DeviceMemory<T> typed(wrapped);
+se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
+ se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
+ se::DeviceMemory<T> typed(wrapped);
return typed;
}
-class CublasScratchAllocator : public perftools::gputools::ScratchAllocator {
+class CublasScratchAllocator : public se::ScratchAllocator {
public:
- using Stream = ::perftools::gputools::Stream;
- using DeviceMemoryBytes = ::perftools::gputools::DeviceMemory<uint8>;
+ using Stream = se::Stream;
+ using DeviceMemoryBytes = se::DeviceMemory<uint8>;
CublasScratchAllocator(OpKernelContext* context) : context_(context) {}
int64 GetMemoryLimitInBytes(Stream* stream) override { return -1; }
- perftools::gputools::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
+ se::port::StatusOr<DeviceMemoryBytes> AllocateBytes(
Stream* stream, int64 byte_size) override {
Tensor temporary_memory;
Status allocation_status(context_->allocate_temp(
DT_UINT8, TensorShape({byte_size}), &temporary_memory));
if (!allocation_status.ok()) {
- return perftools::gputools::port::StatusOr<DeviceMemoryBytes>(
+ return se::port::StatusOr<DeviceMemoryBytes>(
DeviceMemoryBytes::MakeFromByteSize(nullptr, 0));
}
// Hold the reference of the allocated tensors until the end of the
// allocator.
allocated_tensors_.push_back(temporary_memory);
- return perftools::gputools::port::StatusOr<DeviceMemoryBytes>(
+ return se::port::StatusOr<DeviceMemoryBytes>(
DeviceMemoryBytes::MakeFromByteSize(
temporary_memory.flat<uint8>().data(),
temporary_memory.flat<uint8>().size()));
@@ -289,12 +289,11 @@ template <typename Scalar>
struct LaunchBatchMatMul<GPUDevice, Scalar> {
static void Launch(OpKernelContext* context, const Tensor& in_x,
const Tensor& in_y, bool adj_x, bool adj_y, Tensor* out) {
- constexpr perftools::gputools::blas::Transpose kTranspose =
- is_complex<Scalar>::value
- ? perftools::gputools::blas::Transpose::kConjugateTranspose
- : perftools::gputools::blas::Transpose::kTranspose;
- perftools::gputools::blas::Transpose trans[] = {
- perftools::gputools::blas::Transpose::kNoTranspose, kTranspose};
+ constexpr se::blas::Transpose kTranspose =
+ is_complex<Scalar>::value ? se::blas::Transpose::kConjugateTranspose
+ : se::blas::Transpose::kTranspose;
+ se::blas::Transpose trans[] = {se::blas::Transpose::kNoTranspose,
+ kTranspose};
const uint64 m = in_x.dim_size(adj_x ? 2 : 1);
const uint64 k = in_x.dim_size(adj_x ? 1 : 2);
const uint64 n = in_y.dim_size(adj_y ? 1 : 2);
@@ -305,7 +304,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- typedef perftools::gputools::DeviceMemory<Scalar> DeviceMemoryType;
+ typedef se::DeviceMemory<Scalar> DeviceMemoryType;
std::vector<DeviceMemoryType> a_device_memory;
std::vector<DeviceMemoryType> b_device_memory;
std::vector<DeviceMemoryType> c_device_memory;
@@ -340,19 +339,16 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
// This is a regular matrix*matrix or matrix*vector multiply. Avoid the
// overhead of the scratch allocator and the batch interface.
if (n == 1 &&
- blas_transpose_b !=
- perftools::gputools::blas::Transpose::kConjugateTranspose &&
- blas_transpose_a !=
- perftools::gputools::blas::Transpose::kConjugateTranspose) {
+ blas_transpose_b != se::blas::Transpose::kConjugateTranspose &&
+ blas_transpose_a != se::blas::Transpose::kConjugateTranspose) {
// This is a matrix*vector multiply so use GEMV to compute A * b.
// Here we are multiplying in the natural order, so we have to flip
// the transposition flag to compensate for the tensor being stored
// row-major. Since GEMV doesn't provide a way to just conjugate an
// argument, we have to defer those cases to GEMM below.
- auto gemv_trans_a =
- blas_transpose_a == perftools::gputools::blas::Transpose::kTranspose
- ? perftools::gputools::blas::Transpose::kNoTranspose
- : perftools::gputools::blas::Transpose::kTranspose;
+ auto gemv_trans_a = blas_transpose_a == se::blas::Transpose::kTranspose
+ ? se::blas::Transpose::kNoTranspose
+ : se::blas::Transpose::kTranspose;
bool blas_launch_status =
stream
->ThenBlasGemv(gemv_trans_a, adj_x ? m : k, adj_x ? k : m,
diff --git a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h
index 339d792302..f5ced95feb 100644
--- a/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h
+++ b/tensorflow/core/kernels/batching_util/adaptive_shared_batch_scheduler.h
@@ -28,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/thread_annotations.h"
diff --git a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
index 139475389d..b4bce90841 100644
--- a/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
+++ b/tensorflow/core/kernels/batching_util/shared_batch_scheduler.h
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/thread_annotations.h"
diff --git a/tensorflow/core/kernels/bias_op.cc b/tensorflow/core/kernels/bias_op.cc
index 368993c827..9fda7169a8 100644
--- a/tensorflow/core/kernels/bias_op.cc
+++ b/tensorflow/core/kernels/bias_op.cc
@@ -393,8 +393,8 @@ class BiasGradOp<GPUDevice, T> : public OpKernel {
if (channel == 0) return;
auto* stream = context->op_device_context()->stream();
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
- perftools::gputools::DeviceMemoryBase output_ptr(
- output->flat<T>().data(), output->NumElements() * sizeof(T));
+ se::DeviceMemoryBase output_ptr(output->flat<T>().data(),
+ output->NumElements() * sizeof(T));
stream->ThenMemZero(&output_ptr, output->NumElements() * sizeof(T));
if (output_backprop.NumElements() > 0) {
BiasGradGPU<T>::compute(context->template eigen_device<Device>(),
diff --git a/tensorflow/core/kernels/cast_op.h b/tensorflow/core/kernels/cast_op.h
index fd4e75d26f..16d2e0e0a5 100644
--- a/tensorflow/core/kernels/cast_op.h
+++ b/tensorflow/core/kernels/cast_op.h
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
-#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/check_numerics_op.cc b/tensorflow/core/kernels/check_numerics_op.cc
index d3b67f4614..c3c0c50007 100644
--- a/tensorflow/core/kernels/check_numerics_op.cc
+++ b/tensorflow/core/kernels/check_numerics_op.cc
@@ -139,7 +139,7 @@ class CheckNumericsOp<GPUDevice, T> : public AsyncOpKernel {
OP_REQUIRES_ASYNC(context, stream != nullptr,
errors::Internal("No GPU stream available."), done);
- perftools::gputools::DeviceMemoryBase abnormal_detected_ptr(
+ se::DeviceMemoryBase abnormal_detected_ptr(
abnormal_detected.flat<int>().data(),
abnormal_detected.flat<int>().size());
stream->ThenMemset32(&abnormal_detected_ptr, 0,
@@ -174,8 +174,8 @@ class CheckNumericsOp<GPUDevice, T> : public AsyncOpKernel {
TensorReference abnormal_detected_ref(abnormal_detected);
auto check_cb = [this, stream, abnormal_detected_ref,
abnormal_detected_host, context, done]() {
- ::perftools::gputools::cuda::ScopedActivateExecutorContext
- scoped_activation{stream->parent()};
+ se::cuda::ScopedActivateExecutorContext scoped_activation{
+ stream->parent()};
auto abnormal_detected_host_flat = abnormal_detected_host.flat<int>();
int is_nan = abnormal_detected_host_flat(0);
int is_inf = abnormal_detected_host_flat(1);
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index f3b91494b9..ef1e73e5ab 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -532,7 +532,7 @@ struct ConvBackwardFilterAutoTuneGroup {
static string name() { return "ConvBwdFilter"; }
};
typedef AutoTuneSingleton<ConvBackwardFilterAutoTuneGroup, ConvParameters,
- perftools::gputools::dnn::AlgorithmConfig>
+ se::dnn::AlgorithmConfig>
AutoTuneConvBwdFilter;
// Backprop for filter.
@@ -636,9 +636,9 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
const Tensor& out_backprop, const Tensor& input, int row_dilation,
int col_dilation, int row_stride, int col_stride, const Padding& padding,
Tensor* filter_backprop, TensorFormat data_format) {
- using perftools::gputools::dnn::AlgorithmConfig;
- using perftools::gputools::dnn::AlgorithmDesc;
- using perftools::gputools::dnn::ProfileResult;
+ using se::dnn::AlgorithmConfig;
+ using se::dnn::AlgorithmDesc;
+ using se::dnn::ProfileResult;
std::vector<int32> dilations(4, 1);
dilations[GetTensorDimIndex(data_format, 'H')] = row_dilation;
@@ -721,9 +721,9 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
bool blas_launch_status =
stream
- ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
- perftools::gputools::blas::Transpose::kTranspose, n,
- m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
+ ->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
+ se::blas::Transpose::kTranspose, n, m, k, 1.0f,
+ a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
.ok();
if (!blas_launch_status) {
ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
@@ -751,9 +751,9 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
bool blas_launch_status =
stream
- ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
- perftools::gputools::blas::Transpose::kTranspose, n,
- m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
+ ->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
+ se::blas::Transpose::kTranspose, n, m, k, 1.0f,
+ b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
.ok();
if (!blas_launch_status) {
ctx->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
@@ -787,24 +787,24 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
CHECK(padding_rows >= 0 && padding_cols >= 0)
<< "Negative row or col paddings: (" << padding_rows << ", "
<< padding_cols << ")";
- perftools::gputools::dnn::BatchDescriptor input_desc;
+ se::dnn::BatchDescriptor input_desc;
input_desc.set_count(dims.batch_size)
.set_height(GetTensorDim(compatible_input, data_format, 'H'))
.set_width(GetTensorDim(compatible_input, data_format, 'W'))
.set_feature_map_count(dims.in_depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor output_desc;
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+ se::dnn::BatchDescriptor output_desc;
output_desc.set_count(dims.batch_size)
.set_height(dims.spatial_dims[0].output_size)
.set_width(dims.spatial_dims[1].output_size)
.set_feature_map_count(dims.out_depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::FilterDescriptor filter_desc;
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+ se::dnn::FilterDescriptor filter_desc;
filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
.set_input_filter_width(dims.spatial_dims[1].filter_size)
.set_input_feature_map_count(dims.in_depth)
.set_output_feature_map_count(dims.out_depth);
- perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
+ se::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
.set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
.set_vertical_filter_stride(dims.spatial_dims[0].stride)
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index 66d15c6e78..35f2676023 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -604,7 +604,7 @@ struct ConvBackwardDataAutoTuneGroup {
static string name() { return "ConvBwdData"; }
};
typedef AutoTuneSingleton<ConvBackwardDataAutoTuneGroup, ConvParameters,
- perftools::gputools::dnn::AlgorithmConfig>
+ se::dnn::AlgorithmConfig>
AutoTuneConvBwdData;
// Backprop for input.
@@ -705,9 +705,9 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
const Tensor& out_backprop, const Tensor& filter, int row_dilation,
int col_dilation, int row_stride, int col_stride, const Padding& padding,
Tensor* in_backprop, TensorFormat data_format) {
- using perftools::gputools::dnn::AlgorithmConfig;
- using perftools::gputools::dnn::AlgorithmDesc;
- using perftools::gputools::dnn::ProfileResult;
+ using se::dnn::AlgorithmConfig;
+ using se::dnn::AlgorithmDesc;
+ using se::dnn::ProfileResult;
std::vector<int32> strides(4, 1);
std::vector<int32> dilations(4, 1);
@@ -778,8 +778,8 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
in_backprop->template flat<T>().size());
- auto transpose = perftools::gputools::blas::Transpose::kTranspose;
- auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+ auto transpose = se::blas::Transpose::kTranspose;
+ auto no_transpose = se::blas::Transpose::kNoTranspose;
bool blas_launch_status =
stream
@@ -810,8 +810,8 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
in_backprop->template flat<T>().size());
- auto transpose = perftools::gputools::blas::Transpose::kTranspose;
- auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+ auto transpose = se::blas::Transpose::kTranspose;
+ auto no_transpose = se::blas::Transpose::kNoTranspose;
bool blas_launch_status =
stream
@@ -841,24 +841,24 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
CHECK(padding_rows >= 0 && padding_cols >= 0)
<< "Negative row or col paddings: (" << padding_rows << ", "
<< padding_cols << ")";
- perftools::gputools::dnn::BatchDescriptor input_desc;
+ se::dnn::BatchDescriptor input_desc;
input_desc.set_count(dims.batch_size)
.set_height(GetTensorDim(compatible_input_shape, data_format, 'H'))
.set_width(GetTensorDim(compatible_input_shape, data_format, 'W'))
.set_feature_map_count(dims.in_depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor output_desc;
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+ se::dnn::BatchDescriptor output_desc;
output_desc.set_count(dims.batch_size)
.set_height(dims.spatial_dims[0].output_size)
.set_width(dims.spatial_dims[1].output_size)
.set_feature_map_count(dims.out_depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::FilterDescriptor filter_desc;
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+ se::dnn::FilterDescriptor filter_desc;
filter_desc.set_input_filter_height(dims.spatial_dims[0].filter_size)
.set_input_filter_width(dims.spatial_dims[1].filter_size)
.set_input_feature_map_count(dims.in_depth)
.set_output_feature_map_count(dims.out_depth);
- perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
+ se::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
.set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
.set_vertical_filter_stride(dims.spatial_dims[0].stride)
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index 092e859a5b..9edc6d416e 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -35,7 +35,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
-using perftools::gputools::dnn::DimIndex;
+using stream_executor::dnn::DimIndex;
#endif
namespace tensorflow {
@@ -468,7 +468,7 @@ struct Conv3dBackwardDataAutoTuneGroup {
static string name() { return "Conv3dBwdData"; }
};
typedef AutoTuneSingleton<Conv3dBackwardDataAutoTuneGroup, ConvParameters,
- perftools::gputools::dnn::AlgorithmConfig>
+ se::dnn::AlgorithmConfig>
AutoTuneConv3dBwdData;
template <typename T>
@@ -554,8 +554,8 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
in_backprop->template flat<T>().size());
- auto transpose = perftools::gputools::blas::Transpose::kTranspose;
- auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+ auto transpose = se::blas::Transpose::kTranspose;
+ auto no_transpose = se::blas::Transpose::kNoTranspose;
bool blas_launch_status =
stream
@@ -582,8 +582,8 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
auto c_ptr = AsDeviceMemory(in_backprop->template flat<T>().data(),
in_backprop->template flat<T>().size());
- auto transpose = perftools::gputools::blas::Transpose::kTranspose;
- auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+ auto transpose = se::blas::Transpose::kTranspose;
+ auto no_transpose = se::blas::Transpose::kNoTranspose;
bool blas_launch_status =
stream
@@ -629,27 +629,27 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
- perftools::gputools::dnn::BatchDescriptor input_desc(3);
+ se::dnn::BatchDescriptor input_desc(3);
input_desc.set_count(batch)
.set_spatial_dim(DimIndex::X, compatible_input_shape.dim_size(4))
.set_spatial_dim(DimIndex::Y, compatible_input_shape.dim_size(3))
.set_spatial_dim(DimIndex::Z, compatible_input_shape.dim_size(2))
.set_feature_map_count(in_depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor output_desc(3);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+ se::dnn::BatchDescriptor output_desc(3);
output_desc.set_count(batch)
.set_spatial_dim(DimIndex::X, output_cols)
.set_spatial_dim(DimIndex::Y, output_rows)
.set_spatial_dim(DimIndex::Z, output_planes)
.set_feature_map_count(out_depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::FilterDescriptor filter_desc(3);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+ se::dnn::FilterDescriptor filter_desc(3);
filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
.set_spatial_dim(DimIndex::Y, filter_size[1])
.set_spatial_dim(DimIndex::Z, filter_size[0])
.set_input_feature_map_count(in_depth)
.set_output_feature_map_count(out_depth);
- perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3);
+ se::dnn::ConvolutionDescriptor conv_desc(3);
conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
.set_dilation_rate(DimIndex::Y, dilations[1])
.set_dilation_rate(DimIndex::Z, dilations[0])
@@ -725,9 +725,9 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
device_id,
};
- using perftools::gputools::dnn::AlgorithmConfig;
- using perftools::gputools::dnn::AlgorithmDesc;
- using perftools::gputools::dnn::ProfileResult;
+ using se::dnn::AlgorithmConfig;
+ using se::dnn::AlgorithmDesc;
+ using se::dnn::ProfileResult;
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdData::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
@@ -839,7 +839,7 @@ struct Conv3dBackwardFilterAutoTuneGroup {
static string name() { return "Conv3dBwdFilter"; }
};
typedef AutoTuneSingleton<Conv3dBackwardFilterAutoTuneGroup, ConvParameters,
- perftools::gputools::dnn::AlgorithmConfig>
+ se::dnn::AlgorithmConfig>
AutoTuneConv3dBwdFilter;
template <typename T>
@@ -941,9 +941,9 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
bool blas_launch_status =
stream
- ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
- perftools::gputools::blas::Transpose::kTranspose,
- n, m, k, 1.0f, a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
+ ->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
+ se::blas::Transpose::kTranspose, n, m, k, 1.0f,
+ a_ptr, n, b_ptr, m, 0.0f, &c_ptr, n)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
@@ -967,9 +967,9 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
bool blas_launch_status =
stream
- ->ThenBlasGemm(perftools::gputools::blas::Transpose::kNoTranspose,
- perftools::gputools::blas::Transpose::kTranspose,
- n, m, k, 1.0f, b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
+ ->ThenBlasGemm(se::blas::Transpose::kNoTranspose,
+ se::blas::Transpose::kTranspose, n, m, k, 1.0f,
+ b_ptr, n, a_ptr, m, 0.0f, &c_ptr, n)
.ok();
if (!blas_launch_status) {
context->SetStatus(errors::Internal("Blas SGEMM launch failed : m=", m,
@@ -1014,7 +1014,7 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
CHECK(padding_rows >= 0 && padding_cols >= 0 && padding_planes >= 0)
<< "Negative paddings: (" << padding_rows << ", " << padding_cols
<< ", " << padding_planes << ")";
- perftools::gputools::dnn::BatchDescriptor input_desc(3);
+ se::dnn::BatchDescriptor input_desc(3);
input_desc.set_count(batch)
.set_spatial_dim(DimIndex::X,
GetTensorDim(compatible_input, data_format_, '2'))
@@ -1023,21 +1023,21 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
.set_spatial_dim(DimIndex::Z,
GetTensorDim(compatible_input, data_format_, '0'))
.set_feature_map_count(in_depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor output_desc(3);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+ se::dnn::BatchDescriptor output_desc(3);
output_desc.set_count(batch)
.set_spatial_dim(DimIndex::X, output_cols)
.set_spatial_dim(DimIndex::Y, output_rows)
.set_spatial_dim(DimIndex::Z, output_planes)
.set_feature_map_count(out_depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::FilterDescriptor filter_desc(3);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+ se::dnn::FilterDescriptor filter_desc(3);
filter_desc.set_spatial_dim(DimIndex::X, filter_size[2])
.set_spatial_dim(DimIndex::Y, filter_size[1])
.set_spatial_dim(DimIndex::Z, filter_size[0])
.set_input_feature_map_count(in_depth)
.set_output_feature_map_count(out_depth);
- perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3);
+ se::dnn::ConvolutionDescriptor conv_desc(3);
conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
.set_dilation_rate(DimIndex::Y, dilations[1])
.set_dilation_rate(DimIndex::Z, dilations[0])
@@ -1121,9 +1121,9 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
device_id,
};
- using perftools::gputools::dnn::AlgorithmConfig;
- using perftools::gputools::dnn::AlgorithmDesc;
- using perftools::gputools::dnn::ProfileResult;
+ using se::dnn::AlgorithmConfig;
+ using se::dnn::AlgorithmDesc;
+ using se::dnn::ProfileResult;
AlgorithmConfig algorithm_config;
if (cudnn_use_autotune_ && !AutoTuneConv3dBwdFilter::GetInstance()->Find(
conv_parameters, &algorithm_config)) {
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index f0888c655f..c6d36b40fe 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -475,7 +475,7 @@ struct ConvAutoTuneGroup {
static string name() { return "Conv"; }
};
typedef AutoTuneSingleton<ConvAutoTuneGroup, ConvParameters,
- perftools::gputools::dnn::AlgorithmConfig>
+ se::dnn::AlgorithmConfig>
AutoTuneConv;
template <typename T>
@@ -484,9 +484,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
const Tensor& input_param, const Tensor& filter, int row_dilation,
int col_dilation, int row_stride, int col_stride, const Padding& padding,
Tensor* output, TensorFormat data_format) {
- using perftools::gputools::dnn::AlgorithmConfig;
- using perftools::gputools::dnn::AlgorithmDesc;
- using perftools::gputools::dnn::ProfileResult;
+ using se::dnn::AlgorithmConfig;
+ using se::dnn::AlgorithmDesc;
+ using se::dnn::ProfileResult;
auto* stream = ctx->op_device_context()->stream();
OP_REQUIRES(ctx, stream, errors::Internal("No GPU stream available."));
@@ -514,7 +514,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
output->template flat<T>().size());
- auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+ auto no_transpose = se::blas::Transpose::kNoTranspose;
bool blas_launch_status =
stream
->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n,
@@ -543,7 +543,7 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
output->template flat<T>().size());
- auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+ auto no_transpose = se::blas::Transpose::kNoTranspose;
bool blas_launch_status =
stream
->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr, n,
@@ -629,24 +629,24 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
CHECK(padding_rows >= 0 && padding_cols >= 0)
<< "Negative row or col paddings: (" << padding_rows << ", "
<< padding_cols << ")";
- perftools::gputools::dnn::BatchDescriptor input_desc;
+ se::dnn::BatchDescriptor input_desc;
input_desc.set_count(in_batch)
.set_feature_map_count(in_depths)
.set_height(in_rows)
.set_width(in_cols)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor output_desc;
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+ se::dnn::BatchDescriptor output_desc;
output_desc.set_count(out_batch)
.set_height(out_rows)
.set_width(out_cols)
.set_feature_map_count(out_depths)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::FilterDescriptor filter_desc;
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+ se::dnn::FilterDescriptor filter_desc;
filter_desc.set_input_filter_height(filter.dim_size(0))
.set_input_filter_width(filter.dim_size(1))
.set_input_feature_map_count(filter.dim_size(2))
.set_output_feature_map_count(filter.dim_size(3));
- perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
+ se::dnn::ConvolutionDescriptor conv_desc;
conv_desc.set_vertical_dilation_rate(row_dilation)
.set_horizontal_dilation_rate(col_dilation)
.set_vertical_filter_stride(row_stride)
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index 48dd3c9eb0..9ec16be67d 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -34,7 +34,7 @@ limitations under the License.
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
-using perftools::gputools::dnn::DimIndex;
+using stream_executor::dnn::DimIndex;
#endif
namespace tensorflow {
@@ -192,7 +192,7 @@ struct Conv3dAutoTuneGroup {
static string name() { return "Conv3d"; }
};
typedef AutoTuneSingleton<Conv3dAutoTuneGroup, ConvParameters,
- perftools::gputools::dnn::AlgorithmConfig>
+ se::dnn::AlgorithmConfig>
AutoTuneConv3d;
// TODO(mjanusz): Share logic with 2d implementation as much as possible.
@@ -250,7 +250,7 @@ struct LaunchConvOp<GPUDevice, T> {
auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
output->template flat<T>().size());
- auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+ auto no_transpose = se::blas::Transpose::kNoTranspose;
bool blas_launch_status =
stream
->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr,
@@ -277,7 +277,7 @@ struct LaunchConvOp<GPUDevice, T> {
auto c_ptr = AsDeviceMemory(output->template flat<T>().data(),
output->template flat<T>().size());
- auto no_transpose = perftools::gputools::blas::Transpose::kNoTranspose;
+ auto no_transpose = se::blas::Transpose::kNoTranspose;
bool blas_launch_status =
stream
->ThenBlasGemm(no_transpose, no_transpose, n, m, k, 1.0f, b_ptr,
@@ -346,27 +346,27 @@ struct LaunchConvOp<GPUDevice, T> {
CHECK(pad_rows >= 0 && pad_cols >= 0 && pad_planes >= 0)
<< "Negative paddings: (" << pad_rows << ", " << pad_cols << ", "
<< pad_planes << ")";
- perftools::gputools::dnn::BatchDescriptor input_desc(3);
+ se::dnn::BatchDescriptor input_desc(3);
input_desc.set_count(in_batch)
.set_feature_map_count(in_depth)
.set_spatial_dim(DimIndex::X, in_cols)
.set_spatial_dim(DimIndex::Y, in_rows)
.set_spatial_dim(DimIndex::Z, in_planes)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor output_desc(3);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+ se::dnn::BatchDescriptor output_desc(3);
output_desc.set_count(in_batch)
.set_spatial_dim(DimIndex::X, out_cols)
.set_spatial_dim(DimIndex::Y, out_rows)
.set_spatial_dim(DimIndex::Z, out_planes)
.set_feature_map_count(out_depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::FilterDescriptor filter_desc(3);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+ se::dnn::FilterDescriptor filter_desc(3);
filter_desc.set_spatial_dim(DimIndex::X, filter_cols)
.set_spatial_dim(DimIndex::Y, filter_rows)
.set_spatial_dim(DimIndex::Z, filter_planes)
.set_input_feature_map_count(in_depth)
.set_output_feature_map_count(out_depth);
- perftools::gputools::dnn::ConvolutionDescriptor conv_desc(3);
+ se::dnn::ConvolutionDescriptor conv_desc(3);
conv_desc.set_dilation_rate(DimIndex::X, dilations[2])
.set_dilation_rate(DimIndex::Y, dilations[1])
.set_dilation_rate(DimIndex::Z, dilations[0])
@@ -424,9 +424,9 @@ struct LaunchConvOp<GPUDevice, T> {
device_id,
};
- using perftools::gputools::dnn::AlgorithmConfig;
- using perftools::gputools::dnn::AlgorithmDesc;
- using perftools::gputools::dnn::ProfileResult;
+ using se::dnn::AlgorithmConfig;
+ using se::dnn::AlgorithmDesc;
+ using se::dnn::ProfileResult;
AlgorithmConfig algorithm_config;
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index e8da5298e6..d2c8020bb6 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -36,25 +36,23 @@ int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
// A class to provide scratch-space allocator for Stream-Executor Cudnn
// callback. TensorFlow is responsible for releasing the temporary buffers after
// the kernel finishes.
-class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator {
+class CudnnScratchAllocator : public se::ScratchAllocator {
public:
virtual ~CudnnScratchAllocator() {}
CudnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
- int64 GetMemoryLimitInBytes(perftools::gputools::Stream* stream) override {
+ int64 GetMemoryLimitInBytes(se::Stream* stream) override {
return memory_limit_;
}
- perftools::gputools::port::StatusOr<perftools::gputools::DeviceMemory<uint8>>
- AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override {
+ se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
+ se::Stream* stream, int64 byte_size) override {
Tensor temporary_memory;
if (byte_size < 0) {
- return perftools::gputools::port::Status{
- perftools::gputools::port::error::INVALID_ARGUMENT,
- "Requested negative byte size!"};
+ return se::port::Status{se::port::error::INVALID_ARGUMENT,
+ "Requested negative byte size!"};
}
if (byte_size > memory_limit_) {
- return perftools::gputools::port::StatusOr<
- perftools::gputools::DeviceMemory<uint8>>();
+ return se::port::StatusOr<se::DeviceMemory<uint8>>();
}
AllocationAttributes allocation_attr;
allocation_attr.no_retry_on_failure = true;
@@ -62,15 +60,13 @@ class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator {
DT_UINT8, TensorShape({byte_size}), &temporary_memory,
AllocatorAttributes(), allocation_attr));
if (!allocation_status.ok()) {
- return perftools::gputools::port::StatusOr<
- perftools::gputools::DeviceMemory<uint8>>();
+ return se::port::StatusOr<se::DeviceMemory<uint8>>();
}
// Hold the reference of the allocated tensors until the end of the
// allocator.
allocated_tensors_.push_back(temporary_memory);
total_byte_size_ += byte_size;
- return perftools::gputools::port::StatusOr<
- perftools::gputools::DeviceMemory<uint8>>(
+ return se::port::StatusOr<se::DeviceMemory<uint8>>(
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
temporary_memory.flat<uint8>().size()));
}
@@ -141,7 +137,7 @@ class ConvParameters {
// for certain input parameters so as to avoid a bug in cuDNNv5 and cuDNNv6.
template <typename T>
bool ShouldIncludeWinogradNonfusedAlgo(
- perftools::gputools::StreamExecutor* stream_exec) const {
+ se::StreamExecutor* stream_exec) const {
// Skip this check for cuDNN 7 and newer.
auto version = stream_exec->AsDnn()->GetVersion();
if (version.ok() && version.ValueOrDie().major_version() >= 7) {
diff --git a/tensorflow/core/kernels/crop_and_resize_op.cc b/tensorflow/core/kernels/crop_and_resize_op.cc
index 45cc2fbbb8..54ef9c6fb4 100644
--- a/tensorflow/core/kernels/crop_and_resize_op.cc
+++ b/tensorflow/core/kernels/crop_and_resize_op.cc
@@ -39,17 +39,16 @@ limitations under the License.
#include "tensorflow/core/platform/cuda.h"
#include "tensorflow/core/platform/stream_executor.h"
-using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
+using stream_executor::cuda::ScopedActivateExecutorContext;
#endif // GOOGLE_CUDA
namespace tensorflow {
+namespace {
typedef Eigen::ThreadPoolDevice CPUDevice;
typedef Eigen::GpuDevice GPUDevice;
using Callback = std::function<void()>;
-namespace {
-
static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
const Tensor& box_index,
int* num_boxes) {
@@ -753,8 +752,7 @@ inline void RunIfBoxIndexIsValid<GPUDevice>(
context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
&isvalid_host_tensor, alloc_attr),
done);
- perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(),
- sizeof(bool));
+ se::DeviceMemoryBase wrapped(isvalid_dev.data(), sizeof(bool));
const bool status =
stream
->ThenMemcpy(
diff --git a/tensorflow/core/kernels/cuda_device_array.h b/tensorflow/core/kernels/cuda_device_array.h
index e7a5db0683..74dc298c7a 100644
--- a/tensorflow/core/kernels/cuda_device_array.h
+++ b/tensorflow/core/kernels/cuda_device_array.h
@@ -80,7 +80,7 @@ class CudaDeviceArrayOnHost {
TensorReference tensor_ref(out_of_line_values_on_host_);
TF_RETURN_IF_ERROR(context_->allocate_temp(
DT_INT8, TensorShape{total_bytes_}, &out_of_line_values_on_gpu_));
- perftools::gputools::DeviceMemoryBase output_values_base{
+ se::DeviceMemoryBase output_values_base{
out_of_line_values_on_gpu_.flat<int8>().data(),
static_cast<uint64>(total_bytes_)};
stream->ThenMemcpy(&output_values_base,
diff --git a/tensorflow/core/kernels/cuda_solvers.cc b/tensorflow/core/kernels/cuda_solvers.cc
index 6cec032f94..a857bd3ce4 100644
--- a/tensorflow/core/kernels/cuda_solvers.cc
+++ b/tensorflow/core/kernels/cuda_solvers.cc
@@ -35,8 +35,6 @@
#include "tensorflow/core/platform/stream_executor.h"
#include "tensorflow/core/platform/types.h"
-using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
-
// The CUDA cublas_api.h API contains const-correctness errors. Instead of
// casting away constness on our data, we instead reinterpret the CuBLAS
// functions as what they were clearly meant to be, and thus we can call
@@ -80,10 +78,12 @@ using matinv_Z = cublasStatus_t(cublasContext*, int, const double2* const*, int,
namespace tensorflow {
namespace {
+using se::cuda::ScopedActivateExecutorContext;
+
inline bool CopyHostToDevice(OpKernelContext* context, void* dst,
const void* src, uint64 bytes) {
auto stream = context->op_device_context()->stream();
- perftools::gputools::DeviceMemoryBase wrapped_dst(dst);
+ se::DeviceMemoryBase wrapped_dst(dst);
return stream->ThenMemcpy(&wrapped_dst, src, bytes).ok();
}
diff --git a/tensorflow/core/kernels/cuda_solvers.h b/tensorflow/core/kernels/cuda_solvers.h
index ecfa23750c..b2e8ee23a9 100644
--- a/tensorflow/core/kernels/cuda_solvers.h
+++ b/tensorflow/core/kernels/cuda_solvers.h
@@ -398,7 +398,7 @@ class DeviceLapackInfo : public ScratchSpace<int> {
CHECK(success != nullptr);
HostLapackInfo copy(context(), size(), debug_info());
auto stream = context()->op_device_context()->stream();
- perftools::gputools::DeviceMemoryBase wrapped_src(
+ se::DeviceMemoryBase wrapped_src(
static_cast<void*>(const_cast<int*>(this->data())));
*success =
stream->ThenMemcpy(copy.mutable_data(), wrapped_src, this->bytes())
diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.cc b/tensorflow/core/kernels/cudnn_pooling_gpu.cc
index 5939ecdf62..d2b9c9edaa 100644
--- a/tensorflow/core/kernels/cudnn_pooling_gpu.cc
+++ b/tensorflow/core/kernels/cudnn_pooling_gpu.cc
@@ -31,12 +31,13 @@ namespace tensorflow {
#if GOOGLE_CUDA
template <typename T>
-void DnnPooling3dOp<T>::Compute(
- OpKernelContext* context,
- perftools::gputools::dnn::PoolingMode pooling_mode,
- const std::array<int64, 3>& window, const std::array<int64, 3>& stride,
- const std::array<int64, 3>& padding, TensorFormat data_format,
- const Tensor& tensor_in, Tensor* output) {
+void DnnPooling3dOp<T>::Compute(OpKernelContext* context,
+ se::dnn::PoolingMode pooling_mode,
+ const std::array<int64, 3>& window,
+ const std::array<int64, 3>& stride,
+ const std::array<int64, 3>& padding,
+ TensorFormat data_format,
+ const Tensor& tensor_in, Tensor* output) {
const auto in_shape = tensor_in.shape();
const auto out_shape = output->shape();
@@ -67,18 +68,18 @@ void DnnPooling3dOp<T>::Compute(
transformed_output = *output;
}
- perftools::gputools::dnn::PoolingDescriptor pooling_desc(3);
+ se::dnn::PoolingDescriptor pooling_desc(3);
pooling_desc.set_pooling_mode(pooling_mode);
- perftools::gputools::dnn::BatchDescriptor input_desc(3);
+ se::dnn::BatchDescriptor input_desc(3);
input_desc.set_count(in_batch)
.set_feature_map_count(in_features)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor output_desc(3);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
+ se::dnn::BatchDescriptor output_desc(3);
output_desc.set_count(in_batch)
.set_feature_map_count(in_features)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
for (size_t i = 0; i < window.size(); ++i) {
- const auto dim_i = static_cast<perftools::gputools::dnn::DimIndex>(i);
+ const auto dim_i = static_cast<se::dnn::DimIndex>(i);
pooling_desc.set_window(dim_i, window[i]);
pooling_desc.set_stride(dim_i, stride[i]);
pooling_desc.set_padding(dim_i, padding[i]);
@@ -115,14 +116,13 @@ void DnnPooling3dOp<T>::Compute(
template <typename T>
void DnnPooling3dGradOp<T>::Compute(
- OpKernelContext* context,
- perftools::gputools::dnn::PoolingMode pooling_mode,
+ OpKernelContext* context, se::dnn::PoolingMode pooling_mode,
const std::array<int64, 3>& window, const std::array<int64, 3>& stride,
const std::array<int64, 3>& padding,
const std::array<int64, 3>& output_size, TensorFormat data_format,
const Tensor& out_backprop, const TensorShape& tensor_in_shape,
const Tensor* tensor_in, const Tensor* tensor_out, Tensor* input_backprop) {
- CHECK((pooling_mode != perftools::gputools::dnn::PoolingMode::kMaximum) ||
+ CHECK((pooling_mode != se::dnn::PoolingMode::kMaximum) ||
(tensor_in && tensor_out))
<< "For MaxPoolGrad, both tensor_in and tensor_out needs to be "
"specified";
@@ -186,21 +186,21 @@ void DnnPooling3dGradOp<T>::Compute(
transformed_output_backprop.tensor<T, 5>());
}
- perftools::gputools::dnn::PoolingDescriptor pooling_desc(3);
+ se::dnn::PoolingDescriptor pooling_desc(3);
pooling_desc.set_pooling_mode(pooling_mode);
- perftools::gputools::dnn::BatchDescriptor orig_output_desc(3);
+ se::dnn::BatchDescriptor orig_output_desc(3);
orig_output_desc.set_count(in_batch)
.set_feature_map_count(in_features)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor orig_input_desc(3);
+ se::dnn::BatchDescriptor orig_input_desc(3);
orig_input_desc.set_count(in_batch)
.set_feature_map_count(in_features)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
for (size_t i = 0; i < window.size(); ++i) {
- const auto dim_i = static_cast<perftools::gputools::dnn::DimIndex>(i);
+ const auto dim_i = static_cast<se::dnn::DimIndex>(i);
pooling_desc.set_window(dim_i, window[i]);
pooling_desc.set_stride(dim_i, stride[i]);
pooling_desc.set_padding(dim_i, padding[i]);
diff --git a/tensorflow/core/kernels/cudnn_pooling_gpu.h b/tensorflow/core/kernels/cudnn_pooling_gpu.h
index ff4de75845..280d697fc2 100644
--- a/tensorflow/core/kernels/cudnn_pooling_gpu.h
+++ b/tensorflow/core/kernels/cudnn_pooling_gpu.h
@@ -38,7 +38,7 @@ template <typename T>
class DnnPooling3dOp {
public:
static void Compute(OpKernelContext* context,
- perftools::gputools::dnn::PoolingMode pooling_mode,
+ se::dnn::PoolingMode pooling_mode,
const std::array<int64, 3>& size,
const std::array<int64, 3>& stride,
const std::array<int64, 3>& padding,
@@ -52,7 +52,7 @@ template <typename T>
class DnnPooling3dGradOp {
public:
static void Compute(OpKernelContext* context,
- perftools::gputools::dnn::PoolingMode pooling_mode,
+ se::dnn::PoolingMode pooling_mode,
const std::array<int64, 3>& window,
const std::array<int64, 3>& stride,
const std::array<int64, 3>& padding,
diff --git a/tensorflow/core/kernels/cudnn_rnn_ops.cc b/tensorflow/core/kernels/cudnn_rnn_ops.cc
index a21f13a4dd..25560b7c28 100644
--- a/tensorflow/core/kernels/cudnn_rnn_ops.cc
+++ b/tensorflow/core/kernels/cudnn_rnn_ops.cc
@@ -33,6 +33,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/gpu_utils.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
@@ -43,6 +44,7 @@ limitations under the License.
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/env_var.h"
+#include "tensorflow/core/util/use_cudnn.h"
#if GOOGLE_CUDA
#include "tensorflow/core/platform/stream_executor.h"
@@ -78,7 +80,9 @@ using CPUDevice = Eigen::ThreadPoolDevice;
#if GOOGLE_CUDA
using GPUDevice = Eigen::GpuDevice;
-using ::perftools::gputools::StreamExecutor;
+using se::Stream;
+using se::StreamExecutor;
+using se::dnn::RnnDescriptor;
template <typename Device, typename T, typename Index>
class CudnnRNNParamsSizeOp;
@@ -95,6 +99,12 @@ class CudnnRNNForwardOp;
template <typename Device, typename T>
class CudnnRNNBackwardOp;
+template <typename Device, typename T>
+class CudnnRNNForwardOpV2;
+
+template <typename Device, typename T>
+class CudnnRNNBackwardOpV2;
+
enum class TFRNNInputMode {
kRNNLinearInput = 0,
kRNNSkipInput = 1,
@@ -102,21 +112,111 @@ enum class TFRNNInputMode {
};
namespace {
-using ::perftools::gputools::DeviceMemory;
-using ::perftools::gputools::DeviceMemoryBase;
-using ::perftools::gputools::ScratchAllocator;
-using ::perftools::gputools::Stream;
-using ::perftools::gputools::dnn::AlgorithmConfig;
-using ::perftools::gputools::dnn::AlgorithmDesc;
-using ::perftools::gputools::dnn::ProfileResult;
-using ::perftools::gputools::dnn::RnnDescriptor;
-using ::perftools::gputools::dnn::RnnDirectionMode;
-using ::perftools::gputools::dnn::RnnInputMode;
-using ::perftools::gputools::dnn::RnnMode;
-using ::perftools::gputools::dnn::RnnSequenceTensorDescriptor;
-using ::perftools::gputools::dnn::RnnStateTensorDescriptor;
-using ::perftools::gputools::dnn::ToDataType;
-using ::perftools::gputools::port::StatusOr;
+using se::DeviceMemory;
+using se::DeviceMemoryBase;
+using se::ScratchAllocator;
+using se::dnn::AlgorithmConfig;
+using se::dnn::AlgorithmDesc;
+using se::dnn::ProfileResult;
+using se::dnn::RnnDirectionMode;
+using se::dnn::RnnInputMode;
+using se::dnn::RnnMode;
+using se::dnn::RnnSequenceTensorDescriptor;
+using se::dnn::RnnStateTensorDescriptor;
+using se::dnn::ToDataType;
+using se::port::StatusOr;
+
+uint64 HashList(const std::vector<int>& list) {
+ if (list.empty()) {
+ return 0;
+ }
+ uint64 hash_code = list[0];
+ for (int i = 1; i < list.size(); i++) {
+ hash_code = Hash64Combine(hash_code, list[i]);
+ }
+ return hash_code;
+}
+
+// Encapsulate all the shape information that is used in both forward and
+// backward rnn operations.
+class CudnnRnnParameters {
+ public:
+ CudnnRnnParameters(int num_layers, int input_size, int num_units,
+ int seq_length, int batch_size, int dir_count,
+ bool has_dropout, bool is_training, RnnMode rnn_mode,
+ TFRNNInputMode rnn_input_mode, DataType dtype)
+ : num_layers_(num_layers),
+ input_size_(input_size),
+ num_units_(num_units),
+ seq_length_(seq_length),
+ batch_size_(batch_size),
+ dir_count_(dir_count),
+ has_dropout_(has_dropout),
+ is_training_(is_training),
+ rnn_mode_(rnn_mode),
+ rnn_input_mode_(rnn_input_mode),
+ dtype_(dtype) {
+ hash_code_ = HashList(
+ {num_layers, input_size, num_units, seq_length, batch_size, dir_count,
+ static_cast<int>(has_dropout), static_cast<int>(is_training),
+ static_cast<int>(rnn_mode), static_cast<int>(rnn_input_mode), dtype});
+ }
+
+ bool operator==(const CudnnRnnParameters& other) const {
+ return this->get_data_as_tuple() == other.get_data_as_tuple();
+ }
+
+ bool operator!=(const CudnnRnnParameters& other) const {
+ return !(*this == other);
+ }
+ uint64 hash() const { return hash_code_; }
+
+ string ToString() const {
+ std::vector<string> fields = {
+ std::to_string(num_layers_),
+ std::to_string(input_size_),
+ std::to_string(num_units_),
+ std::to_string(seq_length_),
+ std::to_string(batch_size_),
+ std::to_string(dir_count_),
+ std::to_string(has_dropout_),
+ std::to_string(is_training_),
+ std::to_string(static_cast<int>(rnn_mode_)),
+ std::to_string(static_cast<int>(rnn_input_mode_)),
+ std::to_string(static_cast<int>(dtype_))};
+ return str_util::Join(fields, ", ");
+ }
+
+ private:
+ using ParameterDataType = std::tuple<int, int, int, int, int, int, bool, bool,
+ RnnMode, TFRNNInputMode, DataType>;
+
+ ParameterDataType get_data_as_tuple() const {
+ return std::make_tuple(num_layers_, input_size_, num_units_, seq_length_,
+ batch_size_, dir_count_, has_dropout_, is_training_,
+ rnn_mode_, rnn_input_mode_, dtype_);
+ }
+
+ const int num_layers_;
+ const int input_size_;
+ const int num_units_;
+ const int seq_length_;
+ const int batch_size_;
+ const int dir_count_;
+ const bool has_dropout_;
+ const bool is_training_;
+ const RnnMode rnn_mode_;
+ const TFRNNInputMode rnn_input_mode_;
+ const DataType dtype_;
+ uint64 hash_code_;
+};
+
+struct RnnAutoTuneGroup {
+ static string name() { return "Rnn"; }
+};
+
+using AutoTuneRnnConfigMap =
+ AutoTuneSingleton<RnnAutoTuneGroup, CudnnRnnParameters, AlgorithmConfig>;
Status ParseRNNMode(const string& str, RnnMode* rnn_mode) {
if (str == "rnn_relu") {
@@ -213,25 +313,22 @@ DeviceMemoryBase SliceDeviceMemory(const DeviceMemoryBase& device_memory,
return DeviceMemoryBase(offset_ptr, size);
}
-inline Status FromExecutorStatus(const perftools::gputools::port::Status& s) {
+inline Status FromExecutorStatus(const se::port::Status& s) {
return s.ok() ? Status::OK()
- : Status(static_cast<tensorflow::error::Code>(
- static_cast<int>(s.code())),
+ : Status(static_cast<error::Code>(static_cast<int>(s.code())),
s.error_message());
}
template <typename T>
-inline Status FromExecutorStatus(
- const perftools::gputools::port::StatusOr<T>& s) {
+inline Status FromExecutorStatus(const se::port::StatusOr<T>& s) {
return FromExecutorStatus(s.status());
}
-inline perftools::gputools::port::Status ToExecutorStatus(const Status& s) {
- return s.ok() ? perftools::gputools::port::Status::OK()
- : perftools::gputools::port::Status(
- static_cast<perftools::gputools::port::error::Code>(
- static_cast<int>(s.code())),
- s.error_message());
+inline se::port::Status ToExecutorStatus(const Status& s) {
+ return s.ok() ? se::port::Status::OK()
+ : se::port::Status(static_cast<se::port::error::Code>(
+ static_cast<int>(s.code())),
+ s.error_message());
}
template <typename>
@@ -414,24 +511,29 @@ struct CudnnRnnModelShapes {
}
};
-// Utility class for using CudnnRnnModelShapes as a hash table key.
-struct CudnnRnnModelShapesHasher {
- uint64 operator()(const CudnnRnnModelShapes& to_hash) const {
- uint64 hash = static_cast<uint64>(to_hash.num_layers);
- hash = tensorflow::FingerprintCat64(
- hash, static_cast<uint64>(to_hash.input_size));
- hash = tensorflow::FingerprintCat64(hash,
- static_cast<uint64>(to_hash.num_units));
- return tensorflow::FingerprintCat64(hash,
- static_cast<uint64>(to_hash.dir_count));
+// Utility class for using CudnnRnnConfig and AlgorithmDesc pair a hash table
+// key.
+struct CudnnRnnConfigHasher {
+ uint64 operator()(
+ const std::pair<CudnnRnnModelShapes, AlgorithmDesc>& to_hash) const {
+ auto& shapes = to_hash.first;
+ auto& algo_desc = to_hash.second;
+
+ uint64 hash =
+ HashList({shapes.num_layers, shapes.input_size, shapes.num_units,
+ shapes.dir_count, shapes.batch_size});
+ hash = Hash64Combine(hash, algo_desc.hash());
+ return hash;
}
};
-// Utility class for using CudnnRnnModelShapes as a hash table key.
-struct CudnnRnnModelShapesComparator {
- bool operator()(const CudnnRnnModelShapes& first,
- const CudnnRnnModelShapes& second) const {
- return first.IsCompatibleWith(second);
+// Utility class for using CudnnRnnModelShapes and AlgorithmDesc pair as a hash
+// table key.
+struct CudnnRnnConfigComparator {
+ bool operator()(
+ const std::pair<CudnnRnnModelShapes, AlgorithmDesc>& lhs,
+ const std::pair<CudnnRnnModelShapes, AlgorithmDesc>& rhs) const {
+ return lhs.first.IsCompatibleWith(rhs.first) && lhs.second == rhs.second;
}
};
@@ -503,7 +605,7 @@ Status CreateForwardAndBackwardIODescriptors(
std::unique_ptr<RnnStateTensorDescriptor>* state_desc,
std::unique_ptr<RnnSequenceTensorDescriptor>* output_desc) {
StreamExecutor* executor = context->op_device_context()->stream()->parent();
- ::perftools::gputools::dnn::DataType data_type = ToDataType<T>::value;
+ se::dnn::DataType data_type = ToDataType<T>::value;
const TensorShape& input_shape = model_shapes.input_shape;
const TensorShape& hidden_state_shape = model_shapes.hidden_state_shape;
@@ -719,7 +821,7 @@ class CudnnRNNKernelCommon : public OpKernel {
RnnDirectionMode rnn_direction_mode() const {
return model_types_.rnn_direction_mode;
}
- CudnnModelTypes model_types() const { return model_types_; }
+ const CudnnModelTypes& model_types() const { return model_types_; }
float dropout() const { return dropout_; }
uint64 seed() { return (static_cast<uint64>(seed_) << 32) | seed2_; }
bool ResetRndGenState() { return reset_rnd_gen_state_; }
@@ -755,9 +857,9 @@ class CudnnRNNKernelCommon : public OpKernel {
// random number generator, therefore set state_allocator to nullptr.
const AlgorithmConfig algo_config;
auto rnn_desc_s = stream->parent()->createRnnDescriptor(
- num_layers, num_units, input_size, input_mode, rnn_direction_mode(),
- rnn_mode(), ToDataType<T>::value, algo_config, dropout(), seed(),
- nullptr /* state_allocator */);
+ num_layers, num_units, input_size, /*batch_size=*/0, input_mode,
+ rnn_direction_mode(), rnn_mode(), ToDataType<T>::value, algo_config,
+ dropout(), seed(), /* state_allocator=*/nullptr);
if (!rnn_desc_s.ok()) {
return FromExecutorStatus(rnn_desc_s);
}
@@ -773,11 +875,12 @@ class CudnnRNNKernelCommon : public OpKernel {
ScratchAllocator* dropout_state_allocator,
std::unique_ptr<RnnDescriptor>* rnn_desc) {
StreamExecutor* executor = context->op_device_context()->stream()->parent();
- ::perftools::gputools::dnn::DataType data_type = ToDataType<T>::value;
+ se::dnn::DataType data_type = ToDataType<T>::value;
auto rnn_desc_s = executor->createRnnDescriptor(
model_shapes.num_layers, model_shapes.num_units,
- model_shapes.input_size, input_mode, rnn_direction_mode(), rnn_mode(),
- data_type, algo_config, dropout(), seed(), dropout_state_allocator);
+ model_shapes.input_size, model_shapes.batch_size, input_mode,
+ rnn_direction_mode(), rnn_mode(), data_type, algo_config, dropout(),
+ seed(), dropout_state_allocator);
TF_RETURN_IF_ERROR(rnn_desc_s.status());
*rnn_desc = rnn_desc_s.ConsumeValueOrDie();
@@ -785,8 +888,9 @@ class CudnnRNNKernelCommon : public OpKernel {
}
using RnnStateCache =
- gtl::FlatMap<CudnnRnnModelShapes, RnnScratchSpace,
- CudnnRnnModelShapesHasher, CudnnRnnModelShapesComparator>;
+ gtl::FlatMap<std::pair<CudnnRnnModelShapes, AlgorithmDesc>,
+ RnnScratchSpace, CudnnRnnConfigHasher,
+ CudnnRnnConfigComparator>;
// Returns a raw rnn descriptor pointer. The cache owns the rnn descriptor and
// should outlive the returned pointer.
template <typename T>
@@ -796,7 +900,8 @@ class CudnnRNNKernelCommon : public OpKernel {
const AlgorithmConfig& algo_config,
RnnStateCache* cache,
RnnDescriptor** rnn_desc) {
- RnnScratchSpace& rnn_state = (*cache)[model_shapes];
+ auto key = std::make_pair(model_shapes, algo_config.algorithm());
+ RnnScratchSpace& rnn_state = (*cache)[key];
if (rnn_state.rnn_desc == nullptr || ResetRndGenState()) {
CudnnRNNPersistentSpaceAllocator* dropout_state_allocator =
new CudnnRNNPersistentSpaceAllocator(context);
@@ -825,7 +930,6 @@ class CudnnRNNKernelCommon : public OpKernel {
template <typename T, typename Index>
class CudnnRNNParamsSizeOp<GPUDevice, T, Index> : public CudnnRNNKernelCommon {
public:
- typedef GPUDevice Device;
explicit CudnnRNNParamsSizeOp(OpKernelConstruction* context)
: CudnnRNNKernelCommon(context) {}
@@ -864,7 +968,6 @@ TF_CALL_double(REGISTER_GPU);
template <typename T>
class CudnnRNNParamsToCanonical<GPUDevice, T> : public CudnnRNNKernelCommon {
public:
- typedef GPUDevice Device;
explicit CudnnRNNParamsToCanonical(OpKernelConstruction* context)
: CudnnRNNKernelCommon(context) {
OP_REQUIRES_OK(context, context->GetAttr("num_params", &num_params_));
@@ -999,7 +1102,6 @@ TF_CALL_double(REGISTER_GPU);
template <typename T>
class CudnnRNNCanonicalToParams<GPUDevice, T> : public CudnnRNNKernelCommon {
public:
- typedef GPUDevice Device;
explicit CudnnRNNCanonicalToParams(OpKernelConstruction* context)
: CudnnRNNKernelCommon(context) {}
@@ -1045,13 +1147,26 @@ TF_CALL_double(REGISTER_GPU);
template <typename T>
class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
public:
- typedef GPUDevice Device;
explicit CudnnRNNForwardOp(OpKernelConstruction* context)
: CudnnRNNKernelCommon(context) {
OP_REQUIRES_OK(context, context->GetAttr("is_training", &is_training_));
+
+ // Read debug env variables.
+ is_debug_mode_ = DebugCudnnRnn();
+ debug_cudnn_rnn_algo_ = DebugCudnnRnnAlgo();
+ debug_use_tensor_ops_ = DebugCudnnRnnUseTensorOps();
}
void Compute(OpKernelContext* context) override {
+ AlgorithmConfig algo_config;
+ ComputeAndReturnAlgorithm(context, &algo_config);
+ }
+
+ protected:
+ virtual void ComputeAndReturnAlgorithm(OpKernelContext* context,
+ AlgorithmConfig* output_algo_config) {
+ CHECK_NE(output_algo_config, nullptr);
+
const Tensor* input = nullptr;
const Tensor* input_h = nullptr;
const Tensor* input_c = nullptr;
@@ -1071,7 +1186,6 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
OP_REQUIRES_OK(context, AllocateOutputs(context, model_shapes, &output,
&output_h, &output_c));
- AlgorithmConfig algo_config;
// Creates a memory callback for the reserve_space. The memory lives in the
// output of this kernel. And it will be fed into the backward pass when
// needed.
@@ -1079,14 +1193,25 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
// Creates a memory callback for the workspace. The memory lives to the end
// of this kernel calls.
CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
+
+ if (is_debug_mode_) {
+ AlgorithmDesc algo_desc(debug_cudnn_rnn_algo_, debug_use_tensor_ops_);
+ output_algo_config->set_algorithm(algo_desc);
+ } else {
+ OP_REQUIRES_OK(context,
+ MaybeAutoTune(context, model_shapes, input_mode, input,
+ input_h, input_c, params, output, output_h,
+ output_c, output_algo_config));
+ }
+
Status launch_status;
{
mutex_lock l(mu_);
RnnDescriptor* rnn_desc_ptr = nullptr;
OP_REQUIRES_OK(
context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
- algo_config, &rnn_state_cache_,
- &rnn_desc_ptr));
+ *output_algo_config,
+ &rnn_state_cache_, &rnn_desc_ptr));
launch_status = DoForward<T>(
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
input_c, params, is_training_, output, output_h, output_c,
@@ -1096,6 +1221,25 @@ class CudnnRNNForwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
OP_REQUIRES_OK(context, launch_status);
}
+ protected:
+ virtual Status MaybeAutoTune(OpKernelContext* context,
+ const CudnnRnnModelShapes& model_shapes,
+ const RnnInputMode& input_mode,
+ const Tensor* input, const Tensor* input_h,
+ const Tensor* input_c, const Tensor* params,
+ Tensor* output, Tensor* output_h,
+ Tensor* output_c,
+ AlgorithmConfig* best_algo_config) {
+ CHECK_NE(best_algo_config, nullptr);
+ *best_algo_config = AlgorithmConfig();
+ return Status::OK();
+ }
+
+ bool is_training() const { return is_training_; }
+ bool is_debug_mode_;
+ bool debug_use_tensor_ops_;
+ int64 debug_cudnn_rnn_algo_;
+
private:
Status AllocateOutputs(OpKernelContext* context,
const CudnnRnnModelShapes& model_shapes,
@@ -1137,12 +1281,197 @@ TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
#undef REGISTER_GPU
+template <typename T>
+class CudnnRNNForwardOpV2<GPUDevice, T>
+ : public CudnnRNNForwardOp<GPUDevice, T> {
+ private:
+ using CudnnRNNForwardOp<GPUDevice, T>::is_training;
+ using CudnnRNNKernelCommon::CreateRnnDescriptor;
+ using CudnnRNNKernelCommon::dropout;
+ using CudnnRNNKernelCommon::HasInputC;
+ using CudnnRNNKernelCommon::model_types;
+
+ public:
+ explicit CudnnRNNForwardOpV2(OpKernelConstruction* context)
+ : CudnnRNNForwardOp<GPUDevice, T>(context) {}
+
+ void Compute(OpKernelContext* context) override {
+ AlgorithmConfig best_algo_config;
+ CudnnRNNForwardOp<GPUDevice, T>::ComputeAndReturnAlgorithm(
+ context, &best_algo_config);
+ if (!context->status().ok()) {
+ return;
+ }
+
+ Tensor* output_host_reserved = nullptr;
+ // output_host_reserved stores opaque info used for backprop when running
+ // in training mode. At present, it includes a serialization of the best
+ // AlgorithmDesc picked during rnn forward pass autotune.
+ // int8 algorithm_id
+ // int8 use_tensor_op
+ // If autotune is not enabled, the algorithm_id is
+ // stream_executor::dnn::kDefaultAlgorithm and use_tensor_op is false. If
+ // running in inference mode, the output_host_reserved is currently not
+ // populated.
+ if (is_training()) {
+ OP_REQUIRES_OK(context, context->allocate_output(4, TensorShape({2}),
+ &output_host_reserved));
+ auto output_host_reserved_int8 = output_host_reserved->vec<int8>();
+ output_host_reserved_int8(0) = best_algo_config.algorithm().algo_id();
+ output_host_reserved_int8(1) =
+ best_algo_config.algorithm().tensor_ops_enabled();
+ } else {
+ OP_REQUIRES_OK(context,
+ context->allocate_output(4, {}, &output_host_reserved));
+ }
+ }
+
+ protected:
+ Status MaybeAutoTune(OpKernelContext* context,
+ const CudnnRnnModelShapes& model_shapes,
+ const RnnInputMode& input_mode, const Tensor* input,
+ const Tensor* input_h, const Tensor* input_c,
+ const Tensor* params, Tensor* output, Tensor* output_h,
+ Tensor* output_c,
+ AlgorithmConfig* algo_config) override {
+ CHECK_NE(algo_config, nullptr);
+ if (!CudnnRnnUseAutotune() || this->is_debug_mode_) {
+ *algo_config = AlgorithmConfig();
+ return Status::OK();
+ }
+
+ std::vector<AlgorithmDesc> algorithms;
+ auto* stream = context->op_device_context()->stream();
+ CHECK(stream->parent()->GetRnnAlgorithms(&algorithms));
+ if (algorithms.empty()) {
+ LOG(WARNING) << "No Rnn algorithm found";
+ return Status::OK();
+ }
+
+ const auto& modeltypes = model_types();
+ CudnnRnnParameters rnn_params(
+ model_shapes.num_layers, model_shapes.input_size,
+ model_shapes.num_units, model_shapes.seq_length,
+ model_shapes.batch_size, model_shapes.dir_count,
+ /*has_dropout=*/std::abs(dropout()) > 1e-8, is_training(),
+ modeltypes.rnn_mode, modeltypes.rnn_input_mode, input->dtype());
+
+ if (AutoTuneRnnConfigMap::GetInstance()->Find(rnn_params, algo_config)) {
+ return Status::OK();
+ }
+
+ // Create temp tensors when profiling backprop pass.
+ auto data_type = input->dtype();
+ Tensor output_backprop;
+ Tensor output_h_backprop;
+ Tensor output_c_backprop;
+ Tensor input_backprop;
+ Tensor input_h_backprop;
+ Tensor input_c_backprop;
+ Tensor params_backprop;
+ if (is_training()) {
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ data_type, model_shapes.output_shape, &output_backprop));
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ data_type, model_shapes.hidden_state_shape, &output_h_backprop));
+
+ TF_RETURN_IF_ERROR(
+ context->allocate_temp(data_type, params->shape(), &params_backprop));
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ data_type, model_shapes.input_shape, &input_backprop));
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ data_type, model_shapes.hidden_state_shape, &input_h_backprop));
+ if (HasInputC()) {
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ data_type, model_shapes.hidden_state_shape, &output_c_backprop));
+ TF_RETURN_IF_ERROR(context->allocate_temp(
+ data_type, model_shapes.hidden_state_shape, &input_c_backprop));
+ }
+ }
+ ProfileResult best_result;
+ for (auto& algo : algorithms) {
+ Status status;
+ ProfileResult final_profile_result;
+
+ ProfileResult fwd_profile_result;
+ ProfileResult bak_profile_result;
+
+ // RnnDescriptor is algorithm-dependent, thus not reusable.
+ std::unique_ptr<RnnDescriptor> rnn_desc;
+ // Use a temp scratch allocator for the random num generator.
+ CudnnRnnAllocatorInTemp<uint8> dropout_state_allocator(context);
+ if (!this->template CreateRnnDescriptor<T>(
+ context, model_shapes, input_mode, AlgorithmConfig(algo),
+ &dropout_state_allocator, &rnn_desc)
+ .ok()) {
+ continue;
+ }
+
+ // Again use temp scratch allocator during profiling.
+ CudnnRnnAllocatorInTemp<T> reserve_space_allocator(context);
+ CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
+ status = DoForward<T>(
+ context, *rnn_desc.get(), model_types(), model_shapes, input, input_h,
+ input_c, params, is_training(), output, output_h, output_c,
+ &reserve_space_allocator, &workspace_allocator, &fwd_profile_result);
+ if (!status.ok()) {
+ continue;
+ }
+
+ if (is_training()) {
+ // Get reserve space from the forward pass.
+ Tensor reserve_space = reserve_space_allocator.get_allocated_tensor(0);
+ status = DoBackward<T>(
+ context, *rnn_desc.get(), model_types(), model_shapes, input,
+ input_h, input_c, params, output, output_h, output_c,
+ &output_backprop, &output_h_backprop, &output_c_backprop,
+ &reserve_space, &input_backprop, &input_h_backprop,
+ &input_c_backprop, &params_backprop, &workspace_allocator,
+ &bak_profile_result);
+ if (!status.ok()) {
+ continue;
+ }
+ final_profile_result.set_elapsed_time_in_ms(
+ fwd_profile_result.elapsed_time_in_ms() +
+ bak_profile_result.elapsed_time_in_ms());
+ } else {
+ final_profile_result = fwd_profile_result;
+ }
+
+ auto total_time = final_profile_result.elapsed_time_in_ms();
+ VLOG(1) << "Profile Cudnn RNN algo " << algo.algo_id()
+ << " run time: " << total_time << " ms";
+ if (total_time < best_result.elapsed_time_in_ms()) {
+ best_result.set_elapsed_time_in_ms(total_time);
+ best_result.set_algorithm(algo);
+ }
+ }
+
+ if (!best_result.is_valid()) {
+ return Status(error::Code::INTERNAL, "No algorithm worked!");
+ }
+ algo_config->set_algorithm(best_result.algorithm());
+ AutoTuneRnnConfigMap::GetInstance()->Insert(rnn_params, *algo_config);
+ return Status::OK();
+ }
+};
+
+#define REGISTER_GPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("CudnnRNNV2") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("host_reserved") \
+ .TypeConstraint<T>("T"), \
+ CudnnRNNForwardOpV2<GPUDevice, T>);
+
+TF_CALL_half(REGISTER_GPU);
+TF_CALL_float(REGISTER_GPU);
+TF_CALL_double(REGISTER_GPU);
+#undef REGISTER_GPU
+
// Run the backward operation of the RNN model.
template <typename T>
class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
public:
- typedef GPUDevice Device;
-
explicit CudnnRNNBackwardOp(OpKernelConstruction* context)
: CudnnRNNKernelCommon(context) {}
@@ -1185,15 +1514,16 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
// Creates a memory callback for the workspace. The memory lives to the end
// of this kernel calls.
CudnnRnnAllocatorInTemp<uint8> workspace_allocator(context);
- const AlgorithmConfig default_algo_config;
+ AlgorithmConfig algo_config;
+ OP_REQUIRES_OK(context, GetAlgorithm(context, &algo_config));
Status launch_status;
{
mutex_lock l(mu_);
RnnDescriptor* rnn_desc_ptr = nullptr;
OP_REQUIRES_OK(
context, GetCachedRnnDescriptor<T>(context, model_shapes, input_mode,
- default_algo_config,
- &rnn_state_cache_, &rnn_desc_ptr));
+ algo_config, &rnn_state_cache_,
+ &rnn_desc_ptr));
launch_status = DoBackward<T>(
context, *rnn_desc_ptr, model_types(), model_shapes, input, input_h,
input_c, params, output, output_h, output_c, output_backprop,
@@ -1204,6 +1534,14 @@ class CudnnRNNBackwardOp<GPUDevice, T> : public CudnnRNNKernelCommon {
OP_REQUIRES_OK(context, launch_status);
}
+ protected:
+ virtual Status GetAlgorithm(OpKernelContext* context,
+ AlgorithmConfig* algo_config) {
+ CHECK_NE(algo_config, nullptr);
+ *algo_config = AlgorithmConfig();
+ return Status::OK();
+ }
+
private:
mutex mu_;
RnnStateCache rnn_state_cache_ GUARDED_BY(mu_);
@@ -1302,6 +1640,39 @@ TF_CALL_float(REGISTER_GPU);
TF_CALL_double(REGISTER_GPU);
#undef REGISTER_GPU
+template <typename T>
+class CudnnRNNBackwardOpV2<GPUDevice, T>
+ : public CudnnRNNBackwardOp<GPUDevice, T> {
+ public:
+ explicit CudnnRNNBackwardOpV2(OpKernelConstruction* context)
+ : CudnnRNNBackwardOp<GPUDevice, T>(context) {}
+
+ protected:
+ Status GetAlgorithm(OpKernelContext* context,
+ AlgorithmConfig* algo_config) override {
+ CHECK_NE(algo_config, nullptr);
+ const Tensor* host_reserved = nullptr;
+ TF_RETURN_IF_ERROR(context->input("host_reserved", &host_reserved));
+
+ auto host_reserved_int8 = host_reserved->vec<int8>();
+ const AlgorithmDesc algo_desc(host_reserved_int8(0), host_reserved_int8(1));
+ algo_config->set_algorithm(algo_desc);
+ return Status::OK();
+ }
+};
+
+#define REGISTER_GPU(T) \
+ REGISTER_KERNEL_BUILDER(Name("CudnnRNNBackpropV2") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("host_reserved") \
+ .TypeConstraint<T>("T"), \
+ CudnnRNNBackwardOpV2<GPUDevice, T>);
+
+TF_CALL_half(REGISTER_GPU);
+TF_CALL_float(REGISTER_GPU);
+TF_CALL_double(REGISTER_GPU);
+#undef REGISTER_GPU
+
// TODO(zhengxq): Add the conversion of Cudnn RNN Params from and to
// its canonical form.
diff --git a/tensorflow/core/kernels/decode_raw_op.cc b/tensorflow/core/kernels/decode_raw_op.cc
index bacacb94ae..eaef5a6097 100644
--- a/tensorflow/core/kernels/decode_raw_op.cc
+++ b/tensorflow/core/kernels/decode_raw_op.cc
@@ -21,7 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
-#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/byte_order.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
index 94989089ec..0abd64030f 100644
--- a/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op_gpu.cu.cc
@@ -1708,8 +1708,7 @@ void LaunchDepthwiseConvBackpropFilterOp<GpuDevice, T>::operator()(
// Initialize the results to 0.
int num_filter_backprop =
args.filter_rows * args.filter_cols * args.out_depth;
- perftools::gputools::DeviceMemoryBase filter_bp_ptr(filter_backprop,
- num_filter_backprop);
+ se::DeviceMemoryBase filter_bp_ptr(filter_backprop, num_filter_backprop);
stream->ThenMemset32(&filter_bp_ptr, 0, num_filter_backprop * sizeof(T));
if (args.filter_rows == 3 && args.filter_cols == 3) {
diff --git a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
index 9dfeccff0e..862a97723f 100644
--- a/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/dynamic_partition_op_gpu.cu.cc
@@ -285,8 +285,8 @@ class DynamicPartitionOpGPU : public AsyncOpKernel {
c->allocate_temp(partition_count.dtype(), partition_count.shape(),
&cpu_tensor, alloc_attr),
done);
- perftools::gputools::DeviceMemoryBase wrapped(
- partition_count.flat<int32>().data(), num_partitions_ * sizeof(int32));
+ se::DeviceMemoryBase wrapped(partition_count.flat<int32>().data(),
+ num_partitions_ * sizeof(int32));
const bool status =
stream
->ThenMemcpy(cpu_tensor.flat<int32>().data(), wrapped,
diff --git a/tensorflow/core/kernels/fft_ops.cc b/tensorflow/core/kernels/fft_ops.cc
index ab5af8caad..661bf5fc5f 100644
--- a/tensorflow/core/kernels/fft_ops.cc
+++ b/tensorflow/core/kernels/fft_ops.cc
@@ -277,20 +277,19 @@ REGISTER_KERNEL_BUILDER(Name("IRFFT3D").Device(DEVICE_CPU).Label(FFT_LABEL),
#undef FFT_LABEL
#if GOOGLE_CUDA
-namespace gpu = ::perftools::gputools;
namespace {
template <typename T>
-gpu::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
- gpu::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
- gpu::DeviceMemory<T> typed(wrapped);
+se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory) {
+ se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory));
+ se::DeviceMemory<T> typed(wrapped);
return typed;
}
template <typename T>
-gpu::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
- gpu::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
- gpu::DeviceMemory<T> typed(wrapped);
+se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
+ se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
+ se::DeviceMemory<T> typed(wrapped);
return typed;
}
@@ -299,19 +298,19 @@ gpu::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
// the kernel finishes.
// TODO(yangzihao): Refactor redundant code in subclasses of ScratchAllocator
// into base class.
-class CufftScratchAllocator : public gpu::ScratchAllocator {
+class CufftScratchAllocator : public se::ScratchAllocator {
public:
~CufftScratchAllocator() override {}
CufftScratchAllocator(int64 memory_limit, OpKernelContext* context)
: memory_limit_(memory_limit), total_byte_size_(0), context_(context) {}
- int64 GetMemoryLimitInBytes(gpu::Stream* stream) override {
+ int64 GetMemoryLimitInBytes(se::Stream* stream) override {
return memory_limit_;
}
- gpu::port::StatusOr<gpu::DeviceMemory<uint8>> AllocateBytes(
- gpu::Stream* stream, int64 byte_size) override {
+ se::port::StatusOr<se::DeviceMemory<uint8>> AllocateBytes(
+ se::Stream* stream, int64 byte_size) override {
Tensor temporary_memory;
if (byte_size > memory_limit_) {
- return gpu::port::StatusOr<gpu::DeviceMemory<uint8>>();
+ return se::port::StatusOr<se::DeviceMemory<uint8>>();
}
AllocationAttributes allocation_attr;
allocation_attr.no_retry_on_failure = true;
@@ -319,13 +318,13 @@ class CufftScratchAllocator : public gpu::ScratchAllocator {
DT_UINT8, TensorShape({byte_size}), &temporary_memory,
AllocatorAttributes(), allocation_attr));
if (!allocation_status.ok()) {
- return gpu::port::StatusOr<gpu::DeviceMemory<uint8>>();
+ return se::port::StatusOr<se::DeviceMemory<uint8>>();
}
// Hold the reference of the allocated tensors until the end of the
// allocator.
allocated_tensors_.push_back(temporary_memory);
total_byte_size_ += byte_size;
- return gpu::port::StatusOr<gpu::DeviceMemory<uint8>>(
+ return se::port::StatusOr<se::DeviceMemory<uint8>>(
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
temporary_memory.flat<uint8>().size()));
}
@@ -394,9 +393,9 @@ class FFTGPUBase : public FFTBase {
constexpr bool kInPlaceFft = false;
const auto kFftType =
- IsReal() ? (IsForward() ? gpu::fft::Type::kR2C : gpu::fft::Type::kC2R)
- : (IsForward() ? gpu::fft::Type::kC2CForward
- : gpu::fft::Type::kC2CInverse);
+ IsReal() ? (IsForward() ? se::fft::Type::kR2C : se::fft::Type::kC2R)
+ : (IsForward() ? se::fft::Type::kC2CForward
+ : se::fft::Type::kC2CInverse);
CufftScratchAllocator scratch_allocator(CufftScratchSize, ctx);
auto plan =
diff --git a/tensorflow/core/kernels/fused_batch_norm_op.cc b/tensorflow/core/kernels/fused_batch_norm_op.cc
index 9b4dca8511..f99dd643f7 100644
--- a/tensorflow/core/kernels/fused_batch_norm_op.cc
+++ b/tensorflow/core/kernels/fused_batch_norm_op.cc
@@ -251,7 +251,7 @@ struct FusedBatchNorm<GPUDevice, T, U> {
Tensor x_maybe_transformed = x;
Tensor x_transformed;
Tensor y_transformed;
- perftools::gputools::DeviceMemory<T> y_ptr;
+ se::DeviceMemory<T> y_ptr;
if (tensor_format == FORMAT_NCHW) {
y_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*y);
@@ -279,19 +279,19 @@ struct FusedBatchNorm<GPUDevice, T, U> {
return;
}
- perftools::gputools::dnn::BatchDescriptor x_desc;
+ se::dnn::BatchDescriptor x_desc;
x_desc.set_count(batch_size)
.set_feature_map_count(channels)
.set_height(height)
.set_width(width)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor scale_offset_desc;
+ se::dnn::BatchDescriptor scale_offset_desc;
scale_offset_desc.set_count(1)
.set_feature_map_count(channels)
.set_height(1)
.set_width(1)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
auto x_ptr = StreamExecutorUtil::AsDeviceMemory<T>(x_maybe_transformed);
auto scale_ptr = StreamExecutorUtil::AsDeviceMemory<U>(scale);
@@ -308,7 +308,7 @@ struct FusedBatchNorm<GPUDevice, T, U> {
StreamExecutorUtil::AsDeviceMemory<U>(*saved_inv_var);
GPUDevice d = context->eigen_device<GPUDevice>();
- using perftools::gputools::DeviceMemory;
+ using se::DeviceMemory;
Tensor inv_var;
OP_REQUIRES_OK(
context, context->allocate_temp(DataTypeToEnum<U>::value,
@@ -390,7 +390,7 @@ struct FusedBatchNormGrad<GPUDevice, T, U> {
// Outputs
Tensor x_backprop_transformed;
- perftools::gputools::DeviceMemory<T> x_backprop_ptr;
+ se::DeviceMemory<T> x_backprop_ptr;
if (tensor_format == FORMAT_NCHW) {
x_backprop_ptr = StreamExecutorUtil::AsDeviceMemory<T>(*x_backprop);
@@ -433,19 +433,19 @@ struct FusedBatchNormGrad<GPUDevice, T, U> {
return;
}
- perftools::gputools::dnn::BatchDescriptor x_desc;
+ se::dnn::BatchDescriptor x_desc;
x_desc.set_count(batch_size)
.set_feature_map_count(channels)
.set_height(height)
.set_width(width)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor scale_offset_desc;
+ se::dnn::BatchDescriptor scale_offset_desc;
scale_offset_desc.set_count(1)
.set_feature_map_count(channels)
.set_height(1)
.set_width(1)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
auto y_backprop_ptr =
StreamExecutorUtil::AsDeviceMemory<T>(y_backprop_maybe_transformed);
diff --git a/tensorflow/core/kernels/fuzzing/BUILD b/tensorflow/core/kernels/fuzzing/BUILD
index aab4b009b5..8bfa40304e 100644
--- a/tensorflow/core/kernels/fuzzing/BUILD
+++ b/tensorflow/core/kernels/fuzzing/BUILD
@@ -37,6 +37,8 @@ tf_ops_fuzz_target_lib("decode_png")
tf_ops_fuzz_target_lib("decode_jpeg")
+tf_ops_fuzz_target_lib("decode_wav")
+
tf_ops_fuzz_target_lib("example_proto_fast_parsing")
tf_ops_fuzz_target_lib("parse_tensor_op")
diff --git a/tensorflow/core/kernels/fuzzing/decode_wav_fuzz.cc b/tensorflow/core/kernels/fuzzing/decode_wav_fuzz.cc
new file mode 100644
index 0000000000..33a11d8e13
--- /dev/null
+++ b/tensorflow/core/kernels/fuzzing/decode_wav_fuzz.cc
@@ -0,0 +1,30 @@
+/* Copyright 2018 Google Inc. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/cc/ops/audio_ops.h"
+#include "tensorflow/cc/ops/standard_ops.h"
+#include "tensorflow/core/kernels/fuzzing/fuzz_session.h"
+
+namespace tensorflow {
+namespace fuzzing {
+
+class FuzzDecodeWav : public FuzzStringInputOp {
+ SINGLE_INPUT_OP_BUILDER(DT_STRING, DecodeWav);
+};
+
+STANDARD_TF_FUZZ_FUNCTION(FuzzDecodeWav);
+
+} // namespace fuzzing
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/gpu_utils.h b/tensorflow/core/kernels/gpu_utils.h
index ffc733e6bb..2f64619afc 100644
--- a/tensorflow/core/kernels/gpu_utils.h
+++ b/tensorflow/core/kernels/gpu_utils.h
@@ -29,11 +29,9 @@ limitations under the License.
namespace tensorflow {
template <typename T>
-inline perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
- uint64 size) {
- perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
- size * sizeof(T));
- perftools::gputools::DeviceMemory<T> typed(wrapped);
+inline se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
+ se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
+ se::DeviceMemory<T> typed(wrapped);
return typed;
}
diff --git a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
index 66d24d171d..3810cbe5b5 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_control_wrapper.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
+#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/kernels/hexagon/soc_interface.h"
diff --git a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
index 5fb6b9247f..d53977703e 100644
--- a/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
+++ b/tensorflow/core/kernels/hexagon/hexagon_graph_execution_test.cc
@@ -30,6 +30,7 @@ adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
#include <memory>
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
+#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
diff --git a/tensorflow/core/kernels/i_remote_fused_graph_executor.h b/tensorflow/core/kernels/i_remote_fused_graph_executor.h
index eb6b64da58..6072412689 100644
--- a/tensorflow/core/kernels/i_remote_fused_graph_executor.h
+++ b/tensorflow/core/kernels/i_remote_fused_graph_executor.h
@@ -16,13 +16,15 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
#define TENSORFLOW_CORE_KERNELS_I_REMOTE_GRAPH_EXECUTOR_H_
-#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
+class GraphDef;
+class RemoteFusedGraphExecuteInfo;
+
class IRemoteFusedGraphExecutor {
public:
using TensorAllocatorFunc = std::function<Tensor*(const TensorShape& shape)>;
diff --git a/tensorflow/core/kernels/initializable_lookup_table.h b/tensorflow/core/kernels/initializable_lookup_table.h
index edb779540f..990cbceac2 100644
--- a/tensorflow/core/kernels/initializable_lookup_table.h
+++ b/tensorflow/core/kernels/initializable_lookup_table.h
@@ -51,7 +51,7 @@ class InitializableLookupTable : public LookupInterface {
"Insert not supported by InitializableLookupTable implementations");
}
- Status ExportValues(OpKernelContext* context) final {
+ Status ExportValues(OpKernelContext* context) {
return errors::Unimplemented(
"ExportValues not supported by InitializableLookupTable "
"implementations");
diff --git a/tensorflow/core/kernels/lookup_table_op.h b/tensorflow/core/kernels/lookup_table_op.h
index 29a0cc91fe..3977f16299 100644
--- a/tensorflow/core/kernels/lookup_table_op.h
+++ b/tensorflow/core/kernels/lookup_table_op.h
@@ -177,6 +177,30 @@ class HashTable : public InitializableLookupTable {
return table_ ? table_->size() : 0;
}
+ Status ExportValues(OpKernelContext* context) override {
+ if (!is_initialized_) {
+ return errors::Aborted("HashTable is not initialized.");
+ }
+
+ const int64 size = table_->size();
+
+ Tensor* keys;
+ Tensor* values;
+ TF_RETURN_IF_ERROR(
+ context->allocate_output("keys", TensorShape({size}), &keys));
+ TF_RETURN_IF_ERROR(
+ context->allocate_output("values", TensorShape({size}), &values));
+
+ auto keys_data = keys->flat<K>();
+ auto values_data = values->flat<V>();
+ int64 i = 0;
+ for (auto it = table_->begin(); it != table_->end(); ++it, ++i) {
+ keys_data(i) = it->first;
+ values_data(i) = it->second;
+ }
+ return Status::OK();
+ }
+
DataType key_dtype() const override { return DataTypeToEnum<K>::v(); }
DataType value_dtype() const override { return DataTypeToEnum<V>::v(); }
diff --git a/tensorflow/core/kernels/lrn_op.cc b/tensorflow/core/kernels/lrn_op.cc
index c3a59c9576..b4252eb044 100644
--- a/tensorflow/core/kernels/lrn_op.cc
+++ b/tensorflow/core/kernels/lrn_op.cc
@@ -187,14 +187,14 @@ struct LaunchLRN<GPUDevice, T> {
const int cols = static_cast<int>(in.dim_size(2));
const int depth = static_cast<int>(in.dim_size(3));
- perftools::gputools::dnn::BatchDescriptor dimensions_desc;
+ se::dnn::BatchDescriptor dimensions_desc;
dimensions_desc.set_count(batch)
.set_height(rows)
.set_width(cols)
.set_feature_map_count(depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
+ .set_layout(se::dnn::DataLayout::kBatchYXDepth);
- perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
+ se::dnn::NormalizeDescriptor normalize_desc;
normalize_desc.set_bias(bias_)
.set_range(depth_radius_)
.set_alpha(alpha_)
@@ -404,14 +404,14 @@ struct LaunchLRNGrad<GPUDevice, T> {
const int64 cols = in_grads.dim_size(2);
const int64 depth = in_grads.dim_size(3);
- perftools::gputools::dnn::BatchDescriptor dimensions_desc;
+ se::dnn::BatchDescriptor dimensions_desc;
dimensions_desc.set_count(batch)
.set_height(rows)
.set_width(cols)
.set_feature_map_count(depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
+ .set_layout(se::dnn::DataLayout::kBatchYXDepth);
- perftools::gputools::dnn::NormalizeDescriptor normalize_desc;
+ se::dnn::NormalizeDescriptor normalize_desc;
normalize_desc.set_bias(bias_)
.set_range(depth_radius_)
.set_alpha(alpha_)
diff --git a/tensorflow/core/kernels/matmul_op.cc b/tensorflow/core/kernels/matmul_op.cc
index f499ce6519..3664f95c3b 100644
--- a/tensorflow/core/kernels/matmul_op.cc
+++ b/tensorflow/core/kernels/matmul_op.cc
@@ -112,7 +112,7 @@ bool ExplicitVectorMatrixOptimization<Eigen::half>(
template <typename Device, typename T>
struct LaunchMatMulBase {
#if GOOGLE_CUDA
- typedef perftools::gputools::blas::AlgorithmType AlgorithmType;
+ typedef se::blas::AlgorithmType AlgorithmType;
#else
typedef int64 AlgorithmType;
#endif // GOOGLE_CUDA
@@ -160,15 +160,12 @@ namespace {
template <typename T>
struct LaunchBlasGemv {
- static void Compute(
- OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans,
- uint64 m, uint64 n, const perftools::gputools::DeviceMemory<T>& a,
- const perftools::gputools::DeviceMemory<T>& b,
- perftools::gputools::DeviceMemory<T>* c,
- perftools::gputools::blas::ProfileResult* output_profile) {
- const auto blas_trans =
- trans ? perftools::gputools::blas::Transpose::kTranspose
- : perftools::gputools::blas::Transpose::kNoTranspose;
+ static void Compute(OpKernelContext* ctx, se::Stream* stream, bool trans,
+ uint64 m, uint64 n, const se::DeviceMemory<T>& a,
+ const se::DeviceMemory<T>& b, se::DeviceMemory<T>* c,
+ se::blas::ProfileResult* output_profile) {
+ const auto blas_trans = trans ? se::blas::Transpose::kTranspose
+ : se::blas::Transpose::kNoTranspose;
if (output_profile == nullptr) {
bool blas_launch_status =
stream
@@ -198,11 +195,10 @@ struct LaunchBlasGemv {
template <>
void LaunchBlasGemv<Eigen::half>::Compute(
- OpKernelContext* ctx, perftools::gputools::Stream* stream, bool trans,
- uint64 m, uint64 n, const perftools::gputools::DeviceMemory<Eigen::half>& a,
- const perftools::gputools::DeviceMemory<Eigen::half>& b,
- perftools::gputools::DeviceMemory<Eigen::half>* c,
- perftools::gputools::blas::ProfileResult* output_profile) {
+ OpKernelContext* ctx, se::Stream* stream, bool trans, uint64 m, uint64 n,
+ const se::DeviceMemory<Eigen::half>& a,
+ const se::DeviceMemory<Eigen::half>& b, se::DeviceMemory<Eigen::half>* c,
+ se::blas::ProfileResult* output_profile) {
ctx->SetStatus(errors::Internal(
"Blas GEMV launch failed: GEMV is not implemented for float16."));
}
@@ -219,10 +215,9 @@ bool ShouldUseGemv(uint64 n) {
} // namespace
-bool GetCublasAutotuneComputationType(
- const DataType& dtype,
- perftools::gputools::blas::ComputationType* compute_type) {
- using perftools::gputools::blas::ComputationType;
+bool GetCublasAutotuneComputationType(const DataType& dtype,
+ se::blas::ComputationType* compute_type) {
+ using se::blas::ComputationType;
bool use_f32_for_f16_computation = MatmulDoFP32ComputationFP16Input();
switch (dtype) {
case DT_HALF:
@@ -250,7 +245,7 @@ struct MatmulAutoTuneGroup {
static string name() { return "Matmul"; }
};
typedef AutoTuneSingleton<MatmulAutoTuneGroup, MatmulParameters,
- perftools::gputools::blas::AlgorithmConfig>
+ se::blas::AlgorithmConfig>
AutoTuneMatmul;
template <typename T>
@@ -259,14 +254,14 @@ struct LaunchMatMul<GPUDevice, T, true /* USE_CUBLAS */> {
OpKernelContext* ctx, const Tensor& a, const Tensor& b,
const Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1>& dim_pair,
std::vector<int64>* algorithms, bool use_autotune, Tensor* out) {
- using perftools::gputools::blas::AlgorithmConfig;
- using perftools::gputools::blas::ComputationType;
- using perftools::gputools::blas::kDefaultAlgorithm;
- using perftools::gputools::blas::kDefaultBlasGemm;
- using perftools::gputools::blas::kDefaultBlasGemv;
- using perftools::gputools::blas::kNoAlgorithm;
- using perftools::gputools::blas::ProfileResult;
- using perftools::gputools::blas::Transpose;
+ using se::blas::AlgorithmConfig;
+ using se::blas::ComputationType;
+ using se::blas::kDefaultAlgorithm;
+ using se::blas::kDefaultBlasGemm;
+ using se::blas::kDefaultBlasGemv;
+ using se::blas::kNoAlgorithm;
+ using se::blas::ProfileResult;
+ using se::blas::Transpose;
Transpose trans[] = {Transpose::kNoTranspose, Transpose::kTranspose};
const uint64 m = a.dim_size(1 - dim_pair[0].first);
const uint64 k = a.dim_size(dim_pair[0].first);
diff --git a/tensorflow/core/kernels/matrix_triangular_solve_op.cc b/tensorflow/core/kernels/matrix_triangular_solve_op.cc
index 6f7e6a7496..5de0d1118a 100644
--- a/tensorflow/core/kernels/matrix_triangular_solve_op.cc
+++ b/tensorflow/core/kernels/matrix_triangular_solve_op.cc
@@ -34,11 +34,9 @@ namespace tensorflow {
#if GOOGLE_CUDA
namespace {
template <typename Scalar>
-perftools::gputools::DeviceMemory<Scalar> AsDeviceMemory(
- const Scalar* cuda_memory) {
- perftools::gputools::DeviceMemoryBase wrapped(
- const_cast<Scalar*>(cuda_memory));
- perftools::gputools::DeviceMemory<Scalar> typed(wrapped);
+se::DeviceMemory<Scalar> AsDeviceMemory(const Scalar* cuda_memory) {
+ se::DeviceMemoryBase wrapped(const_cast<Scalar*>(cuda_memory));
+ se::DeviceMemory<Scalar> typed(wrapped);
return typed;
}
} // namespace
@@ -204,18 +202,17 @@ class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
// output' = rhs' / matrix' (' stands for transpose)
// Upper/lower needs to be swapped for this.
- perftools::gputools::blas::UpperLower upper_lower_matrix;
- perftools::gputools::blas::Transpose transpose_matrix;
+ se::blas::UpperLower upper_lower_matrix;
+ se::blas::Transpose transpose_matrix;
if (lower_) {
- upper_lower_matrix = perftools::gputools::blas::UpperLower::kUpper;
+ upper_lower_matrix = se::blas::UpperLower::kUpper;
} else {
- upper_lower_matrix = perftools::gputools::blas::UpperLower::kLower;
+ upper_lower_matrix = se::blas::UpperLower::kLower;
}
if (adjoint_) {
- transpose_matrix =
- perftools::gputools::blas::Transpose::kConjugateTranspose;
+ transpose_matrix = se::blas::Transpose::kConjugateTranspose;
} else {
- transpose_matrix = perftools::gputools::blas::Transpose::kNoTranspose;
+ transpose_matrix = se::blas::Transpose::kNoTranspose;
}
uint64 leading_dim_matrix = matrix.cols();
uint64 leading_dim_output = output.cols();
@@ -224,11 +221,11 @@ class MatrixTriangularSolveOpGPU : public LinearAlgebraOp<Scalar> {
bool blas_launch_status =
stream
->ThenBlasTrsm(
- perftools::gputools::blas::Side::kRight /*side*/,
- upper_lower_matrix /*uplo*/, transpose_matrix /*trans*/,
- perftools::gputools::blas::Diagonal::kNonUnit /*diag*/,
- colmajor_rows /*m*/, colmajor_cols /*n*/, Scalar(1.0) /*alpha*/,
- matrix_ptr, leading_dim_matrix /*lda*/, &out_ptr,
+ se::blas::Side::kRight /*side*/, upper_lower_matrix /*uplo*/,
+ transpose_matrix /*trans*/,
+ se::blas::Diagonal::kNonUnit /*diag*/, colmajor_rows /*m*/,
+ colmajor_cols /*n*/, Scalar(1.0) /*alpha*/, matrix_ptr,
+ leading_dim_matrix /*lda*/, &out_ptr,
leading_dim_output /*ldb*/)
.ok();
if (!blas_launch_status) {
diff --git a/tensorflow/core/kernels/maxpooling_op.cc b/tensorflow/core/kernels/maxpooling_op.cc
index aaaf45d3e7..507fc99837 100644
--- a/tensorflow/core/kernels/maxpooling_op.cc
+++ b/tensorflow/core/kernels/maxpooling_op.cc
@@ -404,10 +404,10 @@ class MaxPoolingGradOp<Eigen::GpuDevice, T> : public OpKernel {
"Pooling is not yet supported on the batch dimension."));
if (use_dnn_) {
- DnnPoolingGradOp<T>::Compute(
- context, perftools::gputools::dnn::PoolingMode::kMaximum, ksize,
- stride, padding_, data_format_, &tensor_in, &tensor_out, out_backprop,
- output_shape, propagate_nans_);
+ DnnPoolingGradOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum,
+ ksize, stride, padding_, data_format_,
+ &tensor_in, &tensor_out, out_backprop,
+ output_shape, propagate_nans_);
} else {
CHECK(data_format_ == FORMAT_NHWC)
<< "Non-Cudnn MaxPoolGrad only supports NHWC format";
@@ -1136,10 +1136,9 @@ class MaxPoolingNoMaskOp<GPUDevice, T> : public OpKernel {
// These is_int8x4 checks avoid linker errors for missing qint8 kernels.
if (!is_int8x4 && use_dnn_ && data_format_ == FORMAT_NCHW) {
- DnnPoolingOp<T>::Compute(context,
- perftools::gputools::dnn::PoolingMode::kMaximum,
- ksize_, stride_, padding_, data_format_,
- tensor_in, out_shape, propagate_nans_);
+ DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize_,
+ stride_, padding_, data_format_, tensor_in,
+ out_shape, propagate_nans_);
} else {
Tensor* output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(0, out_shape, &output));
@@ -1240,9 +1239,8 @@ class MaxPoolingNoMaskV2Op<GPUDevice, T> : public OpKernel {
ShapeFromFormat(data_format_, params.tensor_in_batch, params.out_height,
params.out_width, params.depth);
if (use_dnn_ && data_format_ == FORMAT_NCHW) {
- DnnPoolingOp<T>::Compute(context,
- perftools::gputools::dnn::PoolingMode::kMaximum,
- ksize, stride, padding_, data_format_, tensor_in,
+ DnnPoolingOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, ksize,
+ stride, padding_, data_format_, tensor_in,
out_shape, propagate_nans_);
} else {
CHECK(data_format_ == FORMAT_NHWC)
diff --git a/tensorflow/core/kernels/mkl_input_conversion_op.cc b/tensorflow/core/kernels/mkl_input_conversion_op.cc
index 96d5fbc408..cda1402b03 100644
--- a/tensorflow/core/kernels/mkl_input_conversion_op.cc
+++ b/tensorflow/core/kernels/mkl_input_conversion_op.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/tensor_format.h"
diff --git a/tensorflow/core/kernels/mkl_tfconv_op.h b/tensorflow/core/kernels/mkl_tfconv_op.h
index ddea9e281b..4120f013ac 100644
--- a/tensorflow/core/kernels/mkl_tfconv_op.h
+++ b/tensorflow/core/kernels/mkl_tfconv_op.h
@@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/kernels/ops_util.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/tensor_format.h"
diff --git a/tensorflow/core/kernels/partitioned_function_ops.cc b/tensorflow/core/kernels/partitioned_function_ops.cc
new file mode 100644
index 0000000000..d66b1ba663
--- /dev/null
+++ b/tensorflow/core/kernels/partitioned_function_ops.cc
@@ -0,0 +1,279 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/common_runtime/placer.h"
+#include "tensorflow/core/common_runtime/rendezvous_mgr.h"
+#include "tensorflow/core/framework/function.h"
+#include "tensorflow/core/framework/graph_to_functiondef.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/graph/graph.h"
+#include "tensorflow/core/graph/graph_constructor.h"
+#include "tensorflow/core/graph/graph_partition.h"
+#include "tensorflow/core/util/reffed_status_callback.h"
+
+#if GOOGLE_CUDA
+#include "tensorflow/stream_executor/stream.h"
+#endif // GOOGLE_CUDA
+
+namespace tensorflow {
+typedef FunctionLibraryRuntime::Handle FHandle;
+
+namespace {
+
+// A `PartitionedCallOp` asynchronously executes a function, potentially across
+// multiple devices but within a single process. The kernel places and
+// partitions a given function's underlying graph, and executes each of the
+// partitioned subgraphs as a function.
+//
+// TODO(akshayka): Support distributed execution.
+class PartitionedCallOp : public AsyncOpKernel {
+ public:
+ explicit PartitionedCallOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("f", &func_));
+ }
+
+ ~PartitionedCallOp() override {}
+
+ void ComputeAsync(OpKernelContext* ctx, DoneCallback done) override {
+ FunctionLibraryRuntime* lib = ctx->function_library();
+ OP_REQUIRES_ASYNC(ctx, lib != nullptr,
+ errors::Internal("No function library is provided."),
+ done);
+
+ // The function body's graph is placed and partitioned the first time
+ // `ComputeAsync` is invoked; every subsequent invocation calls each
+ // of the function shards yielded by partitioning.
+ //
+ // The partitioning step yields a set of devices on which to run the
+ // function, and exactly one function shard is created for each device
+ // Inputs and outputs are pinned to the local device, for simplicity.
+ //
+ // TODO(akshayka): Support re-sharding the function on subsequent calls,
+ // via, e.g., virtual device annotations and a list of device names supplied
+ // through an attribute.
+ //
+ // TODO(akshayka): Lift the constraint pinning inputs and outputs to the
+ // local device.
+ //
+ // TODO(akshayka): Add a fastpath for functions that execute on a single
+ // device.
+ {
+ mutex_lock l(mu_);
+ if (!partitioned_) {
+ // Instantiate the function to obtain its underlying graph, complete
+ // with nodes for arguments and return values.
+ FunctionLibraryRuntime::InstantiateOptions opts;
+ FHandle handle;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ lib->Instantiate(func_.name(), AttrSlice(&func_.attr()), opts,
+ &handle),
+ done);
+ Graph* graph = lib->GetFunctionBody(handle)->graph;
+
+ // Pin the inputs and outputs to the local device to simplify the
+ // function-dispatching logic.
+ local_device_name_ = lib->device()->name();
+ for (Node* node : graph->op_nodes()) {
+ string node_type = node->type_string();
+ if (node_type == FunctionLibraryDefinition::kArgOp ||
+ node_type == FunctionLibraryDefinition::kRetOp) {
+ node->set_assigned_device_name(local_device_name_);
+ }
+ }
+
+ // Place the graph, i.e,. assign a device to every node in it.
+ DeviceSet device_set;
+ for (auto d : lib->device_mgr()->ListDevices()) {
+ device_set.AddDevice(d);
+ }
+ Placer placer(graph, &device_set);
+ OP_REQUIRES_OK_ASYNC(ctx, placer.Run(), done);
+
+ // Partition the graph into subgraphs: exactly one subgraph per device.
+ //
+ // TODO(akshayka): Let devices rewrite their graphs.
+ PartitionOptions partition_options;
+ partition_options.node_to_loc = [](const Node* node) {
+ // TODO(akshayka): To better support the distributed case, first split
+ // the graph by worker (e.g,. using the master session's
+ // `SplitByWorker` policy), and then recursively partition the
+ // per-worker shards at the remote worker(s).
+ return node->assigned_device_name();
+ };
+ int64 edge_name_counter = 0;
+ partition_options.new_name =
+ [&edge_name_counter](const string& prefix) {
+ return strings::StrCat(prefix, "/_", ++edge_name_counter);
+ };
+ partition_options.get_incarnation =
+ [&device_set](const string& name) -> int64 {
+ const Device* d = device_set.FindDeviceByName(name);
+ if (d == nullptr) {
+ return PartitionOptions::kIllegalIncarnation;
+ } else {
+ return d->attributes().incarnation();
+ }
+ };
+ partition_options.control_flow_added = false;
+ std::unordered_map<string, GraphDef> partitions;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, Partition(partition_options, graph, &partitions), done);
+
+ VLOG(3) << "Partitioned function '" << func_.name() << "', yielding "
+ << partitions.size() << " shards.";
+
+ // `subgraphs` is a map from devices to their corresponding subgraphs.
+ gtl::FlatMap<string, std::unique_ptr<Graph>> subgraphs;
+ const FunctionLibraryDefinition* flib_def = &graph->flib_def();
+ for (const auto& partition : partitions) {
+ std::unique_ptr<Graph> subgraph(new Graph(flib_def));
+ GraphConstructorOptions opts;
+ opts.allow_internal_ops = true;
+ opts.expect_device_spec = true;
+ const string& device = partition.first;
+ const GraphDef& graph_def = partition.second;
+ OP_REQUIRES_OK_ASYNC(
+ ctx, ConvertGraphDefToGraph(opts, graph_def, subgraph.get()),
+ done);
+ subgraphs.emplace(device, std::move(subgraph));
+ }
+
+ // The FunctionLibraryRuntime's library cannot be mutated from within
+ // an OpKernel, so the functions are instantiated in an overlay library.
+ overlay_lib_.reset(new FunctionLibraryDefinition(
+ *lib->GetFunctionLibraryDefinition()));
+ for (const auto& pair : subgraphs) {
+ const string& target = pair.first;
+ Graph* subgraph = pair.second.get();
+ FunctionDef shard;
+ string unique_name = UniquifyFunctionName(func_.name());
+ OP_REQUIRES_OK_ASYNC(
+ ctx, GraphToFunctionDef(*subgraph, unique_name, &shard), done);
+ OP_REQUIRES_OK_ASYNC(ctx, overlay_lib_->AddFunctionDef(shard), done);
+ FunctionLibraryRuntime::InstantiateOptions opts;
+ opts.target = target;
+ opts.overlay_lib = overlay_lib_.get();
+ FHandle handle;
+ OP_REQUIRES_OK_ASYNC(
+ ctx,
+ lib->Instantiate(unique_name, AttrSlice(&shard.attr()), opts,
+ &handle),
+ done);
+ device_handle_map_.emplace(target, handle);
+ }
+ partitioned_ = true;
+ }
+ }
+
+ FunctionLibraryRuntime::Options opts;
+ opts.step_id = ctx->step_id();
+ opts.step_container = ctx->step_container();
+ opts.cancellation_manager = ctx->cancellation_manager();
+ opts.stats_collector = ctx->stats_collector();
+ // TODO(akshayka): Consider selecting a runner on a per-device basis, i.e.,
+ // using device-specific threadpools when available.
+ opts.runner = ctx->runner();
+ opts.source_device = local_device_name_;
+ // TODO(akshayka): Accommodate the multiple-worker scenario by adding the
+ // constructed rendezvous to a rendezvous manager.
+ Rendezvous* rendez = new IntraProcessRendezvous(lib->device_mgr());
+ opts.rendezvous = rendez;
+
+ OpInputList arguments;
+ OP_REQUIRES_OK_ASYNC(ctx, ctx->input_list("args", &arguments), done);
+ // Dummy args vector for the remote shards, which do not have inputs.
+ std::vector<Tensor> dummy_args;
+
+ StatusCallback callback = std::bind(
+ [](Rendezvous* rendez, DoneCallback& done, const Status& status) {
+ rendez->Unref();
+ done();
+ },
+ rendez, std::move(done), std::placeholders::_1);
+ auto* refcounted_done = new ReffedStatusCallback(std::move(callback));
+ for (int i = 1; i < device_handle_map_.size(); ++i) {
+ refcounted_done->Ref();
+ }
+
+ for (const auto& pair : device_handle_map_) {
+ const string& target_device = pair.first;
+ FHandle handle = pair.second;
+ VLOG(3) << "Running function shard on device " << target_device;
+ if (target_device == local_device_name_) {
+ opts.remote_execution = false;
+ std::vector<Tensor> args;
+ args.reserve(arguments.size());
+ for (const Tensor& argument : arguments) {
+ args.push_back(argument);
+ }
+ auto* rets = new std::vector<Tensor>;
+ lib->Run(opts, handle, args, rets,
+ [rets, refcounted_done, ctx](const Status& status) {
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ } else {
+ for (int i = 0; i < rets->size(); ++i) {
+ ctx->set_output(i, (*rets)[i]);
+ }
+ }
+ delete rets;
+ refcounted_done->Unref();
+ });
+ } else {
+ opts.remote_execution = true;
+ std::vector<Tensor>* dummy_rets = new std::vector<Tensor>;
+ lib->Run(opts, handle, dummy_args, dummy_rets,
+ [dummy_rets, refcounted_done, ctx](const Status& status) {
+ if (!status.ok()) {
+ ctx->SetStatus(status);
+ }
+ delete dummy_rets;
+ refcounted_done->Unref();
+ });
+ }
+ }
+ }
+
+ private:
+ string UniquifyFunctionName(const string& name) {
+ for (;; ++suffix_) {
+ const string candidate = strings::StrCat(name, "_", suffix_);
+ if (overlay_lib_->Find(candidate) == nullptr) {
+ return candidate;
+ }
+ }
+ }
+
+ // `func_` encapsulates the original, unsharded function.
+ NameAttrList func_;
+ string local_device_name_;
+ // Function shards are added to `overlay_lib_`.
+ std::unique_ptr<FunctionLibraryDefinition> overlay_lib_;
+ // A map from device names to handles of function shards.
+ gtl::FlatMap<string, FHandle> device_handle_map_;
+
+ mutex mu_;
+ bool partitioned_ GUARDED_BY(mu_) = false;
+
+ // Used to uniquify function names in `overlay_lib_`.
+ uint32 suffix_ = 0;
+};
+REGISTER_KERNEL_BUILDER(Name("PartitionedCall").Device(DEVICE_CPU),
+ PartitionedCallOp);
+
+} // namespace
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/pooling_ops_3d.cc b/tensorflow/core/kernels/pooling_ops_3d.cc
index 01bcfede1e..2180c4eb97 100644
--- a/tensorflow/core/kernels/pooling_ops_3d.cc
+++ b/tensorflow/core/kernels/pooling_ops_3d.cc
@@ -748,9 +748,8 @@ struct LaunchPoolingOp<GPUDevice, T, AVG> {
const std::array<int64, 3>& padding,
TensorFormat data_format, Padding padding_type,
Tensor* output) {
- DnnPooling3dOp<T>::Compute(
- context, perftools::gputools::dnn::PoolingMode::kAverage, window,
- stride, padding, data_format, tensor_in, output);
+ DnnPooling3dOp<T>::Compute(context, se::dnn::PoolingMode::kAverage, window,
+ stride, padding, data_format, tensor_in, output);
}
};
@@ -762,9 +761,8 @@ struct LaunchPoolingOp<GPUDevice, T, MAX> {
const std::array<int64, 3>& padding,
TensorFormat data_format, Padding padding_type,
Tensor* output) {
- DnnPooling3dOp<T>::Compute(
- context, perftools::gputools::dnn::PoolingMode::kMaximum, window,
- stride, padding, data_format, tensor_in, output);
+ DnnPooling3dOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum, window,
+ stride, padding, data_format, tensor_in, output);
}
};
@@ -778,10 +776,10 @@ struct LaunchMaxPooling3dGradOp<GPUDevice, T> {
const std::array<int64, 3>& padding,
TensorFormat data_format, Tensor* input_backprop) {
const TensorShape output_shape = tensor_in.shape();
- DnnPooling3dGradOp<T>::Compute(
- context, perftools::gputools::dnn::PoolingMode::kMaximum, window,
- stride, padding, out, data_format, out_backprop, output_shape,
- &tensor_in, &tensor_out, input_backprop);
+ DnnPooling3dGradOp<T>::Compute(context, se::dnn::PoolingMode::kMaximum,
+ window, stride, padding, out, data_format,
+ out_backprop, output_shape, &tensor_in,
+ &tensor_out, input_backprop);
}
};
@@ -796,9 +794,8 @@ struct LaunchAvgPooling3dGradOp<GPUDevice, T> {
const std::array<int64, 3>& padding,
TensorFormat data_format, Tensor* output) {
DnnPooling3dGradOp<T>::Compute(
- context, perftools::gputools::dnn::PoolingMode::kAverage, window,
- stride, padding, out, data_format, out_backprop, tensor_in_shape,
- nullptr, nullptr, output);
+ context, se::dnn::PoolingMode::kAverage, window, stride, padding, out,
+ data_format, out_backprop, tensor_in_shape, nullptr, nullptr, output);
}
};
diff --git a/tensorflow/core/kernels/pooling_ops_common.cc b/tensorflow/core/kernels/pooling_ops_common.cc
index d4241b5809..e583f7feb4 100644
--- a/tensorflow/core/kernels/pooling_ops_common.cc
+++ b/tensorflow/core/kernels/pooling_ops_common.cc
@@ -114,11 +114,9 @@ TensorShape PoolParameters::forward_output_shape() {
namespace {
template <typename T>
-perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
- uint64 size) {
- perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
- size * sizeof(T));
- perftools::gputools::DeviceMemory<T> typed(wrapped);
+se::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory, uint64 size) {
+ se::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory), size * sizeof(T));
+ se::DeviceMemory<T> typed(wrapped);
return typed;
}
} // namespace
@@ -138,12 +136,13 @@ TF_CALL_GPU_NUMBER_TYPES(DECLARE_GPU_SPEC)
} // namespace functor
template <typename T>
-void DnnPoolingOp<T>::Compute(
- OpKernelContext* context,
- perftools::gputools::dnn::PoolingMode pooling_mode,
- const std::vector<int32>& size, const std::vector<int32>& stride,
- Padding padding, TensorFormat data_format, const Tensor& tensor_in,
- const TensorShape& tensor_out_shape, bool propagate_nans) {
+void DnnPoolingOp<T>::Compute(OpKernelContext* context,
+ se::dnn::PoolingMode pooling_mode,
+ const std::vector<int32>& size,
+ const std::vector<int32>& stride, Padding padding,
+ TensorFormat data_format, const Tensor& tensor_in,
+ const TensorShape& tensor_out_shape,
+ bool propagate_nans) {
Tensor* tensor_out = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, tensor_out_shape, &tensor_out));
@@ -184,7 +183,7 @@ void DnnPoolingOp<T>::Compute(
}
/// Get ready to call cudnn
- perftools::gputools::dnn::PoolingDescriptor pooling_desc;
+ se::dnn::PoolingDescriptor pooling_desc;
pooling_desc.set_pooling_mode(pooling_mode)
.set_window_height(params.window_rows)
.set_window_width(params.window_cols)
@@ -194,19 +193,19 @@ void DnnPoolingOp<T>::Compute(
.set_horizontal_padding(params.pad_cols)
.set_propagate_nans(propagate_nans);
- perftools::gputools::dnn::BatchDescriptor input_desc;
+ se::dnn::BatchDescriptor input_desc;
input_desc.set_count(params.tensor_in_batch)
.set_height(params.tensor_in_rows)
.set_width(params.tensor_in_cols)
.set_feature_map_count(params.depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor output_desc;
+ se::dnn::BatchDescriptor output_desc;
output_desc.set_count(params.tensor_in_batch)
.set_height(params.out_height)
.set_width(params.out_width)
.set_feature_map_count(params.depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
auto input_data = AsDeviceMemory(transformed_input.template flat<T>().data(),
transformed_input.template flat<T>().size());
@@ -236,13 +235,12 @@ void DnnPoolingOp<T>::Compute(
template <typename T>
void DnnPoolingGradOp<T>::Compute(
- OpKernelContext* context,
- perftools::gputools::dnn::PoolingMode pooling_mode,
+ OpKernelContext* context, se::dnn::PoolingMode pooling_mode,
const std::vector<int32>& size, const std::vector<int32>& stride,
Padding padding, TensorFormat data_format, const Tensor* tensor_in,
const Tensor* tensor_out, const Tensor& out_backprop,
const TensorShape& tensor_in_shape, bool propagate_nans) {
- CHECK((pooling_mode != perftools::gputools::dnn::PoolingMode::kMaximum) ||
+ CHECK((pooling_mode != se::dnn::PoolingMode::kMaximum) ||
(tensor_in && tensor_out))
<< "For MaxPoolGrad, both tensor_in and tensor_out needs to be "
"specified";
@@ -327,7 +325,7 @@ void DnnPoolingGradOp<T>::Compute(
}
/// Get ready to call cudnn
- perftools::gputools::dnn::PoolingDescriptor pooling_desc;
+ se::dnn::PoolingDescriptor pooling_desc;
pooling_desc.set_pooling_mode(pooling_mode)
.set_window_height(params.window_rows)
.set_window_width(params.window_cols)
@@ -337,19 +335,19 @@ void DnnPoolingGradOp<T>::Compute(
.set_horizontal_padding(params.pad_cols)
.set_propagate_nans(propagate_nans);
- perftools::gputools::dnn::BatchDescriptor orig_output_desc;
+ se::dnn::BatchDescriptor orig_output_desc;
orig_output_desc.set_count(params.tensor_in_batch)
.set_height(params.out_height)
.set_width(params.out_width)
.set_feature_map_count(params.depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
- perftools::gputools::dnn::BatchDescriptor orig_input_desc;
+ se::dnn::BatchDescriptor orig_input_desc;
orig_input_desc.set_count(params.tensor_in_batch)
.set_height(params.tensor_in_rows)
.set_width(params.tensor_in_cols)
.set_feature_map_count(params.depth)
- .set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
+ .set_layout(se::dnn::DataLayout::kBatchDepthYX);
auto orig_output_data =
AsDeviceMemory(transformed_output.template flat<T>().data(),
diff --git a/tensorflow/core/kernels/pooling_ops_common_gpu.h b/tensorflow/core/kernels/pooling_ops_common_gpu.h
index 1458456585..7362c5275f 100644
--- a/tensorflow/core/kernels/pooling_ops_common_gpu.h
+++ b/tensorflow/core/kernels/pooling_ops_common_gpu.h
@@ -40,7 +40,7 @@ class DnnPoolingOp {
public:
typedef GPUDevice Device;
static void Compute(OpKernelContext* context,
- perftools::gputools::dnn::PoolingMode pooling_mode,
+ se::dnn::PoolingMode pooling_mode,
const std::vector<int32>& size,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format, const Tensor& tensor_in,
@@ -55,7 +55,7 @@ class DnnPoolingGradOp {
public:
typedef GPUDevice Device;
static void Compute(OpKernelContext* context,
- perftools::gputools::dnn::PoolingMode pooling_mode,
+ se::dnn::PoolingMode pooling_mode,
const std::vector<int32>& size,
const std::vector<int32>& stride, Padding padding,
TensorFormat data_format, const Tensor* tensor_in,
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
index e2709c117d..cc4d9a49a0 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.cc
@@ -20,7 +20,9 @@ limitations under the License.
#include <utility>
#include "tensorflow/core/common_runtime/shape_refiner.h"
+#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/node_def_util.h"
+#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/algorithm.h"
@@ -1125,46 +1127,43 @@ RemoteFusedGraphExecuteUtils::BuildRemoteFusedGraphExecuteOpNode(
for (size_t i = 0; i < inputs.size(); ++i) {
if (IsSameNodeName(node_def, inputs.at(i), &tid)) {
AppendDeliminator(&attr_str);
- attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::GRAPH_INPUT,
- tid.second, i, remote_graph_executor_name,
+ attr_str += BuildNodeTypeAttr(GRAPH_INPUT, tid.second, i,
+ remote_graph_executor_name,
remote_fused_graph_node_name);
}
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (IsSameNodeName(node_def, outputs.at(i), &tid)) {
AppendDeliminator(&attr_str);
- attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::GRAPH_OUTPUT,
- tid.second, i);
+ attr_str += BuildNodeTypeAttr(GRAPH_OUTPUT, tid.second, i);
}
}
for (const string& fused_node_name : fused_node_names) {
if (fused_node_name == node_def.name()) {
AppendDeliminator(&attr_str);
- attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::FUSED_NODE);
+ attr_str += BuildNodeTypeAttr(FUSED_NODE);
}
}
for (const string& fused_node_name : fused_nodes_filtered_by_op_types) {
if (fused_node_name == node_def.name()) {
AppendDeliminator(&attr_str);
- attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::FUSED_NODE);
+ attr_str += BuildNodeTypeAttr(FUSED_NODE);
}
}
for (size_t i = 0; i < border_inputs.size(); ++i) {
if (IsSameNodeName(node_def, border_inputs.at(i), &tid)) {
AppendDeliminator(&attr_str);
- attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::BORDER_INPUT,
- tid.second, i);
+ attr_str += BuildNodeTypeAttr(BORDER_INPUT, tid.second, i);
}
}
for (size_t i = 0; i < border_outputs.size(); ++i) {
if (IsSameNodeName(node_def, border_outputs.at(i), &tid)) {
AppendDeliminator(&attr_str);
- attr_str += BuildNodeTypeAttr(
- RemoteFusedGraphExecuteInfo::BORDER_OUTPUT, tid.second, i);
+ attr_str += BuildNodeTypeAttr(BORDER_OUTPUT, tid.second, i);
}
}
if (attr_str.empty()) {
- attr_str += BuildNodeTypeAttr(RemoteFusedGraphExecuteInfo::UNUSED);
+ attr_str += BuildNodeTypeAttr(UNUSED);
}
AddNodeAttr(ATTR_NODE_TYPE, attr_str, &node_def);
}
@@ -1200,14 +1199,14 @@ RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
}
int node_type_int;
CHECK(strings::safe_strto32(attr.at(0), &node_type_int)) << attr.at(0);
- const RemoteFusedGraphExecuteInfo::NodeType node_type =
- static_cast<RemoteFusedGraphExecuteInfo::NodeType>(node_type_int);
+ const RemoteFusedGraphNodeType node_type =
+ static_cast<RemoteFusedGraphNodeType>(node_type_int);
const string& name = node_def.name();
int port;
int index;
switch (node_type) {
- case RemoteFusedGraphExecuteInfo::GRAPH_INPUT:
+ case GRAPH_INPUT:
VLOG(2) << "Graph input: " << name;
CHECK_EQ(5, attr.size());
CHECK(strings::safe_strto32(attr.at(1), &port));
@@ -1224,33 +1223,33 @@ RemoteFusedGraphExecuteUtils::FuseRemoteGraphByPlacedArguments(
return Status::OK();
}
break;
- case RemoteFusedGraphExecuteInfo::GRAPH_OUTPUT:
+ case GRAPH_OUTPUT:
VLOG(2) << "Graph output: " << name;
CHECK_EQ(3, attr.size());
CHECK(strings::safe_strto32(attr.at(1), &port));
CHECK(strings::safe_strto32(attr.at(2), &index));
output_map.emplace(index, strings::StrCat(name, ":", port));
break;
- case RemoteFusedGraphExecuteInfo::FUSED_NODE:
+ case FUSED_NODE:
VLOG(2) << "Fused node: " << name;
CHECK_EQ(1, attr.size());
fused_node_names.emplace(name);
break;
- case RemoteFusedGraphExecuteInfo::BORDER_INPUT:
+ case BORDER_INPUT:
VLOG(2) << "Border input: " << name;
CHECK_EQ(3, attr.size());
CHECK(strings::safe_strto32(attr.at(1), &port));
CHECK(strings::safe_strto32(attr.at(2), &index));
border_input_map.emplace(index, strings::StrCat(name, ":", port));
break;
- case RemoteFusedGraphExecuteInfo::BORDER_OUTPUT:
+ case BORDER_OUTPUT:
VLOG(2) << "Border output: " << name;
CHECK_EQ(3, attr.size());
CHECK(strings::safe_strto32(attr.at(1), &port));
CHECK(strings::safe_strto32(attr.at(2), &index));
border_output_map.emplace(index, strings::StrCat(name, ":", port));
break;
- case RemoteFusedGraphExecuteInfo::UNUSED:
+ case UNUSED:
// do nothing
break;
default:
@@ -1461,20 +1460,19 @@ RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
}
/* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
- const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port,
- const int index, const string& executor_name, const string& node_name) {
+ const RemoteFusedGraphNodeType node_type, const int port, const int index,
+ const string& executor_name, const string& node_name) {
return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index,
",", executor_name, ",", node_name);
}
/* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
- const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port,
- const int index) {
+ const RemoteFusedGraphNodeType node_type, const int port, const int index) {
return strings::StrCat(static_cast<int>(node_type), ",", port, ",", index);
}
/* static */ string RemoteFusedGraphExecuteUtils::BuildNodeTypeAttr(
- const RemoteFusedGraphExecuteInfo::NodeType node_type) {
+ const RemoteFusedGraphNodeType node_type) {
return strings::StrCat(static_cast<int>(node_type));
}
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h
index f047144278..ea6b6a1015 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils.h
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils.h
@@ -19,8 +19,6 @@ limitations under the License.
#include <unordered_map>
#include <unordered_set>
-#include "tensorflow/core/framework/graph.pb.h"
-#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
@@ -30,6 +28,17 @@ limitations under the License.
namespace tensorflow {
+enum RemoteFusedGraphNodeType {
+ UNUSED = 0,
+ GRAPH_INPUT = 1,
+ GRAPH_OUTPUT = 2,
+ FUSED_NODE = 3,
+ BORDER_INPUT = 4,
+ BORDER_OUTPUT = 5,
+};
+
+class RemoteFusedGraphExecuteInfo;
+
// RemoteFusedGraphExecuteUtils provides APIs to register and get builder
// functions for IRemoteFusedGraphExecutor.
class RemoteFusedGraphExecuteUtils {
@@ -297,16 +306,15 @@ class RemoteFusedGraphExecuteUtils {
static ExecutorBuildRegistry* GetExecutorBuildRegistry();
- static string BuildNodeTypeAttr(
- const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port,
- const int index, const string& executor_name, const string& node_name);
+ static string BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type,
+ const int port, const int index,
+ const string& executor_name,
+ const string& node_name);
- static string BuildNodeTypeAttr(
- const RemoteFusedGraphExecuteInfo::NodeType node_type, const int port,
- const int index);
+ static string BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type,
+ const int port, const int index);
- static string BuildNodeTypeAttr(
- const RemoteFusedGraphExecuteInfo::NodeType node_type);
+ static string BuildNodeTypeAttr(const RemoteFusedGraphNodeType node_type);
TF_DISALLOW_COPY_AND_ASSIGN(RemoteFusedGraphExecuteUtils);
};
diff --git a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc
index aca8ddfae9..44251e6ff8 100644
--- a/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_execute_utils_test.cc
@@ -17,6 +17,7 @@ limitations under the License.
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/node_def.pb.h"
+#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_op_test_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
diff --git a/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc
index 9217c25978..1e0731e540 100644
--- a/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc
+++ b/tensorflow/core/kernels/remote_fused_graph_rewriter_transform_test.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include "tensorflow/cc/ops/nn_ops.h"
#include "tensorflow/cc/ops/standard_ops.h"
#include "tensorflow/core/common_runtime/function.h"
+#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/default_device.h"
diff --git a/tensorflow/core/kernels/segment_reduction_ops.cc b/tensorflow/core/kernels/segment_reduction_ops.cc
index 2fc73a3309..c87ce78e05 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.cc
+++ b/tensorflow/core/kernels/segment_reduction_ops.cc
@@ -40,7 +40,7 @@ limitations under the License.
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/platform/cuda.h"
-using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
+using stream_executor::cuda::ScopedActivateExecutorContext;
#endif // GOOGLE_CUDA
namespace tensorflow {
@@ -242,7 +242,7 @@ class SegmentSumGPUOp : public AsyncOpKernel {
return;
}
- perftools::gputools::DeviceMemoryBase output_rows_device(
+ se::DeviceMemoryBase output_rows_device(
const_cast<Tensor&>(segment_ids).template flat<Index>().data() +
(num_indices - 1));
ScratchSpace<Index> output_rows_host(context, 1, /* on_host */ true);
diff --git a/tensorflow/core/kernels/segment_reduction_ops.h b/tensorflow/core/kernels/segment_reduction_ops.h
index 0a0f8d4dcf..2ad9fa265e 100644
--- a/tensorflow/core/kernels/segment_reduction_ops.h
+++ b/tensorflow/core/kernels/segment_reduction_ops.h
@@ -24,6 +24,14 @@ limitations under the License.
// non-GPU targets. This only breaks in clang, because it's more strict for
// template code and CudaAtomicMax is used in template context.
+
+// This file requires the following include because it uses CudaAtomicMax:
+// #include "tensorflow/core/util/cuda_kernel_helper.h"
+
+// Unfortunately we can't add the #include, since it breaks compilation for
+// non-GPU targets. This only breaks in clang, because it's more strict for
+// template code and CudaAtomicMax is used in template context.
+
// This file requires the following include because it uses CudaAtomicMax:
// #include "tensorflow/core/util/cuda_kernel_helper.h"
diff --git a/tensorflow/core/kernels/sparse_matmul_op.h b/tensorflow/core/kernels/sparse_matmul_op.h
index 14ef2ed704..e89280724e 100644
--- a/tensorflow/core/kernels/sparse_matmul_op.h
+++ b/tensorflow/core/kernels/sparse_matmul_op.h
@@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_KERNELS_SPARSE_MATMUL_OP_H_
#include "third_party/eigen3/Eigen/Core"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/types.h"
#if defined(PLATFORM_WINDOWS)
diff --git a/tensorflow/core/kernels/string_split_op.cc b/tensorflow/core/kernels/string_split_op.cc
index 9efbd66ef7..4c2b312c34 100644
--- a/tensorflow/core/kernels/string_split_op.cc
+++ b/tensorflow/core/kernels/string_split_op.cc
@@ -71,7 +71,7 @@ class StringSplitOp : public OpKernel {
OP_REQUIRES_OK(ctx, ctx->input("delimiter", &delimiter_tensor));
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(delimiter_tensor->shape()),
- errors::InvalidArgument("delimiter must scalar, got shape: ",
+ errors::InvalidArgument("delimiter must be a scalar, got shape: ",
delimiter_tensor->shape().DebugString()));
const auto delimiter_vec = delimiter_tensor->flat<string>();
const string& delimiter = delimiter_vec(0);
diff --git a/tensorflow/core/kernels/summary_interface.h b/tensorflow/core/kernels/summary_interface.h
index 02391e967a..1854fe5526 100644
--- a/tensorflow/core/kernels/summary_interface.h
+++ b/tensorflow/core/kernels/summary_interface.h
@@ -17,14 +17,15 @@ limitations under the License.
#include <memory>
-#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/platform/types.h"
-#include "tensorflow/core/util/event.pb.h"
namespace tensorflow {
+class Event;
+class GraphDef;
+
// Main interface for the summary writer resource.
class SummaryWriterInterface : public ResourceBase {
public:
diff --git a/tensorflow/core/kernels/summary_kernels.cc b/tensorflow/core/kernels/summary_kernels.cc
index d317a8d33d..b287f0cc2f 100644
--- a/tensorflow/core/kernels/summary_kernels.cc
+++ b/tensorflow/core/kernels/summary_kernels.cc
@@ -21,6 +21,7 @@ limitations under the License.
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/db/sqlite.h"
#include "tensorflow/core/platform/protobuf.h"
+#include "tensorflow/core/util/event.pb.h"
namespace tensorflow {
diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc
index 5b13b10937..271329599f 100644
--- a/tensorflow/core/kernels/training_ops.cc
+++ b/tensorflow/core/kernels/training_ops.cc
@@ -153,8 +153,10 @@ struct ApplyAdagrad<CPUDevice, T> {
void operator()(const CPUDevice& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat accum,
typename TTypes<T>::ConstScalar lr,
- typename TTypes<T>::ConstFlat grad) {
- accum.device(d) += grad.square();
+ typename TTypes<T>::ConstFlat grad, bool update_slots) {
+ if (update_slots) {
+ accum.device(d) += grad.square();
+ }
var.device(d) -= grad * lr() * accum.rsqrt();
}
};
@@ -1074,6 +1076,7 @@ class ApplyAdagradOp : public OpKernel {
public:
explicit ApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
}
void Compute(OpKernelContext* ctx) override {
@@ -1111,13 +1114,15 @@ class ApplyAdagradOp : public OpKernel {
const Device& device = ctx->template eigen_device<Device>();
functor::ApplyAdagrad<Device, T>()(device, var.flat<T>(), accum.flat<T>(),
- lr.scalar<T>(), grad.flat<T>());
+ lr.scalar<T>(), grad.flat<T>(),
+ update_slots_);
MaybeForwardRefInputToRefOutput(ctx, 0, 0);
}
private:
bool use_exclusive_lock_;
+ bool update_slots_;
};
#define REGISTER_KERNELS(D, T) \
@@ -1145,7 +1150,7 @@ namespace functor {
void ApplyAdagrad<GPUDevice, T>::operator()( \
const GPUDevice& d, typename TTypes<T>::Flat var, \
typename TTypes<T>::Flat accum, typename TTypes<T>::ConstScalar lr, \
- typename TTypes<T>::ConstFlat grad); \
+ typename TTypes<T>::ConstFlat grad, bool update_slots); \
extern template struct ApplyAdagrad<GPUDevice, T>;
DECLARE_GPU_SPEC(Eigen::half);
DECLARE_GPU_SPEC(float);
@@ -1266,6 +1271,7 @@ class SparseApplyAdagradOp : public OpKernel {
public:
explicit SparseApplyAdagradOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_locking", &use_exclusive_lock_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("update_slots", &update_slots_));
}
void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS {
@@ -1339,7 +1345,9 @@ class SparseApplyAdagradOp : public OpKernel {
auto a = accum_flat.template chip<0>(index);
auto g = grad_flat.template chip<0>(i);
auto v = var_flat.template chip<0>(index);
- a += g.square();
+ if (update_slots_) {
+ a += g.square();
+ }
v -= g.constant(lr_scalar) * g * a.rsqrt();
}
} else {
@@ -1358,7 +1366,9 @@ class SparseApplyAdagradOp : public OpKernel {
" in indices is out of range")));
T& a = accum_flat(index);
const T& g = grad_flat(i);
- a += g * g;
+ if (update_slots_) {
+ a += g * g;
+ }
var_flat(index) -= lr_scalar * g / Eigen::numext::sqrt(a);
}
}
@@ -1369,6 +1379,7 @@ class SparseApplyAdagradOp : public OpKernel {
private:
bool use_exclusive_lock_;
+ bool update_slots_;
};
#define REGISTER_KERNELS(T, Tindices) \
diff --git a/tensorflow/core/kernels/training_ops.h b/tensorflow/core/kernels/training_ops.h
index f536a61eb0..495a94f1a1 100644
--- a/tensorflow/core/kernels/training_ops.h
+++ b/tensorflow/core/kernels/training_ops.h
@@ -68,7 +68,7 @@ struct ApplyAdagrad {
void operator()(const Device& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat accum,
typename TTypes<T>::ConstScalar lr,
- typename TTypes<T>::ConstFlat grad);
+ typename TTypes<T>::ConstFlat grad, bool update_slots);
};
template <typename Device, typename T>
diff --git a/tensorflow/core/kernels/training_ops_gpu.cu.cc b/tensorflow/core/kernels/training_ops_gpu.cu.cc
index 2aa17f2a0f..4bd32592db 100644
--- a/tensorflow/core/kernels/training_ops_gpu.cu.cc
+++ b/tensorflow/core/kernels/training_ops_gpu.cu.cc
@@ -42,8 +42,10 @@ struct ApplyAdagrad<GPUDevice, T> {
void operator()(const GPUDevice& d, typename TTypes<T>::Flat var,
typename TTypes<T>::Flat accum,
typename TTypes<T>::ConstScalar lr,
- typename TTypes<T>::ConstFlat grad) {
- accum.device(d) += grad.square();
+ typename TTypes<T>::ConstFlat grad, bool update_slots) {
+ if (update_slots) {
+ accum.device(d) += grad.square();
+ }
Eigen::array<typename TTypes<T>::Tensor::Index, 1> bcast;
bcast[0] = grad.dimension(0);
Eigen::Sizes<1> single;
diff --git a/tensorflow/core/kernels/variable_ops.h b/tensorflow/core/kernels/variable_ops.h
index 8b406e5311..f27dab4ddd 100644
--- a/tensorflow/core/kernels/variable_ops.h
+++ b/tensorflow/core/kernels/variable_ops.h
@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/resource_mgr.h"
+#include "tensorflow/core/framework/resource_var.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
@@ -27,39 +28,6 @@ limitations under the License.
namespace tensorflow {
-// Resource stored by variables in the resource manager
-// (new, resource-style version).
-class Var : public ResourceBase {
- public:
- explicit Var(DataType dtype) : tensor_(dtype) {}
- // Not copyable or movable.
- Var(const Var&) = delete;
- Var& operator=(const Var&) = delete;
-
- // TODO(ebrevdo): Use LockSet instead of exposing mu.
- mutex* mu() { return &mu_; }
- Tensor* tensor() { return &tensor_; }
-
- string DebugString() override {
- return strings::StrCat(DataTypeString(tensor_.dtype()), "/",
- tensor_.shape().DebugString());
- }
-
- // Only used in the resource variable path. In resource variables,
- // tensor.IsInitialized() can be true (i.e. have memory allocated to it) while
- // there is not a good value there due to a race condition, and it's possible
- // to stumble upon this during variable.initialized_value(). So it's best to
- // just store directly whether the variable is initialized.
- bool is_initialized = false; // GUARDED_BY(mu_) but annotalysis doesn't like
- // it.
-
- private:
- mutex mu_;
- Tensor tensor_;
-
- ~Var() override {}
-};
-
class VariableOp : public OpKernel {
public:
explicit VariableOp(OpKernelConstruction* context);
diff --git a/tensorflow/core/kernels/where_op.cc b/tensorflow/core/kernels/where_op.cc
index f92c4ed17a..3330442ffd 100644
--- a/tensorflow/core/kernels/where_op.cc
+++ b/tensorflow/core/kernels/where_op.cc
@@ -42,7 +42,7 @@ limitations under the License.
#include "tensorflow/core/kernels/cuda_solvers.h"
#include "tensorflow/core/platform/cuda.h"
-using ::perftools::gputools::cuda::ScopedActivateExecutorContext;
+using stream_executor::cuda::ScopedActivateExecutorContext;
#endif // GOOGLE_CUDA
namespace tensorflow {
@@ -278,8 +278,7 @@ class WhereGPUOp : public AsyncOpKernel {
auto num_true_t = num_true.scalar<Tindex>();
- perftools::gputools::DeviceMemoryBase num_true_ptr(
- static_cast<void*>(num_true_t.data()));
+ se::DeviceMemoryBase num_true_ptr(static_cast<void*>(num_true_t.data()));
// Push kernel to stream to get number of true elements.
const GPUDevice& d = context->eigen_device<GPUDevice>();
Status s = functor::NumTrue<GPUDevice, T, Tindex>::Compute(
diff --git a/tensorflow/core/lib/bfloat16/bfloat16.h b/tensorflow/core/lib/bfloat16/bfloat16.h
index 1a822d441d..2c0576ff10 100644
--- a/tensorflow/core/lib/bfloat16/bfloat16.h
+++ b/tensorflow/core/lib/bfloat16/bfloat16.h
@@ -19,8 +19,7 @@ limitations under the License.
#include <cmath>
#include <complex>
-// We need cpu_info.h here in order to pick up __BYTE_ORDER__.
-#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/byte_order.h"
#ifdef __CUDACC__
// All functions callable from CUDA code must be qualified with __device__
diff --git a/tensorflow/core/lib/core/coding.cc b/tensorflow/core/lib/core/coding.cc
index bb95c27410..50872eef83 100644
--- a/tensorflow/core/lib/core/coding.cc
+++ b/tensorflow/core/lib/core/coding.cc
@@ -15,7 +15,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/coding.h"
-#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/byte_order.h"
namespace tensorflow {
namespace core {
diff --git a/tensorflow/core/lib/core/raw_coding.h b/tensorflow/core/lib/core/raw_coding.h
index bbfd33d303..37201b755d 100644
--- a/tensorflow/core/lib/core/raw_coding.h
+++ b/tensorflow/core/lib/core/raw_coding.h
@@ -17,7 +17,7 @@ limitations under the License.
#define TENSORFLOW_LIB_CORE_RAW_CODING_H_
#include <string.h>
-#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
diff --git a/tensorflow/core/lib/gtl/inlined_vector.h b/tensorflow/core/lib/gtl/inlined_vector.h
index 6e3cb2206d..2011f7d4a1 100644
--- a/tensorflow/core/lib/gtl/inlined_vector.h
+++ b/tensorflow/core/lib/gtl/inlined_vector.h
@@ -43,7 +43,7 @@ limitations under the License.
#include <vector>
#include "tensorflow/core/lib/gtl/manual_constructor.h"
-#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/mem.h"
#include "tensorflow/core/platform/types.h"
diff --git a/tensorflow/core/lib/png/png_io.cc b/tensorflow/core/lib/png/png_io.cc
index cba473927d..62c803afb2 100644
--- a/tensorflow/core/lib/png/png_io.cc
+++ b/tensorflow/core/lib/png/png_io.cc
@@ -26,7 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/lib/png/png_io.h"
-#include "tensorflow/core/platform/cpu_info.h" // endian
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/png.h"
diff --git a/tensorflow/core/lib/wav/wav_io.cc b/tensorflow/core/lib/wav/wav_io.cc
index 51b9c6cd82..36d939e061 100644
--- a/tensorflow/core/lib/wav/wav_io.cc
+++ b/tensorflow/core/lib/wav/wav_io.cc
@@ -23,7 +23,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/coding.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/wav/wav_io.h"
-#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
@@ -285,6 +285,12 @@ Status DecodeLin16WaveAsFloatVector(const string& wav_string,
was_data_found = true;
*sample_count = chunk_size / bytes_per_sample;
const uint32 data_count = *sample_count * *channel_count;
+ int unused_new_offset = 0;
+ // Validate that the data exists before allocating space for it
+ // (prevent easy OOM errors).
+ TF_RETURN_IF_ERROR(IncrementOffset(offset, sizeof(int16) * data_count,
+ wav_string.size(),
+ &unused_new_offset));
float_values->resize(data_count);
for (int i = 0; i < data_count; ++i) {
int16 single_channel_value = 0;
diff --git a/tensorflow/core/ops/compat/ops_history.v1.pbtxt b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
index 247f9edf5b..71ba5f016a 100644
--- a/tensorflow/core/ops/compat/ops_history.v1.pbtxt
+++ b/tensorflow/core/ops/compat/ops_history.v1.pbtxt
@@ -1535,6 +1535,85 @@ op {
}
}
op {
+ name: "ApplyAdaMax"
+ input_arg {
+ name: "var"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "m"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "v"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "beta1_power"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "lr"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "beta1"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "beta2"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "epsilon"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "out"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "ApplyAdadelta"
input_arg {
name: "var"
@@ -2043,6 +2122,71 @@ op {
}
}
op {
+ name: "ApplyAdagrad"
+ input_arg {
+ name: "var"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "accum"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "lr"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "out"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "update_slots"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "ApplyAdagradDA"
input_arg {
name: "var"
@@ -11235,6 +11379,38 @@ op {
}
}
op {
+ name: "BroadcastTo"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "shape"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "Bucketize"
input_arg {
name: "input"
@@ -15006,6 +15182,148 @@ op {
is_stateful: true
}
op {
+ name: "CudnnRNNBackpropV2"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_h"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_c"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "params"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "output_h"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "output_c"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "output_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "output_h_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "output_c_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reserve_space"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "host_reserved"
+ type: DT_INT8
+ }
+ output_arg {
+ name: "input_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "input_h_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "input_c_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "params_backprop"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "rnn_mode"
+ type: "string"
+ default_value {
+ s: "lstm"
+ }
+ allowed_values {
+ list {
+ s: "rnn_relu"
+ s: "rnn_tanh"
+ s: "lstm"
+ s: "gru"
+ }
+ }
+ }
+ attr {
+ name: "input_mode"
+ type: "string"
+ default_value {
+ s: "linear_input"
+ }
+ allowed_values {
+ list {
+ s: "linear_input"
+ s: "skip_input"
+ s: "auto_select"
+ }
+ }
+ }
+ attr {
+ name: "direction"
+ type: "string"
+ default_value {
+ s: "unidirectional"
+ }
+ allowed_values {
+ list {
+ s: "unidirectional"
+ s: "bidirectional"
+ }
+ }
+ }
+ attr {
+ name: "dropout"
+ type: "float"
+ default_value {
+ f: 0
+ }
+ }
+ attr {
+ name: "seed"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ attr {
+ name: "seed2"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ is_stateful: true
+}
+op {
name: "CudnnRNNCanonicalToParams"
input_arg {
name: "num_layers"
@@ -15327,6 +15645,127 @@ op {
}
}
op {
+ name: "CudnnRNNV2"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_h"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_c"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "params"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output_h"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output_c"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "host_reserved"
+ type: DT_INT8
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "rnn_mode"
+ type: "string"
+ default_value {
+ s: "lstm"
+ }
+ allowed_values {
+ list {
+ s: "rnn_relu"
+ s: "rnn_tanh"
+ s: "lstm"
+ s: "gru"
+ }
+ }
+ }
+ attr {
+ name: "input_mode"
+ type: "string"
+ default_value {
+ s: "linear_input"
+ }
+ allowed_values {
+ list {
+ s: "linear_input"
+ s: "skip_input"
+ s: "auto_select"
+ }
+ }
+ }
+ attr {
+ name: "direction"
+ type: "string"
+ default_value {
+ s: "unidirectional"
+ }
+ allowed_values {
+ list {
+ s: "unidirectional"
+ s: "bidirectional"
+ }
+ }
+ }
+ attr {
+ name: "dropout"
+ type: "float"
+ default_value {
+ f: 0
+ }
+ }
+ attr {
+ name: "seed"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ attr {
+ name: "seed2"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ is_stateful: true
+}
+op {
name: "Cumprod"
input_arg {
name: "x"
@@ -35449,6 +35888,31 @@ op {
}
}
op {
+ name: "PartitionedCall"
+ input_arg {
+ name: "args"
+ type_list_attr: "Tin"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+}
+op {
name: "Placeholder"
output_arg {
name: "output"
@@ -42886,6 +43350,78 @@ op {
}
}
op {
+ name: "ResourceApplyAdaMax"
+ input_arg {
+ name: "var"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "m"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "v"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "beta1_power"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "lr"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "beta1"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "beta2"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "epsilon"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ is_stateful: true
+}
+op {
name: "ResourceApplyAdadelta"
input_arg {
name: "var"
@@ -43342,6 +43878,65 @@ op {
is_stateful: true
}
op {
+ name: "ResourceApplyAdagrad"
+ input_arg {
+ name: "var"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "accum"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "lr"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "update_slots"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ is_stateful: true
+}
+op {
name: "ResourceApplyAdagradDA"
input_arg {
name: "var"
@@ -47694,6 +48289,79 @@ op {
is_stateful: true
}
op {
+ name: "ResourceSparseApplyAdagrad"
+ input_arg {
+ name: "var"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "accum"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "lr"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "update_slots"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ is_stateful: true
+}
+op {
name: "ResourceSparseApplyAdagradDA"
input_arg {
name: "var"
@@ -58440,6 +59108,85 @@ op {
}
}
op {
+ name: "SparseApplyAdagrad"
+ input_arg {
+ name: "var"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "accum"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "lr"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "indices"
+ type_attr: "Tindices"
+ }
+ output_arg {
+ name: "out"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tindices"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "update_slots"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+}
+op {
name: "SparseApplyAdagradDA"
input_arg {
name: "var"
@@ -66435,6 +67182,17 @@ op {
}
}
op {
+ name: "StringStrip"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+}
+op {
name: "StringToHashBucket"
input_arg {
name: "string_tensor"
diff --git a/tensorflow/core/ops/cudnn_rnn_ops.cc b/tensorflow/core/ops/cudnn_rnn_ops.cc
index 37d70a22ef..f78f7a897a 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops.cc
@@ -99,6 +99,49 @@ REGISTER_OP("CudnnRNN")
return Status::OK();
});
+REGISTER_OP("CudnnRNNV2")
+ .Input("input: T")
+ .Input("input_h: T")
+ .Input("input_c: T")
+ .Input("params: T")
+ .SetIsStateful()
+ .Output("output: T")
+ .Output("output_h: T")
+ .Output("output_c: T")
+ .Output("reserve_space: T")
+ .Output("host_reserved: int8")
+ .Attr("T: {float16, float32, float64}")
+ .Attr(kRNNModeAttrs)
+ .Attr(kRNNInputModeAttrs)
+ .Attr(kRNNDirectionAttrs)
+ .Attr("dropout: float = 0.0")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .Attr("is_training: bool = true")
+ .SetShapeFn([](InferenceContext* c) {
+ auto input_shape = c->input(0);
+ auto input_h_shape = c->input(1);
+ auto seq_length = c->Dim(input_shape, 0);
+ auto batch_size = c->Dim(input_shape, 1);
+ auto num_units = c->Dim(input_h_shape, 2);
+ string direction;
+ TF_RETURN_IF_ERROR(c->GetAttr("direction", &direction));
+ string rnn_mode;
+ TF_RETURN_IF_ERROR(c->GetAttr("rnn_mode", &rnn_mode));
+ int dir_count = (direction == "bidirectional") ? 2 : 1;
+ DimensionHandle output_size;
+ TF_RETURN_IF_ERROR(c->Multiply(num_units, dir_count, &output_size));
+ auto output_shape = c->MakeShape({seq_length, batch_size, output_size});
+ auto output_h_shape = input_h_shape;
+ auto output_c_shape TF_ATTRIBUTE_UNUSED =
+ (rnn_mode == "lstm") ? output_h_shape : c->MakeShape({});
+ c->set_output(0, output_shape);
+ c->set_output(1, output_h_shape);
+ c->set_output(2, output_c_shape);
+ c->set_output(3, c->UnknownShape());
+ c->set_output(4, c->UnknownShape());
+ return Status::OK();
+ });
REGISTER_OP("CudnnRNNBackprop")
.Input("input: T")
@@ -136,6 +179,42 @@ REGISTER_OP("CudnnRNNBackprop")
return Status::OK();
});
+REGISTER_OP("CudnnRNNBackpropV2")
+ .Input("input: T")
+ .Input("input_h: T")
+ .Input("input_c: T")
+ .Input("params: T")
+ .Input("output: T")
+ .Input("output_h: T")
+ .Input("output_c: T")
+ .Input("output_backprop: T")
+ .Input("output_h_backprop: T")
+ .Input("output_c_backprop: T")
+ .Input("reserve_space: T")
+ .Input("host_reserved: int8")
+ .SetIsStateful()
+ .Output("input_backprop: T")
+ .Output("input_h_backprop: T")
+ .Output("input_c_backprop: T")
+ .Output("params_backprop: T")
+ .Attr("T: {float16, float32, float64}")
+ .Attr(kRNNModeAttrs)
+ .Attr(kRNNInputModeAttrs)
+ .Attr(kRNNDirectionAttrs)
+ .Attr("dropout: float = 0.0")
+ .Attr("seed: int = 0")
+ .Attr("seed2: int = 0")
+ .SetShapeFn([](InferenceContext* c) {
+ auto input_shape = c->input(0);
+ auto input_h_shape = c->input(1);
+ auto input_c_shape = c->input(2);
+ auto params_shape = c->input(3);
+ c->set_output(0, input_shape);
+ c->set_output(1, input_h_shape);
+ c->set_output(2, input_c_shape);
+ c->set_output(3, params_shape);
+ return Status::OK();
+ });
REGISTER_OP("CudnnRNNParamsToCanonical")
.Input("num_layers: int32")
diff --git a/tensorflow/core/ops/cudnn_rnn_ops_test.cc b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
index 95d45c0bb8..2dd867561b 100644
--- a/tensorflow/core/ops/cudnn_rnn_ops_test.cc
+++ b/tensorflow/core/ops/cudnn_rnn_ops_test.cc
@@ -30,6 +30,24 @@ TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) {
}
TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
+ int seq_length = 2;
+ int batch_size = 3;
+ int num_units = 4;
+ int num_layers = 5;
+ int dir_count = 1;
+ std::vector<int> input_shape = {seq_length, batch_size, num_units};
+ std::vector<int> input_h_shape = {num_layers * dir_count, batch_size,
+ num_units};
+ std::vector<int> output_shape = {seq_length, batch_size,
+ num_units * dir_count};
+ auto shape_to_str = [](const std::vector<int>& v) {
+ return strings::StrCat("[", str_util::Join(v, ","), "]");
+ };
+ string input_shapes_desc = strings::StrCat(
+ shape_to_str(input_shape), ";", shape_to_str(input_h_shape), ";",
+ shape_to_str(input_h_shape), ";", "[?]");
+ string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in1;?";
+
ShapeInferenceTestOp op("CudnnRNN");
TF_ASSERT_OK(NodeDefBuilder("test", "CudnnRNN")
.Input({"input", 0, DT_FLOAT})
@@ -40,6 +58,10 @@ TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
.Attr("input_mode", "auto_select")
.Attr("direction", "unidirectional")
.Finalize(&op.node_def));
+ INFER_OK(op, input_shapes_desc, output_shapes_desc);
+}
+
+TEST(CudnnRNNOpsTest, ForwardV2Lstm_ShapeFn) {
int seq_length = 2;
int batch_size = 3;
int num_units = 4;
@@ -56,7 +78,18 @@ TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
string input_shapes_desc = strings::StrCat(
shape_to_str(input_shape), ";", shape_to_str(input_h_shape), ";",
shape_to_str(input_h_shape), ";", "[?]");
- string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in1;?";
+ string output_shapes_desc = "[d0_0,d0_1,d1_2];in1;in1;?;?";
+
+ ShapeInferenceTestOp op("CudnnRNNV2");
+ TF_ASSERT_OK(NodeDefBuilder("test", "CudnnRNNV2")
+ .Input({"input", 0, DT_FLOAT})
+ .Input({"input_h", 0, DT_FLOAT})
+ .Input({"input_c", 0, DT_FLOAT})
+ .Input({"params", 0, DT_FLOAT})
+ .Attr("rnn_mode", "lstm")
+ .Attr("input_mode", "auto_select")
+ .Attr("direction", "unidirectional")
+ .Finalize(&op.node_def));
INFER_OK(op, input_shapes_desc, output_shapes_desc);
}
diff --git a/tensorflow/core/ops/functional_ops.cc b/tensorflow/core/ops/functional_ops.cc
index 792686cae1..4d4a370478 100644
--- a/tensorflow/core/ops/functional_ops.cc
+++ b/tensorflow/core/ops/functional_ops.cc
@@ -145,4 +145,13 @@ REGISTER_OP("For")
.Attr("body: func")
.SetShapeFn(shape_inference::UnknownShape);
+// TODO(b/73826847, b/37549631) Mark as stateful.
+REGISTER_OP("PartitionedCall")
+ .Input("args: Tin")
+ .Output("output: Tout")
+ .Attr("Tin: list(type) >= 0")
+ .Attr("Tout: list(type) >= 0")
+ .Attr("f: func")
+ .SetShapeFn(shape_inference::UnknownShape);
+
} // end namespace tensorflow
diff --git a/tensorflow/core/ops/ops.pbtxt b/tensorflow/core/ops/ops.pbtxt
index d1773daebe..90368fe614 100644
--- a/tensorflow/core/ops/ops.pbtxt
+++ b/tensorflow/core/ops/ops.pbtxt
@@ -685,6 +685,85 @@ op {
}
}
op {
+ name: "ApplyAdaMax"
+ input_arg {
+ name: "var"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "m"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "v"
+ type_attr: "T"
+ is_ref: true
+ }
+ input_arg {
+ name: "beta1_power"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "lr"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "beta1"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "beta2"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "epsilon"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "out"
+ type_attr: "T"
+ is_ref: true
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+}
+op {
name: "ApplyAdadelta"
input_arg {
name: "var"
@@ -812,6 +891,13 @@ op {
b: false
}
}
+ attr {
+ name: "update_slots"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
}
op {
name: "ApplyAdagradDA"
@@ -4389,6 +4475,38 @@ op {
}
}
op {
+ name: "BroadcastTo"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "shape"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+}
+op {
name: "Bucketize"
input_arg {
name: "input"
@@ -6524,6 +6642,148 @@ op {
is_stateful: true
}
op {
+ name: "CudnnRNNBackpropV2"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_h"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_c"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "params"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "output_h"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "output_c"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "output_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "output_h_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "output_c_backprop"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reserve_space"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "host_reserved"
+ type: DT_INT8
+ }
+ output_arg {
+ name: "input_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "input_h_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "input_c_backprop"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "params_backprop"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "rnn_mode"
+ type: "string"
+ default_value {
+ s: "lstm"
+ }
+ allowed_values {
+ list {
+ s: "rnn_relu"
+ s: "rnn_tanh"
+ s: "lstm"
+ s: "gru"
+ }
+ }
+ }
+ attr {
+ name: "input_mode"
+ type: "string"
+ default_value {
+ s: "linear_input"
+ }
+ allowed_values {
+ list {
+ s: "linear_input"
+ s: "skip_input"
+ s: "auto_select"
+ }
+ }
+ }
+ attr {
+ name: "direction"
+ type: "string"
+ default_value {
+ s: "unidirectional"
+ }
+ allowed_values {
+ list {
+ s: "unidirectional"
+ s: "bidirectional"
+ }
+ }
+ }
+ attr {
+ name: "dropout"
+ type: "float"
+ default_value {
+ f: 0
+ }
+ }
+ attr {
+ name: "seed"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ attr {
+ name: "seed2"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ is_stateful: true
+}
+op {
name: "CudnnRNNCanonicalToParams"
input_arg {
name: "num_layers"
@@ -6845,6 +7105,127 @@ op {
}
}
op {
+ name: "CudnnRNNV2"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_h"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "input_c"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "params"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output_h"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output_c"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "reserve_space"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "host_reserved"
+ type: DT_INT8
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ }
+ }
+ }
+ attr {
+ name: "rnn_mode"
+ type: "string"
+ default_value {
+ s: "lstm"
+ }
+ allowed_values {
+ list {
+ s: "rnn_relu"
+ s: "rnn_tanh"
+ s: "lstm"
+ s: "gru"
+ }
+ }
+ }
+ attr {
+ name: "input_mode"
+ type: "string"
+ default_value {
+ s: "linear_input"
+ }
+ allowed_values {
+ list {
+ s: "linear_input"
+ s: "skip_input"
+ s: "auto_select"
+ }
+ }
+ }
+ attr {
+ name: "direction"
+ type: "string"
+ default_value {
+ s: "unidirectional"
+ }
+ allowed_values {
+ list {
+ s: "unidirectional"
+ s: "bidirectional"
+ }
+ }
+ }
+ attr {
+ name: "dropout"
+ type: "float"
+ default_value {
+ f: 0
+ }
+ }
+ attr {
+ name: "seed"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ attr {
+ name: "seed2"
+ type: "int"
+ default_value {
+ i: 0
+ }
+ }
+ attr {
+ name: "is_training"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
+ is_stateful: true
+}
+op {
name: "Cumprod"
input_arg {
name: "x"
@@ -17319,6 +17700,31 @@ op {
}
}
op {
+ name: "PartitionedCall"
+ input_arg {
+ name: "args"
+ type_list_attr: "Tin"
+ }
+ output_arg {
+ name: "output"
+ type_list_attr: "Tout"
+ }
+ attr {
+ name: "Tin"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "Tout"
+ type: "list(type)"
+ has_minimum: true
+ }
+ attr {
+ name: "f"
+ type: "func"
+ }
+}
+op {
name: "Placeholder"
output_arg {
name: "output"
@@ -21488,6 +21894,78 @@ op {
}
}
op {
+ name: "ResourceApplyAdaMax"
+ input_arg {
+ name: "var"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "m"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "v"
+ type: DT_RESOURCE
+ }
+ input_arg {
+ name: "beta1_power"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "lr"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "beta1"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "beta2"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "epsilon"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "grad"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "use_locking"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ is_stateful: true
+}
+op {
name: "ResourceApplyAdadelta"
input_arg {
name: "var"
@@ -21601,6 +22079,13 @@ op {
b: false
}
}
+ attr {
+ name: "update_slots"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
is_stateful: true
}
op {
@@ -22967,6 +23452,13 @@ op {
b: false
}
}
+ attr {
+ name: "update_slots"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
is_stateful: true
}
op {
@@ -27004,6 +27496,13 @@ op {
b: false
}
}
+ attr {
+ name: "update_slots"
+ type: "bool"
+ default_value {
+ b: true
+ }
+ }
}
op {
name: "SparseApplyAdagradDA"
@@ -30484,6 +30983,17 @@ op {
}
}
op {
+ name: "StringStrip"
+ input_arg {
+ name: "input"
+ type: DT_STRING
+ }
+ output_arg {
+ name: "output"
+ type: DT_STRING
+ }
+}
+op {
name: "StringToHashBucket"
input_arg {
name: "string_tensor"
diff --git a/tensorflow/core/ops/training_ops.cc b/tensorflow/core/ops/training_ops.cc
index dc7b588898..94ff092a85 100644
--- a/tensorflow/core/ops/training_ops.cc
+++ b/tensorflow/core/ops/training_ops.cc
@@ -253,6 +253,7 @@ REGISTER_OP("ApplyAdagrad")
.Output("out: Ref(T)")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
+ .Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, false /* sparse */);
});
@@ -264,6 +265,7 @@ REGISTER_OP("ResourceApplyAdagrad")
.Input("grad: T")
.Attr("T: numbertype")
.Attr("use_locking: bool = false")
+ .Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, false /* sparse */);
});
@@ -320,6 +322,7 @@ REGISTER_OP("SparseApplyAdagrad")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, true /* sparse */);
});
@@ -333,6 +336,7 @@ REGISTER_OP("ResourceSparseApplyAdagrad")
.Attr("T: numbertype")
.Attr("Tindices: {int32, int64}")
.Attr("use_locking: bool = false")
+ .Attr("update_slots: bool = true")
.SetShapeFn([](InferenceContext* c) {
return ApplyAdagradShapeFn(c, true /* sparse */);
});
diff --git a/tensorflow/core/platform/byte_order.h b/tensorflow/core/platform/byte_order.h
new file mode 100644
index 0000000000..aab6535e4b
--- /dev/null
+++ b/tensorflow/core/platform/byte_order.h
@@ -0,0 +1,37 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_BYTE_ORDER_H_
+#define TENSORFLOW_CORE_PLATFORM_BYTE_ORDER_H_
+
+// Byte order defines provided by gcc. MSVC doesn't define those so
+// we define them here.
+// We assume that all windows platform out there are little endian.
+#if defined(_MSC_VER) && !defined(__clang__)
+#define __ORDER_LITTLE_ENDIAN__ 0x4d2
+#define __ORDER_BIG_ENDIAN__ 0x10e1
+#define __BYTE_ORDER__ __ORDER_LITTLE_ENDIAN__
+#endif
+
+namespace tensorflow {
+namespace port {
+
+// TODO(jeff,sanjay): Make portable
+constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
+
+} // namespace port
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_BYTE_ORDER_H_
diff --git a/tensorflow/core/platform/cloud/expiring_lru_cache.h b/tensorflow/core/platform/cloud/expiring_lru_cache.h
index c738497ddd..e2d048f141 100644
--- a/tensorflow/core/platform/cloud/expiring_lru_cache.h
+++ b/tensorflow/core/platform/cloud/expiring_lru_cache.h
@@ -51,6 +51,14 @@ class ExpiringLRUCache {
InsertLocked(key, value);
}
+ // Delete the entry with key `key`. Return true if the entry was found for
+ // `key`, false if the entry was not found. In both cases, there is no entry
+ // with key `key` existed after the call.
+ bool Delete(const string& key) {
+ mutex_lock lock(mu_);
+ return DeleteLocked(key);
+ }
+
/// Look up the entry with key `key` and copy it to `value` if found. Returns
/// true if an entry was found for `key`, and its timestamp is not more than
/// max_age_ seconds in the past.
@@ -141,6 +149,16 @@ class ExpiringLRUCache {
}
}
+ bool DeleteLocked(const string& key) EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ auto it = cache_.find(key);
+ if (it == cache_.end()) {
+ return false;
+ }
+ lru_list_.erase(it->second.lru_iterator);
+ cache_.erase(it);
+ return true;
+ }
+
/// The maximum age of entries in the cache, in seconds. A value of 0 means
/// that no entry is ever placed in the cache.
const uint64 max_age_;
diff --git a/tensorflow/core/platform/cloud/expiring_lru_cache_test.cc b/tensorflow/core/platform/cloud/expiring_lru_cache_test.cc
index 3bc6db3842..42879e80a9 100644
--- a/tensorflow/core/platform/cloud/expiring_lru_cache_test.cc
+++ b/tensorflow/core/platform/cloud/expiring_lru_cache_test.cc
@@ -174,5 +174,22 @@ TEST(ExpiringLRUCacheTest, Clear) {
EXPECT_FALSE(cache.Lookup("d", &value));
}
+TEST(ExpiringLRUCacheTest, Delete) {
+ // Insert an entry.
+ ExpiringLRUCache<int> cache(1, 4);
+ cache.Insert("a", 1);
+ int value = 0;
+ EXPECT_TRUE(cache.Lookup("a", &value));
+ EXPECT_EQ(value, 1);
+
+ // Delete the entry.
+ EXPECT_TRUE(cache.Delete("a"));
+ EXPECT_FALSE(cache.Lookup("a", &value));
+
+ // Try deleting the entry again.
+ EXPECT_FALSE(cache.Delete("a"));
+ EXPECT_FALSE(cache.Lookup("a", &value));
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index f0003fa784..2d9c99c124 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -857,14 +857,20 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
return Status::OK();
}
+void GcsFileSystem::ClearFileCaches(const string& fname) {
+ file_block_cache_->RemoveFile(fname);
+ stat_cache_->Delete(fname);
+ // TODO(rxsang): Remove the patterns that matche the file in
+ // MatchingPathsCache as well.
+}
+
Status GcsFileSystem::NewWritableFile(const string& fname,
std::unique_ptr<WritableFile>* result) {
string bucket, object;
TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
- result->reset(new GcsWritableFile(
- bucket, object, this, &timeouts_,
- [this, fname]() { file_block_cache_->RemoveFile(fname); },
- initial_retry_delay_usec_));
+ result->reset(new GcsWritableFile(bucket, object, this, &timeouts_,
+ [this, fname]() { ClearFileCaches(fname); },
+ initial_retry_delay_usec_));
return Status::OK();
}
@@ -904,8 +910,7 @@ Status GcsFileSystem::NewAppendableFile(const string& fname,
TF_RETURN_IF_ERROR(ParseGcsPath(fname, false, &bucket, &object));
result->reset(new GcsWritableFile(
bucket, object, this, old_content_filename, &timeouts_,
- [this, fname]() { file_block_cache_->RemoveFile(fname); },
- initial_retry_delay_usec_));
+ [this, fname]() { ClearFileCaches(fname); }, initial_retry_delay_usec_));
return Status::OK();
}
@@ -1277,7 +1282,7 @@ Status GcsFileSystem::DeleteFile(const string& fname) {
request->SetDeleteRequest();
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when deleting ", fname);
- file_block_cache_->RemoveFile(fname);
+ ClearFileCaches(fname);
return Status::OK();
}
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.h b/tensorflow/core/platform/cloud/gcs_file_system.h
index 703c8d5778..99c94c1751 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.h
+++ b/tensorflow/core/platform/cloud/gcs_file_system.h
@@ -227,6 +227,9 @@ class GcsFileSystem : public FileSystem {
Status LoadBufferFromGCS(const string& filename, size_t offset, size_t n,
char* buffer, size_t* bytes_transferred);
+ // Clear all the caches related to the file with name `filename`.
+ void ClearFileCaches(const string& fname);
+
std::unique_ptr<AuthProvider> auth_provider_;
std::unique_ptr<HttpRequest::Factory> http_request_factory_;
std::unique_ptr<FileBlockCache> file_block_cache_;
diff --git a/tensorflow/core/platform/cloud/gcs_file_system_test.cc b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
index ca4b7722b6..c639299954 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system_test.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system_test.cc
@@ -1551,6 +1551,56 @@ TEST(GcsFileSystemTest, DeleteFile_NoObjectName) {
fs.DeleteFile("gs://bucket/").code());
}
+TEST(GcsFileSystemTest, DeleteFile_StatCacheRemoved) {
+ std::vector<HttpRequest*> requests(
+ {new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
+ "file.txt?fields=size%2Cupdated\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ strings::StrCat("{\"size\": \"1010\","
+ "\"updated\": \"2016-04-29T23:15:24.896Z\"}")),
+ new FakeHttpRequest("Uri: https://www.googleapis.com/storage/v1/b"
+ "/bucket/o/file.txt\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n"
+ "Delete: yes\n",
+ ""),
+ new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o/"
+ "file.txt?fields=size%2Cupdated\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ "", errors::NotFound("404"), 404),
+ new FakeHttpRequest(
+ "Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
+ "fields=items%2Fname%2CnextPageToken&prefix=file.txt%2F"
+ "&maxResults=1\n"
+ "Auth Token: fake_token\n"
+ "Timeouts: 5 1 10\n",
+ "{}")});
+ GcsFileSystem fs(
+ std::unique_ptr<AuthProvider>(new FakeAuthProvider),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ 16 /* block size */, 16 /* max bytes */, 0 /* max staleness */,
+ 3600 /* stat cache max age */, 0 /* stat cache max entries */,
+ 0 /* matching paths cache max age */,
+ 0 /* matching paths cache max entries */, 0 /* initial retry delay*/,
+ kTestTimeoutConfig, nullptr /* gcs additional header */);
+
+ // Stats the file first so the stat is cached.
+ FileStatistics stat_before_deletion;
+ TF_EXPECT_OK(fs.Stat("gs://bucket/file.txt", &stat_before_deletion));
+ EXPECT_EQ(1010, stat_before_deletion.length);
+
+ TF_EXPECT_OK(fs.DeleteFile("gs://bucket/file.txt"));
+
+ FileStatistics stat_after_deletion;
+ EXPECT_EQ(error::Code::NOT_FOUND,
+ fs.Stat("gs://bucket/file.txt", &stat_after_deletion).code());
+}
+
TEST(GcsFileSystemTest, DeleteDir_Empty) {
std::vector<HttpRequest*> requests({new FakeHttpRequest(
"Uri: https://www.googleapis.com/storage/v1/b/bucket/o?"
diff --git a/tensorflow/core/platform/cpu_feature_guard.cc b/tensorflow/core/platform/cpu_feature_guard.cc
index b570658158..9d00aa7b7f 100644
--- a/tensorflow/core/platform/cpu_feature_guard.cc
+++ b/tensorflow/core/platform/cpu_feature_guard.cc
@@ -18,6 +18,7 @@ limitations under the License.
#include <mutex>
#include <string>
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/logging.h"
diff --git a/tensorflow/core/platform/cpu_info.h b/tensorflow/core/platform/cpu_info.h
index bb77650e26..b5be7e8b54 100644
--- a/tensorflow/core/platform/cpu_info.h
+++ b/tensorflow/core/platform/cpu_info.h
@@ -18,6 +18,10 @@ limitations under the License.
#include <string>
+// TODO(ahentz): This is not strictly required here but, for historical
+// reasons, many people depend on cpu_info.h in order to use kLittleEndian.
+#include "tensorflow/core/platform/byte_order.h"
+
#if defined(_MSC_VER)
#include "tensorflow/core/platform/windows/cpu_info.h"
#endif
@@ -25,9 +29,6 @@ limitations under the License.
namespace tensorflow {
namespace port {
-// TODO(jeff,sanjay): Make portable
-constexpr bool kLittleEndian = __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__;
-
// Returns an estimate of the number of schedulable CPUs for this
// process. Usually, it's constant throughout the lifetime of a
// process, but it might change if the underlying cluster management
diff --git a/tensorflow/core/platform/default/build_config.bzl b/tensorflow/core/platform/default/build_config.bzl
index 44356e3438..ca0587e277 100644
--- a/tensorflow/core/platform/default/build_config.bzl
+++ b/tensorflow/core/platform/default/build_config.bzl
@@ -319,10 +319,34 @@ def tf_proto_library_cc(name, srcs = [], has_services = None,
use_grpc_plugin = None
if cc_grpc_version:
use_grpc_plugin = True
+
+ cc_deps = tf_deps(protodeps, "_cc")
+ cc_name = name + "_cc"
+ if not srcs:
+ # This is a collection of sub-libraries. Build header-only and impl
+ # libraries containing all the sources.
+ proto_gen(
+ name = cc_name + "_genproto",
+ deps = [s + "_genproto" for s in cc_deps],
+ protoc = "@protobuf_archive//:protoc",
+ visibility=["//visibility:public"],
+ )
+ native.cc_library(
+ name = cc_name,
+ deps = cc_deps + ["@protobuf_archive//:protobuf_headers"] +
+ if_static([name + "_cc_impl"]),
+ )
+ native.cc_library(
+ name = cc_name + "_impl",
+ deps = [s + "_impl" for s in cc_deps] + ["@protobuf_archive//:cc_wkt_protos"],
+ )
+
+ return
+
cc_proto_library(
- name = name + "_cc",
+ name = cc_name,
srcs = srcs,
- deps = tf_deps(protodeps, "_cc") + ["@protobuf_archive//:cc_wkt_protos"],
+ deps = cc_deps + ["@protobuf_archive//:cc_wkt_protos"],
cc_libs = cc_libs + if_static(
["@protobuf_archive//:protobuf"],
["@protobuf_archive//:protobuf_headers"]
@@ -341,11 +365,28 @@ def tf_proto_library_cc(name, srcs = [], has_services = None,
def tf_proto_library_py(name, srcs=[], protodeps=[], deps=[], visibility=[],
testonly=0, srcs_version="PY2AND3", use_grpc_plugin=False):
+ py_deps = tf_deps(protodeps, "_py")
+ py_name = name + "_py"
+ if not srcs:
+ # This is a collection of sub-libraries. Build header-only and impl
+ # libraries containing all the sources.
+ proto_gen(
+ name = py_name + "_genproto",
+ deps = [s + "_genproto" for s in py_deps],
+ protoc = "@protobuf_archive//:protoc",
+ visibility=["//visibility:public"],
+ )
+ native.py_library(
+ name = py_name,
+ deps = py_deps + ["@protobuf_archive//:protobuf_python"])
+
+ return
+
py_proto_library(
- name = name + "_py",
+ name = py_name,
srcs = srcs,
srcs_version = srcs_version,
- deps = deps + tf_deps(protodeps, "_py") + ["@protobuf_archive//:protobuf_python"],
+ deps = deps + py_deps + ["@protobuf_archive//:protobuf_python"],
protoc = "@protobuf_archive//:protoc",
default_runtime = "@protobuf_archive//:protobuf_python",
visibility = visibility,
diff --git a/tensorflow/core/platform/default/from_stream_executor_status.h b/tensorflow/core/platform/default/from_stream_executor_status.h
deleted file mode 100644
index 36a67a3648..0000000000
--- a/tensorflow/core/platform/default/from_stream_executor_status.h
+++ /dev/null
@@ -1,35 +0,0 @@
-/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
-
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
-==============================================================================*/
-
-#ifndef TENSORFLOW_PLATFORM_DEFAULT_FROM_STREAM_EXECUTOR_STATUS_H_
-#define TENSORFLOW_PLATFORM_DEFAULT_FROM_STREAM_EXECUTOR_STATUS_H_
-
-// IWYU pragma: private, include "third_party/tensorflow/core/platform/from_stream_executor_status.h"
-// IWYU pragma: friend third_party/tensorflow/core/platform/from_stream_executor_status.h
-
-#include "tensorflow/stream_executor/lib/status.h"
-
-namespace tensorflow {
-
-// On the open-source platform, stream_executor currently uses
-// tensorflow::Status
-inline Status FromStreamExecutorStatus(
- const perftools::gputools::port::Status& s) {
- return s;
-}
-
-} // namespace tensorflow
-
-#endif // TENSORFLOW_PLATFORM_DEFAULT_FROM_STREAM_EXECUTOR_STATUS_H_
diff --git a/tensorflow/core/platform/default/gpu/cupti_wrapper.cc b/tensorflow/core/platform/default/gpu/cupti_wrapper.cc
index 580db4844f..7ac5e5c445 100644
--- a/tensorflow/core/platform/default/gpu/cupti_wrapper.cc
+++ b/tensorflow/core/platform/default/gpu/cupti_wrapper.cc
@@ -28,27 +28,27 @@ namespace profiler {
namespace dynload {
-#define LIBCUPTI_WRAP(__name) \
- struct DynLoadShim__##__name { \
- static const char* kName; \
- using FuncPointerT = std::add_pointer<decltype(::__name)>::type; \
- static void* GetDsoHandle() { \
- static auto status = perftools::gputools::internal::CachedDsoLoader:: \
- GetLibcuptiDsoHandle(); \
- return status.ValueOrDie(); \
- } \
- static FuncPointerT DynLoad() { \
- static void* f; \
- TF_CHECK_OK(::tensorflow::Env::Default()->GetSymbolFromLibrary( \
- GetDsoHandle(), kName, &f)) \
- << "could not find " << kName << "in libcupti DSO"; \
- return reinterpret_cast<FuncPointerT>(f); \
- } \
- template <typename... Args> \
- CUptiResult operator()(Args... args) { \
- return DynLoad()(args...); \
- } \
- } __name; \
+#define LIBCUPTI_WRAP(__name) \
+ struct DynLoadShim__##__name { \
+ static const char* kName; \
+ using FuncPointerT = std::add_pointer<decltype(::__name)>::type; \
+ static void* GetDsoHandle() { \
+ static auto status = \
+ stream_executor::internal::CachedDsoLoader::GetLibcuptiDsoHandle(); \
+ return status.ValueOrDie(); \
+ } \
+ static FuncPointerT DynLoad() { \
+ static void* f; \
+ TF_CHECK_OK(::tensorflow::Env::Default()->GetSymbolFromLibrary( \
+ GetDsoHandle(), kName, &f)) \
+ << "could not find " << kName << "in libcupti DSO"; \
+ return reinterpret_cast<FuncPointerT>(f); \
+ } \
+ template <typename... Args> \
+ CUptiResult operator()(Args... args) { \
+ return DynLoad()(args...); \
+ } \
+ } __name; \
const char* DynLoadShim__##__name::kName = #__name;
LIBCUPTI_WRAP(cuptiActivityDisable);
diff --git a/tensorflow/core/platform/denormal.cc b/tensorflow/core/platform/denormal.cc
index 82cbc43b4f..c510dc204f 100644
--- a/tensorflow/core/platform/denormal.cc
+++ b/tensorflow/core/platform/denormal.cc
@@ -15,8 +15,9 @@ limitations under the License.
#include <tuple>
-#include "tensorflow/core/platform/denormal.h"
+#include "tensorflow/core/platform/byte_order.h"
#include "tensorflow/core/platform/cpu_info.h"
+#include "tensorflow/core/platform/denormal.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/platform.h"
// If we're on gcc 4.8 or older, there's a known bug that prevents the use of
diff --git a/tensorflow/core/platform/stream_executor.h b/tensorflow/core/platform/stream_executor.h
index 006184ddef..0a590b3d40 100644
--- a/tensorflow/core/platform/stream_executor.h
+++ b/tensorflow/core/platform/stream_executor.h
@@ -19,10 +19,8 @@ limitations under the License.
#include "tensorflow/core/platform/platform.h"
#if defined(PLATFORM_GOOGLE)
-#include "tensorflow/core/platform/google/from_stream_executor_status.h"
#include "tensorflow/stream_executor/platform/google/dso_loader.h"
#else
-#include "tensorflow/core/platform/default/from_stream_executor_status.h"
#include "tensorflow/stream_executor/dso_loader.h"
#endif
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
diff --git a/tensorflow/core/platform/stream_executor_no_cuda.h b/tensorflow/core/platform/stream_executor_no_cuda.h
index 4a41d7adf5..50a5e732c0 100644
--- a/tensorflow/core/platform/stream_executor_no_cuda.h
+++ b/tensorflow/core/platform/stream_executor_no_cuda.h
@@ -19,10 +19,8 @@ limitations under the License.
#include "tensorflow/core/platform/platform.h"
#if defined(PLATFORM_GOOGLE)
-#include "tensorflow/core/platform/google/from_stream_executor_status.h"
#include "tensorflow/stream_executor/platform/google/dso_loader.h"
#else
-#include "tensorflow/core/platform/default/from_stream_executor_status.h"
#include "tensorflow/stream_executor/dso_loader.h"
#endif
#include "tensorflow/stream_executor/cuda/cuda_platform_id.h"
diff --git a/tensorflow/core/platform/types.h b/tensorflow/core/platform/types.h
index f2471712cc..68897ac423 100644
--- a/tensorflow/core/platform/types.h
+++ b/tensorflow/core/platform/types.h
@@ -63,9 +63,7 @@ typedef uint64 Fprint;
// Alias namespace ::stream_executor as ::tensorflow::se.
namespace stream_executor {}
namespace tensorflow {
-// TODO(b/77980417): Uncomment this once all namespace aliases named 'se' are
-// removed in ::xla.
-// namespace se = ::stream_executor;
+namespace se = ::stream_executor;
} // namespace tensorflow
#endif // TENSORFLOW_PLATFORM_TYPES_H_
diff --git a/tensorflow/core/platform/windows/cpu_info.h b/tensorflow/core/platform/windows/cpu_info.h
index f20939d3c0..ba2126abcf 100644
--- a/tensorflow/core/platform/windows/cpu_info.h
+++ b/tensorflow/core/platform/windows/cpu_info.h
@@ -19,13 +19,4 @@ limitations under the License.
// included so __cpuidex function is available for GETCPUID on Windows
#include <intrin.h>
-// Byte order defines provided by gcc. MSVC doesn't define those so
-// we define them here.
-// We assume that all windows platform out there are little endian.
-#if defined(_MSC_VER) && !defined(__clang__)
-#define __ORDER_LITTLE_ENDIAN__ 0x4d2
-#define __ORDER_BIG_ENDIAN__ 0x10e1
-#define __BYTE_ORDER__ __ORDER_LITTLE_ENDIAN__
-#endif
-
#endif // TENSORFLOW_PLATFORM_WINDOWS_CPU_INFO_H_
diff --git a/tensorflow/core/protobuf/eager_service.proto b/tensorflow/core/protobuf/eager_service.proto
new file mode 100644
index 0000000000..c2325cc803
--- /dev/null
+++ b/tensorflow/core/protobuf/eager_service.proto
@@ -0,0 +1,158 @@
+syntax = "proto3";
+
+package tensorflow.eager;
+
+import "tensorflow/core/framework/attr_value.proto";
+import "tensorflow/core/framework/device_attributes.proto";
+import "tensorflow/core/framework/function.proto";
+import "tensorflow/core/framework/versions.proto";
+import "tensorflow/core/protobuf/tensorflow_server.proto";
+
+message RemoteTensorHandle {
+ // The ID of the operation that produced this tensor.
+ int64 op_id = 1;
+ // The index into the outputs of the operation that produced this tensor.
+ int32 output_num = 2;
+}
+
+// A proto representation of an eager operation.
+message Operation {
+ // A unique identifier for the operation. Set by the client so that the client
+ // can uniquely identify the outputs of the scheduled operation.
+ //
+ // In the initial implementation, sending duplicate IDs has undefined
+ // behaviour, but additional constraints may be placed upon this in the
+ // future.
+ int64 id = 1;
+ string name = 2;
+ repeated RemoteTensorHandle inputs = 3;
+
+ // Control Operation IDs that will be respected when ops are re-ordered by
+ // async execution. If async execution (+ op re-ordering) is not enabled, this
+ // should have no effect.
+ repeated int64 control_op_ids = 4;
+ map<string, AttrValue> attrs = 5;
+ string device = 6;
+}
+
+message QueueItem {
+ // The remote executor should be able to handle either executing ops directly,
+ // or releasing any unused tensor handles, since the tensor lifetime is
+ // maintained by the client.
+ oneof item {
+ RemoteTensorHandle handle_to_decref = 1;
+ Operation operation = 2;
+ }
+}
+
+message CreateContextRequest {
+ // Identifies the full cluster, and this particular worker's position within.
+ ServerDef server_def = 1;
+
+ // Whether the ops on the worker should be executed synchronously or
+ // asynchronously. By default, ops are executed synchronously.
+ bool async = 2;
+
+ // Number of seconds to keep the context alive. If more than keep_alive_secs
+ // has passed since a particular context has been communicated with, it will
+ // be garbage collected.
+ int64 keep_alive_secs = 3;
+
+ // This is the version for all the ops that will be enqueued by the client.
+ VersionDef version_def = 4;
+}
+
+message CreateContextResponse {
+ // The ID of the created context. This is usually a randomly generated number,
+ // that will be used to identify the context in future requests to the
+ // service. Contexts are not persisted through server restarts.
+ fixed64 context_id = 1;
+
+ // List of devices that are locally accessible to the worker.
+ repeated DeviceAttributes device_attributes = 2;
+}
+
+message EnqueueRequest {
+ fixed64 context_id = 1;
+
+ repeated QueueItem queue = 3;
+}
+
+message EnqueueResponse {
+}
+
+message WaitQueueDoneRequest {
+ fixed64 context_id = 1;
+
+ // Ids to wait on. If empty, wait on everything currently pending.
+ repeated int64 op_id = 2;
+}
+
+message WaitQueueDoneResponse {
+ // TODO(nareshmodi): Consider adding NodeExecStats here to be able to
+ // propagate some stats.
+}
+
+message KeepAliveRequest {
+ fixed64 context_id = 1;
+}
+
+message KeepAliveResponse {
+}
+
+message CloseContextRequest {
+ fixed64 context_id = 1;
+}
+
+message CloseContextResponse {
+}
+
+message RegisterFunctionRequest {
+ fixed64 context_id = 1;
+
+ FunctionDef function_def = 2;
+}
+
+message RegisterFunctionResponse {
+}
+
+////////////////////////////////////////////////////////////////////////////////
+//
+// Eager Service defines a TensorFlow service that executes operations eagerly
+// on a set of local devices, on behalf of a remote Eager executor.
+//
+// The service impl will keep track of the various peers and devices it has
+// access to and allows the client to enqueue ops on any devices that it is able
+// to access and schedule data transfers from/to any of the peers.
+//
+////////////////////////////////////////////////////////////////////////////////
+service EagerService {
+ // This initializes the worker, informing it about the other workers in the
+ // cluster and exchanging authentication tokens which will be used in all
+ // other RPCs to detect whether the worker has restarted.
+ rpc CreateContext(CreateContextRequest) returns (CreateContextResponse);
+
+ // This takes a list of Execute and DeleteTensorHandle operations and enqueues
+ // (in async mode) or executes (in sync mode) them on the remote server.
+ // All outputs of ops which were not explicitly deleted with
+ // DeleteTensorHandle entries will be assumed to be alive and are usable by
+ // future calls to Enqueue.
+ rpc Enqueue(EnqueueRequest) returns (EnqueueResponse);
+
+ // Takes a set of op IDs and waits until those ops are done. Returns any error
+ // in the stream so far.
+ rpc WaitQueueDone(WaitQueueDoneRequest) returns (WaitQueueDoneResponse);
+
+ // Contexts are always created with a deadline and no RPCs within a deadline
+ // will trigger a context garbage collection. KeepAlive calls can be used to
+ // delay this.
+ rpc KeepAlive(KeepAliveRequest) returns (KeepAliveResponse);
+
+ // Closes the context. No calls to other methods using the existing context ID
+ // are valid after this.
+ rpc CloseContext(CloseContextRequest) returns (CloseContextResponse);
+
+ // Takes a FunctionDef and makes it enqueable on the remote worker.
+ rpc RegisterFunction(RegisterFunctionRequest)
+ returns (RegisterFunctionResponse);
+}
diff --git a/tensorflow/core/util/rpc/call_container.h b/tensorflow/core/util/rpc/call_container.h
index 7f36056797..e1226a7f16 100644
--- a/tensorflow/core/util/rpc/call_container.h
+++ b/tensorflow/core/util/rpc/call_container.h
@@ -26,53 +26,60 @@ limitations under the License.
namespace tensorflow {
-template <typename Call>
+namespace internal {
+// The following class is used for coordination between a `CallContainer`
+// instance and a cancellation callback to make sure that the `CallContainer`
+// instance waits for the cancellation callback to be destroyed (either because
+// a cancellation occurred or because the callback was deregistered) before
+// deleting itself. Without this coordination the cancellation callback could
+// attempt to access a `CallContainer` instance that is no longer valid.
+class NotifyWhenDestroyed {
+ public:
+ explicit NotifyWhenDestroyed(std::shared_ptr<Notification> notification)
+ : notification_(std::move(notification)) {}
+
+ ~NotifyWhenDestroyed() { notification_->Notify(); }
+
+ private:
+ std::shared_ptr<Notification> notification_;
+};
+} // namespace internal
+
+// The following class is responsible for the life cycle management of a set of
+// RPC calls. The calls are started when an instance of the class is created and
+// the class contract guarantees to invoke a "done" callback provided by the
+// caller when all RPC calls have either completed or been cancelled.
+//
+// The caller should not make any assumptions about the validity of an instance
+// of this class after the provided callback has been invoked, which may be
+// immediately after the instance was created.
+template <class Call>
class CallContainer {
public:
+ typedef std::function<void(CallContainer<Call>*, int)> CreateCallFn;
+ typedef std::function<void(Call*)> StartCallFn;
+
+ // Uses the provided `create_call_fn` and `start_call_fn` functions to create
+ // and start a set of RPC calls. When all RPC calls have either completed or
+ // been cancelled, the `done` callback is invoked. The caller should not make
+ // any assumptions about the validity of the created instance as the instance
+ // will delete itself after invoking the `done` callback.
explicit CallContainer(OpKernelContext* ctx, int num_calls, bool fail_fast,
bool try_rpc, AsyncOpKernel::DoneCallback done,
- CancellationToken token)
- : ctx_(ctx),
- done_(std::move(done)),
- token_(token),
- fail_fast_(fail_fast),
- try_rpc_(try_rpc) {
- CHECK_GT(num_calls, 0);
-
- // This will run when all RPCs are finished.
- reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
- ctx_->cancellation_manager()->DeregisterCallback(token_);
- ctx_->SetStatus(s);
- done_();
- delete this;
- });
-
- // Subtract reference count from the initial creation.
- core::ScopedUnref unref(reffed_status_callback_);
-
- for (int i = 0; i < num_calls; ++i) {
- // Increase the reference on the callback for each new RPC.
- reffed_status_callback_->Ref();
- }
- }
+ CreateCallFn create_call_fn,
+ StartCallFn start_call_fn);
- std::list<Call>* calls() { return &calls_; }
+ // Registers a call with this container. This method expects its arguments to
+ // match those of a `Call` constructor as it forwards them to an underlying
+ // collection, which creates a `Call` instance in place.
+ template <class... Args>
+ void RegisterCall(Args&&... args);
- void StartCancel() {
- // Once this loop is done, can no longer assume anything is valid
- // because "delete this" may have been immediately called.
- // Nothing should run after this loop.
- for (auto& call : calls_) {
- call.StartCancel();
- }
- }
+ // Starts the cancellation of all RPC calls managed by this container.
+ void StartCancel();
- void Done(const Status& s, int index) {
- if (!try_rpc_) {
- reffed_status_callback_->UpdateStatus(s);
- }
- reffed_status_callback_->Unref();
- }
+ // Indicates that the `index`-th RPC call has finished.
+ void Done(const Status& s, int index);
private:
OpKernelContext* ctx_;
@@ -81,10 +88,88 @@ class CallContainer {
const CancellationToken token_;
const bool fail_fast_;
const bool try_rpc_;
+ std::shared_ptr<Notification> callback_destroyed_;
// Performs its own reference counting.
ReffedStatusCallback* reffed_status_callback_;
};
+template <class Call>
+CallContainer<Call>::CallContainer(
+ OpKernelContext* ctx, int num_calls, bool fail_fast, bool try_rpc,
+ AsyncOpKernel::DoneCallback done,
+ typename CallContainer<Call>::CreateCallFn create_call_fn,
+ typename CallContainer<Call>::StartCallFn start_call_fn)
+ : ctx_(ctx),
+ done_(std::move(done)),
+ token_(ctx->cancellation_manager()->get_cancellation_token()),
+ fail_fast_(fail_fast),
+ try_rpc_(try_rpc),
+ callback_destroyed_(new Notification) {
+ CHECK_GT(num_calls, 0);
+
+ // This will run when all RPCs are finished.
+ reffed_status_callback_ = new ReffedStatusCallback([this](const Status& s) {
+ ctx_->cancellation_manager()->DeregisterCallback(token_);
+ ctx_->SetStatus(s);
+ done_();
+ callback_destroyed_->WaitForNotification();
+ delete this;
+ });
+
+ // The cancellation callback needs to be registered before the RPC calls are
+ // started to make sure that the callback is properly cleaned up by the
+ // `reffed_status_callback` when all calls complete. At the same time, the
+ // cancellation callback should wait for the RPC calls to be started for the
+ // cancellation to take effect.
+ std::shared_ptr<internal::NotifyWhenDestroyed> notify_when_destroyed(
+ new internal::NotifyWhenDestroyed(callback_destroyed_));
+ std::shared_ptr<Notification> calls_started(new Notification);
+ bool is_cancelled = !ctx_->cancellation_manager()->RegisterCallback(
+ token_, [this, calls_started, notify_when_destroyed]() {
+ calls_started->WaitForNotification();
+ StartCancel();
+ });
+
+ for (int i = 0; i < num_calls; ++i) {
+ create_call_fn(this, i);
+ // Increase the reference on the callback for each new RPC.
+ reffed_status_callback_->Ref();
+ }
+ for (Call& call : calls_) {
+ start_call_fn(&call);
+ }
+ calls_started->Notify();
+
+ if (is_cancelled) {
+ ctx_->SetStatus(errors::Cancelled("Operation has been cancelled."));
+ StartCancel();
+ }
+
+ // Subtract reference count from the initial creation.
+ reffed_status_callback_->Unref();
+}
+
+template <class Call>
+template <class... Args>
+void CallContainer<Call>::RegisterCall(Args&&... args) {
+ calls_.emplace_back(std::forward<Args>(args)...);
+}
+
+template <class Call>
+void CallContainer<Call>::StartCancel() {
+ for (auto& call : calls_) {
+ call.StartCancel();
+ }
+}
+
+template <class Call>
+void CallContainer<Call>::Done(const Status& s, int index) {
+ if (!try_rpc_) {
+ reffed_status_callback_->UpdateStatus(s);
+ }
+ reffed_status_callback_->Unref();
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_CORE_UTIL_RPC_CALL_CONTAINER_H_
diff --git a/tensorflow/core/util/rpc/rpc_factory.h b/tensorflow/core/util/rpc/rpc_factory.h
index 9bf078c0f4..c4eaaf4457 100644
--- a/tensorflow/core/util/rpc/rpc_factory.h
+++ b/tensorflow/core/util/rpc/rpc_factory.h
@@ -32,10 +32,11 @@ class RPCFactory {
RPCFactory() {}
virtual ~RPCFactory() {}
- // Start a Call() to methods `method_t` at addresses `address_t` with
+ // Asynchronously invokes methods `method_t` at addresses `address_t` with
// request strings from `request_t`. Any of these may be scalar
// Tensors, in which case the operands are broadcasted.
- // Upon completion of all requests, `response_t` will be populated.
+ // Upon completion of all requests, `response_t` will be populated and the
+ // `done` callback will be invoked.
//
// If `try_rpc` is `true`, then `status_message_t` and
// `status_code_t` will be populated as well.
diff --git a/tensorflow/core/util/stream_executor_util.h b/tensorflow/core/util/stream_executor_util.h
index f7767ace71..4787bcf6de 100644
--- a/tensorflow/core/util/stream_executor_util.h
+++ b/tensorflow/core/util/stream_executor_util.h
@@ -30,21 +30,9 @@ class StreamExecutorUtil {
// Map a Tensor as a DeviceMemory object wrapping the given typed
// buffer.
template <typename T>
- static perftools::gputools::DeviceMemory<T> AsDeviceMemory(const Tensor& t) {
+ static se::DeviceMemory<T> AsDeviceMemory(const Tensor& t) {
T* ptr = reinterpret_cast<T*>(const_cast<char*>(t.tensor_data().data()));
- return perftools::gputools::DeviceMemory<T>(
- perftools::gputools::DeviceMemoryBase(ptr, t.TotalBytes()));
- }
-
- // Converts from a StreamExecutor Status to a TensorFlow Status.
- //
- // This assumes that the error codes between the two implementations
- // match.
- static Status ConvertStatus(const perftools::gputools::port::Status& s) {
- return s.ok() ? Status::OK()
- : Status(static_cast<tensorflow::error::Code>(
- static_cast<int>(s.code())),
- s.error_message());
+ return se::DeviceMemory<T>(se::DeviceMemoryBase(ptr, t.TotalBytes()));
}
};
diff --git a/tensorflow/core/util/use_cudnn.cc b/tensorflow/core/util/use_cudnn.cc
index d7d03f151e..c119df6419 100644
--- a/tensorflow/core/util/use_cudnn.cc
+++ b/tensorflow/core/util/use_cudnn.cc
@@ -22,9 +22,9 @@ limitations under the License.
namespace tensorflow {
-#define ADD_CUDNN_FLAG(func_name, flag_name, default_value) \
+#define ADD_BOOL_CUDNN_FLAG(func_name, flag_name, default_value) \
bool func_name() { \
- bool value; \
+ bool value = default_value; \
Status status = ReadBoolFromEnvVar(#flag_name, default_value, &value); \
if (!status.ok()) { \
LOG(ERROR) << status; \
@@ -32,12 +32,44 @@ namespace tensorflow {
return value; \
}
-ADD_CUDNN_FLAG(CanUseCudnn, TF_USE_CUDNN, true);
-ADD_CUDNN_FLAG(CudnnUseAutotune, TF_CUDNN_USE_AUTOTUNE, true);
-ADD_CUDNN_FLAG(CudnnDisableConv1x1Optimization,
- TF_CUDNN_DISABLE_CONV_1X1_OPTIMIZATION, false);
+ADD_BOOL_CUDNN_FLAG(CanUseCudnn, TF_USE_CUDNN, true);
+ADD_BOOL_CUDNN_FLAG(CudnnUseAutotune, TF_CUDNN_USE_AUTOTUNE, true);
+// Whether to auto-tuning Cudnn RNN forward and backward pass to pick
+// statistically the best cudnnRNNAlgo_t and cudnnMathType_t.
+// The flag is disabled when TF_DEBUG_CUDNN_RNN is turned on.
+ADD_BOOL_CUDNN_FLAG(CudnnRnnUseAutotune, TF_CUDNN_RNN_USE_AUTOTUNE, true);
+ADD_BOOL_CUDNN_FLAG(CudnnDisableConv1x1Optimization,
+ TF_CUDNN_DISABLE_CONV_1X1_OPTIMIZATION, false);
-#undef ADD_CUDNN_FLAG
+// Whether to run Cudnn RNN forward and backward in debug mode, where users can
+// force a specified cudnnRNNAlgo_t and cudnnMathType_t, when used together with
+// the following two env vars:
+// TF_DEBUG_CUDNN_RNN_USE_TENSOR_OPS
+// TF_DEBUG_CUDNN_RNN_ALGO
+// By default it is disabled and only intended for testing and profiling.
+ADD_BOOL_CUDNN_FLAG(DebugCudnnRnn, TF_DEBUG_CUDNN_RNN, false);
+// If using TENSOR_OP_MATH in Cudnn RNN for both forward and backward pass. Only
+// effective when TF_DEBUG_CUDNN_RNN is true.
+// Note none of the persistent RNN algorithm support TENSOR_OP_MATH before
+// Cudnn 7.1. See Nvidia Cudnn manual for more details.
+ADD_BOOL_CUDNN_FLAG(DebugCudnnRnnUseTensorOps,
+ TF_DEBUG_CUDNN_RNN_USE_TENSOR_OPS, false);
+#undef ADD_BOOL_CUDNN_FLAG
+
+#define ADD_INT64_CUDNN_FLAG(func_name, flag_name, default_value) \
+ int64 func_name() { \
+ int64 value = default_value; \
+ Status status = ReadInt64FromEnvVar(#flag_name, default_value, &value); \
+ if (!status.ok()) { \
+ LOG(ERROR) << status; \
+ } \
+ return value; \
+ }
+// Cudnn RNN algorithm to use for both forward and backward pass. Only effective
+// when TF_DEBUG_CUDNN_RNN is true. See Nvidia Cudnn manual for allowed
+// cudnnRNNAlgo_t.
+ADD_INT64_CUDNN_FLAG(DebugCudnnRnnAlgo, TF_DEBUG_CUDNN_RNN_ALGO, -1);
+#undef ADD_INT64_CUDNN_FLAG
FP16ConvMode CudnnConvComputeMode() {
string value;
diff --git a/tensorflow/core/util/use_cudnn.h b/tensorflow/core/util/use_cudnn.h
index a39a032e3f..f8cc5944d7 100644
--- a/tensorflow/core/util/use_cudnn.h
+++ b/tensorflow/core/util/use_cudnn.h
@@ -15,8 +15,10 @@ limitations under the License.
// The utility to check Cudnn dependency and set Cudnn-related flags.
-#ifndef TENSORFLOW_UTIL_USE_CUDNN_H_
-#define TENSORFLOW_UTIL_USE_CUDNN_H_
+#ifndef TENSORFLOW_CORE_UTIL_USE_CUDNN_H_
+#define TENSORFLOW_CORE_UTIL_USE_CUDNN_H_
+
+#include "tensorflow/core/platform/types.h"
namespace tensorflow {
@@ -31,9 +33,12 @@ enum class FP16ConvMode {
bool CanUseCudnn();
bool CudnnUseAutotune();
+bool CudnnRnnUseAutotune();
bool CudnnDisableConv1x1Optimization();
FP16ConvMode CudnnConvComputeMode();
-
+bool DebugCudnnRnn();
+bool DebugCudnnRnnUseTensorOps();
+int64 DebugCudnnRnnAlgo();
} // namespace tensorflow
-#endif // TENSORFLOW_UTIL_USE_CUDNN_H_
+#endif // TENSORFLOW_CORE_UTIL_USE_CUDNN_H_
diff --git a/tensorflow/docs_src/get_started/feature_columns.md b/tensorflow/docs_src/get_started/feature_columns.md
index d8e4bec863..9c777a0077 100644
--- a/tensorflow/docs_src/get_started/feature_columns.md
+++ b/tensorflow/docs_src/get_started/feature_columns.md
@@ -364,7 +364,7 @@ def make_dataset(latitude, longitude, labels):
return tf.data.Dataset.from_tensor_slices((features, labels))
-# Bucketize the latitude and longitude usig the `edges`
+# Bucketize the latitude and longitude using the `edges`
latitude_bucket_fc = tf.feature_column.bucketized_column(
tf.feature_column.numeric_column('latitude'),
list(atlanta.latitude.edges))
diff --git a/tensorflow/docs_src/install/install_java.md b/tensorflow/docs_src/install/install_java.md
index 05b2878701..6a4ac29088 100644
--- a/tensorflow/docs_src/install/install_java.md
+++ b/tensorflow/docs_src/install/install_java.md
@@ -65,7 +65,11 @@ As an example, these steps will create a Maven project that uses TensorFlow:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>tensorflow</artifactId>
+<<<<<<< HEAD
<version>1.8.0-rc1</version>
+=======
+ <version>1.8.0-rc0</version>
+>>>>>>> 43a7072882196c7ac2d9429050a3140b1ecb52db
</dependency>
</dependencies>
</project>
@@ -124,12 +128,20 @@ instead:
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow</artifactId>
+<<<<<<< HEAD
<version>1.8.0-rc1</version>
+=======
+ <version>1.8.0-rc0</version>
+>>>>>>> 43a7072882196c7ac2d9429050a3140b1ecb52db
</dependency>
<dependency>
<groupId>org.tensorflow</groupId>
<artifactId>libtensorflow_jni_gpu</artifactId>
+<<<<<<< HEAD
<version>1.8.0-rc1</version>
+=======
+ <version>1.8.0-rc0</version>
+>>>>>>> 43a7072882196c7ac2d9429050a3140b1ecb52db
</dependency>
```
@@ -148,7 +160,11 @@ refer to the simpler instructions above instead.
Take the following steps to install TensorFlow for Java on Linux or macOS:
1. Download
+<<<<<<< HEAD
[libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.8.0-rc1.jar),
+=======
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.8.0-rc0.jar),
+>>>>>>> 43a7072882196c7ac2d9429050a3140b1ecb52db
which is the TensorFlow Java Archive (JAR).
2. Decide whether you will run TensorFlow for Java on CPU(s) only or with
@@ -167,7 +183,11 @@ Take the following steps to install TensorFlow for Java on Linux or macOS:
OS=$(uname -s | tr '[:upper:]' '[:lower:]')
mkdir -p ./jni
curl -L \
+<<<<<<< HEAD
"https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.8.0-rc1.tar.gz" |
+=======
+ "https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-${TF_TYPE}-${OS}-x86_64-1.8.0-rc0.tar.gz" |
+>>>>>>> 43a7072882196c7ac2d9429050a3140b1ecb52db
tar -xz -C ./jni
### Install on Windows
@@ -175,10 +195,17 @@ Take the following steps to install TensorFlow for Java on Linux or macOS:
Take the following steps to install TensorFlow for Java on Windows:
1. Download
+<<<<<<< HEAD
[libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.8.0-rc1.jar),
which is the TensorFlow Java Archive (JAR).
2. Download the following Java Native Interface (JNI) file appropriate for
[TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.8.0-rc1.zip).
+=======
+ [libtensorflow.jar](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow-1.8.0-rc0.jar),
+ which is the TensorFlow Java Archive (JAR).
+ 2. Download the following Java Native Interface (JNI) file appropriate for
+ [TensorFlow for Java on Windows](https://storage.googleapis.com/tensorflow/libtensorflow/libtensorflow_jni-cpu-windows-x86_64-1.8.0-rc0.zip).
+>>>>>>> 43a7072882196c7ac2d9429050a3140b1ecb52db
3. Extract this .zip file.
@@ -227,7 +254,11 @@ must be part of your `classpath`. For example, you can include the
downloaded `.jar` in your `classpath` by using the `-cp` compilation flag
as follows:
+<<<<<<< HEAD
<pre><b>javac -cp libtensorflow-1.8.0-rc1.jar HelloTF.java</b></pre>
+=======
+<pre><b>javac -cp libtensorflow-1.8.0-rc0.jar HelloTF.java</b></pre>
+>>>>>>> 43a7072882196c7ac2d9429050a3140b1ecb52db
### Running
@@ -241,11 +272,19 @@ two files are available to the JVM:
For example, the following command line executes the `HelloTF` program on Linux
and macOS X:
+<<<<<<< HEAD
<pre><b>java -cp libtensorflow-1.8.0-rc1.jar:. -Djava.library.path=./jni HelloTF</b></pre>
And the following command line executes the `HelloTF` program on Windows:
<pre><b>java -cp libtensorflow-1.8.0-rc1.jar;. -Djava.library.path=jni HelloTF</b></pre>
+=======
+<pre><b>java -cp libtensorflow-1.8.0-rc0.jar:. -Djava.library.path=./jni HelloTF</b></pre>
+
+And the following command line executes the `HelloTF` program on Windows:
+
+<pre><b>java -cp libtensorflow-1.8.0-rc0.jar;. -Djava.library.path=jni HelloTF</b></pre>
+>>>>>>> 43a7072882196c7ac2d9429050a3140b1ecb52db
If the program prints <tt>Hello from <i>version</i></tt>, you've successfully
installed TensorFlow for Java and are ready to use the API. If the program
diff --git a/tensorflow/docs_src/install/install_sources.md b/tensorflow/docs_src/install/install_sources.md
index a5b05491af..5c5c9e057b 100644
--- a/tensorflow/docs_src/install/install_sources.md
+++ b/tensorflow/docs_src/install/install_sources.md
@@ -393,9 +393,9 @@ If you are new to TensorFlow, see @{$get_started/premade_estimators$Getting Star
If the system outputs an error message instead of a greeting, see [Common
installation problems](#common_installation_problems).
-## Common installation problems
+## Common build and installation problems
-The installation problems you encounter typically depend on the
+The build and installation problems you encounter typically depend on the
operating system. See the "Common installation problems" section
of one of the following guides:
@@ -448,6 +448,11 @@ Stack Overflow and specify the `tensorflow` tag.
</td>
</tr>
+<tr>
+ <td><a href="https://stackoverflow.com/q/47080760">47080760</a></td>
+ <td><pre>undefined reference to `cublasGemmEx@libcublas.so.9.0'</pre></td>
+</tr>
+
</table>
## Tested source configurations
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 8373a1219d..f530fe1206 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -25,7 +25,7 @@ Calculates gradients of batch norm.
<b> `BatchNormGrad(operand, scale, mean, variance, grad_output, epsilon, feature_index)` </b>
| Arguments | Type | Semantics |
-| -------------- | ----------------------- | -------------------------------- |
+| --------------- | ----------------------- | -------------------------------- |
| `operand` | `ComputationDataHandle` | n dimensional array to be |
: : : normalized (x) :
| `scale` | `ComputationDataHandle` | 1 dimensional array |
@@ -45,31 +45,37 @@ feature dimension in `operand`), the operation calculates the gradients with
respect to `operand`, `offset` and `scale` across all the other dimensions. The
`feature_index` must be a valid index for the feature dimension in `operand`.
-The three gradients are defined by the following formulas (Assuming a
-4-dimensional tensor as `operand` and (l) is the index for feature dimension):
-
-\\( coef_l = \frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h (\nabla y_{ijkl} * (x_{ijkl} - \mu_l) / (\sigma^2_{l}+\epsilon)) \\)
-
-\\( \nabla x_{ijkl} = \gamma_{l} * (1/\sqrt{\sigma^2_{l}+\epsilon}) * [\nabla y_{ijkl} - mean(\nabla y) - (x_{ijkl} - \mu_{l}) * coef_l] \\)
-
-\\( \nabla \beta_l = \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} \\)
-
-\\( \nabla \gamma_l = \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl} * ((x_{ijkl} - \mu_l) / \sqrt{\sigma^2_{l}+\epsilon}) \\)
-
-The inputs `mean` and `variance` represents moments value
+The three gradients are defined by the following formulas (assuming a
+4-dimensional tensor as `operand` and with feature dimension index \\(l\\),
+batch size `m` and spatial sizes `w` and `h`):
+
+\\[ \begin{split} c_l&=
+\frac{1}{mwh}\sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h
+\left( \nabla y_{ijkl} \frac{x_{ijkl} - \mu_l}{\sigma^2_l+\epsilon} \right)
+\\\\
+\nabla x_{ijkl} &= \frac{\gamma_{l}}{\sqrt{\sigma^2_{l}+\epsilon}}
+\left( \nabla y_{ijkl} - \mathrm{mean}(\nabla y) - c_l (x_{ijkl} - \mu_{l})
+\right)
+\\\\
+\nabla \gamma_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \left( \nabla y_{ijkl}
+\frac{x_{ijkl} - \mu_l}{\sqrt{\sigma^2_{l}+\epsilon}} \right)
+\\\\\
+\nabla \beta_l &= \sum_{i=1}^m\sum_{j=1}^w\sum_{k=1}^h \nabla y_{ijkl}
+\end{split} \\]
+
+The inputs `mean` and `variance` represent moments value
across batch and spatial dimensions.
The output type is a tuple of three handles:
-|Outputs | Type | Semantics |
-|------------- | ----------------------- | ------------------------------------ |
-|`grad_operand`| `ComputationDataHandle` | gradient with respect to input |
-: : : `operand` (\\( \nabla x\\)) :
-|`grad_scale` | `ComputationDataHandle` | gradient with respect to input |
-: : : `scale` (\\( \nabla \gamma\\)) :
-|`grad_offset` | `ComputationDataHandle` | gradient with respect to input |
-: : : `offset`(\\( \nabla \beta\\)) :
-
+| Outputs | Type | Semantics |
+| ------------- | ----------------------- | --------------------------------- |
+| `grad_operand` | `ComputationDataHandle` | gradient with respect to input |
+: : : `operand` (\\( \nabla x\\)) :
+| `grad_scale` | `ComputationDataHandle` | gradient with respect to input |
+: : : `scale` (\\( \nabla \gamma\\)) :
+| `grad_offset` | `ComputationDataHandle` | gradient with respect to input |
+: : : `offset`(\\( \nabla \beta\\)) :
## BatchNormInference
@@ -440,13 +446,11 @@ area and a computation is performed for each possible position of the window.
| `lhs` | `ComputationDataHandle` | rank n+2 array of inputs |
| `rhs` | `ComputationDataHandle` | rank n+2 array of kernel |
: : : weights :
-| `window_strides` | `ArraySlice<int64>` | size n array of kernel strides|
-| `padding` | `ArraySlice<pair<int64, | size n array of (low, high) |
+| `window_strides` | `ArraySlice<int64>` | n-d array of kernel strides |
+| `padding` | `ArraySlice<pair<int64, | n-d array of (low, high) |
: : int64>>` : padding :
-| `lhs_dilation` | `ArraySlice<int64>` | size n lhs dilation factor |
-: : : array |
-| `rhs_dilation` | `ArraySlice<int64>` | size n rhs dilation factor
-: : : array |
+| `lhs_dilation` | `ArraySlice<int64>` | n-d lhs dilation factor array |
+| `rhs_dilation` | `ArraySlice<int64>` | n-d rhs dilation factor array |
Let n be the number of spatial dimensions. The `lhs` argument is a rank n+2
array describing the base area. This is called the input, even though of course
diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go
index c31ca8b67a..9e87995441 100644
--- a/tensorflow/go/op/wrappers.go
+++ b/tensorflow/go/op/wrappers.go
@@ -2243,81 +2243,289 @@ func CheckNumerics(scope *Scope, tensor tf.Output, message string) (output tf.Ou
return op.Output(0)
}
-// Returns the complex conjugate of a complex number.
+// Gather slices from `params` into a Tensor with shape specified by `indices`.
//
-// Given a tensor `input` of complex numbers, this operation returns a tensor of
-// complex numbers that are the complex conjugate of each element in `input`. The
-// complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the
-// real part and *b* is the imaginary part.
+// `indices` is an K-dimensional integer tensor, best thought of as a
+// (K-1)-dimensional tensor of indices into `params`, where each element defines a
+// slice of `params`:
//
-// The complex conjugate returned by this operation is of the form \\(a - bj\\).
+// output[i_0, ..., i_{K-2}] = params[indices[i0, ..., i_{K-2}]]
//
-// For example:
+// Whereas in @{tf.gather} `indices` defines slices into the first
+// dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the
+// first `N` dimensions of `params`, where `N = indices.shape[-1]`.
//
+// The last dimension of `indices` can be at most the rank of
+// `params`:
+//
+// indices.shape[-1] <= params.rank
+//
+// The last dimension of `indices` corresponds to elements
+// (if `indices.shape[-1] == params.rank`) or slices
+// (if `indices.shape[-1] < params.rank`) along dimension `indices.shape[-1]`
+// of `params`. The output tensor has shape
+//
+// indices.shape[:-1] + params.shape[indices.shape[-1]:]
+//
+// Note that on CPU, if an out of bound index is found, an error is returned.
+// On GPU, if an out of bound index is found, a 0 is stored in the
+// corresponding output value.
+//
+// Some examples below.
+//
+// Simple indexing into a matrix:
+//
+// ```python
+// indices = [[0, 0], [1, 1]]
+// params = [['a', 'b'], ['c', 'd']]
+// output = ['a', 'd']
// ```
-// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
-// tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
+//
+// Slice indexing into a matrix:
+//
+// ```python
+// indices = [[1], [0]]
+// params = [['a', 'b'], ['c', 'd']]
+// output = [['c', 'd'], ['a', 'b']]
// ```
-func Conj(scope *Scope, input tf.Output) (output tf.Output) {
+//
+// Indexing into a 3-tensor:
+//
+// ```python
+// indices = [[1]]
+// params = [[['a0', 'b0'], ['c0', 'd0']],
+// [['a1', 'b1'], ['c1', 'd1']]]
+// output = [[['a1', 'b1'], ['c1', 'd1']]]
+//
+//
+// indices = [[0, 1], [1, 0]]
+// params = [[['a0', 'b0'], ['c0', 'd0']],
+// [['a1', 'b1'], ['c1', 'd1']]]
+// output = [['c0', 'd0'], ['a1', 'b1']]
+//
+//
+// indices = [[0, 0, 1], [1, 0, 1]]
+// params = [[['a0', 'b0'], ['c0', 'd0']],
+// [['a1', 'b1'], ['c1', 'd1']]]
+// output = ['b0', 'b1']
+// ```
+//
+// Batched indexing into a matrix:
+//
+// ```python
+// indices = [[[0, 0]], [[0, 1]]]
+// params = [['a', 'b'], ['c', 'd']]
+// output = [['a'], ['b']]
+// ```
+//
+// Batched slice indexing into a matrix:
+//
+// ```python
+// indices = [[[1]], [[0]]]
+// params = [['a', 'b'], ['c', 'd']]
+// output = [[['c', 'd']], [['a', 'b']]]
+// ```
+//
+// Batched indexing into a 3-tensor:
+//
+// ```python
+// indices = [[[1]], [[0]]]
+// params = [[['a0', 'b0'], ['c0', 'd0']],
+// [['a1', 'b1'], ['c1', 'd1']]]
+// output = [[[['a1', 'b1'], ['c1', 'd1']]],
+// [[['a0', 'b0'], ['c0', 'd0']]]]
+//
+// indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]]
+// params = [[['a0', 'b0'], ['c0', 'd0']],
+// [['a1', 'b1'], ['c1', 'd1']]]
+// output = [[['c0', 'd0'], ['a1', 'b1']],
+// [['a0', 'b0'], ['c1', 'd1']]]
+//
+//
+// indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]
+// params = [[['a0', 'b0'], ['c0', 'd0']],
+// [['a1', 'b1'], ['c1', 'd1']]]
+// output = [['b0', 'b1'], ['d0', 'c1']]
+// ```
+//
+// Arguments:
+// params: The tensor from which to gather values.
+// indices: Index tensor.
+//
+// Returns Values from `params` gathered from indices given by `indices`, with
+// shape `indices.shape[:-1] + params.shape[indices.shape[-1]:]`.
+func GatherNd(scope *Scope, params tf.Output, indices tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "Conj",
+ Type: "GatherNd",
Input: []tf.Input{
- input,
+ params, indices,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// ResourceSparseApplyMomentumAttr is an optional argument to ResourceSparseApplyMomentum.
-type ResourceSparseApplyMomentumAttr func(optionalAttr)
+// GatherAttr is an optional argument to Gather.
+type GatherAttr func(optionalAttr)
-// ResourceSparseApplyMomentumUseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var and accum tensors will be protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceSparseApplyMomentumUseLocking(value bool) ResourceSparseApplyMomentumAttr {
+// GatherValidateIndices sets the optional validate_indices attribute to value.
+// If not specified, defaults to true
+func GatherValidateIndices(value bool) GatherAttr {
return func(m optionalAttr) {
- m["use_locking"] = value
+ m["validate_indices"] = value
}
}
-// ResourceSparseApplyMomentumUseNesterov sets the optional use_nesterov attribute to value.
+// Gather slices from `params` according to `indices`.
//
-// value: If `True`, the tensor passed to compute grad will be
-// var - lr * momentum * accum, so in the end, the var you get is actually
-// var - lr * momentum * accum.
-// If not specified, defaults to false
-func ResourceSparseApplyMomentumUseNesterov(value bool) ResourceSparseApplyMomentumAttr {
- return func(m optionalAttr) {
- m["use_nesterov"] = value
+// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
+// Produces an output tensor with shape `indices.shape + params.shape[1:]` where:
+//
+// ```python
+// # Scalar indices
+// output[:, ..., :] = params[indices, :, ... :]
+//
+// # Vector indices
+// output[i, :, ..., :] = params[indices[i], :, ... :]
+//
+// # Higher rank indices
+// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]
+// ```
+//
+// If `indices` is a permutation and `len(indices) == params.shape[0]` then
+// this operation will permute `params` accordingly.
+//
+// `validate_indices`: DEPRECATED. If this operation is assigned to CPU, values in
+// `indices` are always validated to be within range. If assigned to GPU,
+// out-of-bound indices result in safe but unspecified behavior, which may include
+// raising an error.
+//
+// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
+// <img style="width:100%" src="https://www.tensorflow.org/images/Gather.png" alt>
+// </div>
+func Gather(scope *Scope, params tf.Output, indices tf.Output, optional ...GatherAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Gather",
+ Input: []tf.Input{
+ params, indices,
+ },
+ Attrs: attrs,
}
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
-// Update relevant entries in '*var' and '*accum' according to the momentum scheme.
+// Creates a tensor filled with a scalar value.
//
-// Set use_nesterov = True if you want to use Nesterov momentum.
+// This operation creates a tensor of shape `dims` and fills it with `value`.
//
-// That is for rows we have grad for, we update var and accum as follows:
+// For example:
//
-// accum = accum * momentum + grad
-// var -= lr * accum
+// ```
+// # Output tensor has shape [2, 3].
+// fill([2, 3], 9) ==> [[9, 9, 9]
+// [9, 9, 9]]
+// ```
//
// Arguments:
-// var_: Should be from a Variable().
-// accum: Should be from a Variable().
-// lr: Learning rate. Must be a scalar.
-// grad: The gradient.
-// indices: A vector of indices into the first dimension of var and accum.
-// momentum: Momentum. Must be a scalar.
+// dims: 1-D. Represents the shape of the output tensor.
+// value: 0-D (scalar). Value to fill the returned tensor.
//
-// Returns the created operation.
-func ResourceSparseApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyMomentumAttr) (o *tf.Operation) {
+// @compatibility(numpy)
+// Equivalent to np.full
+// @end_compatibility
+func Fill(scope *Scope, dims tf.Output, value tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Fill",
+ Input: []tf.Input{
+ dims, value,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// EditDistanceAttr is an optional argument to EditDistance.
+type EditDistanceAttr func(optionalAttr)
+
+// EditDistanceNormalize sets the optional normalize attribute to value.
+//
+// value: boolean (if true, edit distances are normalized by length of truth).
+//
+// The output is:
+// If not specified, defaults to true
+func EditDistanceNormalize(value bool) EditDistanceAttr {
+ return func(m optionalAttr) {
+ m["normalize"] = value
+ }
+}
+
+// Computes the (possibly normalized) Levenshtein Edit Distance.
+//
+// The inputs are variable-length sequences provided by SparseTensors
+// (hypothesis_indices, hypothesis_values, hypothesis_shape)
+// and
+// (truth_indices, truth_values, truth_shape).
+//
+// The inputs are:
+//
+// Arguments:
+// hypothesis_indices: The indices of the hypothesis list SparseTensor.
+// This is an N x R int64 matrix.
+// hypothesis_values: The values of the hypothesis list SparseTensor.
+// This is an N-length vector.
+// hypothesis_shape: The shape of the hypothesis list SparseTensor.
+// This is an R-length vector.
+// truth_indices: The indices of the truth list SparseTensor.
+// This is an M x R int64 matrix.
+// truth_values: The values of the truth list SparseTensor.
+// This is an M-length vector.
+// truth_shape: truth indices, vector.
+//
+// Returns A dense float tensor with rank R - 1.
+//
+// For the example input:
+//
+// // hypothesis represents a 2x1 matrix with variable-length values:
+// // (0,0) = ["a"]
+// // (1,0) = ["b"]
+// hypothesis_indices = [[0, 0, 0],
+// [1, 0, 0]]
+// hypothesis_values = ["a", "b"]
+// hypothesis_shape = [2, 1, 1]
+//
+// // truth represents a 2x2 matrix with variable-length values:
+// // (0,0) = []
+// // (0,1) = ["a"]
+// // (1,0) = ["b", "c"]
+// // (1,1) = ["a"]
+// truth_indices = [[0, 1, 0],
+// [1, 0, 0],
+// [1, 0, 1],
+// [1, 1, 0]]
+// truth_values = ["a", "b", "c", "a"]
+// truth_shape = [2, 2, 2]
+// normalize = true
+//
+// The output will be:
+//
+// // output is a 2x2 matrix with edit distances normalized by truth lengths.
+// output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis
+// [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis
+func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values tf.Output, hypothesis_shape tf.Output, truth_indices tf.Output, truth_values tf.Output, truth_shape tf.Output, optional ...EditDistanceAttr) (output tf.Output) {
if scope.Err() != nil {
return
}
@@ -2326,13 +2534,14 @@ func ResourceSparseApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output,
a(attrs)
}
opspec := tf.OpSpec{
- Type: "ResourceSparseApplyMomentum",
+ Type: "EditDistance",
Input: []tf.Input{
- var_, accum, lr, grad, indices, momentum,
+ hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape,
},
Attrs: attrs,
}
- return scope.AddOperation(opspec)
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
// Clips tensor values to a specified min and max.
@@ -4548,62 +4757,6 @@ func QuantizedBatchNormWithGlobalNormalization(scope *Scope, t tf.Output, t_min
return op.Output(0), op.Output(1), op.Output(2)
}
-// HistogramFixedWidthAttr is an optional argument to HistogramFixedWidth.
-type HistogramFixedWidthAttr func(optionalAttr)
-
-// HistogramFixedWidthDtype sets the optional dtype attribute to value.
-// If not specified, defaults to DT_INT32
-func HistogramFixedWidthDtype(value tf.DataType) HistogramFixedWidthAttr {
- return func(m optionalAttr) {
- m["dtype"] = value
- }
-}
-
-// Return histogram of values.
-//
-// Given the tensor `values`, this operation returns a rank 1 histogram counting
-// the number of entries in `values` that fall into every bin. The bins are
-// equal width and determined by the arguments `value_range` and `nbins`.
-//
-// ```python
-// # Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
-// nbins = 5
-// value_range = [0.0, 5.0]
-// new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
-//
-// with tf.get_default_session() as sess:
-// hist = tf.histogram_fixed_width(new_values, value_range, nbins=5)
-// variables.global_variables_initializer().run()
-// sess.run(hist) => [2, 1, 1, 0, 2]
-// ```
-//
-// Arguments:
-// values: Numeric `Tensor`.
-// value_range: Shape [2] `Tensor` of same `dtype` as `values`.
-// values <= value_range[0] will be mapped to hist[0],
-// values >= value_range[1] will be mapped to hist[-1].
-// nbins: Scalar `int32 Tensor`. Number of histogram bins.
-//
-// Returns A 1-D `Tensor` holding histogram of values.
-func HistogramFixedWidth(scope *Scope, values tf.Output, value_range tf.Output, nbins tf.Output, optional ...HistogramFixedWidthAttr) (out tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "HistogramFixedWidth",
- Input: []tf.Input{
- values, value_range, nbins,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Adds Tensor 'bias' to Tensor 'input' for Quantized types.
//
// Broadcasts the values of bias on dimensions 0..N-2 of 'input'.
@@ -7020,38 +7173,107 @@ func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_ke
return sparse_indices, sparse_values, sparse_shapes, dense_values
}
-// Real-valued fast Fourier transform.
+// DecodeRawAttr is an optional argument to DecodeRaw.
+type DecodeRawAttr func(optionalAttr)
+
+// DecodeRawLittleEndian sets the optional little_endian attribute to value.
//
-// Computes the 1-dimensional discrete Fourier transform of a real-valued signal
-// over the inner-most dimension of `input`.
+// value: Whether the input `bytes` are in little-endian order.
+// Ignored for `out_type` values that are stored in a single byte like
+// `uint8`.
+// If not specified, defaults to true
+func DecodeRawLittleEndian(value bool) DecodeRawAttr {
+ return func(m optionalAttr) {
+ m["little_endian"] = value
+ }
+}
+
+// Reinterpret the bytes of a string as a vector of numbers.
//
-// Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the
-// `fft_length / 2 + 1` unique components of the FFT: the zero-frequency term,
-// followed by the `fft_length / 2` positive-frequency terms.
+// Arguments:
+// bytes: All the elements must have the same length.
//
-// Along the axis `RFFT` is computed on, if `fft_length` is smaller than the
-// corresponding dimension of `input`, the dimension is cropped. If it is larger,
-// the dimension is padded with zeros.
//
-// Arguments:
-// input: A float32 tensor.
-// fft_length: An int32 tensor of shape [1]. The FFT length.
+// Returns A Tensor with one more dimension than the input `bytes`. The
+// added dimension will have size equal to the length of the elements
+// of `bytes` divided by the number of bytes to represent `out_type`.
+func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...DecodeRawAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"out_type": out_type}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DecodeRaw",
+ Input: []tf.Input{
+ bytes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// Copy a tensor setting everything outside a central band in each innermost matrix
//
-// Returns A complex64 tensor of the same rank as `input`. The inner-most
-// dimension of `input` is replaced with the `fft_length / 2 + 1` unique
-// frequency components of its 1D Fourier transform.
+// to zero.
//
-// @compatibility(numpy)
-// Equivalent to np.fft.rfft
-// @end_compatibility
-func RFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) {
+// The `band` part is computed as follows:
+// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
+// tensor with the same shape where
+//
+// `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
+//
+// The indicator function
+//
+// `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
+// (num_upper < 0 || (n-m) <= num_upper)`.
+//
+// For example:
+//
+// ```
+// # if 'input' is [[ 0, 1, 2, 3]
+// [-1, 0, 1, 2]
+// [-2, -1, 0, 1]
+// [-3, -2, -1, 0]],
+//
+// tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3]
+// [-1, 0, 1, 2]
+// [ 0, -1, 0, 1]
+// [ 0, 0, -1, 0]],
+//
+// tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0]
+// [-1, 0, 1, 0]
+// [-2, -1, 0, 1]
+// [ 0, -2, -1, 0]]
+// ```
+//
+// Useful special cases:
+//
+// ```
+// tf.matrix_band_part(input, 0, -1) ==> Upper triangular part.
+// tf.matrix_band_part(input, -1, 0) ==> Lower triangular part.
+// tf.matrix_band_part(input, 0, 0) ==> Diagonal.
+// ```
+//
+// Arguments:
+// input: Rank `k` tensor.
+// num_lower: 0-D tensor. Number of subdiagonals to keep. If negative, keep entire
+// lower triangle.
+// num_upper: 0-D tensor. Number of superdiagonals to keep. If negative, keep
+// entire upper triangle.
+//
+// Returns Rank `k` tensor of the same shape as input. The extracted banded tensor.
+func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_upper tf.Output) (band tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "RFFT",
+ Type: "MatrixBandPart",
Input: []tf.Input{
- input, fft_length,
+ input, num_lower, num_upper,
},
}
op := scope.AddOperation(opspec)
@@ -8207,63 +8429,6 @@ func QuantizedReshape(scope *Scope, tensor tf.Output, shape tf.Output, input_min
return op.Output(0), op.Output(1), op.Output(2)
}
-// GatherAttr is an optional argument to Gather.
-type GatherAttr func(optionalAttr)
-
-// GatherValidateIndices sets the optional validate_indices attribute to value.
-// If not specified, defaults to true
-func GatherValidateIndices(value bool) GatherAttr {
- return func(m optionalAttr) {
- m["validate_indices"] = value
- }
-}
-
-// Gather slices from `params` according to `indices`.
-//
-// `indices` must be an integer tensor of any dimension (usually 0-D or 1-D).
-// Produces an output tensor with shape `indices.shape + params.shape[1:]` where:
-//
-// ```python
-// # Scalar indices
-// output[:, ..., :] = params[indices, :, ... :]
-//
-// # Vector indices
-// output[i, :, ..., :] = params[indices[i], :, ... :]
-//
-// # Higher rank indices
-// output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]
-// ```
-//
-// If `indices` is a permutation and `len(indices) == params.shape[0]` then
-// this operation will permute `params` accordingly.
-//
-// `validate_indices`: DEPRECATED. If this operation is assigned to CPU, values in
-// `indices` are always validated to be within range. If assigned to GPU,
-// out-of-bound indices result in safe but unspecified behavior, which may include
-// raising an error.
-//
-// <div style="width:70%; margin:auto; margin-bottom:10px; margin-top:20px;">
-// <img style="width:100%" src="https://www.tensorflow.org/images/Gather.png" alt>
-// </div>
-func Gather(scope *Scope, params tf.Output, indices tf.Output, optional ...GatherAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Gather",
- Input: []tf.Input{
- params, indices,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Returns the truth value of (x != y) element-wise.
//
// *NOTE*: `NotEqual` supports broadcasting. More about broadcasting
@@ -8386,6 +8551,98 @@ func StringSplit(scope *Scope, input tf.Output, delimiter tf.Output, optional ..
return op.Output(0), op.Output(1), op.Output(2)
}
+// ResourceSparseApplyMomentumAttr is an optional argument to ResourceSparseApplyMomentum.
+type ResourceSparseApplyMomentumAttr func(optionalAttr)
+
+// ResourceSparseApplyMomentumUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var and accum tensors will be protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceSparseApplyMomentumUseLocking(value bool) ResourceSparseApplyMomentumAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// ResourceSparseApplyMomentumUseNesterov sets the optional use_nesterov attribute to value.
+//
+// value: If `True`, the tensor passed to compute grad will be
+// var - lr * momentum * accum, so in the end, the var you get is actually
+// var - lr * momentum * accum.
+// If not specified, defaults to false
+func ResourceSparseApplyMomentumUseNesterov(value bool) ResourceSparseApplyMomentumAttr {
+ return func(m optionalAttr) {
+ m["use_nesterov"] = value
+ }
+}
+
+// Update relevant entries in '*var' and '*accum' according to the momentum scheme.
+//
+// Set use_nesterov = True if you want to use Nesterov momentum.
+//
+// That is for rows we have grad for, we update var and accum as follows:
+//
+// accum = accum * momentum + grad
+// var -= lr * accum
+//
+// Arguments:
+// var_: Should be from a Variable().
+// accum: Should be from a Variable().
+// lr: Learning rate. Must be a scalar.
+// grad: The gradient.
+// indices: A vector of indices into the first dimension of var and accum.
+// momentum: Momentum. Must be a scalar.
+//
+// Returns the created operation.
+func ResourceSparseApplyMomentum(scope *Scope, var_ tf.Output, accum tf.Output, lr tf.Output, grad tf.Output, indices tf.Output, momentum tf.Output, optional ...ResourceSparseApplyMomentumAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceSparseApplyMomentum",
+ Input: []tf.Input{
+ var_, accum, lr, grad, indices, momentum,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Returns the complex conjugate of a complex number.
+//
+// Given a tensor `input` of complex numbers, this operation returns a tensor of
+// complex numbers that are the complex conjugate of each element in `input`. The
+// complex numbers in `input` must be of the form \\(a + bj\\), where *a* is the
+// real part and *b* is the imaginary part.
+//
+// The complex conjugate returned by this operation is of the form \\(a - bj\\).
+//
+// For example:
+//
+// ```
+// # tensor 'input' is [-2.25 + 4.75j, 3.25 + 5.75j]
+// tf.conj(input) ==> [-2.25 - 4.75j, 3.25 - 5.75j]
+// ```
+func Conj(scope *Scope, input tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Conj",
+ Input: []tf.Input{
+ input,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// ResizeBilinearAttr is an optional argument to ResizeBilinear.
type ResizeBilinearAttr func(optionalAttr)
@@ -9464,6 +9721,14 @@ func ResourceApplyAdagradUseLocking(value bool) ResourceApplyAdagradAttr {
}
}
+// ResourceApplyAdagradUpdateSlots sets the optional update_slots attribute to value.
+// If not specified, defaults to true
+func ResourceApplyAdagradUpdateSlots(value bool) ResourceApplyAdagradAttr {
+ return func(m optionalAttr) {
+ m["update_slots"] = value
+ }
+}
+
// Update '*var' according to the adagrad scheme.
//
// accum += grad * grad
@@ -9799,6 +10064,305 @@ func BatchDataset(scope *Scope, input_dataset tf.Output, batch_size tf.Output, o
return op.Output(0)
}
+// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg.
+type DecodeAndCropJpegAttr func(optionalAttr)
+
+// DecodeAndCropJpegChannels sets the optional channels attribute to value.
+//
+// value: Number of color channels for the decoded image.
+// If not specified, defaults to 0
+func DecodeAndCropJpegChannels(value int64) DecodeAndCropJpegAttr {
+ return func(m optionalAttr) {
+ m["channels"] = value
+ }
+}
+
+// DecodeAndCropJpegRatio sets the optional ratio attribute to value.
+//
+// value: Downscaling ratio.
+// If not specified, defaults to 1
+func DecodeAndCropJpegRatio(value int64) DecodeAndCropJpegAttr {
+ return func(m optionalAttr) {
+ m["ratio"] = value
+ }
+}
+
+// DecodeAndCropJpegFancyUpscaling sets the optional fancy_upscaling attribute to value.
+//
+// value: If true use a slower but nicer upscaling of the
+// chroma planes (yuv420/422 only).
+// If not specified, defaults to true
+func DecodeAndCropJpegFancyUpscaling(value bool) DecodeAndCropJpegAttr {
+ return func(m optionalAttr) {
+ m["fancy_upscaling"] = value
+ }
+}
+
+// DecodeAndCropJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value.
+//
+// value: If true try to recover an image from truncated input.
+// If not specified, defaults to false
+func DecodeAndCropJpegTryRecoverTruncated(value bool) DecodeAndCropJpegAttr {
+ return func(m optionalAttr) {
+ m["try_recover_truncated"] = value
+ }
+}
+
+// DecodeAndCropJpegAcceptableFraction sets the optional acceptable_fraction attribute to value.
+//
+// value: The minimum required fraction of lines before a truncated
+// input is accepted.
+// If not specified, defaults to 1
+func DecodeAndCropJpegAcceptableFraction(value float32) DecodeAndCropJpegAttr {
+ return func(m optionalAttr) {
+ m["acceptable_fraction"] = value
+ }
+}
+
+// DecodeAndCropJpegDctMethod sets the optional dct_method attribute to value.
+//
+// value: string specifying a hint about the algorithm used for
+// decompression. Defaults to "" which maps to a system-specific
+// default. Currently valid values are ["INTEGER_FAST",
+// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal
+// jpeg library changes to a version that does not have that specific
+// option.)
+// If not specified, defaults to ""
+func DecodeAndCropJpegDctMethod(value string) DecodeAndCropJpegAttr {
+ return func(m optionalAttr) {
+ m["dct_method"] = value
+ }
+}
+
+// Decode and Crop a JPEG-encoded image to a uint8 tensor.
+//
+// The attr `channels` indicates the desired number of color channels for the
+// decoded image.
+//
+// Accepted values are:
+//
+// * 0: Use the number of channels in the JPEG-encoded image.
+// * 1: output a grayscale image.
+// * 3: output an RGB image.
+//
+// If needed, the JPEG-encoded image is transformed to match the requested number
+// of color channels.
+//
+// The attr `ratio` allows downscaling the image by an integer factor during
+// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than
+// downscaling the image later.
+//
+//
+// It is equivalent to a combination of decode and crop, but much faster by only
+// decoding partial jpeg image.
+//
+// Arguments:
+// contents: 0-D. The JPEG-encoded image.
+// crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width].
+//
+// Returns 3-D with shape `[height, width, channels]`..
+func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, optional ...DecodeAndCropJpegAttr) (image tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "DecodeAndCropJpeg",
+ Input: []tf.Input{
+ contents, crop_window,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler.
+type AllCandidateSamplerAttr func(optionalAttr)
+
+// AllCandidateSamplerSeed sets the optional seed attribute to value.
+//
+// value: If either seed or seed2 are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value.
+//
+// value: An second seed to avoid seed collision.
+// If not specified, defaults to 0
+func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Generates labels for candidate sampling with a learned unigram distribution.
+//
+// See explanations of candidate sampling and the data formats at
+// go/candidate-sampling.
+//
+// For each batch, this op picks a single set of sampled candidate labels.
+//
+// The advantages of sampling candidates per-batch are simplicity and the
+// possibility of efficient dense matrix multiplication. The disadvantage is that
+// the sampled candidates must be chosen independently of the context and of the
+// true labels.
+//
+// Arguments:
+// true_classes: A batch_size * num_true matrix, in which each row contains the
+// IDs of the num_true target_classes in the corresponding original label.
+// num_true: Number of true labels per context.
+// num_sampled: Number of candidates to produce.
+// unique: If unique is true, we sample with rejection, so that all sampled
+// candidates in a batch are unique. This requires some approximation to
+// estimate the post-rejection sampling probabilities.
+//
+// Returns A vector of length num_sampled, in which each element is
+// the ID of a sampled candidate.A batch_size * num_true matrix, representing
+// the number of times each candidate is expected to occur in a batch
+// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
+// candidate representing the number of times the candidate is expected
+// to occur in a batch of sampled candidates. If unique=true, then this is a
+// probability.
+func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "AllCandidateSampler",
+ Input: []tf.Input{
+ true_classes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// Adds two `SparseTensor` objects to produce another `SparseTensor`.
+//
+// The input `SparseTensor` objects' indices are assumed ordered in standard
+// lexicographic order. If this is not the case, before this step run
+// `SparseReorder` to restore index ordering.
+//
+// By default, if two values sum to zero at some index, the output `SparseTensor`
+// would still include that particular location in its index, storing a zero in the
+// corresponding value slot. To override this, callers can specify `thresh`,
+// indicating that if the sum has a magnitude strictly smaller than `thresh`, its
+// corresponding value and index would then not be included. In particular,
+// `thresh == 0` (default) means everything is kept and actual thresholding happens
+// only for a positive value.
+//
+// In the following shapes, `nnz` is the count after taking `thresh` into account.
+//
+// Arguments:
+// a_indices: 2-D. The `indices` of the first `SparseTensor`, size `[nnz, ndims]` Matrix.
+// a_values: 1-D. The `values` of the first `SparseTensor`, size `[nnz]` Vector.
+// a_shape: 1-D. The `shape` of the first `SparseTensor`, size `[ndims]` Vector.
+// b_indices: 2-D. The `indices` of the second `SparseTensor`, size `[nnz, ndims]` Matrix.
+// b_values: 1-D. The `values` of the second `SparseTensor`, size `[nnz]` Vector.
+// b_shape: 1-D. The `shape` of the second `SparseTensor`, size `[ndims]` Vector.
+// thresh: 0-D. The magnitude threshold that determines if an output value/index
+// pair takes space.
+func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output, thresh tf.Output) (sum_indices tf.Output, sum_values tf.Output, sum_shape tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "SparseAdd",
+ Input: []tf.Input{
+ a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
+// OrderedMapPeekAttr is an optional argument to OrderedMapPeek.
+type OrderedMapPeekAttr func(optionalAttr)
+
+// OrderedMapPeekCapacity sets the optional capacity attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func OrderedMapPeekCapacity(value int64) OrderedMapPeekAttr {
+ return func(m optionalAttr) {
+ m["capacity"] = value
+ }
+}
+
+// OrderedMapPeekMemoryLimit sets the optional memory_limit attribute to value.
+// If not specified, defaults to 0
+//
+// REQUIRES: value >= 0
+func OrderedMapPeekMemoryLimit(value int64) OrderedMapPeekAttr {
+ return func(m optionalAttr) {
+ m["memory_limit"] = value
+ }
+}
+
+// OrderedMapPeekContainer sets the optional container attribute to value.
+// If not specified, defaults to ""
+func OrderedMapPeekContainer(value string) OrderedMapPeekAttr {
+ return func(m optionalAttr) {
+ m["container"] = value
+ }
+}
+
+// OrderedMapPeekSharedName sets the optional shared_name attribute to value.
+// If not specified, defaults to ""
+func OrderedMapPeekSharedName(value string) OrderedMapPeekAttr {
+ return func(m optionalAttr) {
+ m["shared_name"] = value
+ }
+}
+
+// Op peeks at the values at the specified key. If the
+//
+// underlying container does not contain this key
+// this op will block until it does. This Op is optimized for
+// performance.
+func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapPeekAttr) (values []tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtypes": dtypes}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "OrderedMapPeek",
+ Input: []tf.Input{
+ key, indices,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ if scope.Err() != nil {
+ return
+ }
+ var idx int
+ var err error
+ if values, idx, err = makeOutputList(op, idx, "values"); err != nil {
+ scope.UpdateErr("OrderedMapPeek", err)
+ return
+ }
+ return values
+}
+
// Inverse fast Fourier transform.
//
// Computes the inverse 1-dimensional discrete Fourier transform over the
@@ -9900,6 +10464,235 @@ func DestroyResourceOp(scope *Scope, resource tf.Output, optional ...DestroyReso
return scope.AddOperation(opspec)
}
+// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp.
+type ResourceSparseApplyRMSPropAttr func(optionalAttr)
+
+// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var, ms, and mom tensors is protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' according to the RMSProp algorithm.
+//
+// Note that in dense implementation of this algorithm, ms and mom will
+// update even if the grad is zero, but in this sparse implementation, ms
+// and mom will not update in iterations during which the grad is zero.
+//
+// mean_square = decay * mean_square + (1-decay) * gradient ** 2
+// Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
+//
+// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
+// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
+// var <- var - mom
+//
+// Arguments:
+// var_: Should be from a Variable().
+// ms: Should be from a Variable().
+// mom: Should be from a Variable().
+// lr: Scaling factor. Must be a scalar.
+// rho: Decay rate. Must be a scalar.
+//
+// epsilon: Ridge term. Must be a scalar.
+// grad: The gradient.
+// indices: A vector of indices into the first dimension of var, ms and mom.
+//
+// Returns the created operation.
+func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceSparseApplyRMSProp",
+ Input: []tf.Input{
+ var_, ms, mom, lr, rho, momentum, epsilon, grad, indices,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// Returns the truth value of (x > y) element-wise.
+//
+// *NOTE*: `Greater` supports broadcasting. More about broadcasting
+// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
+func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "Greater",
+ Input: []tf.Input{
+ x, y,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox.
+type SampleDistortedBoundingBoxAttr func(optionalAttr)
+
+// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to non-zero, the random number
+// generator is seeded by the given `seed`. Otherwise, it is seeded by a random
+// seed.
+// If not specified, defaults to 0
+func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value.
+//
+// value: The cropped area of the image must contain at least this
+// fraction of any bounding box supplied. The value of this parameter should be
+// non-negative. In the case of 0, the cropped area does not need to overlap
+// any of the bounding boxes supplied.
+// If not specified, defaults to 0.1
+func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["min_object_covered"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value.
+//
+// value: The cropped area of the image must have an aspect ratio =
+// width / height within this range.
+// If not specified, defaults to <f:0.75 f:1.33 >
+func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["aspect_ratio_range"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value.
+//
+// value: The cropped area of the image must contain a fraction of the
+// supplied image within in this range.
+// If not specified, defaults to <f:0.05 f:1 >
+func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["area_range"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value.
+//
+// value: Number of attempts at generating a cropped region of the image
+// of the specified constraints. After `max_attempts` failures, return the entire
+// image.
+// If not specified, defaults to 100
+func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["max_attempts"] = value
+ }
+}
+
+// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value.
+//
+// value: Controls behavior if no bounding boxes supplied.
+// If true, assume an implicit bounding box covering the whole input. If false,
+// raise an error.
+// If not specified, defaults to false
+func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr {
+ return func(m optionalAttr) {
+ m["use_image_if_no_bounding_boxes"] = value
+ }
+}
+
+// Generate a single randomly distorted bounding box for an image.
+//
+// Bounding box annotations are often supplied in addition to ground-truth labels
+// in image recognition or object localization tasks. A common technique for
+// training such a system is to randomly distort an image while preserving
+// its content, i.e. *data augmentation*. This Op outputs a randomly distorted
+// localization of an object, i.e. bounding box, given an `image_size`,
+// `bounding_boxes` and a series of constraints.
+//
+// The output of this Op is a single bounding box that may be used to crop the
+// original image. The output is returned as 3 tensors: `begin`, `size` and
+// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the
+// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize
+// what the bounding box looks like.
+//
+// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The
+// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
+// height of the underlying image.
+//
+// For example,
+//
+// ```python
+// # Generate a single distorted bounding box.
+// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(
+// tf.shape(image),
+// bounding_boxes=bounding_boxes)
+//
+// # Draw the bounding box in an image summary.
+// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
+// bbox_for_draw)
+// tf.summary.image('images_with_box', image_with_box)
+//
+// # Employ the bounding box to distort the image.
+// distorted_image = tf.slice(image, begin, size)
+// ```
+//
+// Note that if no bounding box information is available, setting
+// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit
+// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is
+// false and no bounding boxes are supplied, an error is raised.
+//
+// Arguments:
+// image_size: 1-D, containing `[height, width, channels]`.
+// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes
+// associated with the image.
+//
+// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to
+// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to
+// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box.
+// Provide as input to `tf.image.draw_bounding_boxes`.
+func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "SampleDistortedBoundingBox",
+ Input: []tf.Input{
+ image_size, bounding_boxes,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2)
+}
+
// LRNAttr is an optional argument to LRN.
type LRNAttr func(optionalAttr)
@@ -10010,6 +10803,14 @@ func ResourceSparseApplyAdagradUseLocking(value bool) ResourceSparseApplyAdagrad
}
}
+// ResourceSparseApplyAdagradUpdateSlots sets the optional update_slots attribute to value.
+// If not specified, defaults to true
+func ResourceSparseApplyAdagradUpdateSlots(value bool) ResourceSparseApplyAdagradAttr {
+ return func(m optionalAttr) {
+ m["update_slots"] = value
+ }
+}
+
// Update relevant entries in '*var' and '*accum' according to the adagrad scheme.
//
// That is for rows we have grad for, we update var and accum as follows:
@@ -10042,159 +10843,6 @@ func ResourceSparseApplyAdagrad(scope *Scope, var_ tf.Output, accum tf.Output, l
return scope.AddOperation(opspec)
}
-// 2D real-valued fast Fourier transform.
-//
-// Computes the 2-dimensional discrete Fourier transform of a real-valued signal
-// over the inner-most 2 dimensions of `input`.
-//
-// Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the
-// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
-// of `output`: the zero-frequency term, followed by the `fft_length / 2`
-// positive-frequency terms.
-//
-// Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the
-// corresponding dimension of `input`, the dimension is cropped. If it is larger,
-// the dimension is padded with zeros.
-//
-// Arguments:
-// input: A float32 tensor.
-// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension.
-//
-// Returns A complex64 tensor of the same rank as `input`. The inner-most 2
-// dimensions of `input` are replaced with their 2D Fourier transform. The
-// inner-most dimension contains `fft_length / 2 + 1` unique frequency
-// components.
-//
-// @compatibility(numpy)
-// Equivalent to np.fft.rfft2
-// @end_compatibility
-func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "RFFT2D",
- Input: []tf.Input{
- input, fft_length,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ResizeAreaAttr is an optional argument to ResizeArea.
-type ResizeAreaAttr func(optionalAttr)
-
-// ResizeAreaAlignCorners sets the optional align_corners attribute to value.
-//
-// value: If true, the centers of the 4 corner pixels of the input and output tensors are
-// aligned, preserving the values at the corner pixels. Defaults to false.
-// If not specified, defaults to false
-func ResizeAreaAlignCorners(value bool) ResizeAreaAttr {
- return func(m optionalAttr) {
- m["align_corners"] = value
- }
-}
-
-// Resize `images` to `size` using area interpolation.
-//
-// Input images can be of different types but output images are always float.
-//
-// The range of pixel values for the output image might be slightly different
-// from the range for the input image because of limited numerical precision.
-// To guarantee an output range, for example `[0.0, 1.0]`, apply
-// `tf.clip_by_value` to the output.
-//
-// Each output pixel is computed by first transforming the pixel's footprint into
-// the input tensor and then averaging the pixels that intersect the footprint. An
-// input pixel's contribution to the average is weighted by the fraction of its
-// area that intersects the footprint. This is the same as OpenCV's INTER_AREA.
-//
-// Arguments:
-// images: 4-D with shape `[batch, height, width, channels]`.
-// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
-// new size for the images.
-//
-// Returns 4-D with shape
-// `[batch, new_height, new_width, channels]`.
-func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResizeArea",
- Input: []tf.Input{
- images, size,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Pads a tensor with zeros.
-//
-// This operation pads a `input` with zeros according to the `paddings` you
-// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the
-// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
-// how many zeros to add before the contents of `input` in that dimension, and
-// `paddings[D, 1]` indicates how many zeros to add after the contents of `input`
-// in that dimension.
-//
-// The padded size of each dimension D of the output is:
-//
-// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
-//
-// For example:
-//
-// ```
-// # 't' is [[1, 1], [2, 2]]
-// # 'paddings' is [[1, 1], [2, 2]]
-// # rank of 't' is 2
-// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
-// [0, 0, 1, 1, 0, 0]
-// [0, 0, 2, 2, 0, 0]
-// [0, 0, 0, 0, 0, 0]]
-// ```
-func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Pad",
- Input: []tf.Input{
- input, paddings,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Checks whether a resource handle-based variable has been initialized.
-//
-// Arguments:
-// resource: the input resource handle.
-//
-// Returns a scalar boolean which is true if the variable has been
-// initialized.
-func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "VarIsInitializedOp",
- Input: []tf.Input{
- resource,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// StatelessRandomUniformAttr is an optional argument to StatelessRandomUniform.
type StatelessRandomUniformAttr func(optionalAttr)
@@ -10804,47 +11452,42 @@ func SparseDenseCwiseMul(scope *Scope, sp_indices tf.Output, sp_values tf.Output
return op.Output(0)
}
-// ResourceSparseApplyRMSPropAttr is an optional argument to ResourceSparseApplyRMSProp.
-type ResourceSparseApplyRMSPropAttr func(optionalAttr)
+// ResizeAreaAttr is an optional argument to ResizeArea.
+type ResizeAreaAttr func(optionalAttr)
-// ResourceSparseApplyRMSPropUseLocking sets the optional use_locking attribute to value.
+// ResizeAreaAlignCorners sets the optional align_corners attribute to value.
//
-// value: If `True`, updating of the var, ms, and mom tensors is protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
+// value: If true, the centers of the 4 corner pixels of the input and output tensors are
+// aligned, preserving the values at the corner pixels. Defaults to false.
// If not specified, defaults to false
-func ResourceSparseApplyRMSPropUseLocking(value bool) ResourceSparseApplyRMSPropAttr {
+func ResizeAreaAlignCorners(value bool) ResizeAreaAttr {
return func(m optionalAttr) {
- m["use_locking"] = value
+ m["align_corners"] = value
}
}
-// Update '*var' according to the RMSProp algorithm.
+// Resize `images` to `size` using area interpolation.
//
-// Note that in dense implementation of this algorithm, ms and mom will
-// update even if the grad is zero, but in this sparse implementation, ms
-// and mom will not update in iterations during which the grad is zero.
+// Input images can be of different types but output images are always float.
//
-// mean_square = decay * mean_square + (1-decay) * gradient ** 2
-// Delta = learning_rate * gradient / sqrt(mean_square + epsilon)
+// The range of pixel values for the output image might be slightly different
+// from the range for the input image because of limited numerical precision.
+// To guarantee an output range, for example `[0.0, 1.0]`, apply
+// `tf.clip_by_value` to the output.
//
-// ms <- rho * ms_{t-1} + (1-rho) * grad * grad
-// mom <- momentum * mom_{t-1} + lr * grad / sqrt(ms + epsilon)
-// var <- var - mom
+// Each output pixel is computed by first transforming the pixel's footprint into
+// the input tensor and then averaging the pixels that intersect the footprint. An
+// input pixel's contribution to the average is weighted by the fraction of its
+// area that intersects the footprint. This is the same as OpenCV's INTER_AREA.
//
// Arguments:
-// var_: Should be from a Variable().
-// ms: Should be from a Variable().
-// mom: Should be from a Variable().
-// lr: Scaling factor. Must be a scalar.
-// rho: Decay rate. Must be a scalar.
-//
-// epsilon: Ridge term. Must be a scalar.
-// grad: The gradient.
-// indices: A vector of indices into the first dimension of var, ms and mom.
+// images: 4-D with shape `[batch, height, width, channels]`.
+// size: = A 1-D int32 Tensor of 2 elements: `new_height, new_width`. The
+// new size for the images.
//
-// Returns the created operation.
-func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom tf.Output, lr tf.Output, rho tf.Output, momentum tf.Output, epsilon tf.Output, grad tf.Output, indices tf.Output, optional ...ResourceSparseApplyRMSPropAttr) (o *tf.Operation) {
+// Returns 4-D with shape
+// `[batch, new_height, new_width, channels]`.
+func ResizeArea(scope *Scope, images tf.Output, size tf.Output, optional ...ResizeAreaAttr) (resized_images tf.Output) {
if scope.Err() != nil {
return
}
@@ -10853,184 +11496,113 @@ func ResourceSparseApplyRMSProp(scope *Scope, var_ tf.Output, ms tf.Output, mom
a(attrs)
}
opspec := tf.OpSpec{
- Type: "ResourceSparseApplyRMSProp",
+ Type: "ResizeArea",
Input: []tf.Input{
- var_, ms, mom, lr, rho, momentum, epsilon, grad, indices,
+ images, size,
},
Attrs: attrs,
}
- return scope.AddOperation(opspec)
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
-// Returns the truth value of (x > y) element-wise.
+// 2D real-valued fast Fourier transform.
//
-// *NOTE*: `Greater` supports broadcasting. More about broadcasting
-// [here](http://docs.scipy.org/doc/numpy/user/basics.broadcasting.html)
-func Greater(scope *Scope, x tf.Output, y tf.Output) (z tf.Output) {
+// Computes the 2-dimensional discrete Fourier transform of a real-valued signal
+// over the inner-most 2 dimensions of `input`.
+//
+// Since the DFT of a real signal is Hermitian-symmetric, `RFFT2D` only returns the
+// `fft_length / 2 + 1` unique components of the FFT for the inner-most dimension
+// of `output`: the zero-frequency term, followed by the `fft_length / 2`
+// positive-frequency terms.
+//
+// Along each axis `RFFT2D` is computed on, if `fft_length` is smaller than the
+// corresponding dimension of `input`, the dimension is cropped. If it is larger,
+// the dimension is padded with zeros.
+//
+// Arguments:
+// input: A float32 tensor.
+// fft_length: An int32 tensor of shape [2]. The FFT length for each dimension.
+//
+// Returns A complex64 tensor of the same rank as `input`. The inner-most 2
+// dimensions of `input` are replaced with their 2D Fourier transform. The
+// inner-most dimension contains `fft_length / 2 + 1` unique frequency
+// components.
+//
+// @compatibility(numpy)
+// Equivalent to np.fft.rfft2
+// @end_compatibility
+func RFFT2D(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) {
if scope.Err() != nil {
return
}
opspec := tf.OpSpec{
- Type: "Greater",
+ Type: "RFFT2D",
Input: []tf.Input{
- x, y,
+ input, fft_length,
},
}
op := scope.AddOperation(opspec)
return op.Output(0)
}
-// SampleDistortedBoundingBoxAttr is an optional argument to SampleDistortedBoundingBox.
-type SampleDistortedBoundingBoxAttr func(optionalAttr)
-
-// SampleDistortedBoundingBoxSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to non-zero, the random number
-// generator is seeded by the given `seed`. Otherwise, it is seeded by a random
-// seed.
-// If not specified, defaults to 0
-func SampleDistortedBoundingBoxSeed(value int64) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// SampleDistortedBoundingBoxSeed2 sets the optional seed2 attribute to value.
+// Pads a tensor with zeros.
//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func SampleDistortedBoundingBoxSeed2(value int64) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// SampleDistortedBoundingBoxMinObjectCovered sets the optional min_object_covered attribute to value.
+// This operation pads a `input` with zeros according to the `paddings` you
+// specify. `paddings` is an integer tensor with shape `[Dn, 2]`, where n is the
+// rank of `input`. For each dimension D of `input`, `paddings[D, 0]` indicates
+// how many zeros to add before the contents of `input` in that dimension, and
+// `paddings[D, 1]` indicates how many zeros to add after the contents of `input`
+// in that dimension.
//
-// value: The cropped area of the image must contain at least this
-// fraction of any bounding box supplied. The value of this parameter should be
-// non-negative. In the case of 0, the cropped area does not need to overlap
-// any of the bounding boxes supplied.
-// If not specified, defaults to 0.1
-func SampleDistortedBoundingBoxMinObjectCovered(value float32) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["min_object_covered"] = value
- }
-}
-
-// SampleDistortedBoundingBoxAspectRatioRange sets the optional aspect_ratio_range attribute to value.
+// The padded size of each dimension D of the output is:
//
-// value: The cropped area of the image must have an aspect ratio =
-// width / height within this range.
-// If not specified, defaults to <f:0.75 f:1.33 >
-func SampleDistortedBoundingBoxAspectRatioRange(value []float32) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["aspect_ratio_range"] = value
- }
-}
-
-// SampleDistortedBoundingBoxAreaRange sets the optional area_range attribute to value.
+// `paddings(D, 0) + input.dim_size(D) + paddings(D, 1)`
//
-// value: The cropped area of the image must contain a fraction of the
-// supplied image within in this range.
-// If not specified, defaults to <f:0.05 f:1 >
-func SampleDistortedBoundingBoxAreaRange(value []float32) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["area_range"] = value
- }
-}
-
-// SampleDistortedBoundingBoxMaxAttempts sets the optional max_attempts attribute to value.
+// For example:
//
-// value: Number of attempts at generating a cropped region of the image
-// of the specified constraints. After `max_attempts` failures, return the entire
-// image.
-// If not specified, defaults to 100
-func SampleDistortedBoundingBoxMaxAttempts(value int64) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["max_attempts"] = value
+// ```
+// # 't' is [[1, 1], [2, 2]]
+// # 'paddings' is [[1, 1], [2, 2]]
+// # rank of 't' is 2
+// pad(t, paddings) ==> [[0, 0, 0, 0, 0, 0]
+// [0, 0, 1, 1, 0, 0]
+// [0, 0, 2, 2, 0, 0]
+// [0, 0, 0, 0, 0, 0]]
+// ```
+func Pad(scope *Scope, input tf.Output, paddings tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
}
-}
-
-// SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes sets the optional use_image_if_no_bounding_boxes attribute to value.
-//
-// value: Controls behavior if no bounding boxes supplied.
-// If true, assume an implicit bounding box covering the whole input. If false,
-// raise an error.
-// If not specified, defaults to false
-func SampleDistortedBoundingBoxUseImageIfNoBoundingBoxes(value bool) SampleDistortedBoundingBoxAttr {
- return func(m optionalAttr) {
- m["use_image_if_no_bounding_boxes"] = value
+ opspec := tf.OpSpec{
+ Type: "Pad",
+ Input: []tf.Input{
+ input, paddings,
+ },
}
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
}
-// Generate a single randomly distorted bounding box for an image.
-//
-// Bounding box annotations are often supplied in addition to ground-truth labels
-// in image recognition or object localization tasks. A common technique for
-// training such a system is to randomly distort an image while preserving
-// its content, i.e. *data augmentation*. This Op outputs a randomly distorted
-// localization of an object, i.e. bounding box, given an `image_size`,
-// `bounding_boxes` and a series of constraints.
-//
-// The output of this Op is a single bounding box that may be used to crop the
-// original image. The output is returned as 3 tensors: `begin`, `size` and
-// `bboxes`. The first 2 tensors can be fed directly into `tf.slice` to crop the
-// image. The latter may be supplied to `tf.image.draw_bounding_boxes` to visualize
-// what the bounding box looks like.
-//
-// Bounding boxes are supplied and returned as `[y_min, x_min, y_max, x_max]`. The
-// bounding box coordinates are floats in `[0.0, 1.0]` relative to the width and
-// height of the underlying image.
-//
-// For example,
-//
-// ```python
-// # Generate a single distorted bounding box.
-// begin, size, bbox_for_draw = tf.image.sample_distorted_bounding_box(
-// tf.shape(image),
-// bounding_boxes=bounding_boxes)
-//
-// # Draw the bounding box in an image summary.
-// image_with_box = tf.image.draw_bounding_boxes(tf.expand_dims(image, 0),
-// bbox_for_draw)
-// tf.summary.image('images_with_box', image_with_box)
-//
-// # Employ the bounding box to distort the image.
-// distorted_image = tf.slice(image, begin, size)
-// ```
-//
-// Note that if no bounding box information is available, setting
-// `use_image_if_no_bounding_boxes = true` will assume there is a single implicit
-// bounding box covering the whole image. If `use_image_if_no_bounding_boxes` is
-// false and no bounding boxes are supplied, an error is raised.
+// Checks whether a resource handle-based variable has been initialized.
//
// Arguments:
-// image_size: 1-D, containing `[height, width, channels]`.
-// bounding_boxes: 3-D with shape `[batch, N, 4]` describing the N bounding boxes
-// associated with the image.
+// resource: the input resource handle.
//
-// Returns 1-D, containing `[offset_height, offset_width, 0]`. Provide as input to
-// `tf.slice`.1-D, containing `[target_height, target_width, -1]`. Provide as input to
-// `tf.slice`.3-D with shape `[1, 1, 4]` containing the distorted bounding box.
-// Provide as input to `tf.image.draw_bounding_boxes`.
-func SampleDistortedBoundingBox(scope *Scope, image_size tf.Output, bounding_boxes tf.Output, optional ...SampleDistortedBoundingBoxAttr) (begin tf.Output, size tf.Output, bboxes tf.Output) {
+// Returns a scalar boolean which is true if the variable has been
+// initialized.
+func VarIsInitializedOp(scope *Scope, resource tf.Output) (is_initialized tf.Output) {
if scope.Err() != nil {
return
}
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
opspec := tf.OpSpec{
- Type: "SampleDistortedBoundingBox",
+ Type: "VarIsInitializedOp",
Input: []tf.Input{
- image_size, bounding_boxes,
+ resource,
},
- Attrs: attrs,
}
op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
+ return op.Output(0)
}
// Converts each string in the input Tensor to its hash mod by a number of buckets.
@@ -12548,39 +13120,6 @@ func IFFT2D(scope *Scope, input tf.Output) (output tf.Output) {
return op.Output(0)
}
-// Creates a tensor filled with a scalar value.
-//
-// This operation creates a tensor of shape `dims` and fills it with `value`.
-//
-// For example:
-//
-// ```
-// # Output tensor has shape [2, 3].
-// fill([2, 3], 9) ==> [[9, 9, 9]
-// [9, 9, 9]]
-// ```
-//
-// Arguments:
-// dims: 1-D. Represents the shape of the output tensor.
-// value: 0-D (scalar). Value to fill the returned tensor.
-//
-// @compatibility(numpy)
-// Equivalent to np.full
-// @end_compatibility
-func Fill(scope *Scope, dims tf.Output, value tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "Fill",
- Input: []tf.Input{
- dims, value,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// 2D fast Fourier transform.
//
// Computes the 2-dimensional discrete Fourier transform over the inner-most
@@ -13698,6 +14237,44 @@ func InitializeTableFromTextFileV2(scope *Scope, table_handle tf.Output, filenam
return scope.AddOperation(opspec)
}
+// Real-valued fast Fourier transform.
+//
+// Computes the 1-dimensional discrete Fourier transform of a real-valued signal
+// over the inner-most dimension of `input`.
+//
+// Since the DFT of a real signal is Hermitian-symmetric, `RFFT` only returns the
+// `fft_length / 2 + 1` unique components of the FFT: the zero-frequency term,
+// followed by the `fft_length / 2` positive-frequency terms.
+//
+// Along the axis `RFFT` is computed on, if `fft_length` is smaller than the
+// corresponding dimension of `input`, the dimension is cropped. If it is larger,
+// the dimension is padded with zeros.
+//
+// Arguments:
+// input: A float32 tensor.
+// fft_length: An int32 tensor of shape [1]. The FFT length.
+//
+// Returns A complex64 tensor of the same rank as `input`. The inner-most
+// dimension of `input` is replaced with the `fft_length / 2 + 1` unique
+// frequency components of its 1D Fourier transform.
+//
+// @compatibility(numpy)
+// Equivalent to np.fft.rfft
+// @end_compatibility
+func RFFT(scope *Scope, input tf.Output, fft_length tf.Output) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ opspec := tf.OpSpec{
+ Type: "RFFT",
+ Input: []tf.Input{
+ input, fft_length,
+ },
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// QuantizedReluAttr is an optional argument to QuantizedRelu.
type QuantizedReluAttr func(optionalAttr)
@@ -15418,6 +15995,216 @@ func MaxPoolV2(scope *Scope, input tf.Output, ksize tf.Output, strides tf.Output
return op.Output(0)
}
+// SkipgramAttr is an optional argument to Skipgram.
+type SkipgramAttr func(optionalAttr)
+
+// SkipgramWindowSize sets the optional window_size attribute to value.
+//
+// value: The number of words to predict to the left and right of the target.
+// If not specified, defaults to 5
+func SkipgramWindowSize(value int64) SkipgramAttr {
+ return func(m optionalAttr) {
+ m["window_size"] = value
+ }
+}
+
+// SkipgramMinCount sets the optional min_count attribute to value.
+//
+// value: The minimum number of word occurrences for it to be included in the
+// vocabulary.
+// If not specified, defaults to 5
+func SkipgramMinCount(value int64) SkipgramAttr {
+ return func(m optionalAttr) {
+ m["min_count"] = value
+ }
+}
+
+// SkipgramSubsample sets the optional subsample attribute to value.
+//
+// value: Threshold for word occurrence. Words that appear with higher
+// frequency will be randomly down-sampled. Set to 0 to disable.
+// If not specified, defaults to 0.001
+func SkipgramSubsample(value float32) SkipgramAttr {
+ return func(m optionalAttr) {
+ m["subsample"] = value
+ }
+}
+
+// Parses a text file and creates a batch of examples.
+//
+// DEPRECATED at GraphDef version 19: Moving word2vec into tensorflow_models/tutorials and deprecating its ops here as a result
+//
+// Arguments:
+// filename: The corpus's text file name.
+// batch_size: The size of produced batch.
+//
+// Returns A vector of words in the corpus.Frequencies of words. Sorted in the non-ascending order.Number of words per epoch in the data file.The current epoch number.The total number of words processed so far.A vector of word ids.A vector of word ids.
+func Skipgram(scope *Scope, filename string, batch_size int64, optional ...SkipgramAttr) (vocab_word tf.Output, vocab_freq tf.Output, words_per_epoch tf.Output, current_epoch tf.Output, total_words_processed tf.Output, examples tf.Output, labels tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"filename": filename, "batch_size": batch_size}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "Skipgram",
+
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6)
+}
+
+// StringToNumberAttr is an optional argument to StringToNumber.
+type StringToNumberAttr func(optionalAttr)
+
+// StringToNumberOutType sets the optional out_type attribute to value.
+//
+// value: The numeric type to interpret each string in `string_tensor` as.
+// If not specified, defaults to DT_FLOAT
+func StringToNumberOutType(value tf.DataType) StringToNumberAttr {
+ return func(m optionalAttr) {
+ m["out_type"] = value
+ }
+}
+
+// Converts each string in the input Tensor to the specified numeric type.
+//
+// (Note that int32 overflow results in an error while float overflow
+// results in a rounded value.)
+//
+// Returns A Tensor of the same shape as the input `string_tensor`.
+func StringToNumber(scope *Scope, string_tensor tf.Output, optional ...StringToNumberAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "StringToNumber",
+ Input: []tf.Input{
+ string_tensor,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
+// ResourceApplyFtrlV2Attr is an optional argument to ResourceApplyFtrlV2.
+type ResourceApplyFtrlV2Attr func(optionalAttr)
+
+// ResourceApplyFtrlV2UseLocking sets the optional use_locking attribute to value.
+//
+// value: If `True`, updating of the var and accum tensors will be protected
+// by a lock; otherwise the behavior is undefined, but may exhibit less
+// contention.
+// If not specified, defaults to false
+func ResourceApplyFtrlV2UseLocking(value bool) ResourceApplyFtrlV2Attr {
+ return func(m optionalAttr) {
+ m["use_locking"] = value
+ }
+}
+
+// Update '*var' according to the Ftrl-proximal scheme.
+//
+// grad_with_shrinkage = grad + 2 * l2_shrinkage * var
+// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage
+// linear += grad_with_shrinkage +
+// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
+// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
+// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
+// accum = accum_new
+//
+// Arguments:
+// var_: Should be from a Variable().
+// accum: Should be from a Variable().
+// linear: Should be from a Variable().
+// grad: The gradient.
+// lr: Scaling factor. Must be a scalar.
+// l1: L1 regulariation. Must be a scalar.
+// l2: L2 shrinkage regulariation. Must be a scalar.
+//
+// lr_power: Scaling factor. Must be a scalar.
+//
+// Returns the created operation.
+func ResourceApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlV2Attr) (o *tf.Operation) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "ResourceApplyFtrlV2",
+ Input: []tf.Input{
+ var_, accum, linear, grad, lr, l1, l2, l2_shrinkage, lr_power,
+ },
+ Attrs: attrs,
+ }
+ return scope.AddOperation(opspec)
+}
+
+// TruncatedNormalAttr is an optional argument to TruncatedNormal.
+type TruncatedNormalAttr func(optionalAttr)
+
+// TruncatedNormalSeed sets the optional seed attribute to value.
+//
+// value: If either `seed` or `seed2` are set to be non-zero, the random number
+// generator is seeded by the given seed. Otherwise, it is seeded by a
+// random seed.
+// If not specified, defaults to 0
+func TruncatedNormalSeed(value int64) TruncatedNormalAttr {
+ return func(m optionalAttr) {
+ m["seed"] = value
+ }
+}
+
+// TruncatedNormalSeed2 sets the optional seed2 attribute to value.
+//
+// value: A second seed to avoid seed collision.
+// If not specified, defaults to 0
+func TruncatedNormalSeed2(value int64) TruncatedNormalAttr {
+ return func(m optionalAttr) {
+ m["seed2"] = value
+ }
+}
+
+// Outputs random values from a truncated normal distribution.
+//
+// The generated values follow a normal distribution with mean 0 and standard
+// deviation 1, except that values whose magnitude is more than 2 standard
+// deviations from the mean are dropped and re-picked.
+//
+// Arguments:
+// shape: The shape of the output tensor.
+// dtype: The type of the output.
+//
+// Returns A tensor of the specified shape filled with random truncated normal
+// values.
+func TruncatedNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...TruncatedNormalAttr) (output tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{"dtype": dtype}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "TruncatedNormal",
+ Input: []tf.Input{
+ shape,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// MutableDenseHashTableV2Attr is an optional argument to MutableDenseHashTableV2.
type MutableDenseHashTableV2Attr func(optionalAttr)
@@ -16053,6 +16840,62 @@ func MatchingFiles(scope *Scope, pattern tf.Output) (filenames tf.Output) {
return op.Output(0)
}
+// HistogramFixedWidthAttr is an optional argument to HistogramFixedWidth.
+type HistogramFixedWidthAttr func(optionalAttr)
+
+// HistogramFixedWidthDtype sets the optional dtype attribute to value.
+// If not specified, defaults to DT_INT32
+func HistogramFixedWidthDtype(value tf.DataType) HistogramFixedWidthAttr {
+ return func(m optionalAttr) {
+ m["dtype"] = value
+ }
+}
+
+// Return histogram of values.
+//
+// Given the tensor `values`, this operation returns a rank 1 histogram counting
+// the number of entries in `values` that fall into every bin. The bins are
+// equal width and determined by the arguments `value_range` and `nbins`.
+//
+// ```python
+// # Bins will be: (-inf, 1), [1, 2), [2, 3), [3, 4), [4, inf)
+// nbins = 5
+// value_range = [0.0, 5.0]
+// new_values = [-1.0, 0.0, 1.5, 2.0, 5.0, 15]
+//
+// with tf.get_default_session() as sess:
+// hist = tf.histogram_fixed_width(new_values, value_range, nbins=5)
+// variables.global_variables_initializer().run()
+// sess.run(hist) => [2, 1, 1, 0, 2]
+// ```
+//
+// Arguments:
+// values: Numeric `Tensor`.
+// value_range: Shape [2] `Tensor` of same `dtype` as `values`.
+// values <= value_range[0] will be mapped to hist[0],
+// values >= value_range[1] will be mapped to hist[-1].
+// nbins: Scalar `int32 Tensor`. Number of histogram bins.
+//
+// Returns A 1-D `Tensor` holding histogram of values.
+func HistogramFixedWidth(scope *Scope, values tf.Output, value_range tf.Output, nbins tf.Output, optional ...HistogramFixedWidthAttr) (out tf.Output) {
+ if scope.Err() != nil {
+ return
+ }
+ attrs := map[string]interface{}{}
+ for _, a := range optional {
+ a(attrs)
+ }
+ opspec := tf.OpSpec{
+ Type: "HistogramFixedWidth",
+ Input: []tf.Input{
+ values, value_range, nbins,
+ },
+ Attrs: attrs,
+ }
+ op := scope.AddOperation(opspec)
+ return op.Output(0)
+}
+
// Returns the truth value of (x >= y) element-wise.
//
// *NOTE*: `GreaterEqual` supports broadcasting. More about broadcasting
@@ -16561,305 +17404,6 @@ func TensorArrayCloseV3(scope *Scope, handle tf.Output) (o *tf.Operation) {
return scope.AddOperation(opspec)
}
-// Adds two `SparseTensor` objects to produce another `SparseTensor`.
-//
-// The input `SparseTensor` objects' indices are assumed ordered in standard
-// lexicographic order. If this is not the case, before this step run
-// `SparseReorder` to restore index ordering.
-//
-// By default, if two values sum to zero at some index, the output `SparseTensor`
-// would still include that particular location in its index, storing a zero in the
-// corresponding value slot. To override this, callers can specify `thresh`,
-// indicating that if the sum has a magnitude strictly smaller than `thresh`, its
-// corresponding value and index would then not be included. In particular,
-// `thresh == 0` (default) means everything is kept and actual thresholding happens
-// only for a positive value.
-//
-// In the following shapes, `nnz` is the count after taking `thresh` into account.
-//
-// Arguments:
-// a_indices: 2-D. The `indices` of the first `SparseTensor`, size `[nnz, ndims]` Matrix.
-// a_values: 1-D. The `values` of the first `SparseTensor`, size `[nnz]` Vector.
-// a_shape: 1-D. The `shape` of the first `SparseTensor`, size `[ndims]` Vector.
-// b_indices: 2-D. The `indices` of the second `SparseTensor`, size `[nnz, ndims]` Matrix.
-// b_values: 1-D. The `values` of the second `SparseTensor`, size `[nnz]` Vector.
-// b_shape: 1-D. The `shape` of the second `SparseTensor`, size `[ndims]` Vector.
-// thresh: 0-D. The magnitude threshold that determines if an output value/index
-// pair takes space.
-func SparseAdd(scope *Scope, a_indices tf.Output, a_values tf.Output, a_shape tf.Output, b_indices tf.Output, b_values tf.Output, b_shape tf.Output, thresh tf.Output) (sum_indices tf.Output, sum_values tf.Output, sum_shape tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "SparseAdd",
- Input: []tf.Input{
- a_indices, a_values, a_shape, b_indices, b_values, b_shape, thresh,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
-// OrderedMapPeekAttr is an optional argument to OrderedMapPeek.
-type OrderedMapPeekAttr func(optionalAttr)
-
-// OrderedMapPeekCapacity sets the optional capacity attribute to value.
-// If not specified, defaults to 0
-//
-// REQUIRES: value >= 0
-func OrderedMapPeekCapacity(value int64) OrderedMapPeekAttr {
- return func(m optionalAttr) {
- m["capacity"] = value
- }
-}
-
-// OrderedMapPeekMemoryLimit sets the optional memory_limit attribute to value.
-// If not specified, defaults to 0
-//
-// REQUIRES: value >= 0
-func OrderedMapPeekMemoryLimit(value int64) OrderedMapPeekAttr {
- return func(m optionalAttr) {
- m["memory_limit"] = value
- }
-}
-
-// OrderedMapPeekContainer sets the optional container attribute to value.
-// If not specified, defaults to ""
-func OrderedMapPeekContainer(value string) OrderedMapPeekAttr {
- return func(m optionalAttr) {
- m["container"] = value
- }
-}
-
-// OrderedMapPeekSharedName sets the optional shared_name attribute to value.
-// If not specified, defaults to ""
-func OrderedMapPeekSharedName(value string) OrderedMapPeekAttr {
- return func(m optionalAttr) {
- m["shared_name"] = value
- }
-}
-
-// Op peeks at the values at the specified key. If the
-//
-// underlying container does not contain this key
-// this op will block until it does. This Op is optimized for
-// performance.
-func OrderedMapPeek(scope *Scope, key tf.Output, indices tf.Output, dtypes []tf.DataType, optional ...OrderedMapPeekAttr) (values []tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtypes": dtypes}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "OrderedMapPeek",
- Input: []tf.Input{
- key, indices,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- if scope.Err() != nil {
- return
- }
- var idx int
- var err error
- if values, idx, err = makeOutputList(op, idx, "values"); err != nil {
- scope.UpdateErr("OrderedMapPeek", err)
- return
- }
- return values
-}
-
-// DecodeAndCropJpegAttr is an optional argument to DecodeAndCropJpeg.
-type DecodeAndCropJpegAttr func(optionalAttr)
-
-// DecodeAndCropJpegChannels sets the optional channels attribute to value.
-//
-// value: Number of color channels for the decoded image.
-// If not specified, defaults to 0
-func DecodeAndCropJpegChannels(value int64) DecodeAndCropJpegAttr {
- return func(m optionalAttr) {
- m["channels"] = value
- }
-}
-
-// DecodeAndCropJpegRatio sets the optional ratio attribute to value.
-//
-// value: Downscaling ratio.
-// If not specified, defaults to 1
-func DecodeAndCropJpegRatio(value int64) DecodeAndCropJpegAttr {
- return func(m optionalAttr) {
- m["ratio"] = value
- }
-}
-
-// DecodeAndCropJpegFancyUpscaling sets the optional fancy_upscaling attribute to value.
-//
-// value: If true use a slower but nicer upscaling of the
-// chroma planes (yuv420/422 only).
-// If not specified, defaults to true
-func DecodeAndCropJpegFancyUpscaling(value bool) DecodeAndCropJpegAttr {
- return func(m optionalAttr) {
- m["fancy_upscaling"] = value
- }
-}
-
-// DecodeAndCropJpegTryRecoverTruncated sets the optional try_recover_truncated attribute to value.
-//
-// value: If true try to recover an image from truncated input.
-// If not specified, defaults to false
-func DecodeAndCropJpegTryRecoverTruncated(value bool) DecodeAndCropJpegAttr {
- return func(m optionalAttr) {
- m["try_recover_truncated"] = value
- }
-}
-
-// DecodeAndCropJpegAcceptableFraction sets the optional acceptable_fraction attribute to value.
-//
-// value: The minimum required fraction of lines before a truncated
-// input is accepted.
-// If not specified, defaults to 1
-func DecodeAndCropJpegAcceptableFraction(value float32) DecodeAndCropJpegAttr {
- return func(m optionalAttr) {
- m["acceptable_fraction"] = value
- }
-}
-
-// DecodeAndCropJpegDctMethod sets the optional dct_method attribute to value.
-//
-// value: string specifying a hint about the algorithm used for
-// decompression. Defaults to "" which maps to a system-specific
-// default. Currently valid values are ["INTEGER_FAST",
-// "INTEGER_ACCURATE"]. The hint may be ignored (e.g., the internal
-// jpeg library changes to a version that does not have that specific
-// option.)
-// If not specified, defaults to ""
-func DecodeAndCropJpegDctMethod(value string) DecodeAndCropJpegAttr {
- return func(m optionalAttr) {
- m["dct_method"] = value
- }
-}
-
-// Decode and Crop a JPEG-encoded image to a uint8 tensor.
-//
-// The attr `channels` indicates the desired number of color channels for the
-// decoded image.
-//
-// Accepted values are:
-//
-// * 0: Use the number of channels in the JPEG-encoded image.
-// * 1: output a grayscale image.
-// * 3: output an RGB image.
-//
-// If needed, the JPEG-encoded image is transformed to match the requested number
-// of color channels.
-//
-// The attr `ratio` allows downscaling the image by an integer factor during
-// decoding. Allowed values are: 1, 2, 4, and 8. This is much faster than
-// downscaling the image later.
-//
-//
-// It is equivalent to a combination of decode and crop, but much faster by only
-// decoding partial jpeg image.
-//
-// Arguments:
-// contents: 0-D. The JPEG-encoded image.
-// crop_window: 1-D. The crop window: [crop_y, crop_x, crop_height, crop_width].
-//
-// Returns 3-D with shape `[height, width, channels]`..
-func DecodeAndCropJpeg(scope *Scope, contents tf.Output, crop_window tf.Output, optional ...DecodeAndCropJpegAttr) (image tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "DecodeAndCropJpeg",
- Input: []tf.Input{
- contents, crop_window,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// AllCandidateSamplerAttr is an optional argument to AllCandidateSampler.
-type AllCandidateSamplerAttr func(optionalAttr)
-
-// AllCandidateSamplerSeed sets the optional seed attribute to value.
-//
-// value: If either seed or seed2 are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func AllCandidateSamplerSeed(value int64) AllCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// AllCandidateSamplerSeed2 sets the optional seed2 attribute to value.
-//
-// value: An second seed to avoid seed collision.
-// If not specified, defaults to 0
-func AllCandidateSamplerSeed2(value int64) AllCandidateSamplerAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Generates labels for candidate sampling with a learned unigram distribution.
-//
-// See explanations of candidate sampling and the data formats at
-// go/candidate-sampling.
-//
-// For each batch, this op picks a single set of sampled candidate labels.
-//
-// The advantages of sampling candidates per-batch are simplicity and the
-// possibility of efficient dense matrix multiplication. The disadvantage is that
-// the sampled candidates must be chosen independently of the context and of the
-// true labels.
-//
-// Arguments:
-// true_classes: A batch_size * num_true matrix, in which each row contains the
-// IDs of the num_true target_classes in the corresponding original label.
-// num_true: Number of true labels per context.
-// num_sampled: Number of candidates to produce.
-// unique: If unique is true, we sample with rejection, so that all sampled
-// candidates in a batch are unique. This requires some approximation to
-// estimate the post-rejection sampling probabilities.
-//
-// Returns A vector of length num_sampled, in which each element is
-// the ID of a sampled candidate.A batch_size * num_true matrix, representing
-// the number of times each candidate is expected to occur in a batch
-// of sampled candidates. If unique=true, then this is a probability.A vector of length num_sampled, for each sampled
-// candidate representing the number of times the candidate is expected
-// to occur in a batch of sampled candidates. If unique=true, then this is a
-// probability.
-func AllCandidateSampler(scope *Scope, true_classes tf.Output, num_true int64, num_sampled int64, unique bool, optional ...AllCandidateSamplerAttr) (sampled_candidates tf.Output, true_expected_count tf.Output, sampled_expected_count tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"num_true": num_true, "num_sampled": num_sampled, "unique": unique}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "AllCandidateSampler",
- Input: []tf.Input{
- true_classes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2)
-}
-
// Saves the input tensors to disk.
//
// The size of `tensor_names` must match the number of tensors in `data`. `data[i]`
@@ -18997,216 +19541,6 @@ func RandomUniformInt(scope *Scope, shape tf.Output, minval tf.Output, maxval tf
return op.Output(0)
}
-// SkipgramAttr is an optional argument to Skipgram.
-type SkipgramAttr func(optionalAttr)
-
-// SkipgramWindowSize sets the optional window_size attribute to value.
-//
-// value: The number of words to predict to the left and right of the target.
-// If not specified, defaults to 5
-func SkipgramWindowSize(value int64) SkipgramAttr {
- return func(m optionalAttr) {
- m["window_size"] = value
- }
-}
-
-// SkipgramMinCount sets the optional min_count attribute to value.
-//
-// value: The minimum number of word occurrences for it to be included in the
-// vocabulary.
-// If not specified, defaults to 5
-func SkipgramMinCount(value int64) SkipgramAttr {
- return func(m optionalAttr) {
- m["min_count"] = value
- }
-}
-
-// SkipgramSubsample sets the optional subsample attribute to value.
-//
-// value: Threshold for word occurrence. Words that appear with higher
-// frequency will be randomly down-sampled. Set to 0 to disable.
-// If not specified, defaults to 0.001
-func SkipgramSubsample(value float32) SkipgramAttr {
- return func(m optionalAttr) {
- m["subsample"] = value
- }
-}
-
-// Parses a text file and creates a batch of examples.
-//
-// DEPRECATED at GraphDef version 19: Moving word2vec into tensorflow_models/tutorials and deprecating its ops here as a result
-//
-// Arguments:
-// filename: The corpus's text file name.
-// batch_size: The size of produced batch.
-//
-// Returns A vector of words in the corpus.Frequencies of words. Sorted in the non-ascending order.Number of words per epoch in the data file.The current epoch number.The total number of words processed so far.A vector of word ids.A vector of word ids.
-func Skipgram(scope *Scope, filename string, batch_size int64, optional ...SkipgramAttr) (vocab_word tf.Output, vocab_freq tf.Output, words_per_epoch tf.Output, current_epoch tf.Output, total_words_processed tf.Output, examples tf.Output, labels tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"filename": filename, "batch_size": batch_size}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "Skipgram",
-
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0), op.Output(1), op.Output(2), op.Output(3), op.Output(4), op.Output(5), op.Output(6)
-}
-
-// StringToNumberAttr is an optional argument to StringToNumber.
-type StringToNumberAttr func(optionalAttr)
-
-// StringToNumberOutType sets the optional out_type attribute to value.
-//
-// value: The numeric type to interpret each string in `string_tensor` as.
-// If not specified, defaults to DT_FLOAT
-func StringToNumberOutType(value tf.DataType) StringToNumberAttr {
- return func(m optionalAttr) {
- m["out_type"] = value
- }
-}
-
-// Converts each string in the input Tensor to the specified numeric type.
-//
-// (Note that int32 overflow results in an error while float overflow
-// results in a rounded value.)
-//
-// Returns A Tensor of the same shape as the input `string_tensor`.
-func StringToNumber(scope *Scope, string_tensor tf.Output, optional ...StringToNumberAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "StringToNumber",
- Input: []tf.Input{
- string_tensor,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// ResourceApplyFtrlV2Attr is an optional argument to ResourceApplyFtrlV2.
-type ResourceApplyFtrlV2Attr func(optionalAttr)
-
-// ResourceApplyFtrlV2UseLocking sets the optional use_locking attribute to value.
-//
-// value: If `True`, updating of the var and accum tensors will be protected
-// by a lock; otherwise the behavior is undefined, but may exhibit less
-// contention.
-// If not specified, defaults to false
-func ResourceApplyFtrlV2UseLocking(value bool) ResourceApplyFtrlV2Attr {
- return func(m optionalAttr) {
- m["use_locking"] = value
- }
-}
-
-// Update '*var' according to the Ftrl-proximal scheme.
-//
-// grad_with_shrinkage = grad + 2 * l2_shrinkage * var
-// accum_new = accum + grad_with_shrinkage * grad_with_shrinkage
-// linear += grad_with_shrinkage +
-// (accum_new^(-lr_power) - accum^(-lr_power)) / lr * var
-// quadratic = 1.0 / (accum_new^(lr_power) * lr) + 2 * l2
-// var = (sign(linear) * l1 - linear) / quadratic if |linear| > l1 else 0.0
-// accum = accum_new
-//
-// Arguments:
-// var_: Should be from a Variable().
-// accum: Should be from a Variable().
-// linear: Should be from a Variable().
-// grad: The gradient.
-// lr: Scaling factor. Must be a scalar.
-// l1: L1 regulariation. Must be a scalar.
-// l2: L2 shrinkage regulariation. Must be a scalar.
-//
-// lr_power: Scaling factor. Must be a scalar.
-//
-// Returns the created operation.
-func ResourceApplyFtrlV2(scope *Scope, var_ tf.Output, accum tf.Output, linear tf.Output, grad tf.Output, lr tf.Output, l1 tf.Output, l2 tf.Output, l2_shrinkage tf.Output, lr_power tf.Output, optional ...ResourceApplyFtrlV2Attr) (o *tf.Operation) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "ResourceApplyFtrlV2",
- Input: []tf.Input{
- var_, accum, linear, grad, lr, l1, l2, l2_shrinkage, lr_power,
- },
- Attrs: attrs,
- }
- return scope.AddOperation(opspec)
-}
-
-// TruncatedNormalAttr is an optional argument to TruncatedNormal.
-type TruncatedNormalAttr func(optionalAttr)
-
-// TruncatedNormalSeed sets the optional seed attribute to value.
-//
-// value: If either `seed` or `seed2` are set to be non-zero, the random number
-// generator is seeded by the given seed. Otherwise, it is seeded by a
-// random seed.
-// If not specified, defaults to 0
-func TruncatedNormalSeed(value int64) TruncatedNormalAttr {
- return func(m optionalAttr) {
- m["seed"] = value
- }
-}
-
-// TruncatedNormalSeed2 sets the optional seed2 attribute to value.
-//
-// value: A second seed to avoid seed collision.
-// If not specified, defaults to 0
-func TruncatedNormalSeed2(value int64) TruncatedNormalAttr {
- return func(m optionalAttr) {
- m["seed2"] = value
- }
-}
-
-// Outputs random values from a truncated normal distribution.
-//
-// The generated values follow a normal distribution with mean 0 and standard
-// deviation 1, except that values whose magnitude is more than 2 standard
-// deviations from the mean are dropped and re-picked.
-//
-// Arguments:
-// shape: The shape of the output tensor.
-// dtype: The type of the output.
-//
-// Returns A tensor of the specified shape filled with random truncated normal
-// values.
-func TruncatedNormal(scope *Scope, shape tf.Output, dtype tf.DataType, optional ...TruncatedNormalAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"dtype": dtype}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "TruncatedNormal",
- Input: []tf.Input{
- shape,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// RandomShuffleAttr is an optional argument to RandomShuffle.
type RandomShuffleAttr func(optionalAttr)
@@ -19325,113 +19659,6 @@ func OrderedMapIncompleteSize(scope *Scope, dtypes []tf.DataType, optional ...Or
return op.Output(0)
}
-// DecodeRawAttr is an optional argument to DecodeRaw.
-type DecodeRawAttr func(optionalAttr)
-
-// DecodeRawLittleEndian sets the optional little_endian attribute to value.
-//
-// value: Whether the input `bytes` are in little-endian order.
-// Ignored for `out_type` values that are stored in a single byte like
-// `uint8`.
-// If not specified, defaults to true
-func DecodeRawLittleEndian(value bool) DecodeRawAttr {
- return func(m optionalAttr) {
- m["little_endian"] = value
- }
-}
-
-// Reinterpret the bytes of a string as a vector of numbers.
-//
-// Arguments:
-// bytes: All the elements must have the same length.
-//
-//
-// Returns A Tensor with one more dimension than the input `bytes`. The
-// added dimension will have size equal to the length of the elements
-// of `bytes` divided by the number of bytes to represent `out_type`.
-func DecodeRaw(scope *Scope, bytes tf.Output, out_type tf.DataType, optional ...DecodeRawAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{"out_type": out_type}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "DecodeRaw",
- Input: []tf.Input{
- bytes,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Copy a tensor setting everything outside a central band in each innermost matrix
-//
-// to zero.
-//
-// The `band` part is computed as follows:
-// Assume `input` has `k` dimensions `[I, J, K, ..., M, N]`, then the output is a
-// tensor with the same shape where
-//
-// `band[i, j, k, ..., m, n] = in_band(m, n) * input[i, j, k, ..., m, n]`.
-//
-// The indicator function
-//
-// `in_band(m, n) = (num_lower < 0 || (m-n) <= num_lower)) &&
-// (num_upper < 0 || (n-m) <= num_upper)`.
-//
-// For example:
-//
-// ```
-// # if 'input' is [[ 0, 1, 2, 3]
-// [-1, 0, 1, 2]
-// [-2, -1, 0, 1]
-// [-3, -2, -1, 0]],
-//
-// tf.matrix_band_part(input, 1, -1) ==> [[ 0, 1, 2, 3]
-// [-1, 0, 1, 2]
-// [ 0, -1, 0, 1]
-// [ 0, 0, -1, 0]],
-//
-// tf.matrix_band_part(input, 2, 1) ==> [[ 0, 1, 0, 0]
-// [-1, 0, 1, 0]
-// [-2, -1, 0, 1]
-// [ 0, -2, -1, 0]]
-// ```
-//
-// Useful special cases:
-//
-// ```
-// tf.matrix_band_part(input, 0, -1) ==> Upper triangular part.
-// tf.matrix_band_part(input, -1, 0) ==> Lower triangular part.
-// tf.matrix_band_part(input, 0, 0) ==> Diagonal.
-// ```
-//
-// Arguments:
-// input: Rank `k` tensor.
-// num_lower: 0-D tensor. Number of subdiagonals to keep. If negative, keep entire
-// lower triangle.
-// num_upper: 0-D tensor. Number of superdiagonals to keep. If negative, keep
-// entire upper triangle.
-//
-// Returns Rank `k` tensor of the same shape as input. The extracted banded tensor.
-func MatrixBandPart(scope *Scope, input tf.Output, num_lower tf.Output, num_upper tf.Output) (band tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "MatrixBandPart",
- Input: []tf.Input{
- input, num_lower, num_upper,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
// Counts the number of occurrences of each value in an integer array.
//
// Outputs a vector with length `size` and the same dtype as `weights`. If
@@ -30483,214 +30710,3 @@ func MatrixSetDiag(scope *Scope, input tf.Output, diagonal tf.Output) (output tf
op := scope.AddOperation(opspec)
return op.Output(0)
}
-
-// EditDistanceAttr is an optional argument to EditDistance.
-type EditDistanceAttr func(optionalAttr)
-
-// EditDistanceNormalize sets the optional normalize attribute to value.
-//
-// value: boolean (if true, edit distances are normalized by length of truth).
-//
-// The output is:
-// If not specified, defaults to true
-func EditDistanceNormalize(value bool) EditDistanceAttr {
- return func(m optionalAttr) {
- m["normalize"] = value
- }
-}
-
-// Computes the (possibly normalized) Levenshtein Edit Distance.
-//
-// The inputs are variable-length sequences provided by SparseTensors
-// (hypothesis_indices, hypothesis_values, hypothesis_shape)
-// and
-// (truth_indices, truth_values, truth_shape).
-//
-// The inputs are:
-//
-// Arguments:
-// hypothesis_indices: The indices of the hypothesis list SparseTensor.
-// This is an N x R int64 matrix.
-// hypothesis_values: The values of the hypothesis list SparseTensor.
-// This is an N-length vector.
-// hypothesis_shape: The shape of the hypothesis list SparseTensor.
-// This is an R-length vector.
-// truth_indices: The indices of the truth list SparseTensor.
-// This is an M x R int64 matrix.
-// truth_values: The values of the truth list SparseTensor.
-// This is an M-length vector.
-// truth_shape: truth indices, vector.
-//
-// Returns A dense float tensor with rank R - 1.
-//
-// For the example input:
-//
-// // hypothesis represents a 2x1 matrix with variable-length values:
-// // (0,0) = ["a"]
-// // (1,0) = ["b"]
-// hypothesis_indices = [[0, 0, 0],
-// [1, 0, 0]]
-// hypothesis_values = ["a", "b"]
-// hypothesis_shape = [2, 1, 1]
-//
-// // truth represents a 2x2 matrix with variable-length values:
-// // (0,0) = []
-// // (0,1) = ["a"]
-// // (1,0) = ["b", "c"]
-// // (1,1) = ["a"]
-// truth_indices = [[0, 1, 0],
-// [1, 0, 0],
-// [1, 0, 1],
-// [1, 1, 0]]
-// truth_values = ["a", "b", "c", "a"]
-// truth_shape = [2, 2, 2]
-// normalize = true
-//
-// The output will be:
-//
-// // output is a 2x2 matrix with edit distances normalized by truth lengths.
-// output = [[inf, 1.0], // (0,0): no truth, (0,1): no hypothesis
-// [0.5, 1.0]] // (1,0): addition, (1,1): no hypothesis
-func EditDistance(scope *Scope, hypothesis_indices tf.Output, hypothesis_values tf.Output, hypothesis_shape tf.Output, truth_indices tf.Output, truth_values tf.Output, truth_shape tf.Output, optional ...EditDistanceAttr) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- attrs := map[string]interface{}{}
- for _, a := range optional {
- a(attrs)
- }
- opspec := tf.OpSpec{
- Type: "EditDistance",
- Input: []tf.Input{
- hypothesis_indices, hypothesis_values, hypothesis_shape, truth_indices, truth_values, truth_shape,
- },
- Attrs: attrs,
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
-
-// Gather slices from `params` into a Tensor with shape specified by `indices`.
-//
-// `indices` is an K-dimensional integer tensor, best thought of as a
-// (K-1)-dimensional tensor of indices into `params`, where each element defines a
-// slice of `params`:
-//
-// output[i_0, ..., i_{K-2}] = params[indices[i0, ..., i_{K-2}]]
-//
-// Whereas in @{tf.gather} `indices` defines slices into the first
-// dimension of `params`, in `tf.gather_nd`, `indices` defines slices into the
-// first `N` dimensions of `params`, where `N = indices.shape[-1]`.
-//
-// The last dimension of `indices` can be at most the rank of
-// `params`:
-//
-// indices.shape[-1] <= params.rank
-//
-// The last dimension of `indices` corresponds to elements
-// (if `indices.shape[-1] == params.rank`) or slices
-// (if `indices.shape[-1] < params.rank`) along dimension `indices.shape[-1]`
-// of `params`. The output tensor has shape
-//
-// indices.shape[:-1] + params.shape[indices.shape[-1]:]
-//
-// Note that on CPU, if an out of bound index is found, an error is returned.
-// On GPU, if an out of bound index is found, a 0 is stored in the
-// corresponding output value.
-//
-// Some examples below.
-//
-// Simple indexing into a matrix:
-//
-// ```python
-// indices = [[0, 0], [1, 1]]
-// params = [['a', 'b'], ['c', 'd']]
-// output = ['a', 'd']
-// ```
-//
-// Slice indexing into a matrix:
-//
-// ```python
-// indices = [[1], [0]]
-// params = [['a', 'b'], ['c', 'd']]
-// output = [['c', 'd'], ['a', 'b']]
-// ```
-//
-// Indexing into a 3-tensor:
-//
-// ```python
-// indices = [[1]]
-// params = [[['a0', 'b0'], ['c0', 'd0']],
-// [['a1', 'b1'], ['c1', 'd1']]]
-// output = [[['a1', 'b1'], ['c1', 'd1']]]
-//
-//
-// indices = [[0, 1], [1, 0]]
-// params = [[['a0', 'b0'], ['c0', 'd0']],
-// [['a1', 'b1'], ['c1', 'd1']]]
-// output = [['c0', 'd0'], ['a1', 'b1']]
-//
-//
-// indices = [[0, 0, 1], [1, 0, 1]]
-// params = [[['a0', 'b0'], ['c0', 'd0']],
-// [['a1', 'b1'], ['c1', 'd1']]]
-// output = ['b0', 'b1']
-// ```
-//
-// Batched indexing into a matrix:
-//
-// ```python
-// indices = [[[0, 0]], [[0, 1]]]
-// params = [['a', 'b'], ['c', 'd']]
-// output = [['a'], ['b']]
-// ```
-//
-// Batched slice indexing into a matrix:
-//
-// ```python
-// indices = [[[1]], [[0]]]
-// params = [['a', 'b'], ['c', 'd']]
-// output = [[['c', 'd']], [['a', 'b']]]
-// ```
-//
-// Batched indexing into a 3-tensor:
-//
-// ```python
-// indices = [[[1]], [[0]]]
-// params = [[['a0', 'b0'], ['c0', 'd0']],
-// [['a1', 'b1'], ['c1', 'd1']]]
-// output = [[[['a1', 'b1'], ['c1', 'd1']]],
-// [[['a0', 'b0'], ['c0', 'd0']]]]
-//
-// indices = [[[0, 1], [1, 0]], [[0, 0], [1, 1]]]
-// params = [[['a0', 'b0'], ['c0', 'd0']],
-// [['a1', 'b1'], ['c1', 'd1']]]
-// output = [[['c0', 'd0'], ['a1', 'b1']],
-// [['a0', 'b0'], ['c1', 'd1']]]
-//
-//
-// indices = [[[0, 0, 1], [1, 0, 1]], [[0, 1, 1], [1, 1, 0]]]
-// params = [[['a0', 'b0'], ['c0', 'd0']],
-// [['a1', 'b1'], ['c1', 'd1']]]
-// output = [['b0', 'b1'], ['d0', 'c1']]
-// ```
-//
-// Arguments:
-// params: The tensor from which to gather values.
-// indices: Index tensor.
-//
-// Returns Values from `params` gathered from indices given by `indices`, with
-// shape `indices.shape[:-1] + params.shape[indices.shape[-1]:]`.
-func GatherNd(scope *Scope, params tf.Output, indices tf.Output) (output tf.Output) {
- if scope.Err() != nil {
- return
- }
- opspec := tf.OpSpec{
- Type: "GatherNd",
- Input: []tf.Input{
- params, indices,
- },
- }
- op := scope.AddOperation(opspec)
- return op.Output(0)
-}
diff --git a/tensorflow/java/maven/libtensorflow/pom.xml b/tensorflow/java/maven/libtensorflow/pom.xml
index 9c1601753b..66985e3b18 100644
--- a/tensorflow/java/maven/libtensorflow/pom.xml
+++ b/tensorflow/java/maven/libtensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0-rc0</version>
+ <version>1.8.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni/pom.xml b/tensorflow/java/maven/libtensorflow_jni/pom.xml
index 3d013e12b0..34d4ba0b08 100644
--- a/tensorflow/java/maven/libtensorflow_jni/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0-rc0</version>
+ <version>1.8.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni</artifactId>
diff --git a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
index 40e44af1f5..1909d08e41 100644
--- a/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
+++ b/tensorflow/java/maven/libtensorflow_jni_gpu/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0-rc0</version>
+ <version>1.8.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>libtensorflow_jni_gpu</artifactId>
diff --git a/tensorflow/java/maven/pom.xml b/tensorflow/java/maven/pom.xml
index 82bfd0c73a..ba98732f5a 100644
--- a/tensorflow/java/maven/pom.xml
+++ b/tensorflow/java/maven/pom.xml
@@ -6,7 +6,7 @@
<modelVersion>4.0.0</modelVersion>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0-rc0</version>
+ <version>1.8.0-rc1</version>
<packaging>pom</packaging>
<url>https://www.tensorflow.org</url>
diff --git a/tensorflow/java/maven/proto/pom.xml b/tensorflow/java/maven/proto/pom.xml
index 0a2775a500..dee8c34359 100644
--- a/tensorflow/java/maven/proto/pom.xml
+++ b/tensorflow/java/maven/proto/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0-rc0</version>
+ <version>1.8.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>proto</artifactId>
diff --git a/tensorflow/java/maven/tensorflow/pom.xml b/tensorflow/java/maven/tensorflow/pom.xml
index 61961432a7..95e024ace9 100644
--- a/tensorflow/java/maven/tensorflow/pom.xml
+++ b/tensorflow/java/maven/tensorflow/pom.xml
@@ -6,7 +6,7 @@
<parent>
<groupId>org.tensorflow</groupId>
<artifactId>parentpom</artifactId>
- <version>1.8.0-rc0</version>
+ <version>1.8.0-rc1</version>
<relativePath>../</relativePath>
</parent>
<artifactId>tensorflow</artifactId>
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index bb32f4bbe0..8e7f0cadad 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -70,6 +70,7 @@ py_library(
srcs_version = "PY2AND3",
visibility = [
"//tensorflow:__pkg__",
+ "//tensorflow/python/tools:__pkg__",
],
deps = [
":array_ops",
diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py
index 13f8420a67..cf707fb2c7 100644
--- a/tensorflow/python/__init__.py
+++ b/tensorflow/python/__init__.py
@@ -120,31 +120,9 @@ from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import sysconfig
from tensorflow.python.platform import test
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.all_util import make_all
from tensorflow.python.util.tf_export import tf_export
-# Import modules whose docstrings contribute, for use by remove_undocumented
-# below.
-from tensorflow.python.client import client_lib
-from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import framework_lib
-from tensorflow.python.framework import subscribe
-from tensorflow.python.ops import array_ops
-from tensorflow.python.ops import check_ops
-from tensorflow.python.ops import confusion_matrix as confusion_matrix_m
-from tensorflow.python.ops import control_flow_ops
-from tensorflow.python.ops import functional_ops
-from tensorflow.python.ops import histogram_ops
-from tensorflow.python.ops import io_ops
-from tensorflow.python.ops import math_ops
-from tensorflow.python.ops import script_ops
-from tensorflow.python.ops import session_ops
-from tensorflow.python.ops import sparse_ops
-from tensorflow.python.ops import state_ops
-from tensorflow.python.ops import string_ops
-from tensorflow.python.ops import tensor_array_ops
-
# Eager execution
from tensorflow.python.eager.context import executing_eagerly
from tensorflow.python.framework.ops import enable_eager_execution
@@ -160,37 +138,9 @@ nn.dynamic_rnn = rnn.dynamic_rnn
nn.static_rnn = rnn.static_rnn
nn.raw_rnn = rnn.raw_rnn
nn.bidirectional_dynamic_rnn = rnn.bidirectional_dynamic_rnn
+nn.static_state_saving_rnn = rnn.static_state_saving_rnn
nn.rnn_cell = rnn_cell
-# Symbols whitelisted for export without documentation.
-# TODO(cwhipkey): review these and move to contrib, expose through
-# documentation, or remove.
-_allowed_symbols = [
- 'AttrValue',
- 'ConfigProto',
- 'ClusterDef',
- 'DeviceSpec',
- 'Event',
- 'GPUOptions',
- 'GRAPH_DEF_VERSION',
- 'GRAPH_DEF_VERSION_MIN_CONSUMER',
- 'GRAPH_DEF_VERSION_MIN_PRODUCER',
- 'GraphDef',
- 'GraphOptions',
- 'HistogramProto',
- 'LogMessage',
- 'MetaGraphDef',
- 'NameAttrList',
- 'NodeDef',
- 'OptimizerOptions',
- 'RunOptions',
- 'RunMetadata',
- 'SessionLog',
- 'Summary',
- 'SummaryMetadata',
- 'TensorInfo', # Used for tf.saved_model functionality.
-]
-
# Export protos
# pylint: disable=undefined-variable
tf_export('AttrValue')(AttrValue)
@@ -215,121 +165,6 @@ tf_export('summary.TaggedRunMetadata')(TaggedRunMetadata)
tf_export('TensorInfo')(TensorInfo)
# pylint: enable=undefined-variable
-
-# The following symbols are kept for compatibility. It is our plan
-# to remove them in the future.
-_allowed_symbols.extend([
- 'arg_max',
- 'arg_min',
- 'create_partitioned_variables',
- 'deserialize_many_sparse',
- 'lin_space',
- 'listdiff', # Use tf.listdiff instead.
- 'parse_single_sequence_example',
- 'serialize_many_sparse',
- 'serialize_sparse',
- 'sparse_matmul', ## use tf.matmul instead.
-])
-
-# This is needed temporarily because we import it explicitly.
-_allowed_symbols.extend([
- 'pywrap_tensorflow',
-])
-
-# Dtypes exported by framework/dtypes.py.
-# TODO(cwhipkey): expose these through documentation.
-_allowed_symbols.extend([
- 'QUANTIZED_DTYPES',
- 'bfloat16',
- 'bool',
- 'complex64',
- 'complex128',
- 'double',
- 'half',
- 'float16',
- 'float32',
- 'float64',
- 'int16',
- 'int32',
- 'int64',
- 'int8',
- 'qint16',
- 'qint32',
- 'qint8',
- 'quint16',
- 'quint8',
- 'string',
- 'uint64',
- 'uint32',
- 'uint16',
- 'uint8',
- 'resource',
- 'variant',
-])
-
-# Export modules and constants.
-_allowed_symbols.extend([
- 'app',
- 'bitwise',
- 'compat',
- 'data',
- 'distributions',
- 'errors',
- 'estimator',
- 'feature_column',
- 'flags',
- 'gfile',
- 'graph_util',
- 'image',
- 'initializers',
- 'keras',
- 'layers',
- 'linalg',
- 'logging',
- 'losses',
- 'manip',
- 'metrics',
- 'newaxis',
- 'nn',
- 'profiler',
- 'python_io',
- 'resource_loader',
- 'saved_model',
- 'sets',
- 'spectral',
- 'summary',
- 'sysconfig',
- 'test',
- 'train',
- 'user_ops',
-])
-
-# Variables framework.versions:
-_allowed_symbols.extend([
- 'VERSION',
- 'GIT_VERSION',
- 'COMPILER_VERSION',
- 'CXX11_ABI_FLAG',
- 'MONOLITHIC_BUILD',
-])
-
-# Eager execution
-_allowed_symbols.extend([
- 'enable_eager_execution',
- 'executing_eagerly',
-])
-
-# Remove all extra symbols that don't have a docstring or are not explicitly
-# referenced in the whitelist.
-remove_undocumented(__name__, _allowed_symbols, [
- framework_lib, array_ops, check_ops, client_lib, compat, constant_op,
- control_flow_ops, confusion_matrix_m, data, distributions,
- functional_ops, histogram_ops, io_ops, keras, layers,
- losses, math_ops, metrics, nn, profiler, resource_loader, sets, script_ops,
- session_ops, sparse_ops, state_ops, string_ops, summary, tensor_array_ops,
- train
-])
-
# Special dunders that we choose to export:
_exported_dunders = set([
'__version__',
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index b82182d5d3..1db1432d65 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -458,7 +458,7 @@ TF_ImportGraphDefResultsMissingUnusedInputMappings_wrapper{
}
// Override default py3 behavior of attempting to encode into Unicode.
-%typemap(out) std::string tensorflow::ResourceHandleShapeAndType {
+%typemap(out) std::string tensorflow::GetResourceHandleShapeAndType {
$result = PyBytes_FromStringAndSize($1.data(), $1.size());
}
diff --git a/tensorflow/python/data/__init__.py b/tensorflow/python/data/__init__.py
index 239f9b0d59..5cedb89bf8 100644
--- a/tensorflow/python/data/__init__.py
+++ b/tensorflow/python/data/__init__.py
@@ -34,6 +34,3 @@ from tensorflow.python.data.ops.readers import FixedLengthRecordDataset
from tensorflow.python.data.ops.readers import TextLineDataset
from tensorflow.python.data.ops.readers import TFRecordDataset
# pylint: enable=unused-import
-
-from tensorflow.python.util.all_util import remove_undocumented
-remove_undocumented(__name__)
diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py
index e90ce3fb40..eff6e02c14 100644
--- a/tensorflow/python/data/util/nest.py
+++ b/tensorflow/python/data/util/nest.py
@@ -44,7 +44,6 @@ import collections as _collections
import six as _six
from tensorflow.python.framework import sparse_tensor as _sparse_tensor
-from tensorflow.python.util.all_util import remove_undocumented
def _sorted(dict_):
@@ -538,16 +537,3 @@ def map_structure_up_to(shallow_tree, func, *inputs):
results = [func(*tensors) for tensors in zip(*all_flattened_up_to)]
return pack_sequence_as(structure=shallow_tree, flat_sequence=results)
-
-_allowed_symbols = [
- "assert_same_structure",
- "is_sequence",
- "flatten",
- "pack_sequence_as",
- "map_structure",
- "assert_shallow_structure",
- "flatten_up_to",
- "map_structure_up_to",
-]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 0f1170bb42..bdbbe864df 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -80,7 +80,7 @@ def capture_value(tensor_map, value, dtype, name):
if handle_data is not None and handle_data.is_set:
# pylint: disable=protected-access
if ops._USE_C_SHAPES:
- pywrap_tensorflow.TFE_SetResourceHandleShapeAndType(
+ pywrap_tensorflow.SetResourceHandleShapeAndType(
captured_value.graph._c_graph, captured_value._as_tf_output(),
handle_data.SerializeToString())
else:
@@ -405,7 +405,15 @@ class GraphModeFunction(object):
c_known_ops = set()
c_captured_tensors = set()
- def add_op_internal(op):
+ existing_op_len = len(self._graph.get_operations())
+ filtered_outputs = [x for x in self._returns if x is not None]
+ self._out_grad_placeholders = [
+ graph_placeholder(x.dtype, x.shape) for x in filtered_outputs]
+ in_gradients = gradients_impl.gradients(
+ filtered_outputs,
+ self._input_placeholders,
+ grad_ys=self._out_grad_placeholders)
+ for op in self._graph.get_operations()[existing_op_len:]:
if op.type in ["Variable", "VariableV2", "VarHandleOp"]:
raise ValueError("tfe.defun cannot capture variables created without "
"using tf.get_variable. Op: %s" % op)
@@ -414,17 +422,6 @@ class GraphModeFunction(object):
if i.op not in c_known_ops:
c_captured_tensors.add(i)
- c = HelperContext(add_op_internal)
-
- with c:
- filtered_outputs = [x for x in self._returns if x is not None]
- self._out_grad_placeholders = [
- graph_placeholder(x.dtype, x.shape) for x in filtered_outputs]
- in_gradients = gradients_impl.gradients(
- filtered_outputs,
- self._input_placeholders,
- grad_ys=self._out_grad_placeholders)
-
backward_outputs = tuple(
grad for grad in _flatten(in_gradients) if grad is not None)
output_shapes = tuple(grad.shape for grad in backward_outputs)
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index 887e9a474a..2f1212d5a2 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -100,10 +100,6 @@ class Estimator(object):
None of `Estimator`'s methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use `model_fn` to configure
the base class, and may add methods implementing specialized functionality.
-
- @compatibility(eager)
- Estimators are not compatible with eager execution.
- @end_compatibility
"""
def __init__(self, model_fn, model_dir=None, config=None, params=None,
@@ -166,15 +162,10 @@ class Estimator(object):
vocabularies and Tensor names are unchanged.
Raises:
- RuntimeError: If eager execution is enabled.
ValueError: parameters of `model_fn` don't match `params`.
ValueError: if this is called via a subclass and if that class overrides
a member of `Estimator`.
"""
- if context.executing_eagerly():
- raise RuntimeError(
- 'Estimators are not supported when eager execution is enabled.')
-
Estimator._assert_members_are_not_overridden(self)
if config is None:
@@ -270,7 +261,8 @@ class Estimator(object):
ValueError: If the Estimator has not produced a checkpoint yet.
"""
_check_checkpoint_available(self.model_dir)
- return training.load_variable(self.model_dir, name)
+ with context.graph_mode():
+ return training.load_variable(self.model_dir, name)
def get_variable_names(self):
"""Returns list of all variable names in this model.
@@ -282,7 +274,8 @@ class Estimator(object):
ValueError: If the Estimator has not produced a checkpoint yet.
"""
_check_checkpoint_available(self.model_dir)
- return [name for name, _ in training.list_variables(self.model_dir)]
+ with context.graph_mode():
+ return [name for name, _ in training.list_variables(self.model_dir)]
def latest_checkpoint(self):
"""Finds the filename of latest saved checkpoint file in `model_dir`.
@@ -291,7 +284,8 @@ class Estimator(object):
The full path to the latest checkpoint or `None` if no checkpoint was
found.
"""
- return saver.latest_checkpoint(self.model_dir)
+ with context.graph_mode():
+ return saver.latest_checkpoint(self.model_dir)
def train(self,
input_fn,
@@ -343,27 +337,28 @@ class Estimator(object):
ValueError: If both `steps` and `max_steps` are not `None`.
ValueError: If either `steps` or `max_steps` is <= 0.
"""
- if (steps is not None) and (max_steps is not None):
- raise ValueError('Can not provide both steps and max_steps.')
- if steps is not None and steps <= 0:
- raise ValueError('Must specify steps > 0, given: {}'.format(steps))
- if max_steps is not None and max_steps <= 0:
- raise ValueError(
- 'Must specify max_steps > 0, given: {}'.format(max_steps))
+ with context.graph_mode():
+ if (steps is not None) and (max_steps is not None):
+ raise ValueError('Can not provide both steps and max_steps.')
+ if steps is not None and steps <= 0:
+ raise ValueError('Must specify steps > 0, given: {}'.format(steps))
+ if max_steps is not None and max_steps <= 0:
+ raise ValueError(
+ 'Must specify max_steps > 0, given: {}'.format(max_steps))
- if max_steps is not None:
- start_step = _load_global_step_from_checkpoint_dir(self._model_dir)
- if max_steps <= start_step:
- logging.info('Skipping training since max_steps has already saved.')
- return self
+ if max_steps is not None:
+ start_step = _load_global_step_from_checkpoint_dir(self._model_dir)
+ if max_steps <= start_step:
+ logging.info('Skipping training since max_steps has already saved.')
+ return self
- hooks = _check_hooks_type(hooks)
- hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))
+ hooks = _check_hooks_type(hooks)
+ hooks.extend(self._convert_train_steps_to_hooks(steps, max_steps))
- saving_listeners = _check_listeners_type(saving_listeners)
- loss = self._train_model(input_fn, hooks, saving_listeners)
- logging.info('Loss for final step: %s.', loss)
- return self
+ saving_listeners = _check_listeners_type(saving_listeners)
+ loss = self._train_model(input_fn, hooks, saving_listeners)
+ logging.info('Loss for final step: %s.', loss)
+ return self
def _convert_train_steps_to_hooks(self, steps, max_steps):
if steps is not None or max_steps is not None:
@@ -416,14 +411,15 @@ class Estimator(object):
ValueError: If no model has been trained, namely `model_dir`, or the
given `checkpoint_path` is empty.
"""
- hooks = _check_hooks_type(hooks)
- hooks.extend(self._convert_eval_steps_to_hooks(steps))
+ with context.graph_mode():
+ hooks = _check_hooks_type(hooks)
+ hooks.extend(self._convert_eval_steps_to_hooks(steps))
- return self._evaluate_model(
- input_fn=input_fn,
- hooks=hooks,
- checkpoint_path=checkpoint_path,
- name=name)
+ return self._evaluate_model(
+ input_fn=input_fn,
+ hooks=hooks,
+ checkpoint_path=checkpoint_path,
+ name=name)
def _convert_eval_steps_to_hooks(self, steps):
if steps is None:
@@ -480,45 +476,48 @@ class Estimator(object):
`predictions`. For example if `predict_keys` is not `None` but
`EstimatorSpec.predictions` is not a `dict`.
"""
- hooks = _check_hooks_type(hooks)
- # Check that model has been trained.
- if not checkpoint_path:
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
- if not checkpoint_path:
- raise ValueError('Could not find trained model in model_dir: {}.'.format(
- self._model_dir))
+ with context.graph_mode():
+ hooks = _check_hooks_type(hooks)
+ # Check that model has been trained.
+ if not checkpoint_path:
+ checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ if not checkpoint_path:
+ raise ValueError(
+ 'Could not find trained model in model_dir: {}.'.format(
+ self._model_dir))
- with ops.Graph().as_default() as g:
- random_seed.set_random_seed(self._config.tf_random_seed)
- self._create_and_assert_global_step(g)
- features, input_hooks = self._get_features_from_input_fn(
- input_fn, model_fn_lib.ModeKeys.PREDICT)
- estimator_spec = self._call_model_fn(
- features, None, model_fn_lib.ModeKeys.PREDICT, self.config)
- predictions = self._extract_keys(estimator_spec.predictions, predict_keys)
- all_hooks = list(input_hooks)
- all_hooks.extend(hooks)
- all_hooks.extend(list(estimator_spec.prediction_hooks or []))
- with training.MonitoredSession(
- session_creator=training.ChiefSessionCreator(
- checkpoint_filename_with_path=checkpoint_path,
- master=self._config.master,
- scaffold=estimator_spec.scaffold,
- config=self._session_config),
- hooks=all_hooks) as mon_sess:
- while not mon_sess.should_stop():
- preds_evaluated = mon_sess.run(predictions)
- if not yield_single_examples:
- yield preds_evaluated
- elif not isinstance(predictions, dict):
- for pred in preds_evaluated:
- yield pred
- else:
- for i in range(self._extract_batch_length(preds_evaluated)):
- yield {
- key: value[i]
- for key, value in six.iteritems(preds_evaluated)
- }
+ with ops.Graph().as_default() as g:
+ random_seed.set_random_seed(self._config.tf_random_seed)
+ self._create_and_assert_global_step(g)
+ features, input_hooks = self._get_features_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.PREDICT)
+ estimator_spec = self._call_model_fn(
+ features, None, model_fn_lib.ModeKeys.PREDICT, self.config)
+ predictions = self._extract_keys(
+ estimator_spec.predictions, predict_keys)
+ all_hooks = list(input_hooks)
+ all_hooks.extend(hooks)
+ all_hooks.extend(list(estimator_spec.prediction_hooks or []))
+ with training.MonitoredSession(
+ session_creator=training.ChiefSessionCreator(
+ checkpoint_filename_with_path=checkpoint_path,
+ master=self._config.master,
+ scaffold=estimator_spec.scaffold,
+ config=self._session_config),
+ hooks=all_hooks) as mon_sess:
+ while not mon_sess.should_stop():
+ preds_evaluated = mon_sess.run(predictions)
+ if not yield_single_examples:
+ yield preds_evaluated
+ elif not isinstance(predictions, dict):
+ for pred in preds_evaluated:
+ yield pred
+ else:
+ for i in range(self._extract_batch_length(preds_evaluated)):
+ yield {
+ key: value[i]
+ for key, value in six.iteritems(preds_evaluated)
+ }
def _assert_members_are_not_overridden(self):
"""Asserts members of `Estimator` are not overridden."""
@@ -598,73 +597,75 @@ class Estimator(object):
are provided, or no checkpoint can be found.
"""
# pylint: enable=line-too-long
- if serving_input_receiver_fn is None:
- raise ValueError('serving_input_receiver_fn must be defined.')
-
- with ops.Graph().as_default() as g:
- self._create_and_assert_global_step(g)
- random_seed.set_random_seed(self._config.tf_random_seed)
- serving_input_receiver = serving_input_receiver_fn()
+ with context.graph_mode():
+ if serving_input_receiver_fn is None:
+ raise ValueError('serving_input_receiver_fn must be defined.')
- # Call the model_fn and collect the export_outputs.
- estimator_spec = self._call_model_fn(
- features=serving_input_receiver.features,
- labels=None,
- mode=model_fn_lib.ModeKeys.PREDICT,
- config=self.config)
-
- # Build the SignatureDefs from receivers and all outputs
- signature_def_map = build_all_signature_defs(
- serving_input_receiver.receiver_tensors,
- estimator_spec.export_outputs,
- serving_input_receiver.receiver_tensors_alternatives)
-
- if not checkpoint_path:
- # Locate the latest checkpoint
- checkpoint_path = saver.latest_checkpoint(self._model_dir)
- if not checkpoint_path:
- raise ValueError("Couldn't find trained model at %s." % self._model_dir)
-
- export_dir = get_timestamped_export_dir(export_dir_base)
- temp_export_dir = get_temp_export_dir(export_dir)
-
- # TODO(soergel): Consider whether MonitoredSession makes sense here
- with tf_session.Session(config=self._session_config) as session:
-
- saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
- sharded=True)
- saver_for_restore.restore(session, checkpoint_path)
-
- # pylint: disable=protected-access
- local_init_op = (
- estimator_spec.scaffold.local_init_op or
- monitored_session.Scaffold.default_local_init_op())
- # pylint: enable=protected-access
-
- # Perform the export
- builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
- builder.add_meta_graph_and_variables(
- session, [tag_constants.SERVING],
- signature_def_map=signature_def_map,
- assets_collection=ops.get_collection(
- ops.GraphKeys.ASSET_FILEPATHS),
- legacy_init_op=local_init_op,
- strip_default_attrs=strip_default_attrs)
- builder.save(as_text)
-
- # Add the extra assets
- if assets_extra:
- assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir),
- compat.as_bytes('assets.extra'))
- for dest_relative, source in assets_extra.items():
- dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
- compat.as_bytes(dest_relative))
- dest_path = os.path.dirname(dest_absolute)
- gfile.MakeDirs(dest_path)
- gfile.Copy(source, dest_absolute)
-
- gfile.Rename(temp_export_dir, export_dir)
- return export_dir
+ with ops.Graph().as_default() as g:
+ self._create_and_assert_global_step(g)
+ random_seed.set_random_seed(self._config.tf_random_seed)
+ serving_input_receiver = serving_input_receiver_fn()
+
+ # Call the model_fn and collect the export_outputs.
+ estimator_spec = self._call_model_fn(
+ features=serving_input_receiver.features,
+ labels=None,
+ mode=model_fn_lib.ModeKeys.PREDICT,
+ config=self.config)
+
+ # Build the SignatureDefs from receivers and all outputs
+ signature_def_map = build_all_signature_defs(
+ serving_input_receiver.receiver_tensors,
+ estimator_spec.export_outputs,
+ serving_input_receiver.receiver_tensors_alternatives)
+
+ if not checkpoint_path:
+ # Locate the latest checkpoint
+ checkpoint_path = saver.latest_checkpoint(self._model_dir)
+ if not checkpoint_path:
+ raise ValueError(
+ "Couldn't find trained model at %s." % self._model_dir)
+
+ export_dir = get_timestamped_export_dir(export_dir_base)
+ temp_export_dir = get_temp_export_dir(export_dir)
+
+ # TODO(soergel): Consider whether MonitoredSession makes sense here
+ with tf_session.Session(config=self._session_config) as session:
+
+ saver_for_restore = estimator_spec.scaffold.saver or saver.Saver(
+ sharded=True)
+ saver_for_restore.restore(session, checkpoint_path)
+
+ # pylint: disable=protected-access
+ local_init_op = (
+ estimator_spec.scaffold.local_init_op or
+ monitored_session.Scaffold._default_local_init_op())
+ # pylint: enable=protected-access
+
+ # Perform the export
+ builder = saved_model_builder.SavedModelBuilder(temp_export_dir)
+ builder.add_meta_graph_and_variables(
+ session, [tag_constants.SERVING],
+ signature_def_map=signature_def_map,
+ assets_collection=ops.get_collection(
+ ops.GraphKeys.ASSET_FILEPATHS),
+ legacy_init_op=local_init_op,
+ strip_default_attrs=strip_default_attrs)
+ builder.save(as_text)
+
+ # Add the extra assets
+ if assets_extra:
+ assets_extra_path = os.path.join(compat.as_bytes(temp_export_dir),
+ compat.as_bytes('assets.extra'))
+ for dest_relative, source in assets_extra.items():
+ dest_absolute = os.path.join(compat.as_bytes(assets_extra_path),
+ compat.as_bytes(dest_relative))
+ dest_path = os.path.dirname(dest_absolute)
+ gfile.MakeDirs(dest_path)
+ gfile.Copy(source, dest_absolute)
+
+ gfile.Rename(temp_export_dir, export_dir)
+ return export_dir
def _get_features_from_input_fn(self, input_fn, mode):
"""Extracts the `features` from return values of `input_fn`."""
diff --git a/tensorflow/python/estimator/estimator_lib.py b/tensorflow/python/estimator/estimator_lib.py
index 60c59cbc18..3815f42470 100644
--- a/tensorflow/python/estimator/estimator_lib.py
+++ b/tensorflow/python/estimator/estimator_lib.py
@@ -47,45 +47,4 @@ from tensorflow.python.estimator.training import train_and_evaluate
from tensorflow.python.estimator.training import TrainSpec
-from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long,wildcard-import
-
-_allowed_symbols = [
- # Canned Estimators
- 'BaselineClassifier',
- 'BaselineRegressor',
- 'BoostedTreesClassifier',
- 'BoostedTreesRegressor',
- 'DNNClassifier',
- 'DNNRegressor',
- 'DNNLinearCombinedClassifier',
- 'DNNLinearCombinedRegressor',
- 'LinearClassifier',
- 'LinearRegressor',
-
- # I/O
- 'classifier_parse_example_spec',
- 'regressor_parse_example_spec',
- 'inputs',
- 'export',
-
- # Estimator
- 'Estimator',
- 'EstimatorSpec',
- 'ModeKeys',
- 'RunConfig',
-
- # Training utilities
- 'train_and_evaluate',
- 'EvalSpec',
- 'TrainSpec',
- 'Exporter',
- 'LatestExporter',
- 'FinalExporter',
-
- # Warm-starting
- 'WarmStartSettings',
- 'VocabInfo',
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index f4255091bf..0fea86124c 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -679,8 +679,10 @@ class EstimatorTrainTest(test.TestCase):
ckpt = checkpoint_state_pb2.CheckpointState()
text_format.Merge(checkpoint_file_content, ckpt)
self.assertEqual(ckpt.model_checkpoint_path, 'model.ckpt-5')
+ # TODO(b/78461127): Please modify tests to not directly rely on names of
+ # checkpoints.
self.assertAllEqual(
- ['model.ckpt-1', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
+ ['model.ckpt-0', 'model.ckpt-5'], ckpt.all_model_checkpoint_paths)
def test_train_save_copy_reload(self):
tmpdir = tempfile.mkdtemp()
@@ -2287,6 +2289,7 @@ class EstimatorHookOrderingTest(test.TestCase):
class EstimatorIntegrationTest(test.TestCase):
+ @test_util.run_in_graph_and_eager_modes()
def test_complete_flow_with_a_simple_linear_model(self):
def _model_fn(features, labels, mode):
diff --git a/tensorflow/python/estimator/export/export_lib.py b/tensorflow/python/estimator/export/export_lib.py
index 226fc97fd3..f4ac8581ea 100644
--- a/tensorflow/python/estimator/export/export_lib.py
+++ b/tensorflow/python/estimator/export/export_lib.py
@@ -28,18 +28,5 @@ from tensorflow.python.estimator.export.export_output import ExportOutput
from tensorflow.python.estimator.export.export_output import PredictOutput
from tensorflow.python.estimator.export.export_output import RegressionOutput
-from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long
-_allowed_symbols = [
- 'build_parsing_serving_input_receiver_fn',
- 'build_raw_serving_input_receiver_fn',
- 'ServingInputReceiver',
- 'TensorServingInputReceiver',
- 'ClassificationOutput',
- 'ExportOutput',
- 'PredictOutput',
- 'RegressionOutput',
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/python/estimator/inputs/inputs.py b/tensorflow/python/estimator/inputs/inputs.py
index 1a1c9a6c3f..6be168ee08 100644
--- a/tensorflow/python/estimator/inputs/inputs.py
+++ b/tensorflow/python/estimator/inputs/inputs.py
@@ -22,12 +22,4 @@ from __future__ import print_function
from tensorflow.python.estimator.inputs.numpy_io import numpy_input_fn
from tensorflow.python.estimator.inputs.pandas_io import pandas_input_fn
-from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long
-
-_allowed_symbols = [
- 'numpy_input_fn',
- 'pandas_input_fn'
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/python/feature_column/feature_column_lib.py b/tensorflow/python/feature_column/feature_column_lib.py
index 505a1408d2..3b818f18b5 100644
--- a/tensorflow/python/feature_column/feature_column_lib.py
+++ b/tensorflow/python/feature_column/feature_column_lib.py
@@ -20,25 +20,4 @@ from __future__ import print_function
# pylint: disable=unused-import,line-too-long,wildcard-import
from tensorflow.python.feature_column.feature_column import *
-
-from tensorflow.python.util.all_util import remove_undocumented
# pylint: enable=unused-import,line-too-long
-
-_allowed_symbols = [
- 'input_layer',
- 'linear_model',
- 'make_parse_example_spec',
- 'embedding_column',
- 'shared_embedding_columns',
- 'crossed_column',
- 'numeric_column',
- 'bucketized_column',
- 'categorical_column_with_hash_bucket',
- 'categorical_column_with_vocabulary_file',
- 'categorical_column_with_vocabulary_list',
- 'categorical_column_with_identity',
- 'weighted_categorical_column',
- 'indicator_column',
-]
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/python/framework/errors.py b/tensorflow/python/framework/errors.py
index c8cf9ae39b..be0187c2ef 100644
--- a/tensorflow/python/framework/errors.py
+++ b/tensorflow/python/framework/errors.py
@@ -25,50 +25,4 @@ from tensorflow.python.framework import errors_impl as _impl
# pylint: disable=wildcard-import
from tensorflow.python.framework.errors_impl import *
# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
-# These are referenced in client/client_lib.py.
-# Unfortunately, we can't import client_lib to examine
-# the references, since it would create a dependency cycle.
-_allowed_symbols = [
- "AbortedError",
- "AlreadyExistsError",
- "CancelledError",
- "DataLossError",
- "DeadlineExceededError",
- "FailedPreconditionError",
- "InternalError",
- "InvalidArgumentError",
- "NotFoundError",
- "OpError",
- "OutOfRangeError",
- "PermissionDeniedError",
- "ResourceExhaustedError",
- "UnauthenticatedError",
- "UnavailableError",
- "UnimplementedError",
- "UnknownError",
- "error_code_from_exception_type",
- "exception_type_from_error_code",
- "raise_exception_on_not_ok_status",
- # Scalars that have no docstrings:
- "OK",
- "CANCELLED",
- "UNKNOWN",
- "INVALID_ARGUMENT",
- "DEADLINE_EXCEEDED",
- "NOT_FOUND",
- "ALREADY_EXISTS",
- "PERMISSION_DENIED",
- "UNAUTHENTICATED",
- "RESOURCE_EXHAUSTED",
- "FAILED_PRECONDITION",
- "ABORTED",
- "OUT_OF_RANGE",
- "UNIMPLEMENTED",
- "INTERNAL",
- "UNAVAILABLE",
- "DATA_LOSS",
-]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 9570f009a5..2432ab378c 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -703,11 +703,23 @@ class _FuncGraph(ops.Graph):
with ops.control_dependencies(None):
ph = array_ops.placeholder(tensor.dtype, shape=tensor.get_shape())
# pylint: disable=protected-access
- ph._handle_data = tensor._handle_data
+ if ops._USE_C_SHAPES:
+ handle_data = c_api.GetResourceHandleShapeAndType(tensor.graph._c_graph,
+ tensor._as_tf_output())
+ if handle_data:
+ c_api.SetResourceHandleShapeAndType(ph.graph._c_graph,
+ ph._as_tf_output(),
+ compat.as_bytes(handle_data))
+ else:
+ ph._handle_data = tensor._handle_data
# pylint: enable=protected-access
self._captured[tensor] = ph
self.extra_args.append(ph)
- return ph
+ if _is_guaranteed_const(tensor):
+ with ops.control_dependencies(None):
+ return array_ops.guarantee_const(ph)
+ else:
+ return ph
def _add_tensor_and_parents(self, tensor):
op = self._add_op_and_parents(tensor.op)
@@ -739,6 +751,57 @@ class _FuncGraph(ops.Graph):
return captured_op
+def _is_guaranteed_const(tensor):
+ """Determines whether `tensor` is guaranteed to be a constant.
+
+ A tensor is guaranteed to be a constant if either it was produced by
+ a `GuaranteeConst` op or if all of its children are guaranteed to be
+ constants.
+
+ Args:
+ tensor: The tensor for which to determine const-ness.
+
+ Returns:
+ True if `tensor` is guaranteed to be a constant, False otherwise.
+ """
+
+ if isinstance(tensor, ops.EagerTensor):
+ return False
+
+ class Work(object):
+
+ def __init__(self, op, leaving):
+ self.op = op
+ self.leaving = leaving
+
+ is_guaranteed_const = lambda op: op.node_def.op == "GuaranteeConst"
+ constants = set([])
+ def all_inputs_const(op):
+ # If all inputs of an op are guaranteed constants, then we can infer that
+ # the op produces a constant as well.
+ return op.inputs and all(inp.op in constants for inp in op.inputs)
+
+ visited = set([])
+ stack = [Work(tensor.op, leaving=False)]
+ while stack:
+ work = stack.pop()
+ if work.leaving:
+ if all_inputs_const(work.op):
+ constants.add(work.op)
+ continue
+ visited.add(work.op)
+ if is_guaranteed_const(work.op):
+ constants.add(work.op)
+ continue
+
+ # This op will be revisited after all its inputs are checked for const-ness.
+ stack.append(Work(work.op, leaving=True))
+ for inp in work.op.inputs:
+ if inp.op not in visited:
+ stack.append(Work(inp.op, leaving=False))
+ return tensor.op in constants
+
+
def _call(sig, *inputs, **kwargs):
"""Adds a node calling a function.
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index d6bc14fbc7..594596ec1e 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -85,7 +85,7 @@ def _OptimizerOptions():
yield cfg
-@test_util.with_c_api
+@test_util.with_c_shapes
class FunctionTest(test.TestCase):
"""Test methods for verifying Function support.
@@ -431,7 +431,6 @@ class FunctionTest(test.TestCase):
"assertion failed.*-3"):
self.assertAllEqual(Foo(constant_op.constant(-3.0)).eval(), 6.0)
- @test_util.disable_c_api # Op._add_control_inputs doesn't work with C API
def testAssertWrapper(self):
@function.Defun(dtypes.float32)
@@ -446,7 +445,6 @@ class FunctionTest(test.TestCase):
"assertion"):
_ = MyFn(100.0).eval()
- @test_util.disable_c_api # Op._add_control_inputs doesn't work with C API
def testWhileLoopCallsFunc(self):
with self.test_session(use_gpu=True) as sess:
@@ -466,7 +464,6 @@ class FunctionTest(test.TestCase):
ans = sess.run(loop)
self.assertAllClose(ans, 131072.)
- @test_util.disable_c_api # Op._add_control_inputs doesn't work with C API
def testControlFlowStrictness(self):
"""Inlined functions must not execute in a untaken control flow branch."""
@@ -1053,8 +1050,31 @@ class FunctionTest(test.TestCase):
self.assertEqual(44.0, sess.run(f_1))
self.assertEqual((42.0, 44.0), sess.run((f_0, f_1)))
+ def testGuaranteedConstsAreCaptured(self):
+ var = variables.Variable(1.0)
+ const = array_ops.guarantee_const(var)
+ also_const = array_ops.identity(const)
+ still_const = array_ops.identity(also_const)
+ not_const = still_const + var
+ also_not_const = array_ops.placeholder(dtypes.float32)
-@test_util.with_c_api
+ @function.Defun()
+ def CapturesGuaranteedConst():
+ output = const + also_const + still_const + not_const + also_not_const
+ first, second, third, fourth, fifth = function.get_extra_args()
+ self.assertEqual("GuaranteeConst", first.consumers()[0].node_def.op)
+ self.assertEqual("GuaranteeConst", second.consumers()[0].node_def.op)
+ self.assertEqual("GuaranteeConst", third.consumers()[0].node_def.op)
+ self.assertNotEqual("GuaranteeConst", fourth.consumers()[0].node_def.op)
+ self.assertNotEqual("GuaranteeConst", fifth.consumers()[0].node_def.op)
+ return output
+
+ with self.test_session(use_gpu=False) as sess:
+ sess.run(var.initializer)
+ _ = sess.run(CapturesGuaranteedConst(), {also_not_const: 1.0})
+
+
+@test_util.with_c_shapes
class FunctionsFromProtos(test.TestCase):
def expectFunctionsEqual(self, func, grad_func=None, new_func=None):
@@ -1256,7 +1276,7 @@ class FunctionsFromProtos(test.TestCase):
FunctionWithAttr.definition.attr["experimental_tag"].s, b"tag_value")
-@test_util.with_c_api
+@test_util.with_c_shapes
class FunctionOverloadTest(test.TestCase):
def testBasic(self):
@@ -1309,7 +1329,7 @@ class FunctionOverloadTest(test.TestCase):
"Successor of x.")
-@test_util.with_c_api
+@test_util.with_c_shapes
class FunctionCaptureByValueTest(test.TestCase):
def testCaptureByValue(self):
@@ -1339,7 +1359,7 @@ class FunctionCaptureByValueTest(test.TestCase):
self.assertAllEqual(y.eval(), [[12.0]])
-@test_util.with_c_api
+@test_util.with_c_shapes
class UnrollLSTMTest(test.TestCase):
BATCH_SIZE = 16
LSTM_DIMS = 32
@@ -1475,7 +1495,7 @@ class UnrollLSTMTest(test.TestCase):
self.assertAllClose(d0, d3, rtol=1e-4, atol=1e-4)
-@test_util.with_c_api
+@test_util.with_c_shapes
class FunctionInlineControlTest(test.TestCase):
def testFoo(self):
@@ -1543,10 +1563,6 @@ def Linear2(w1, b1, w2, b2, x):
return Linear(w2, b2, Linear(w1, b1, x))
-# Set C API before defining module level functions
-ops._USE_C_API = True
-
-
@function.Defun(*[dtypes.float32] * 3)
def LinearWithCApi(w, b, x):
return nn_ops.relu(math_ops.matmul(x, w) + b)
@@ -1557,10 +1573,6 @@ def Linear2WithCApi(w1, b1, w2, b2, x):
return LinearWithCApi(w2, b2, LinearWithCApi(w1, b1, x))
-# Unset C API after defining module level functions
-ops._USE_C_API = False
-
-
class ModuleFunctionTest(test.TestCase):
def testBasic(self):
@@ -1568,18 +1580,6 @@ class ModuleFunctionTest(test.TestCase):
a, b, c, d, e = [
constant_op.constant([[_]], dtype=dtypes.float32) for _ in range(5)
]
- y = Linear(a, b, c)
- z = Linear2(a, b, c, d, e)
- with session.Session() as sess:
- self.assertAllEqual([[1]], sess.run(y))
- self.assertAllEqual([[5]], sess.run(z))
-
- @test_util.enable_c_api
- def testBasicWithCApi(self):
- with ops.Graph().as_default():
- a, b, c, d, e = [
- constant_op.constant([[_]], dtype=dtypes.float32) for _ in range(5)
- ]
y = LinearWithCApi(a, b, c)
z = Linear2WithCApi(a, b, c, d, e)
with session.Session() as sess:
@@ -1587,7 +1587,7 @@ class ModuleFunctionTest(test.TestCase):
self.assertAllEqual([[5]], sess.run(z))
-@test_util.with_c_api
+@test_util.with_c_shapes
class VariableHoistingTest(test.TestCase):
def _testSimpleModel(self, use_forward_func, use_resource=False):
diff --git a/tensorflow/python/framework/graph_util.py b/tensorflow/python/framework/graph_util.py
index a666630e44..c5cc110734 100644
--- a/tensorflow/python/framework/graph_util.py
+++ b/tensorflow/python/framework/graph_util.py
@@ -28,14 +28,3 @@ from tensorflow.python.framework.graph_util_impl import must_run_on_cpu
from tensorflow.python.framework.graph_util_impl import remove_training_nodes
from tensorflow.python.framework.graph_util_impl import tensor_shape_from_node_def_name
# pylint: enable=unused-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = [
- # TODO(drpng): find a good place to reference this.
- "convert_variables_to_constants",
- "extract_sub_graph",
- "must_run_on_cpu",
- "tensor_shape_from_node_def_name",
- "remove_training_nodes",
-]
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index 3f8a8c4bef..5112bea48b 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -572,7 +572,14 @@ def import_graph_def(graph_def,
if node.name in name_to_op:
raise ValueError('Duplicate name \'%s\' in GraphDef.' % node.name)
if node.op not in op_dict:
- raise ValueError('No op named %s in defined operations.' % node.op)
+ raise ValueError(
+ 'No op named %s in defined operations. If the Graph you are '
+ 'importing uses custom ops or any parts of tf.contrib, you '
+ 'should explicitly import the libraries defining those ops '
+ 'before loading the Graph. Note that tf.contrib is lazily loaded '
+ 'when accessed, so simply referencing (e.g.) '
+ '`tf.contrib.resampler` will cause those ops to be made '
+ 'available.' % node.op)
op_def = op_dict[node.op]
output_types = _OutputTypes(node, op_dict)
diff --git a/tensorflow/python/framework/meta_graph.py b/tensorflow/python/framework/meta_graph.py
index 391b17720c..923e76fc9c 100644
--- a/tensorflow/python/framework/meta_graph.py
+++ b/tensorflow/python/framework/meta_graph.py
@@ -439,9 +439,10 @@ def add_collection_def(meta_graph_def, key, graph=None,
else:
getattr(col_def, kind).value.extend([x for x in collection_list])
except Exception as e: # pylint: disable=broad-except
- logging.warning("Error encountered when serializing %s.\n"
+ logging.warning("Issue encountered when serializing %s.\n"
"Type is unsupported, or the types of the items don't "
- "match field type in CollectionDef.\n%s", key, str(e))
+ "match field type in CollectionDef. Note this is a warning "
+ "and probably safe to ignore.\n%s", key, str(e))
if key in meta_graph_def.collection_def:
del meta_graph_def.collection_def[key]
return
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 8cd6820f6a..dd9acdd9eb 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -2557,8 +2557,8 @@ def _set_shape_and_handle_data_for_outputs_c_api(op):
output._shape_val = output._c_api_shape()
# Set the resource handle data for compatibility with the Python shape
# inference code.
- serialized = c_api.ResourceHandleShapeAndType(
- op._graph._c_graph, output._as_tf_output())
+ serialized = c_api.GetResourceHandleShapeAndType(op._graph._c_graph,
+ output._as_tf_output())
if serialized:
output._handle_data = (
cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData
@@ -4998,7 +4998,7 @@ def _colocate_with_for_gradient(op, gradient_uid, ignore_existing=False):
default_graph = get_default_graph()
if isinstance(op, EagerTensor):
if default_graph.building_function:
- op = internal_convert_to_tensor(op)
+ return default_graph.device(op.device)
else:
raise ValueError("Encountered an Eager-defined Tensor during graph "
"construction, but a function was not being built.")
diff --git a/tensorflow/python/framework/test_util.py b/tensorflow/python/framework/test_util.py
index 5a8bc43727..dc56d88066 100644
--- a/tensorflow/python/framework/test_util.py
+++ b/tensorflow/python/framework/test_util.py
@@ -21,6 +21,7 @@ from __future__ import print_function
import contextlib
import gc
+import itertools
import math
import random
import re
@@ -1212,8 +1213,14 @@ class TensorFlowTestCase(googletest.TestCase):
self.assertTrue(self._NDArrayNear(ndarray1, ndarray2, err), msg=msg)
def _GetNdArray(self, a):
+ # If a is a tensor then convert it to ndarray
+ if isinstance(a, ops.Tensor):
+ if isinstance(a, ops._EagerTensorBase):
+ return a.numpy()
+ else:
+ a = self.evaluate(a)
if not isinstance(a, np.ndarray):
- a = np.array(a)
+ return np.array(a)
return a
def _assertArrayLikeAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
@@ -1286,8 +1293,8 @@ class TensorFlowTestCase(googletest.TestCase):
# Try to directly compare a, b as ndarrays; if not work, then traverse
# through the sequence, which is more expensive.
try:
- a_as_ndarray = np.array(a)
- b_as_ndarray = np.array(b)
+ a_as_ndarray = self._GetNdArray(a)
+ b_as_ndarray = self._GetNdArray(b)
self._assertArrayLikeAllClose(
a_as_ndarray,
b_as_ndarray,
@@ -1322,16 +1329,18 @@ class TensorFlowTestCase(googletest.TestCase):
raise
def assertAllClose(self, a, b, rtol=1e-6, atol=1e-6, msg=None):
- """Asserts that two structures of numpy arrays, have near values.
+ """Asserts that two structures of numpy arrays or Tensors, have near values.
`a` and `b` can be arbitrarily nested structures. A layer of a nested
structure can be a `dict`, `namedtuple`, `tuple` or `list`.
Args:
a: The expected numpy `ndarray`, or anything that can be converted into a
- numpy `ndarray`, or any arbitrarily nested of structure of these.
+ numpy `ndarray` (including Tensor), or any arbitrarily nested of
+ structure of these.
b: The actual numpy `ndarray`, or anything that can be converted into a
- numpy `ndarray`, or any arbitrarily nested of structure of these.
+ numpy `ndarray` (including Tensor), or any arbitrarily nested of
+ structure of these.
rtol: relative tolerance.
atol: absolute tolerance.
msg: Optional message to report on failure.
@@ -1391,8 +1400,26 @@ class TensorFlowTestCase(googletest.TestCase):
self.assertAllClose(a, b, rtol=rtol, atol=atol, msg=msg)
+ def assertNotAllClose(self, a, b, **kwargs):
+ """Assert that two numpy arrays, or or Tensors, do not have near values.
+
+ Args:
+ a: the first value to compare.
+ b: the second value to compare.
+ **kwargs: additional keyword arguments to be passed to the underlying
+ `assertAllClose` call.
+
+ Raises:
+ AssertionError: If `a` and `b` are unexpectedly close at all elements.
+ """
+ try:
+ self.assertAllClose(a, b, **kwargs)
+ except AssertionError:
+ return
+ raise AssertionError("The two values are close at all elements")
+
def assertAllEqual(self, a, b, msg=None):
- """Asserts that two numpy arrays have the same values.
+ """Asserts that two numpy arrays or Tensors have the same values.
Args:
a: the expected numpy ndarray or anything can be converted to one.
@@ -1424,6 +1451,174 @@ class TensorFlowTestCase(googletest.TestCase):
print("not equal rhs = ", y)
np.testing.assert_array_equal(a, b, err_msg=msg)
+ def assertAllGreater(self, a, comparison_target):
+ """Assert element values are all greater than a target value.
+
+ Args:
+ a: The numpy `ndarray`, or anything that can be converted into a
+ numpy `ndarray` (including Tensor).
+ comparison_target: The target value of comparison.
+ """
+ a = self._GetNdArray(a)
+ self.assertGreater(np.min(a), comparison_target)
+
+ def assertAllLess(self, a, comparison_target):
+ """Assert element values are all greater than a target value.
+
+ Args:
+ a: The numpy `ndarray`, or anything that can be converted into a
+ numpy `ndarray` (including Tensor).
+ comparison_target: The target value of comparison.
+ """
+ a = self._GetNdArray(a)
+ self.assertLess(np.max(a), comparison_target)
+
+ def assertAllGreaterEqual(self, a, comparison_target):
+ """Assert element values are all greater than a target value.
+
+ Args:
+ a: The numpy `ndarray`, or anything that can be converted into a
+ numpy `ndarray` (including Tensor).
+ comparison_target: The target value of comparison.
+ """
+ a = self._GetNdArray(a)
+ self.assertGreaterEqual(np.min(a), comparison_target)
+
+ def assertAllLessEqual(self, a, comparison_target):
+ """Assert element values are all greater than a target value.
+
+ Args:
+ a: The numpy `ndarray`, or anything that can be converted into a
+ numpy `ndarray` (including Tensor).
+ comparison_target: The target value of comparison.
+ """
+ a = self._GetNdArray(a)
+ self.assertLessEqual(np.max(a), comparison_target)
+
+ def _format_subscripts(self, subscripts, value, limit=10, indent=2):
+ """Generate a summary of ndarray subscripts as a list of str.
+
+ If limit == N, this method will print up to the first N subscripts on
+ separate
+ lines. A line of ellipses (...) will be appended at the end if the number of
+ subscripts exceeds N.
+
+ Args:
+ subscripts: The tensor (np.ndarray) subscripts, of the same format as
+ np.where()'s return value, i.e., a tuple of arrays with each array
+ corresponding to a dimension. E.g., (array([1, 1]), array([0, 1])).
+ value: (np.ndarray) value of the tensor.
+ limit: (int) The maximum number of indices to print.
+ indent: (int) Number of characters to indent at the beginning of each
+ line.
+
+ Returns:
+ (list of str) the multi-line representation of the subscripts and values,
+ potentially with omission at the end.
+ """
+ lines = []
+ subscripts = np.transpose(subscripts)
+ prefix = " " * indent
+ for subscript in itertools.islice(subscripts, limit):
+ lines.append(prefix + str(subscript) + " : " +
+ str(value[tuple(subscript)]))
+ if len(subscripts) > limit:
+ lines.append(prefix + "...")
+ return lines
+
+ def assertAllInRange(self,
+ target,
+ lower_bound,
+ upper_bound,
+ open_lower_bound=False,
+ open_upper_bound=False):
+ """Assert that elements in a Tensor are all in a given range.
+
+ Args:
+ target: The numpy `ndarray`, or anything that can be converted into a
+ numpy `ndarray` (including Tensor).
+ lower_bound: lower bound of the range
+ upper_bound: upper bound of the range
+ open_lower_bound: (`bool`) whether the lower bound is open (i.e., > rather
+ than the default >=)
+ open_upper_bound: (`bool`) whether the upper bound is open (i.e., < rather
+ than the default <=)
+
+ Raises:
+ AssertionError:
+ if the value tensor does not have an ordered numeric type (float* or
+ int*), or
+ if there are nan values, or
+ if any of the elements do not fall in the specified range.
+ """
+ target = self._GetNdArray(target)
+ if not (np.issubdtype(target.dtype, np.float) or
+ np.issubdtype(target.dtype, np.integer)):
+ raise AssertionError(
+ "The value of %s does not have an ordered numeric type, instead it "
+ "has type: %s" % (target, target.dtype))
+
+ nan_subscripts = np.where(np.isnan(target))
+ if np.size(nan_subscripts):
+ raise AssertionError(
+ "%d of the %d element(s) are NaN. "
+ "Subscripts(s) and value(s) of the NaN element(s):\n" %
+ (len(nan_subscripts[0]), np.size(target)) +
+ "\n".join(self._format_subscripts(nan_subscripts, target)))
+
+ range_str = (("(" if open_lower_bound else "[") + str(lower_bound) + ", " +
+ str(upper_bound) + (")" if open_upper_bound else "]"))
+
+ violations = (
+ np.less_equal(target, lower_bound)
+ if open_lower_bound else np.less(target, lower_bound))
+ violations = np.logical_or(
+ violations,
+ np.greater_equal(target, upper_bound)
+ if open_upper_bound else np.greater(target, upper_bound))
+ violation_subscripts = np.where(violations)
+ if np.size(violation_subscripts):
+ raise AssertionError(
+ "%d of the %d element(s) are outside the range %s. " %
+ (len(violation_subscripts[0]), np.size(target), range_str) +
+ "Subscript(s) and value(s) of the offending elements:\n" +
+ "\n".join(self._format_subscripts(violation_subscripts, target)))
+
+ def assertAllInSet(self, target, expected_set):
+ """Assert that elements of a Tensor are all in a given closed set.
+
+ Args:
+ target: The numpy `ndarray`, or anything that can be converted into a
+ numpy `ndarray` (including Tensor).
+ expected_set: (`list`, `tuple` or `set`) The closed set that the elements
+ of the value of `target` are expected to fall into.
+
+ Raises:
+ AssertionError:
+ if any of the elements do not fall into `expected_set`.
+ """
+ target = self._GetNdArray(target)
+
+ # Elements in target that are not in expected_set.
+ diff = np.setdiff1d(target.flatten(), list(expected_set))
+ if np.size(diff):
+ raise AssertionError("%d unique element(s) are not in the set %s: %s" %
+ (np.size(diff), expected_set, diff))
+
+ def assertDTypeEqual(self, target, expected_dtype):
+ """Assert ndarray data type is equal to expected.
+
+ Args:
+ target: The numpy `ndarray`, or anything that can be converted into a
+ numpy `ndarray` (including Tensor).
+ expected_dtype: Expected data type.
+ """
+ target = self._GetNdArray(target)
+ if not isinstance(target, list):
+ arrays = [target]
+ for arr in arrays:
+ self.assertEqual(arr.dtype, expected_dtype)
+
# pylint: disable=g-doc-return-or-yield
@contextlib.contextmanager
def assertRaisesWithPredicateMatch(self, exception_type,
diff --git a/tensorflow/python/framework/test_util_test.py b/tensorflow/python/framework/test_util_test.py
index 02ffa93bae..8d492256aa 100644
--- a/tensorflow/python/framework/test_util_test.py
+++ b/tensorflow/python/framework/test_util_test.py
@@ -31,13 +31,16 @@ from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_ops # pylint: disable=unused-import
from tensorflow.python.framework import test_util
from tensorflow.python.ops import control_flow_ops
+from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import resource_variable_ops
+from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
@@ -209,6 +212,21 @@ class TestUtilTest(test_util.TensorFlowTestCase):
self._WeMustGoDeeper("name")
self._WeMustGoDeeper("orig")
+ def testAllCloseTensors(self):
+ a_raw_data = [[1, 2, 3], [4, 5, 6], [7, 8, 9]]
+ a = constant_op.constant(a_raw_data)
+ b = math_ops.add(1, constant_op.constant([[0, 1, 2], [3, 4, 5], [6, 7, 8]]))
+ self.assertAllClose(a, b)
+ self.assertAllClose(a, a_raw_data)
+
+ a_dict = {"key": a}
+ b_dict = {"key": b}
+ self.assertAllClose(a_dict, b_dict)
+
+ x_list = [a, b]
+ y_list = [a_raw_data, b]
+ self.assertAllClose(x_list, y_list)
+
def testAllCloseScalars(self):
self.assertAllClose(7, 7 + 1e-8)
with self.assertRaisesRegexp(AssertionError, r"Not equal to tolerance"):
@@ -317,6 +335,12 @@ class TestUtilTest(test_util.TensorFlowTestCase):
rtol=1e-8, atol=1e-8
)
+ self.assertAllCloseAccordingToType(
+ constant_op.constant([1e-8], dtype=dtypes.float64),
+ constant_op.constant([2e-8], dtype=dtypes.float64),
+ rtol=1e-8,
+ atol=1e-8)
+
with (self.assertRaises(AssertionError)):
self.assertAllCloseAccordingToType(
np.asarray([1e-7], dtype=np.float64),
@@ -332,6 +356,14 @@ class TestUtilTest(test_util.TensorFlowTestCase):
float_rtol=1e-7, float_atol=1e-7
)
+ self.assertAllCloseAccordingToType(
+ constant_op.constant([1e-7], dtype=dtypes.float32),
+ constant_op.constant([2e-7], dtype=dtypes.float32),
+ rtol=1e-8,
+ atol=1e-8,
+ float_rtol=1e-7,
+ float_atol=1e-7)
+
with (self.assertRaises(AssertionError)):
self.assertAllCloseAccordingToType(
np.asarray([1e-6], dtype=np.float32),
@@ -349,6 +381,16 @@ class TestUtilTest(test_util.TensorFlowTestCase):
half_rtol=1e-4, half_atol=1e-4
)
+ self.assertAllCloseAccordingToType(
+ constant_op.constant([1e-4], dtype=dtypes.float16),
+ constant_op.constant([2e-4], dtype=dtypes.float16),
+ rtol=1e-8,
+ atol=1e-8,
+ float_rtol=1e-7,
+ float_atol=1e-7,
+ half_rtol=1e-4,
+ half_atol=1e-4)
+
with (self.assertRaises(AssertionError)):
self.assertAllCloseAccordingToType(
np.asarray([1e-3], dtype=np.float16),
@@ -358,6 +400,157 @@ class TestUtilTest(test_util.TensorFlowTestCase):
half_rtol=1e-4, half_atol=1e-4
)
+ def testAssertAllEqual(self):
+ i = variables.Variable([100] * 3, dtype=dtypes.int32, name="i")
+ j = constant_op.constant([20] * 3, dtype=dtypes.int32, name="j")
+ k = math_ops.add(i, j, name="k")
+
+ self.evaluate(variables.global_variables_initializer())
+ self.assertAllEqual([120] * 3, k)
+ self.assertAllEqual([20] * 3, j)
+
+ def testAssertNotAllClose(self):
+ # Test with arrays
+ self.assertNotAllClose([0.1], [0.2])
+ with self.assertRaises(AssertionError):
+ self.assertNotAllClose([-1.0, 2.0], [-1.0, 2.0])
+
+ # Test with tensors
+ x = constant_op.constant([1.0, 1.0], name="x")
+ y = math_ops.add(x, x)
+
+ self.assertAllClose([2.0, 2.0], y)
+ self.assertNotAllClose([0.9, 1.0], x)
+
+ with self.assertRaises(AssertionError):
+ self.assertNotAllClose([1.0, 1.0], x)
+
+ def testAssertNotAllCloseRTol(self):
+ # Test with arrays
+ with self.assertRaises(AssertionError):
+ self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], rtol=0.2)
+
+ # Test with tensors
+ x = constant_op.constant([1.0, 1.0], name="x")
+ y = math_ops.add(x, x)
+
+ self.assertAllClose([2.0, 2.0], y)
+
+ with self.assertRaises(AssertionError):
+ self.assertNotAllClose([0.9, 1.0], x, rtol=0.2)
+
+ def testAssertNotAllCloseATol(self):
+ # Test with arrays
+ with self.assertRaises(AssertionError):
+ self.assertNotAllClose([1.1, 2.1], [1.0, 2.0], atol=0.2)
+
+ # Test with tensors
+ x = constant_op.constant([1.0, 1.0], name="x")
+ y = math_ops.add(x, x)
+
+ self.assertAllClose([2.0, 2.0], y)
+
+ with self.assertRaises(AssertionError):
+ self.assertNotAllClose([0.9, 1.0], x, atol=0.2)
+
+ def testAssertAllGreaterLess(self):
+ x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
+ y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
+ z = math_ops.add(x, y)
+
+ self.assertAllClose([110.0, 120.0, 130.0], z)
+
+ self.assertAllGreater(x, 95.0)
+ self.assertAllLess(x, 125.0)
+
+ with self.assertRaises(AssertionError):
+ self.assertAllGreater(x, 105.0)
+ with self.assertRaises(AssertionError):
+ self.assertAllGreater(x, 125.0)
+
+ with self.assertRaises(AssertionError):
+ self.assertAllLess(x, 115.0)
+ with self.assertRaises(AssertionError):
+ self.assertAllLess(x, 95.0)
+
+ def testAssertAllGreaterLessEqual(self):
+ x = constant_op.constant([100.0, 110.0, 120.0], dtype=dtypes.float32)
+ y = constant_op.constant([10.0] * 3, dtype=dtypes.float32)
+ z = math_ops.add(x, y)
+
+ self.assertAllEqual([110.0, 120.0, 130.0], z)
+
+ self.assertAllGreaterEqual(x, 95.0)
+ self.assertAllLessEqual(x, 125.0)
+
+ with self.assertRaises(AssertionError):
+ self.assertAllGreaterEqual(x, 105.0)
+ with self.assertRaises(AssertionError):
+ self.assertAllGreaterEqual(x, 125.0)
+
+ with self.assertRaises(AssertionError):
+ self.assertAllLessEqual(x, 115.0)
+ with self.assertRaises(AssertionError):
+ self.assertAllLessEqual(x, 95.0)
+
+ def testAssertAllInRangeWithNonNumericValuesFails(self):
+ s1 = constant_op.constant("Hello, ", name="s1")
+ c = constant_op.constant([1 + 2j, -3 + 5j], name="c")
+ b = constant_op.constant([False, True], name="b")
+
+ with self.assertRaises(AssertionError):
+ self.assertAllInRange(s1, 0.0, 1.0)
+ with self.assertRaises(AssertionError):
+ self.assertAllInRange(c, 0.0, 1.0)
+ with self.assertRaises(AssertionError):
+ self.assertAllInRange(b, 0, 1)
+
+ def testAssertAllInRange(self):
+ x = constant_op.constant([10.0, 15.0], name="x")
+ self.assertAllInRange(x, 10, 15)
+
+ with self.assertRaises(AssertionError):
+ self.assertAllInRange(x, 10, 15, open_lower_bound=True)
+ with self.assertRaises(AssertionError):
+ self.assertAllInRange(x, 10, 15, open_upper_bound=True)
+ with self.assertRaises(AssertionError):
+ self.assertAllInRange(
+ x, 10, 15, open_lower_bound=True, open_upper_bound=True)
+
+ def testAssertAllInRangeErrorMessageEllipses(self):
+ x_init = np.array([[10.0, 15.0]] * 12)
+ x = constant_op.constant(x_init, name="x")
+ with self.assertRaises(AssertionError):
+ self.assertAllInRange(x, 5, 10)
+
+ def testAssertAllInRangeDetectsNaNs(self):
+ x = constant_op.constant(
+ [[np.nan, 0.0], [np.nan, np.inf], [np.inf, np.nan]], name="x")
+ with self.assertRaises(AssertionError):
+ self.assertAllInRange(x, 0.0, 2.0)
+
+ def testAssertAllInRangeWithInfinities(self):
+ x = constant_op.constant([10.0, np.inf], name="x")
+ self.assertAllInRange(x, 10, np.inf)
+ with self.assertRaises(AssertionError):
+ self.assertAllInRange(x, 10, np.inf, open_upper_bound=True)
+
+ def testAssertAllInSet(self):
+ b = constant_op.constant([True, False], name="b")
+ x = constant_op.constant([13, 37], name="x")
+
+ self.assertAllInSet(b, [False, True])
+ self.assertAllInSet(b, (False, True))
+ self.assertAllInSet(b, {False, True})
+ self.assertAllInSet(x, [0, 13, 37, 42])
+ self.assertAllInSet(x, (0, 13, 37, 42))
+ self.assertAllInSet(x, {0, 13, 37, 42})
+
+ with self.assertRaises(AssertionError):
+ self.assertAllInSet(b, [False])
+ with self.assertRaises(AssertionError):
+ self.assertAllInSet(x, (42,))
+
def testRandomSeed(self):
# Call setUp again for WithCApi case (since it makes a new defeault graph
# after setup).
diff --git a/tensorflow/python/keras/BUILD b/tensorflow/python/keras/BUILD
index 1c58553156..a14a121b6e 100755
--- a/tensorflow/python/keras/BUILD
+++ b/tensorflow/python/keras/BUILD
@@ -382,6 +382,7 @@ py_test(
size = "large",
srcs = ["_impl/keras/applications/nasnet_test.py"],
srcs_version = "PY2AND3",
+ tags = ["nomsan"], # times out, http://b/78573625
deps = [
":keras",
"//tensorflow/python:client_testlib",
diff --git a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
index 12775fccec..7b7288793d 100644
--- a/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
+++ b/tensorflow/python/keras/_impl/keras/applications/mobilenet.py
@@ -79,7 +79,6 @@ from tensorflow.python.keras._impl.keras.applications import imagenet_utils
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import _obtain_input_shape
from tensorflow.python.keras._impl.keras.applications.imagenet_utils import decode_predictions
from tensorflow.python.keras._impl.keras.engine import InputSpec
-from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.keras._impl.keras.engine.network import get_source_inputs
from tensorflow.python.keras._impl.keras.layers import Activation
from tensorflow.python.keras._impl.keras.layers import BatchNormalization
diff --git a/tensorflow/python/keras/_impl/keras/engine/base_layer.py b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
index 6c68d25127..a3e78c95dc 100644
--- a/tensorflow/python/keras/_impl/keras/engine/base_layer.py
+++ b/tensorflow/python/keras/_impl/keras/engine/base_layer.py
@@ -20,7 +20,6 @@ from __future__ import print_function
import collections
import inspect # Necessary supplement to tf_inspect to deal with variadic args.
-import re
import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
@@ -35,6 +34,10 @@ from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.utils import generic_utils
+from tensorflow.python.keras._impl.keras.utils import tf_utils
+# A module that only depends on `keras.layers` import these from here.
+from tensorflow.python.keras._impl.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import
+from tensorflow.python.keras._impl.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope as vs
@@ -177,7 +180,8 @@ class Layer(checkpointable.CheckpointableBase):
def _init_set_name(self, name, zero_based=True):
if not name:
self._name = unique_layer_name(
- to_snake_case(self.__class__.__name__), zero_based=zero_based)
+ generic_utils.to_snake_case(self.__class__.__name__),
+ zero_based=zero_based)
else:
self._name = name
@@ -318,7 +322,7 @@ class Layer(checkpointable.CheckpointableBase):
# Requesting input-conditional updates.
inputs = nest.flatten(inputs)
- reachable = get_reachable_from_inputs(inputs, self.updates)
+ reachable = tf_utils.get_reachable_from_inputs(inputs, self.updates)
updates = []
for update in self.updates:
if update in reachable:
@@ -419,7 +423,7 @@ class Layer(checkpointable.CheckpointableBase):
# The losses we want to return will be part of this set.
# To avoid unnecessary work, we stop the search in case all of
# `self.losses` have been retrieved.
- reachable = get_reachable_from_inputs(inputs, self.losses)
+ reachable = tf_utils.get_reachable_from_inputs(inputs, self.losses)
losses = []
for loss in self.losses:
if loss in reachable:
@@ -639,7 +643,7 @@ class Layer(checkpointable.CheckpointableBase):
if not hasattr(self, '_call_fn_args'):
self._call_fn_args = estimator_util.fn_args(self.call)
if ('mask' in self._call_fn_args and 'mask' not in kwargs and
- not is_all_none(previous_mask)):
+ not generic_utils.is_all_none(previous_mask)):
# The previous layer generated a mask, and mask was not explicitly pass
# to __call__, hence we set previous_mask as the default value.
kwargs['mask'] = previous_mask
@@ -726,8 +730,17 @@ class Layer(checkpointable.CheckpointableBase):
if hasattr(self, '_initial_weights') and self._initial_weights is not None:
self.set_weights(self._initial_weights)
del self._initial_weights
+ self._post_build_cleanup()
return outputs
+ def _post_build_cleanup(self):
+ """Hooks to run after all sub-Layers are built."""
+ # Note that in addition to Layer.__call__, this method is called by Model
+ # after building a graph network (which skips __call__). It should be called
+ # when possible if self.built may have switched from False to True, and is
+ # idempotent.
+ pass # No-op for Layers which don't override this method.
+
def apply(self, inputs, *args, **kwargs):
"""Apply the layer on a input.
@@ -1606,9 +1619,9 @@ class Node(object):
# Following 2 properties: input and output shapes.
# List of shape tuples, shapes of input_tensors.
- self.input_shapes = [static_shape(x) for x in input_tensors]
+ self.input_shapes = [backend.int_shape(x) for x in input_tensors]
# List of shape tuples, shapes of output_tensors.
- self.output_shapes = [static_shape(x) for x in output_tensors]
+ self.output_shapes = [backend.int_shape(x) for x in output_tensors]
# Optional keyword arguments to layer's `call`.
self.arguments = arguments
@@ -1669,91 +1682,6 @@ class DeferredTensor(object):
self.dtype.name)
-def shape_type_conversion(fn):
- """Decorator that handles tuple/TensorShape conversion.
-
- Used in `compute_output_shape` and `build`.
-
- Arguments:
- fn: function to wrap.
-
- Returns:
- Wrapped function.
- """
-
- def wrapper(instance, input_shape):
- if input_shape is not None:
- if isinstance(input_shape, list):
- input_shape = [
- tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape]
- else:
- input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
- output_shape = fn(instance, input_shape)
- if output_shape is not None:
- if isinstance(output_shape, list):
- return [tensor_shape.TensorShape(x) for x in output_shape]
- return tensor_shape.TensorShape(output_shape)
-
- return wrapper
-
-
-def object_list_uid(object_list):
- """Creates a single string from object ids."""
- object_list = nest.flatten(object_list)
- return ', '.join([str(abs(id(x))) for x in object_list])
-
-
-def static_shape(x):
- """Get the static shape of a Tensor, or None if it is unavailable."""
- if x is None:
- return None
- try:
- return tuple(x.get_shape().as_list())
- except ValueError:
- return None
-
-
-def get_reachable_from_inputs(inputs, targets=None):
- """Returns the set of tensors/ops reachable from `inputs`.
-
- Stops if all targets have been found (target is optional).
-
- Only valid in Symbolic mode, not Eager mode.
-
- Args:
- inputs: List of tensors.
- targets: List of tensors.
-
- Returns:
- A set of tensors reachable from the inputs (includes the inputs themselves).
- """
- reachable = set(inputs)
- if targets:
- targets = set(targets)
- queue = inputs[:]
-
- while queue:
- x = queue.pop()
- if isinstance(x, ops.Operation):
- outputs = x.outputs[:] or []
- outputs += x._control_outputs
- elif isinstance(x, ops.Tensor):
- outputs = x.consumers()
- elif isinstance(x, tf_variables.Variable):
- outputs = [x.op]
- else:
- raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
-
- for y in outputs:
- if y not in reachable:
- reachable.add(y)
- queue.insert(0, y)
-
- if targets and targets.issubset(reachable):
- return reachable
- return reachable
-
-
def unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace='',
zero_based=False):
"""Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
@@ -1800,28 +1728,6 @@ def unique_layer_name(name, name_uid_map=None, avoid_names=None, namespace='',
return proposed_name
-def to_snake_case(name):
- intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
- insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
- # If the class is private the name starts with "_" which is not secure
- # for creating scopes. We prefix the name with "private" in this case.
- if insecure[0] != '_':
- return insecure
- return 'private' + insecure
-
-
-def is_all_none(iterable_or_element):
- if not isinstance(iterable_or_element, (list, tuple)):
- iterable = [iterable_or_element]
- else:
- iterable = iterable_or_element
- # We cannot use Python's `any` because the iterable may return Tensors.
- for element in iterable:
- if element is not None:
- return False
- return True
-
-
def have_all_keras_metadata(iterable_or_element):
if not isinstance(iterable_or_element, (list, tuple)):
iterable = [iterable_or_element]
@@ -1852,14 +1758,6 @@ def collect_previous_mask(input_tensors):
return masks
-def is_tensor_or_tensor_list(v):
- v = nest.flatten(v)
- if v and isinstance(v[0], ops.Tensor):
- return True
- else:
- return False
-
-
def get_default_graph_uid_map():
# TODO(fchollet): refactor this into backend.
graph = ops.get_default_graph()
diff --git a/tensorflow/python/keras/_impl/keras/engine/network.py b/tensorflow/python/keras/_impl/keras/engine/network.py
index cc177c14a8..a0229be346 100644
--- a/tensorflow/python/keras/_impl/keras/engine/network.py
+++ b/tensorflow/python/keras/_impl/keras/engine/network.py
@@ -22,21 +22,26 @@ from __future__ import print_function
import copy
import json
import os
+import weakref
import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
+from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
-from tensorflow.python.keras._impl.keras import backend as K
+from tensorflow.python.keras._impl.keras import backend
from tensorflow.python.keras._impl.keras.engine import base_layer
from tensorflow.python.keras._impl.keras.engine import saving
from tensorflow.python.keras._impl.keras.utils import generic_utils
+from tensorflow.python.keras._impl.keras.utils import tf_utils
from tensorflow.python.keras._impl.keras.utils.io_utils import ask_to_proceed_with_overwrite
from tensorflow.python.keras._impl.keras.utils.layer_utils import print_summary as print_layer_summary
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import checkpointable
+from tensorflow.python.training import checkpointable_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
@@ -114,6 +119,13 @@ class Network(base_layer.Layer):
self._outbound_nodes = []
self._inbound_nodes = []
+ self._checkpointable_saver = checkpointable_utils.CheckpointableSaver(
+ weakref.ref(self))
+ # A zero-argument function which should be called and set back to None as
+ # soon as the network is built (only applicable to subclassed Models). Runs
+ # restore operations when graph building.
+ self._in_progress_restore_finalizer = None
+
def _init_graph_network(self, inputs, outputs, name=None):
self._uses_inputs_arg = True
# Normalize and set self.inputs, self.outputs.
@@ -126,7 +138,7 @@ class Network(base_layer.Layer):
else:
self.outputs = [outputs]
- # User-prodived argument validation.
+ # User-provided argument validation.
if context.executing_eagerly():
# Check that all inputs/outputs are DeferredTensors.
for tensor in self.inputs:
@@ -227,6 +239,8 @@ class Network(base_layer.Layer):
self._layers = layers
self._layers_by_depth = layers_by_depth
+ self._track_layers(layers)
+
# Create the node linking internal inputs to internal outputs.
base_layer.Node(
outbound_layer=self,
@@ -241,8 +255,8 @@ class Network(base_layer.Layer):
for x in self.inputs:
mask = x._keras_mask if hasattr(x, '_keras_mask') else None # pylint: disable=protected-access
masks.append(mask)
- mask_cache_key = (base_layer.object_list_uid(self.inputs) + '_' +
- base_layer.object_list_uid(masks))
+ mask_cache_key = (generic_utils.object_list_uid(self.inputs) + '_' +
+ generic_utils.object_list_uid(masks))
masks = []
for x in self.outputs:
mask = x._keras_mask if hasattr(x, '_keras_mask') else None # pylint: disable=protected-access
@@ -263,7 +277,7 @@ class Network(base_layer.Layer):
self.input_names.append(layer.name)
if layer.is_placeholder:
self._feed_input_names.append(layer.name)
- self._feed_input_shapes.append(K.int_shape(self.inputs[i]))
+ self._feed_input_shapes.append(backend.int_shape(self.inputs[i]))
# layer.input gives an error in eager mode
if not context.executing_eagerly():
self._feed_inputs.append(layer.input)
@@ -286,6 +300,23 @@ class Network(base_layer.Layer):
self.inputs = None
self.built = False
+ def _track_layers(self, layers):
+ """Add Checkpointable dependencies on a list of Layers."""
+ weight_layer_index = 0
+ for layer_index, layer in enumerate(layers):
+ if layer.weights:
+ # Keep a separate index for layers which have weights. This allows users
+ # to insert Layers without weights anywhere in the network without
+ # breaking checkpoints.
+ self._track_checkpointable(
+ layer, name='layer_with_weights-%d' % weight_layer_index,
+ overwrite=True)
+ weight_layer_index += 1
+ # Even if it doesn't have weights, we should still track everything in
+ # case it has/will have Checkpointable dependencies.
+ self._track_checkpointable(
+ layer, name='layer-%d' % layer_index, overwrite=True)
+
def __setattr__(self, name, value):
if isinstance(value, (base_layer.Layer, Network)):
try:
@@ -362,7 +393,7 @@ class Network(base_layer.Layer):
weights = []
for layer in self.layers:
weights += layer.weights
- return K.batch_get_value(weights)
+ return backend.batch_get_value(weights)
def set_weights(self, weights):
"""Sets the weights of the model.
@@ -378,7 +409,7 @@ class Network(base_layer.Layer):
for sw, w in zip(layer.weights, layer_weights):
tuples.append((sw, w))
weights = weights[num_param:]
- K.batch_set_value(tuples)
+ backend.batch_set_value(tuples)
def compute_mask(self, inputs, mask):
if not self._is_graph_network:
@@ -389,8 +420,8 @@ class Network(base_layer.Layer):
masks = [None for _ in range(len(inputs))]
else:
masks = generic_utils.to_list(mask)
- cache_key = (base_layer.object_list_uid(inputs)
- + '_' + base_layer.object_list_uid(masks))
+ cache_key = (generic_utils.object_list_uid(inputs)
+ + '_' + generic_utils.object_list_uid(masks))
if cache_key in self._output_mask_cache:
return self._output_mask_cache[cache_key]
else:
@@ -504,7 +535,7 @@ class Network(base_layer.Layer):
relevant_inputs += inputs
else:
relevant_inputs.append(inputs)
- reachable = base_layer.get_reachable_from_inputs(relevant_inputs, updates)
+ reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, updates)
relevant_conditional_updates = [x for x in updates if x in reachable]
unconditional_updates = [
x for x in updates if x._unconditional_update] # pylint: disable=protected-access
@@ -541,7 +572,7 @@ class Network(base_layer.Layer):
relevant_inputs += inputs
else:
relevant_inputs.append(inputs)
- reachable = base_layer.get_reachable_from_inputs(relevant_inputs, losses)
+ reachable = tf_utils.get_reachable_from_inputs(relevant_inputs, losses)
relevant_conditional_losses = [x for x in losses if x in reachable]
unconditional_losses = [
x for x in losses if x._unconditional_loss] # pylint: disable=protected-access
@@ -623,8 +654,8 @@ class Network(base_layer.Layer):
if not context.executing_eagerly():
# Try to retrieve cached outputs if the layer has already been called
# on these exact inputs.
- cache_key = (base_layer.object_list_uid(inputs)
- + '_' + base_layer.object_list_uid(masks))
+ cache_key = (generic_utils.object_list_uid(inputs)
+ + '_' + generic_utils.object_list_uid(masks))
if cache_key in self._output_tensor_cache:
# Cache hit.
return self._output_tensor_cache[cache_key]
@@ -656,7 +687,7 @@ class Network(base_layer.Layer):
': model has ' + str(len(self._input_layers)) +
' tensor inputs.')
- cache_key = base_layer.object_list_uid(input_shapes)
+ cache_key = generic_utils.object_list_uid(input_shapes)
if cache_key not in self._output_shape_cache:
# Cache miss. We have to run the network graph manually (recursive calls
# to `compute_output_shape`).
@@ -845,7 +876,7 @@ class Network(base_layer.Layer):
for x in self.outputs:
assert str(id(x)) in tensor_map, 'Could not compute output ' + str(x)
tensor, mask = tensor_map[str(id(x))]
- output_shapes.append(base_layer.static_shape(x))
+ output_shapes.append(backend.int_shape(x))
output_tensors.append(tensor)
output_masks.append(mask)
@@ -859,14 +890,14 @@ class Network(base_layer.Layer):
if not context.executing_eagerly():
# Update cache;
# keys are based on ids on input tensors and inputs masks.
- cache_key = (base_layer.object_list_uid(inputs)
- + '_' + base_layer.object_list_uid(masks))
+ cache_key = (generic_utils.object_list_uid(inputs)
+ + '_' + generic_utils.object_list_uid(masks))
self._output_tensor_cache[cache_key] = output_tensors
self._output_mask_cache[cache_key] = output_masks
if output_shapes is not None:
- input_shapes = [base_layer.static_shape(x) for x in inputs]
- cache_key = base_layer.object_list_uid(input_shapes)
+ input_shapes = [backend.int_shape(x) for x in inputs]
+ cache_key = generic_utils.object_list_uid(input_shapes)
self._output_shape_cache[cache_key] = output_shapes
return output_tensors, output_masks
@@ -1125,62 +1156,160 @@ class Network(base_layer.Layer):
from tensorflow.python.keras._impl.keras.models import save_model # pylint: disable=g-import-not-at-top
save_model(self, filepath, overwrite, include_optimizer)
- def save_weights(self, filepath, overwrite=True):
- """Dumps all layer weights to a HDF5 file.
-
- The weight file has:
- - `layer_names` (attribute), a list of strings
- (ordered names of model layers).
- - For every layer, a `group` named `layer.name`
- - For every such layer group, a group attribute `weight_names`,
- a list of strings
- (ordered names of weights tensor of the layer).
- - For every weight in the layer, a dataset
- storing the weight value, named after the weight tensor.
+ def save_weights(self, filepath, overwrite=True, save_format=None):
+ """Saves all layer weights.
+
+ Either saves in HDF5 or in TensorFlow format based on the `save_format`
+ argument.
+
+ When saving in HDF5 format, the weight file has:
+ - `layer_names` (attribute), a list of strings
+ (ordered names of model layers).
+ - For every layer, a `group` named `layer.name`
+ - For every such layer group, a group attribute `weight_names`,
+ a list of strings
+ (ordered names of weights tensor of the layer).
+ - For every weight in the layer, a dataset
+ storing the weight value, named after the weight tensor.
+
+ When saving in TensorFlow format, all objects referenced by the network are
+ saved in the same format as `tf.train.Checkpoint`, including any `Layer`
+ instances or `Optimizer` instances assigned to object attributes. For
+ networks constructed from inputs and outputs using `tf.keras.Model(inputs,
+ outputs)`, `Layer` instances used by the network are tracked/saved
+ automatically. For user-defined classes which inherit from `tf.keras.Model`,
+ `Layer` instances must be assigned to object attributes, typically in the
+ constructor. See the documentation of `tf.train.Checkpoint` and
+ `tf.keras.Model` for details.
Arguments:
- filepath: String, path to the file to save the weights to.
+ filepath: String, path to the file to save the weights to. When saving
+ in TensorFlow format, this is the prefix used for checkpoint files
+ (multiple files are generated). Note that the '.h5' suffix causes
+ weights to be saved in HDF5 format.
overwrite: Whether to silently overwrite any existing file at the
target location, or provide the user with a manual prompt.
+ save_format: Either 'tf' or 'h5'. A `filepath` ending in '.h5' or
+ '.keras' will default to HDF5 if `save_format` is `None`. Otherwise
+ `None` defaults to 'tf'.
Raises:
- ImportError: If h5py is not available.
+ ImportError: If h5py is not available when attempting to save in HDF5
+ format.
+ ValueError: For invalid/unknown format arguments.
"""
- if h5py is None:
- raise ImportError('`save_weights` requires h5py.')
+ filepath_is_h5 = filepath.endswith('.h5') or filepath.endswith('.keras')
+ if save_format is None:
+ if filepath_is_h5:
+ save_format = 'h5'
+ else:
+ save_format = 'tf'
+ else:
+ user_format = save_format.lower().strip()
+ if user_format in ('tensorflow', 'tf'):
+ save_format = 'tf'
+ elif user_format in ('hdf5', 'h5', 'keras'):
+ save_format = 'h5'
+ else:
+ raise ValueError(
+ 'Unknown format "%s". Was expecting one of {"tf", "h5"}.' % (
+ save_format,))
+ if save_format == 'tf' and filepath_is_h5:
+ raise ValueError(
+ ('save_weights got save_format="tf"/"tensorflow", but the '
+ 'filepath ("%s") looks like an HDF5 file. Omit the ".h5"/".keras" '
+ 'when saving in TensorFlow format.')
+ % filepath)
+
+ if save_format == 'h5' and h5py is None:
+ raise ImportError(
+ '`save_weights` requires h5py when saving in hdf5.')
+ if save_format == 'tf':
+ check_filepath = filepath + '.index'
+ else:
+ check_filepath = filepath
# If file exists and should not be overwritten:
- if not overwrite and os.path.isfile(filepath):
- proceed = ask_to_proceed_with_overwrite(filepath)
+ if not overwrite and os.path.isfile(check_filepath):
+ proceed = ask_to_proceed_with_overwrite(check_filepath)
if not proceed:
return
- with h5py.File(filepath, 'w') as f:
- saving.save_weights_to_hdf5_group(f, self.layers)
+ if save_format == 'h5':
+ with h5py.File(filepath, 'w') as f:
+ saving.save_weights_to_hdf5_group(f, self.layers)
+ else:
+ self._checkpointable_saver.save(filepath)
def load_weights(self, filepath, by_name=False):
- """Loads all layer weights from a HDF5 save file.
-
- If `by_name` is False (default) weights are loaded
- based on the network's topology, meaning the architecture
- should be the same as when the weights were saved.
- Note that layers that don't have weights are not taken
- into account in the topological ordering, so adding or
- removing layers is fine as long as they don't have weights.
-
- If `by_name` is True, weights are loaded into layers
- only if they share the same name. This is useful
- for fine-tuning or transfer-learning models where
+ """Loads all layer weights, either from a TensorFlow or an HDF5 weight file.
+
+ If `by_name` is False weights are loaded based on the network's
+ topology. This means the architecture should be the same as when the weights
+ were saved. Note that layers that don't have weights are not taken into
+ account in the topological ordering, so adding or removing layers is fine as
+ long as they don't have weights.
+
+ If `by_name` is True, weights are loaded into layers only if they share the
+ same name. This is useful for fine-tuning or transfer-learning models where
some of the layers have changed.
+ Only topological loading (`by_name=False`) is supported when loading weights
+ from the TensorFlow format. Note that topological loading differs slightly
+ between TensorFlow and HDF5 formats for user-defined classes inheriting from
+ `tf.keras.Model`: HDF5 loads based on a flattened list of weights, while the
+ TensorFlow format loads based on the object-local names of attributes to
+ which layers are assigned in the `Model`'s constructor.
+
Arguments:
- filepath: String, path to the weights file to load.
- by_name: Boolean, whether to load weights by name
- or by topological order.
+ filepath: String, path to the weights file to load. For weight files in
+ TensorFlow format, this is the file prefix (the same as was passed
+ to `save_weights`).
+ by_name: Boolean, whether to load weights by name or by topological
+ order. Only topological loading is supported for weight files in
+ TensorFlow format.
+
+ Returns:
+ When loading a weight file in TensorFlow format, returns the same status
+ object as `tf.train.Checkpoint.restore`. When graph building, restore
+ ops are run automatically as soon as the network is built (on first call
+ for user-defined classes inheriting from `Model`, immediately if it is
+ already built).
+
+ When loading weights in HDF5 format, returns `None`.
Raises:
- ImportError: If h5py is not available.
+ ImportError: If h5py is not available and the weight file is in HDF5
+ format.
"""
+ try:
+ pywrap_tensorflow.NewCheckpointReader(filepath)
+ save_format = 'tf'
+ except errors_impl.DataLossError:
+ # The checkpoint is not readable in TensorFlow format. Try HDF5.
+ save_format = 'h5'
+ if save_format == 'tf':
+ status = self._checkpointable_saver.restore(filepath)
+ if by_name:
+ raise NotImplementedError(
+ 'Weights may only be loaded based on topology into Models when '
+ 'loading TensorFlow-formatted weights (got by_name=True to '
+ 'load_weights).')
+ if not context.executing_eagerly():
+ finalizer = status.run_restore_ops
+ if self.built:
+ finalizer()
+ else:
+ # Hold on to this status object until the network is built (for
+ # subclassed Models). Then we'll run restore ops if necessary.
+ self._in_progress_restore_finalizer = finalizer
+ return status
if h5py is None:
- raise ImportError('`load_weights` requires h5py.')
+ raise ImportError(
+ '`load_weights` requires h5py when loading weights from HDF5.')
+ if self._is_graph_network and not self.built:
+ raise NotImplementedError(
+ 'Unable to load weights saved in HDF5 format into a subclassed '
+ 'Model which has not created its variables yet. Call the Model '
+ 'first, then load the weights.')
with h5py.File(filepath, 'r') as f:
if 'layer_names' not in f.attrs and 'model_weights' in f:
f = f['model_weights']
@@ -1189,6 +1318,14 @@ class Network(base_layer.Layer):
else:
saving.load_weights_from_hdf5_group(f, self.layers)
+ def _post_build_cleanup(self):
+ super(Network, self)._post_build_cleanup()
+ if self._in_progress_restore_finalizer is not None:
+ # Runs queued restore operations left over from load_weights when graph
+ # building.
+ self._in_progress_restore_finalizer()
+ self._in_progress_restore_finalizer = None
+
def _updated_config(self):
"""Util shared between different serialization methods.
@@ -1202,7 +1339,7 @@ class Network(base_layer.Layer):
'class_name': self.__class__.__name__,
'config': config,
'keras_version': keras_version,
- 'backend': K.backend()
+ 'backend': backend.backend()
}
return model_config
diff --git a/tensorflow/python/keras/_impl/keras/engine/saving_test.py b/tensorflow/python/keras/_impl/keras/engine/saving_test.py
index 3b1578cddf..edd296a281 100644
--- a/tensorflow/python/keras/_impl/keras/engine/saving_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/saving_test.py
@@ -24,7 +24,15 @@ import tempfile
import numpy as np
+from tensorflow.python.eager import context
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import test_util
from tensorflow.python.keras._impl import keras
+from tensorflow.python.keras._impl.keras.engine import training
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import random_ops
from tensorflow.python.platform import test
from tensorflow.python.training import training as training_module
@@ -55,12 +63,18 @@ class TestWeightSavingAndLoading(test.TestCase):
with self.assertRaises(ValueError):
model.set_weights(weights[::-1])
- if h5py is None:
- return # Skip rest of test if H5py isn't available.
-
temp_dir = self.get_temp_dir()
self.addCleanup(shutil.rmtree, temp_dir)
+ no_extension_path = os.path.join(temp_dir, 'test')
+ model.save_weights(no_extension_path, save_format='tf')
+ model.load_weights(no_extension_path)
+ y = model.predict(x)
+ self.assertAllClose(ref_y, y)
+
+ if h5py is None:
+ return # Skip rest of test if H5py isn't available.
+
h5_path = os.path.join(temp_dir, 'test.h5')
model.save_weights(h5_path)
model.load_weights(h5_path)
@@ -71,6 +85,11 @@ class TestWeightSavingAndLoading(test.TestCase):
y = model.predict(x)
self.assertAllClose(ref_y, y)
+ model.save_weights(no_extension_path, save_format='hdf5')
+ model.load_weights(no_extension_path)
+ y = model.predict(x)
+ self.assertAllClose(ref_y, y)
+
def test_weight_preprocessing(self):
input_dim = 3
output_dim = 3
@@ -457,5 +476,194 @@ class TestWholeModelSaving(test.TestCase):
os.remove(fname)
+class SubclassedModel(training.Model):
+
+ def __init__(self):
+ super(SubclassedModel, self).__init__()
+ self.x_layer = keras.layers.Dense(3)
+ self.b_layer = keras.layers.Dense(1)
+
+ def call(self, a):
+ return self.b_layer(self.x_layer(a))
+
+
+class TestWeightSavingAndLoadingTFFormat(test.TestCase):
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_tensorflow_format_overwrite(self):
+ with self.test_session() as session:
+ model = SubclassedModel()
+ temp_dir = self.get_temp_dir()
+ prefix = os.path.join(temp_dir, 'ckpt')
+
+ x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
+ executing_eagerly = context.executing_eagerly()
+ model(x) # pylint: disable=not-callable
+ if not executing_eagerly:
+ session.run([v.initializer for v in model.variables])
+ model.save_weights(prefix, save_format='tensorflow')
+ model.save_weights(prefix, save_format='tensorflow', overwrite=True)
+ with self.assertRaises(EOFError):
+ # Indirectly tests that the user is prompted
+ model.save_weights(prefix, save_format='tensorflow', overwrite=False)
+
+ def test_no_graph_pollution(self):
+ with context.graph_mode():
+ graph = ops.Graph()
+ with graph.as_default(), self.test_session(graph) as session:
+ model = SubclassedModel()
+ temp_dir = self.get_temp_dir()
+ prefix = os.path.join(temp_dir, 'ckpt')
+
+ x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
+ model(x) # pylint: disable=not-callable
+ session.run([v.initializer for v in model.variables])
+ model.save_weights(prefix, save_format='tensorflow')
+ op_count = len(graph.get_operations())
+ model.save_weights(prefix, save_format='tensorflow')
+ self.assertEqual(len(graph.get_operations()), op_count)
+
+ model.load_weights(prefix)
+ op_count = len(graph.get_operations())
+ model.load_weights(prefix)
+ self.assertEqual(len(graph.get_operations()), op_count)
+
+ def _weight_loading_test_template(self, make_model_fn):
+ with self.test_session() as session:
+ model = make_model_fn()
+ temp_dir = self.get_temp_dir()
+ prefix = os.path.join(temp_dir, 'ckpt')
+
+ x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
+ executing_eagerly = context.executing_eagerly()
+ ref_y_tensor = model(x)
+ if not executing_eagerly:
+ session.run([v.initializer for v in model.variables])
+ ref_y = self.evaluate(ref_y_tensor)
+ model.save_weights(prefix, save_format='tf')
+ for v in model.variables:
+ self.evaluate(
+ v.assign(random_ops.random_normal(shape=array_ops.shape(v))))
+
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ model.load_weights(prefix)
+ y = self.evaluate(model(x))
+ self.assertAllClose(ref_y, y)
+
+ # Test restore-on-create if this is a subclassed Model (graph Networks
+ # will have already created their variables).
+ load_model = make_model_fn()
+ load_model.load_weights(prefix)
+ restore_on_create_y_tensor = load_model(x)
+ restore_on_create_y = self.evaluate(restore_on_create_y_tensor)
+ self.assertAllClose(ref_y, restore_on_create_y)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_weight_loading_graph_model(self):
+ def _make_graph_model():
+ a = keras.layers.Input(shape=(2,))
+ x = keras.layers.Dense(3)(a)
+ b = keras.layers.Dense(1)(x)
+ return keras.models.Model(a, b)
+
+ self._weight_loading_test_template(_make_graph_model)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_weight_loading_subclassed_model(self):
+ self._weight_loading_test_template(SubclassedModel)
+
+ def _new_layer_weight_loading_test_template(
+ self, first_model_fn, second_model_fn, restore_init_fn):
+ with self.test_session() as session:
+ model = first_model_fn()
+ temp_dir = self.get_temp_dir()
+ prefix = os.path.join(temp_dir, 'ckpt')
+
+ x = constant_op.constant(np.random.random((3, 2)), dtype=dtypes.float32)
+ executing_eagerly = context.executing_eagerly()
+ ref_y_tensor = model(x)
+ if not executing_eagerly:
+ session.run([v.initializer for v in model.variables])
+ ref_y = self.evaluate(ref_y_tensor)
+ model.save_weights(prefix)
+ for v in model.variables:
+ self.evaluate(
+ v.assign(random_ops.random_normal(shape=array_ops.shape(v))))
+
+ self.addCleanup(shutil.rmtree, temp_dir)
+
+ second_model = second_model_fn()
+ second_model.load_weights(prefix)
+ second_model(x)
+ self.evaluate(restore_init_fn(second_model))
+ second_model.save_weights(prefix)
+ # Check that the second model's checkpoint loads into the original model
+ model.load_weights(prefix)
+ y = self.evaluate(model(x))
+ self.assertAllClose(ref_y, y)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_weight_loading_graph_model_added_layer(self):
+ def _save_graph_model():
+ a = keras.layers.Input(shape=(2,))
+ x = keras.layers.Dense(3, name='first')(a)
+ b = keras.layers.Dense(1, name='second')(x)
+ return keras.models.Model(a, b)
+ def _restore_graph_model():
+ a = keras.layers.Input(shape=(2,))
+ x = keras.layers.Dense(3, name='first')(a)
+ y = keras.layers.Dense(1, name='second')(x)
+ b = keras.layers.Dense(3, name='secondjr')(y)
+ return keras.models.Model(a, b)
+ def _restore_init_fn(restore_model):
+ return [v.initializer for v in restore_model.layers[-1].variables]
+
+ self._new_layer_weight_loading_test_template(
+ _save_graph_model, _restore_graph_model,
+ _restore_init_fn)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_weight_loading_graph_model_added_no_weight_layer(self):
+ def _save_graph_model():
+ a = keras.layers.Input(shape=(2,))
+ x = keras.layers.Dense(3, name='first')(a)
+ b = keras.layers.Dense(1, name='second')(x)
+ return keras.models.Model(a, b)
+ def _restore_graph_model():
+ a = keras.layers.Input(shape=(2,))
+ x = keras.layers.Dense(3, name='first')(a)
+ y = keras.layers.Dropout(rate=0.1)(x)
+ b = keras.layers.Dense(1, name='second')(y)
+ return keras.models.Model(a, b)
+ def _restore_init_fn(restore_model):
+ del restore_model # unused
+ return []
+
+ self._new_layer_weight_loading_test_template(
+ _save_graph_model, _restore_graph_model,
+ _restore_init_fn)
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_weight_loading_subclassed_model_added_layer(self):
+
+ class SubclassedModelRestore(training.Model):
+
+ def __init__(self):
+ super(SubclassedModelRestore, self).__init__()
+ self.x_layer = keras.layers.Dense(3)
+ self.y_layer = keras.layers.Dense(3)
+ self.b_layer = keras.layers.Dense(1)
+
+ def call(self, a):
+ return self.b_layer(self.y_layer(self.x_layer(a)))
+
+ def _restore_init_fn(restore_model):
+ return [v.initializer for v in restore_model.y_layer.variables]
+
+ self._new_layer_weight_loading_test_template(
+ SubclassedModel, SubclassedModelRestore,
+ _restore_init_fn)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/engine/sequential.py b/tensorflow/python/keras/_impl/keras/engine/sequential.py
index bd13ca6713..8626626ca1 100644
--- a/tensorflow/python/keras/_impl/keras/engine/sequential.py
+++ b/tensorflow/python/keras/_impl/keras/engine/sequential.py
@@ -29,7 +29,6 @@ from tensorflow.python.keras._impl.keras.engine.input_layer import Input
from tensorflow.python.keras._impl.keras.engine.input_layer import InputLayer
from tensorflow.python.keras._impl.keras.engine.training import Model
from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import checkpointable
from tensorflow.python.util.tf_export import tf_export
@@ -193,36 +192,6 @@ class Sequential(Model):
self.build()
else:
self._layers.append(layer)
- # In implementing Checkpointable, Sequential does not track its Layers
- # normally, since they may be added and removed (in pop()). Instead, it
- # names everything on demand (gathering dependencies in
- # _checkpoint_dependencies, and looking them up in
- # _lookup_dependency). _handle_deferred_dependencies just checks whether an
- # existing checkpoint load targets this Layer, it does not create a
- # dependency on the Layer.
- self._handle_deferred_dependencies(
- name='layer-%d' % (len(self._layers) - 1), checkpointable=layer)
-
- @property
- def _checkpoint_dependencies(self):
- """For implementing Checkpointable. Layers which should be saved."""
- return super(Sequential, self)._checkpoint_dependencies + [
- checkpointable.CheckpointableReference(
- name='layer-%d' % layer_index, ref=layer)
- for layer_index, layer in enumerate(self._layers)]
-
- def _lookup_dependency(self, name):
- """For implementing Checkpointable. Looks up a Layer."""
- super_lookup = super(Sequential, self)._lookup_dependency(name=name)
- if super_lookup is not None:
- return super_lookup
- if name.startswith('layer-'):
- try:
- return self._layers[int(name[6:])]
- except IndexError:
- return None
- else:
- return None
def pop(self):
"""Removes the last layer in the model.
@@ -257,6 +226,7 @@ class Sequential(Model):
if self.inputs:
self._init_graph_network(self.inputs, self.outputs, name=self.name)
self.built = True
+ self._track_layers(self._layers)
def predict_proba(self, x, batch_size=32, verbose=0):
"""Generates class probability predictions for the input samples.
diff --git a/tensorflow/python/keras/_impl/keras/engine/topology_test.py b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
index 49cc1cd3b3..6993a04289 100644
--- a/tensorflow/python/keras/_impl/keras/engine/topology_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/topology_test.py
@@ -964,16 +964,16 @@ class GraphUtilsTest(test.TestCase):
x_5 = x_3 * pl_1
self.assertEqual(
- keras.engine.base_layer.get_reachable_from_inputs([pl_1]),
+ keras.utils.tf_utils.get_reachable_from_inputs([pl_1]),
{pl_1, x_1, x_4, x_5, x_1.op, x_4.op, x_5.op})
self.assertEqual(
- keras.engine.base_layer.get_reachable_from_inputs([pl_1, pl_2]),
+ keras.utils.tf_utils.get_reachable_from_inputs([pl_1, pl_2]),
{pl_1, pl_2, x_1, x_2, x_4, x_5, x_1.op, x_2.op, x_4.op, x_5.op})
self.assertEqual(
- keras.engine.base_layer.get_reachable_from_inputs([pl_3]),
+ keras.utils.tf_utils.get_reachable_from_inputs([pl_3]),
{pl_3, x_3, x_5, x_3.op, x_5.op})
self.assertEqual(
- keras.engine.base_layer.get_reachable_from_inputs([x_3]),
+ keras.utils.tf_utils.get_reachable_from_inputs([x_3]),
{x_3, x_5, x_5.op})
diff --git a/tensorflow/python/keras/_impl/keras/engine/training.py b/tensorflow/python/keras/_impl/keras/engine/training.py
index 146e8fdac9..5f9b3e8c7d 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training.py
@@ -584,6 +584,7 @@ class Model(Network):
updates=updates,
name='train_function',
**self._function_kwargs)
+ self._post_build_cleanup()
def _make_test_function(self):
if not hasattr(self, 'test_function'):
@@ -601,6 +602,7 @@ class Model(Network):
updates=self.state_updates + self.metrics_updates,
name='test_function',
**self._function_kwargs)
+ self._post_build_cleanup()
def _make_predict_function(self):
if not hasattr(self, 'predict_function'):
@@ -619,6 +621,7 @@ class Model(Network):
updates=self.state_updates,
name='predict_function',
**kwargs)
+ self._post_build_cleanup()
def _standardize_user_data(self,
x,
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager.py b/tensorflow/python/keras/_impl/keras/engine/training_eager.py
index ad239d6151..34adeb7599 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_eager.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_eager.py
@@ -96,7 +96,7 @@ def _eager_metrics_fn(model, outputs, targets):
model.metrics_names.append(metric_name)
with backend.name_scope(metric_name):
- metric_result = metric_fn(outputs[i], targets[i])
+ metric_result = metric_fn(targets[i], outputs[i])
metric_names.append(metric_name)
metric_results.append(backend.mean(metric_result))
diff --git a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
index deaf1d1306..5adb3ef940 100644
--- a/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
+++ b/tensorflow/python/keras/_impl/keras/engine/training_eager_test.py
@@ -625,6 +625,7 @@ class LossWeightingTest(test.TestCase):
bad_w_np = np.random.random((10, 2, 2))
model.fit(x_np, [y_np, y_np], epochs=1, sample_weight={'1': bad_w_np})
+
class CorrectnessTest(test.TestCase):
@tf_test_util.run_in_graph_and_eager_modes()
@@ -648,6 +649,27 @@ class CorrectnessTest(test.TestCase):
self.assertEqual(
np.around(history.history['loss'][-1], decimals=4), 0.6173)
+ @tf_test_util.run_in_graph_and_eager_modes()
+ def test_metrics_correctness(self):
+ model = keras.Sequential()
+ model.add(keras.layers.Dense(3,
+ activation='relu',
+ input_dim=4,
+ kernel_initializer='ones'))
+ model.add(keras.layers.Dense(1,
+ activation='sigmoid',
+ kernel_initializer='ones'))
+ model.compile(loss='mae',
+ metrics=['acc'],
+ optimizer=RMSPropOptimizer(learning_rate=0.001))
+ x = np.ones((100, 4))
+ y = np.ones((100, 1))
+ outs = model.evaluate(x, y)
+ self.assertEqual(outs[1], 1.)
+ y = np.zeros((100, 1))
+ outs = model.evaluate(x, y)
+ self.assertEqual(outs[1], 0.)
+
if __name__ == '__main__':
ops.enable_eager_execution()
test.main()
diff --git a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
index 11ca89d625..89931db3c0 100644
--- a/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
+++ b/tensorflow/python/keras/_impl/keras/layers/advanced_activations.py
@@ -25,7 +25,7 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
+from tensorflow.python.keras._impl.keras.utils import tf_utils
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@@ -64,7 +64,7 @@ class LeakyReLU(Layer):
base_config = super(LeakyReLU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape
@@ -119,7 +119,7 @@ class PReLU(Layer):
else:
self.shared_axes = list(shared_axes)
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
param_shape = list(input_shape[1:])
self.param_broadcast = [False] * len(param_shape)
@@ -162,7 +162,7 @@ class PReLU(Layer):
base_config = super(PReLU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape
@@ -201,7 +201,7 @@ class ELU(Layer):
base_config = super(ELU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape
@@ -241,7 +241,7 @@ class ThresholdedReLU(Layer):
base_config = super(ThresholdedReLU, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape
@@ -275,6 +275,6 @@ class Softmax(Layer):
base_config = super(Softmax, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional.py b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
index 12b965587f..9971f12773 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional.py
@@ -28,7 +28,6 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
# imports for backwards namespace compatibility
# pylint: disable=unused-import
from tensorflow.python.keras._impl.keras.layers.pooling import AveragePooling1D
@@ -39,6 +38,7 @@ from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling2D
from tensorflow.python.keras._impl.keras.layers.pooling import MaxPooling3D
# pylint: enable=unused-import
from tensorflow.python.keras._impl.keras.utils import conv_utils
+from tensorflow.python.keras._impl.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import nn
from tensorflow.python.ops import nn_ops
@@ -1731,7 +1731,7 @@ class DepthwiseConv2D(Conv2D):
return outputs
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
rows = input_shape[2]
diff --git a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
index 6b2a1d98fe..be25bbc043 100644
--- a/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/convolutional_recurrent.py
@@ -28,11 +28,11 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.keras._impl.keras.layers.recurrent import _generate_dropout_mask
from tensorflow.python.keras._impl.keras.layers.recurrent import RNN
from tensorflow.python.keras._impl.keras.utils import conv_utils
from tensorflow.python.keras._impl.keras.utils import generic_utils
+from tensorflow.python.keras._impl.keras.utils import tf_utils
from tensorflow.python.util.tf_export import tf_export
@@ -168,7 +168,7 @@ class ConvRNN2D(RNN):
self.input_spec = [InputSpec(ndim=5)]
self.states = None
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
@@ -209,7 +209,7 @@ class ConvRNN2D(RNN):
for _ in range(2)]
return output_shape
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
# Note input_shape will be list of shapes of initial states and
# constants if these are passed in __call__.
diff --git a/tensorflow/python/keras/_impl/keras/layers/embeddings.py b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
index 07b8726b85..2b353ac007 100644
--- a/tensorflow/python/keras/_impl/keras/layers/embeddings.py
+++ b/tensorflow/python/keras/_impl/keras/layers/embeddings.py
@@ -23,7 +23,7 @@ from tensorflow.python.keras._impl.keras import constraints
from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
+from tensorflow.python.keras._impl.keras.utils import tf_utils
from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@@ -114,7 +114,7 @@ class Embedding(Layer):
self.mask_zero = mask_zero
self.input_length = input_length
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
self.embeddings = self.add_weight(
shape=(self.input_dim, self.output_dim),
@@ -130,7 +130,7 @@ class Embedding(Layer):
else:
return math_ops.not_equal(inputs, 0)
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if self.input_length is None:
return input_shape + (self.output_dim,)
diff --git a/tensorflow/python/keras/_impl/keras/layers/local.py b/tensorflow/python/keras/_impl/keras/layers/local.py
index 13d96e9392..caae820fb3 100644
--- a/tensorflow/python/keras/_impl/keras/layers/local.py
+++ b/tensorflow/python/keras/_impl/keras/layers/local.py
@@ -25,8 +25,8 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
from tensorflow.python.keras._impl.keras.utils import conv_utils
+from tensorflow.python.keras._impl.keras.utils import tf_utils
from tensorflow.python.util.tf_export import tf_export
@@ -120,7 +120,7 @@ class LocallyConnected1D(Layer):
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=3)
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
input_dim = input_shape[2]
if input_dim is None:
@@ -148,7 +148,7 @@ class LocallyConnected1D(Layer):
self.input_spec = InputSpec(ndim=3, axes={2: input_dim})
self.built = True
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
length = conv_utils.conv_output_length(input_shape[1], self.kernel_size[0],
self.padding, self.strides[0])
@@ -307,7 +307,7 @@ class LocallyConnected2D(Layer):
self.bias_constraint = constraints.get(bias_constraint)
self.input_spec = InputSpec(ndim=4)
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
if self.data_format == 'channels_last':
input_row, input_col = input_shape[1:-1]
@@ -350,7 +350,7 @@ class LocallyConnected2D(Layer):
self.input_spec = InputSpec(ndim=4, axes={-1: input_filter})
self.built = True
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if self.data_format == 'channels_first':
rows = input_shape[2]
diff --git a/tensorflow/python/keras/_impl/keras/layers/merge.py b/tensorflow/python/keras/_impl/keras/layers/merge.py
index 7c87e6c067..2b6cf7c8a9 100644
--- a/tensorflow/python/keras/_impl/keras/layers/merge.py
+++ b/tensorflow/python/keras/_impl/keras/layers/merge.py
@@ -22,7 +22,7 @@ from __future__ import print_function
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine.base_layer import Layer
-from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
+from tensorflow.python.keras._impl.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import nn
@@ -83,7 +83,7 @@ class _Merge(Layer):
output_shape.append(i)
return tuple(output_shape)
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
# Used purely for shape validation.
if not isinstance(input_shape, list):
@@ -181,7 +181,7 @@ class _Merge(Layer):
else:
return self._merge_function(inputs)
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if input_shape[0] is None:
output_shape = None
@@ -274,7 +274,7 @@ class Subtract(_Merge):
```
"""
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
super(Subtract, self).build(input_shape)
if len(input_shape) != 2:
@@ -370,7 +370,7 @@ class Concatenate(_Merge):
self.supports_masking = True
self._reshape_required = False
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
# Used purely for shape validation.
if not isinstance(input_shape, list) or len(input_shape) < 2:
@@ -392,7 +392,7 @@ class Concatenate(_Merge):
def _merge_function(self, inputs):
return K.concatenate(inputs, axis=self.axis)
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if not isinstance(input_shape, list):
raise ValueError('A `Concatenate` layer should be called '
@@ -478,7 +478,7 @@ class Dot(_Merge):
self.supports_masking = True
self._reshape_required = False
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
# Used purely for shape validation.
if not isinstance(input_shape, list) or len(input_shape) != 2:
@@ -523,7 +523,7 @@ class Dot(_Merge):
output = K.batch_dot(x1, x2, axes)
return output
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if not isinstance(input_shape, list) or len(input_shape) != 2:
raise ValueError('A `Dot` layer should be called '
diff --git a/tensorflow/python/keras/_impl/keras/layers/noise.py b/tensorflow/python/keras/_impl/keras/layers/noise.py
index 72dc7a1ff8..addac5b137 100644
--- a/tensorflow/python/keras/_impl/keras/layers/noise.py
+++ b/tensorflow/python/keras/_impl/keras/layers/noise.py
@@ -22,7 +22,7 @@ import numpy as np
from tensorflow.python.keras._impl.keras import backend as K
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
+from tensorflow.python.keras._impl.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.util.tf_export import tf_export
@@ -69,7 +69,7 @@ class GaussianNoise(Layer):
base_config = super(GaussianNoise, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape
@@ -116,7 +116,7 @@ class GaussianDropout(Layer):
base_config = super(GaussianDropout, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape
@@ -188,6 +188,6 @@ class AlphaDropout(Layer):
base_config = super(AlphaDropout, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
return input_shape
diff --git a/tensorflow/python/keras/_impl/keras/layers/recurrent.py b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
index f53db987ff..f6d6e1391c 100644
--- a/tensorflow/python/keras/_impl/keras/layers/recurrent.py
+++ b/tensorflow/python/keras/_impl/keras/layers/recurrent.py
@@ -31,8 +31,8 @@ from tensorflow.python.keras._impl.keras import initializers
from tensorflow.python.keras._impl.keras import regularizers
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
-from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
+from tensorflow.python.keras._impl.keras.utils import generic_utils
+from tensorflow.python.keras._impl.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
@@ -107,7 +107,7 @@ class StackedRNNCells(Layer):
# Call the cells in order and store the returned states.
new_nested_states = []
for cell, states in zip(self.cells, nested_states):
- if has_arg(cell.call, 'constants'):
+ if generic_utils.has_arg(cell.call, 'constants'):
inputs, states = cell.call(inputs, states, constants=constants,
**kwargs)
else:
@@ -122,14 +122,14 @@ class StackedRNNCells(Layer):
states += cell_states
return inputs, states
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
if isinstance(input_shape, list):
constants_shape = input_shape[1:]
input_shape = input_shape[0]
for cell in self.cells:
if isinstance(cell, Layer):
- if has_arg(cell.call, 'constants'):
+ if generic_utils.has_arg(cell.call, 'constants'):
cell.build([input_shape] + constants_shape)
else:
cell.build(input_shape)
@@ -429,7 +429,7 @@ class RNN(Layer):
def states(self, states):
self._states = states
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
@@ -461,7 +461,7 @@ class RNN(Layer):
else:
return output_mask
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
# Note input_shape will be list of shapes of initial states and
# constants if these are passed in __call__.
@@ -609,11 +609,11 @@ class RNN(Layer):
'or `batch_shape` argument to your Input layer.')
kwargs = {}
- if has_arg(self.cell.call, 'training'):
+ if generic_utils.has_arg(self.cell.call, 'training'):
kwargs['training'] = training
if constants:
- if not has_arg(self.cell.call, 'constants'):
+ if not generic_utils.has_arg(self.cell.call, 'constants'):
raise ValueError('RNN cell does not support constants')
def step(inputs, states):
@@ -884,7 +884,7 @@ class SimpleRNNCell(Layer):
self._dropout_mask = None
self._recurrent_dropout_mask = None
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
self.kernel = self.add_weight(
shape=(input_shape[-1], self.units),
@@ -1287,7 +1287,7 @@ class GRUCell(Layer):
self._dropout_mask = None
self._recurrent_dropout_mask = None
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(
@@ -1824,7 +1824,7 @@ class LSTMCell(Layer):
self._dropout_mask = None
self._recurrent_dropout_mask = None
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def build(self, input_shape):
input_dim = input_shape[-1]
self.kernel = self.add_weight(
@@ -2388,7 +2388,7 @@ class Recurrent(Layer):
self.dropout = 0
self.recurrent_dropout = 0
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
if isinstance(input_shape, list):
input_shape = input_shape[0]
diff --git a/tensorflow/python/keras/_impl/keras/layers/wrappers.py b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
index 9aee5f03b6..34a8eeeb5b 100644
--- a/tensorflow/python/keras/_impl/keras/layers/wrappers.py
+++ b/tensorflow/python/keras/_impl/keras/layers/wrappers.py
@@ -23,11 +23,10 @@ import copy
from tensorflow.python.framework import tensor_shape
from tensorflow.python.keras._impl.keras import backend as K
-from tensorflow.python.keras._impl.keras.engine import base_layer
from tensorflow.python.keras._impl.keras.engine import InputSpec
from tensorflow.python.keras._impl.keras.engine import Layer
-from tensorflow.python.keras._impl.keras.engine.base_layer import shape_type_conversion
-from tensorflow.python.keras._impl.keras.utils.generic_utils import has_arg
+from tensorflow.python.keras._impl.keras.utils import generic_utils
+from tensorflow.python.keras._impl.keras.utils import tf_utils
from tensorflow.python.ops import array_ops
from tensorflow.python.util.tf_export import tf_export
@@ -183,7 +182,7 @@ class TimeDistributed(Wrapper):
def call(self, inputs, training=None, mask=None):
kwargs = {}
- if has_arg(self.layer.call, 'training'):
+ if generic_utils.has_arg(self.layer.call, 'training'):
kwargs['training'] = training
uses_learning_phase = False # pylint: disable=redefined-outer-name
@@ -213,7 +212,7 @@ class TimeDistributed(Wrapper):
input_length = array_ops.shape(inputs)[1]
# Shape: (num_samples * timesteps, ...). And track the
# transformation in self._input_map.
- input_uid = base_layer.object_list_uid(inputs)
+ input_uid = generic_utils.object_list_uid(inputs)
inputs = array_ops.reshape(inputs, (-1,) + input_shape[2:])
self._input_map[input_uid] = inputs
# (num_samples * timesteps, ...)
@@ -305,7 +304,7 @@ class Bidirectional(Wrapper):
self.forward_layer.set_weights(weights[:nw // 2])
self.backward_layer.set_weights(weights[nw // 2:])
- @shape_type_conversion
+ @tf_utils.shape_type_conversion
def compute_output_shape(self, input_shape):
output_shape = tuple(self.forward_layer.compute_output_shape(
input_shape).as_list())
@@ -383,12 +382,13 @@ class Bidirectional(Wrapper):
def call(self, inputs, training=None, mask=None, initial_state=None):
kwargs = {}
- if has_arg(self.layer.call, 'training'):
+ if generic_utils.has_arg(self.layer.call, 'training'):
kwargs['training'] = training
- if has_arg(self.layer.call, 'mask'):
+ if generic_utils.has_arg(self.layer.call, 'mask'):
kwargs['mask'] = mask
- if initial_state is not None and has_arg(self.layer.call, 'initial_state'):
+ if initial_state is not None and generic_utils.has_arg(
+ self.layer.call, 'initial_state'):
forward_state = initial_state[:len(initial_state) // 2]
backward_state = initial_state[len(initial_state) // 2:]
y = self.forward_layer.call(inputs, initial_state=forward_state, **kwargs)
diff --git a/tensorflow/python/keras/_impl/keras/metrics_test.py b/tensorflow/python/keras/_impl/keras/metrics_test.py
index 9deaab0c05..13cef97812 100644
--- a/tensorflow/python/keras/_impl/keras/metrics_test.py
+++ b/tensorflow/python/keras/_impl/keras/metrics_test.py
@@ -75,74 +75,75 @@ class KerasMetricsTest(test.TestCase):
self.assertEqual(result, 0.)
def test_stateful_metrics(self):
- np.random.seed(1334)
-
- class BinaryTruePositives(keras.layers.Layer):
- """Stateful Metric to count the total true positives over all batches.
-
- Assumes predictions and targets of shape `(samples, 1)`.
-
- Arguments:
- threshold: Float, lower limit on prediction value that counts as a
- positive class prediction.
- name: String, name for the metric.
- """
-
- def __init__(self, name='true_positives', **kwargs):
- super(BinaryTruePositives, self).__init__(name=name, **kwargs)
- self.true_positives = keras.backend.variable(value=0, dtype='int32')
-
- def reset_states(self):
- keras.backend.set_value(self.true_positives, 0)
+ with self.test_session():
+ np.random.seed(1334)
- def __call__(self, y_true, y_pred):
- """Computes the number of true positives in a batch.
+ class BinaryTruePositives(keras.layers.Layer):
+ """Stateful Metric to count the total true positives over all batches.
- Args:
- y_true: Tensor, batch_wise labels
- y_pred: Tensor, batch_wise predictions
+ Assumes predictions and targets of shape `(samples, 1)`.
- Returns:
- The total number of true positives seen this epoch at the
- completion of the batch.
+ Arguments:
+ threshold: Float, lower limit on prediction value that counts as a
+ positive class prediction.
+ name: String, name for the metric.
"""
- y_true = math_ops.cast(y_true, 'int32')
- y_pred = math_ops.cast(math_ops.round(y_pred), 'int32')
- correct_preds = math_ops.cast(math_ops.equal(y_pred, y_true), 'int32')
- true_pos = math_ops.cast(
- math_ops.reduce_sum(correct_preds * y_true), 'int32')
- current_true_pos = self.true_positives * 1
- self.add_update(
- state_ops.assign_add(self.true_positives, true_pos),
- inputs=[y_true, y_pred])
- return current_true_pos + true_pos
-
- metric_fn = BinaryTruePositives()
- config = keras.metrics.serialize(metric_fn)
- metric_fn = keras.metrics.deserialize(
- config, custom_objects={'BinaryTruePositives': BinaryTruePositives})
-
- # Test on simple model
- inputs = keras.Input(shape=(2,))
- outputs = keras.layers.Dense(1, activation='sigmoid')(inputs)
- model = keras.Model(inputs, outputs)
- model.compile(optimizer='sgd',
- loss='binary_crossentropy',
- metrics=['acc', metric_fn])
-
- # Test fit, evaluate
- samples = 1000
- x = np.random.random((samples, 2))
- y = np.random.randint(2, size=(samples, 1))
- model.fit(x, y, epochs=1, batch_size=10)
- outs = model.evaluate(x, y, batch_size=10)
- preds = model.predict(x)
-
- def ref_true_pos(y_true, y_pred):
- return np.sum(np.logical_and(y_pred > 0.5, y_true == 1))
-
- # Test correctness (e.g. updates should have been run)
- self.assertAllClose(outs[2], ref_true_pos(y, preds), atol=1e-5)
+
+ def __init__(self, name='true_positives', **kwargs):
+ super(BinaryTruePositives, self).__init__(name=name, **kwargs)
+ self.true_positives = keras.backend.variable(value=0, dtype='int32')
+
+ def reset_states(self):
+ keras.backend.set_value(self.true_positives, 0)
+
+ def __call__(self, y_true, y_pred):
+ """Computes the number of true positives in a batch.
+
+ Args:
+ y_true: Tensor, batch_wise labels
+ y_pred: Tensor, batch_wise predictions
+
+ Returns:
+ The total number of true positives seen this epoch at the
+ completion of the batch.
+ """
+ y_true = math_ops.cast(y_true, 'int32')
+ y_pred = math_ops.cast(math_ops.round(y_pred), 'int32')
+ correct_preds = math_ops.cast(math_ops.equal(y_pred, y_true), 'int32')
+ true_pos = math_ops.cast(
+ math_ops.reduce_sum(correct_preds * y_true), 'int32')
+ current_true_pos = self.true_positives * 1
+ self.add_update(
+ state_ops.assign_add(self.true_positives, true_pos),
+ inputs=[y_true, y_pred])
+ return current_true_pos + true_pos
+
+ metric_fn = BinaryTruePositives()
+ config = keras.metrics.serialize(metric_fn)
+ metric_fn = keras.metrics.deserialize(
+ config, custom_objects={'BinaryTruePositives': BinaryTruePositives})
+
+ # Test on simple model
+ inputs = keras.Input(shape=(2,))
+ outputs = keras.layers.Dense(1, activation='sigmoid')(inputs)
+ model = keras.Model(inputs, outputs)
+ model.compile(optimizer='sgd',
+ loss='binary_crossentropy',
+ metrics=['acc', metric_fn])
+
+ # Test fit, evaluate
+ samples = 1000
+ x = np.random.random((samples, 2))
+ y = np.random.randint(2, size=(samples, 1))
+ model.fit(x, y, epochs=1, batch_size=10)
+ outs = model.evaluate(x, y, batch_size=10)
+ preds = model.predict(x)
+
+ def ref_true_pos(y_true, y_pred):
+ return np.sum(np.logical_and(y_pred > 0.5, y_true == 1))
+
+ # Test correctness (e.g. updates should have been run)
+ self.assertAllClose(outs[2], ref_true_pos(y, preds), atol=1e-5)
if __name__ == '__main__':
diff --git a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
index bc8698f235..295ad47f6b 100644
--- a/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
+++ b/tensorflow/python/keras/_impl/keras/model_subclassing_test.py
@@ -19,7 +19,6 @@ from __future__ import division
from __future__ import print_function
import os
-import tempfile
import numpy as np
import six
@@ -420,8 +419,6 @@ class ModelSubclassingTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def test_saving(self):
- if h5py is None:
- return # Skip test if models cannot be saved.
num_classes = (2, 3)
num_samples = 100
@@ -437,20 +434,30 @@ class ModelSubclassingTest(test.TestCase):
model.fit([x1, x2], [y1, y2], epochs=2, batch_size=32, verbose=0)
y_ref_1, y_ref_2 = model.predict([x1, x2])
- fd, fname = tempfile.mkstemp('.h5')
- model.save_weights(fname)
+ tf_format_name = os.path.join(self.get_temp_dir(), 'ckpt')
+ model.save_weights(tf_format_name)
+ if h5py is not None:
+ hdf5_format_name = os.path.join(self.get_temp_dir(), 'weights.h5')
+ model.save_weights(hdf5_format_name)
model = MultiIOTestModel(num_classes=num_classes, use_bn=True)
- # need to build the model before loading weights
- # (otherwise no weights to load)
- model._set_inputs([x1, x2])
- model.load_weights(fname)
+
+ if h5py is not None:
+ with self.assertRaises(ValueError):
+ model.load_weights(hdf5_format_name)
+
+ model.load_weights(tf_format_name)
y1, y2 = model.predict([x1, x2])
self.assertAllClose(y_ref_1, y1, atol=1e-5)
self.assertAllClose(y_ref_2, y2, atol=1e-5)
- os.close(fd)
- os.remove(fname)
+
+ if h5py is not None:
+ model.load_weights(hdf5_format_name)
+
+ y1, y2 = model.predict([x1, x2])
+ self.assertAllClose(y_ref_1, y1, atol=1e-5)
+ self.assertAllClose(y_ref_2, y2, atol=1e-5)
@test_util.run_in_graph_and_eager_modes()
def test_summary(self):
diff --git a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
index 3bbe87f92d..db184d278c 100644
--- a/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/generic_utils.py
@@ -21,6 +21,7 @@ import binascii
import codecs
import marshal
import os
+import re
import sys
import time
import types as python_types
@@ -28,6 +29,7 @@ import types as python_types
import numpy as np
import six
+from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
@@ -526,3 +528,31 @@ def to_list(x):
if isinstance(x, list):
return x
return [x]
+
+
+def object_list_uid(object_list):
+ """Creates a single string from object ids."""
+ object_list = nest.flatten(object_list)
+ return ', '.join([str(abs(id(x))) for x in object_list])
+
+
+def to_snake_case(name):
+ intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
+ insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
+ # If the class is private the name starts with "_" which is not secure
+ # for creating scopes. We prefix the name with "private" in this case.
+ if insecure[0] != '_':
+ return insecure
+ return 'private' + insecure
+
+
+def is_all_none(iterable_or_element):
+ if not isinstance(iterable_or_element, (list, tuple)):
+ iterable = [iterable_or_element]
+ else:
+ iterable = iterable_or_element
+ # We cannot use Python's `any` because the iterable may return Tensors.
+ for element in iterable:
+ if element is not None:
+ return False
+ return True
diff --git a/tensorflow/python/keras/_impl/keras/utils/tf_utils.py b/tensorflow/python/keras/_impl/keras/utils/tf_utils.py
index 8da5f77777..162e5b2cd6 100644
--- a/tensorflow/python/keras/_impl/keras/utils/tf_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/tf_utils.py
@@ -17,9 +17,12 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+from tensorflow.python.framework import ops
from tensorflow.python.framework import smart_cond as smart_module
+from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
+from tensorflow.python.util import nest
def smart_cond(pred, true_fn=None, false_fn=None, name=None):
@@ -72,3 +75,80 @@ def constant_value(pred):
if isinstance(pred, variables.Variable):
return None
return smart_module.smart_constant_value(pred)
+
+
+def is_tensor_or_tensor_list(v):
+ v = nest.flatten(v)
+ if v and isinstance(v[0], ops.Tensor):
+ return True
+ else:
+ return False
+
+
+def get_reachable_from_inputs(inputs, targets=None):
+ """Returns the set of tensors/ops reachable from `inputs`.
+
+ Stops if all targets have been found (target is optional).
+
+ Only valid in Symbolic mode, not Eager mode.
+
+ Args:
+ inputs: List of tensors.
+ targets: List of tensors.
+
+ Returns:
+ A set of tensors reachable from the inputs (includes the inputs themselves).
+ """
+ reachable = set(inputs)
+ if targets:
+ targets = set(targets)
+ queue = inputs[:]
+
+ while queue:
+ x = queue.pop()
+ if isinstance(x, ops.Operation):
+ outputs = x.outputs[:] or []
+ outputs += x._control_outputs # pylint: disable=protected-access
+ elif isinstance(x, ops.Tensor):
+ outputs = x.consumers()
+ elif isinstance(x, variables.Variable):
+ outputs = [x.op]
+ else:
+ raise TypeError('Expected Operation, Variable, or Tensor, got ' + str(x))
+
+ for y in outputs:
+ if y not in reachable:
+ reachable.add(y)
+ queue.insert(0, y)
+
+ if targets and targets.issubset(reachable):
+ return reachable
+ return reachable
+
+
+def shape_type_conversion(fn):
+ """Decorator that handles tuple/TensorShape conversion.
+
+ Used in `compute_output_shape` and `build`.
+
+ Arguments:
+ fn: function to wrap.
+
+ Returns:
+ Wrapped function.
+ """
+
+ def wrapper(instance, input_shape):
+ if input_shape is not None:
+ if isinstance(input_shape, list):
+ input_shape = [
+ tuple(tensor_shape.TensorShape(x).as_list()) for x in input_shape]
+ else:
+ input_shape = tuple(tensor_shape.TensorShape(input_shape).as_list())
+ output_shape = fn(instance, input_shape)
+ if output_shape is not None:
+ if isinstance(output_shape, list):
+ return [tensor_shape.TensorShape(x) for x in output_shape]
+ return tensor_shape.TensorShape(output_shape)
+
+ return wrapper
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index ba8f1fd3ca..b4ff094cdf 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -1566,6 +1566,7 @@ cuda_py_test(
"//tensorflow/python:tensor_array_grad",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
+ "//tensorflow/python/data/ops:iterator_ops",
],
grpc_enabled = True,
tags = ["no_windows"],
@@ -2903,11 +2904,8 @@ tf_py_test(
"//tensorflow/python:random_ops",
"//tensorflow/python:variables",
],
- shard_count = 10,
- tags = [
- "no_windows_gpu",
- "noasan",
- ],
+ shard_count = 20,
+ tags = ["no_windows_gpu"],
)
tf_py_test(
diff --git a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
index e27eb00818..209411cf51 100644
--- a/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
+++ b/tensorflow/python/kernel_tests/control_flow_ops_py_test.py
@@ -1135,11 +1135,10 @@ class ControlFlowTest(test.TestCase):
with self.assertRaisesRegexp(
ValueError,
- r"The shape for while_1/Merge_1:0 is not an invariant for the loop. "
- r"It enters the loop with shape \(2, 2\), but has shape \(4, 2\) "
- r"after one iteration. Provide shape invariants using either the "
- r"`shape_invariants` argument of tf.while_loop or set_shape\(\) on "
- r"the loop variables."):
+ r"Input tensor 'ones:0' enters the loop with shape \(2, 2\), but has "
+ r"shape \(4, 2\) after one iteration. To allow the shape to vary "
+ r"across iterations, use the `shape_invariants` argument of "
+ r"tf.while_loop to specify a less-specific shape."):
r = control_flow_ops.while_loop(c, b, [i, m])
def testWhileShapeInferenceSparseTensor(self):
diff --git a/tensorflow/python/kernel_tests/functional_ops_test.py b/tensorflow/python/kernel_tests/functional_ops_test.py
index 34fb655035..35a274e75f 100644
--- a/tensorflow/python/kernel_tests/functional_ops_test.py
+++ b/tensorflow/python/kernel_tests/functional_ops_test.py
@@ -22,6 +22,7 @@ import numpy as np
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session
+from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
@@ -38,6 +39,7 @@ from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
import tensorflow.python.ops.tensor_array_grad # pylint: disable=unused-import
from tensorflow.python.platform import test
+from tensorflow.python.util import compat
# pylint: disable=invalid-name
@@ -70,6 +72,26 @@ class FunctionalOpsTest(test.TestCase):
initializer=10)
self.assertAllEqual(880, self.evaluate(r))
+ @test_util.run_in_graph_and_eager_modes()
+ def testFoldl_SingleInputMultiOutput(self):
+ with self.test_session():
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array([1, -1.0])
+ r = functional_ops.foldl(lambda a, x: a + x, elems, initializer)
+ r_value = self.evaluate(r)
+
+ self.assertAllEqual(22, r_value[0])
+ self.assertAllEqual(20, r_value[1])
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFoldl_MultiInputSingleOutput(self):
+ with self.test_session():
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array(1.0)
+ r = functional_ops.foldl(lambda a, x: a + x[0] + x[1], (elems, -elems),
+ initializer)
+ self.assertAllEqual(1, self.evaluate(r))
+
def testFoldl_Scoped(self):
with self.test_session() as sess:
with variable_scope.variable_scope("root") as varscope:
@@ -105,6 +127,26 @@ class FunctionalOpsTest(test.TestCase):
initializer=10)
self.assertAllEqual(1282, self.evaluate(r))
+ @test_util.run_in_graph_and_eager_modes()
+ def testFoldr_SingleInputMultiOutput(self):
+ with self.test_session():
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array([1, -1.0])
+ r = functional_ops.foldr(lambda a, x: a + x, elems, initializer)
+ r_value = self.evaluate(r)
+
+ self.assertAllEqual(22, r_value[0])
+ self.assertAllEqual(20, r_value[1])
+
+ @test_util.run_in_graph_and_eager_modes()
+ def testFoldr_MultiInputSingleOutput(self):
+ with self.test_session():
+ elems = np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])
+ initializer = np.array(1.0)
+ r = functional_ops.foldr(lambda a, x: a + x[0] + x[1], (elems, -elems),
+ initializer)
+ self.assertAllEqual(1, self.evaluate(r))
+
def testFoldr_Scoped(self):
with self.test_session() as sess:
with variable_scope.variable_scope("root") as varscope:
@@ -885,6 +927,110 @@ class FunctionalOpsTest(test.TestCase):
self.assertAllEqual(sess.run(bvals), [17., 16.])
+class PartitionedCallTest(test.TestCase):
+
+ def testBasicSingleDevice(self):
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Body(x, y):
+ with ops.device("/cpu:0"):
+ a = x + x
+ b = y + y
+ return a + b
+
+ output, = self.evaluate(
+ functional_ops.partitioned_call(
+ args=[constant_op.constant(1.),
+ constant_op.constant(2.)], f=Body))
+ self.assertEqual(output, 6.)
+
+ def testBasicMultiDevice(self):
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Body(x, y):
+ # if x = 1, y = 2, ...
+ with ops.device("/cpu:0"):
+ # a:= 1 + 1 = 2
+ a = x + x
+ with ops.device("/cpu:1"):
+ # b:= 2 + 2 = 4
+ b = a + y
+ with ops.device("/cpu:2"):
+ # c:= 2 + 4 = 6
+ c = a + b
+ # a + b + c = 2 + 4 + 6 = 12
+ return a + b + c
+
+ with self.test_session(config=config):
+ output, = functional_ops.partitioned_call(
+ args=[constant_op.constant(1.),
+ constant_op.constant(2.)], f=Body)
+ self.assertEqual(output.eval(), 12.)
+
+ def testBasicMultiDeviceGPU(self):
+ if not test_util.is_gpu_available():
+ return
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Body(x, y):
+ with ops.device("/gpu:0"):
+ a = x + x
+ b = y + y
+ with ops.device("/cpu:0"):
+ c = a + b
+ return c
+
+ output, = self.evaluate(
+ functional_ops.partitioned_call(
+ args=[constant_op.constant(1.),
+ constant_op.constant(2.)], f=Body))
+ self.assertEqual(output, 6.)
+
+ def testBasicNoDeviceAnnotations(self):
+
+ @function.Defun(*[dtypes.float32] * 2)
+ def Body(x, y):
+ a = x + x
+ b = y + y
+ return a + b
+
+ output, = self.evaluate(
+ functional_ops.partitioned_call(
+ args=[constant_op.constant(1.),
+ constant_op.constant(2.)], f=Body))
+ self.assertEqual(output, 6.)
+
+ def testShardsRunOnRequestedDevices(self):
+ config = config_pb2.ConfigProto(device_count={"CPU": 3})
+
+ @function.Defun()
+ def Body():
+ # Serialize DT_RESOURCE handles as DT_STRINGs, which encode the device on
+ # which the resource was created, so that we can verify that ops were
+ # actually run on the requested devices.
+ #
+ # TODO(akshayka): Provide a cleaner, more idiomatic API for obtaining the
+ # name of the device on which a resource lives / for determining the
+ # device on which an op ran.
+ with ops.device("/cpu:0"):
+ s1 = iterator_ops.Iterator.from_structure(
+ (dtypes.float32,)).string_handle()
+ with ops.device("/cpu:1"):
+ s2 = iterator_ops.Iterator.from_structure(
+ (dtypes.float32,)).string_handle()
+ with ops.device("/cpu:2"):
+ s3 = iterator_ops.Iterator.from_structure(
+ (dtypes.float32,)).string_handle()
+ return s1, s2, s3
+
+ with self.test_session(config=config):
+ outputs = functional_ops.partitioned_call(args=[], f=Body)
+ self.assertTrue(compat.as_bytes("CPU:0") in outputs[0].eval())
+ self.assertTrue(compat.as_bytes("CPU:1") in outputs[1].eval())
+ self.assertTrue(compat.as_bytes("CPU:2") in outputs[2].eval())
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/linalg/BUILD b/tensorflow/python/kernel_tests/linalg/BUILD
index 7ffa48b653..faeccc8fba 100644
--- a/tensorflow/python/kernel_tests/linalg/BUILD
+++ b/tensorflow/python/kernel_tests/linalg/BUILD
@@ -44,6 +44,26 @@ cuda_py_test(
)
cuda_py_test(
+ name = "linear_operator_circulant_test",
+ size = "medium",
+ srcs = ["linear_operator_circulant_test.py"],
+ additional_deps = [
+ "//tensorflow/python/ops/linalg",
+ "//third_party/py/numpy",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:spectral_ops_test_util",
+ "//tensorflow/python:client_testlib",
+ "//tensorflow/python:framework",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:framework_test_lib",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+ shard_count = 5,
+ tags = ["noasan"], # times out b/63678675
+)
+
+cuda_py_test(
name = "linear_operator_diag_test",
size = "medium",
srcs = ["linear_operator_diag_test.py"],
diff --git a/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
new file mode 100644
index 0000000000..e7f2f1c12b
--- /dev/null
+++ b/tensorflow/python/kernel_tests/linalg/linear_operator_circulant_test.py
@@ -0,0 +1,700 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import contextlib
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops import spectral_ops_test_util
+from tensorflow.python.ops.linalg import linalg
+from tensorflow.python.ops.linalg import linear_operator_circulant
+from tensorflow.python.ops.linalg import linear_operator_test_util
+from tensorflow.python.platform import test
+
+rng = np.random.RandomState(0)
+_to_complex = linear_operator_circulant._to_complex
+
+
+class LinearOperatorCirculantBaseTest(object):
+ """Common class for circulant tests."""
+
+ @contextlib.contextmanager
+ def test_session(self, *args, **kwargs):
+ with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ with spectral_ops_test_util.fft_kernel_label_map():
+ yield sess
+
+ def _shape_to_spectrum_shape(self, shape):
+ # If spectrum.shape = batch_shape + [N],
+ # this creates an operator of shape batch_shape + [N, N]
+ return shape[:-1]
+
+ def _spectrum_to_circulant_1d(self, spectrum, shape, dtype):
+ """Creates a circulant matrix from a spectrum.
+
+ Intentionally done in an explicit yet inefficient way. This provides a
+ cross check to the main code that uses fancy reshapes.
+
+ Args:
+ spectrum: Float or complex `Tensor`.
+ shape: Python list. Desired shape of returned matrix.
+ dtype: Type to cast the returned matrix to.
+
+ Returns:
+ Circulant (batch) matrix of desired `dtype`.
+ """
+ spectrum = _to_complex(spectrum)
+ spectrum_shape = self._shape_to_spectrum_shape(shape)
+ domain_dimension = spectrum_shape[-1]
+ if not domain_dimension:
+ return array_ops.zeros(shape, dtype)
+
+ # Explicitly compute the action of spectrum on basis vectors.
+ matrix_rows = []
+ for m in range(domain_dimension):
+ x = np.zeros([domain_dimension])
+ # x is a basis vector.
+ x[m] = 1.0
+ fft_x = math_ops.fft(x)
+ h_convolve_x = math_ops.ifft(spectrum * fft_x)
+ matrix_rows.append(h_convolve_x)
+ matrix = array_ops.stack(matrix_rows, axis=-1)
+ return math_ops.cast(matrix, dtype)
+
+
+class LinearOperatorCirculantTestSelfAdjointOperator(
+ LinearOperatorCirculantBaseTest,
+ linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
+ """Test of LinearOperatorCirculant when operator is self-adjoint.
+
+ Real spectrum <==> Self adjoint operator.
+ Note that when the spectrum is real, the operator may still be complex.
+ """
+
+ @property
+ def _dtypes_to_test(self):
+ # This operator will always be complex because, although the specturm is
+ # real, the matrix will not be real.
+ return [dtypes.complex64]
+
+ def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ shape = build_info.shape
+ # For this test class, we are creating real spectrums.
+ # We also want the spectrum to have eigenvalues bounded away from zero.
+ #
+ # spectrum is bounded away from zero.
+ spectrum = linear_operator_test_util.random_sign_uniform(
+ shape=self._shape_to_spectrum_shape(shape), minval=1., maxval=2.)
+ # If dtype is complex, cast spectrum to complex. The imaginary part will be
+ # zero, so the operator will still be self-adjoint.
+ spectrum = math_ops.cast(spectrum, dtype)
+
+ if use_placeholder:
+ spectrum_ph = array_ops.placeholder(dtypes.complex64)
+ # Evaluate here because (i) you cannot feed a tensor, and (ii)
+ # it is random and we want the same value used for both mat and feed_dict.
+ spectrum = spectrum.eval()
+ operator = linalg.LinearOperatorCirculant(
+ spectrum_ph, is_self_adjoint=True, input_output_dtype=dtype)
+ feed_dict = {spectrum_ph: spectrum}
+ else:
+ operator = linalg.LinearOperatorCirculant(
+ spectrum, is_self_adjoint=True, input_output_dtype=dtype)
+ feed_dict = None
+
+ mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
+
+ return operator, mat, feed_dict
+
+ def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
+ with self.test_session():
+ spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
+ operator = linalg.LinearOperatorCirculant(
+ spectrum, input_output_dtype=dtypes.complex64)
+ matrix = operator.to_dense()
+ imag_matrix = math_ops.imag(matrix)
+ eps = np.finfo(np.float32).eps
+ np.testing.assert_allclose(0, imag_matrix.eval(), rtol=0, atol=eps * 3)
+
+
+class LinearOperatorCirculantTestHermitianSpectrum(
+ LinearOperatorCirculantBaseTest,
+ linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
+ """Test of LinearOperatorCirculant when the spectrum is Hermitian.
+
+ Hermitian spectrum <==> Real valued operator. We test both real and complex
+ dtypes here though. So in some cases the matrix will be complex but with
+ zero imaginary part.
+ """
+
+ @property
+ def _dtypes_to_test(self):
+ return [dtypes.float32, dtypes.complex64]
+
+ def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ shape = build_info.shape
+ # For this test class, we are creating Hermitian spectrums.
+ # We also want the spectrum to have eigenvalues bounded away from zero.
+ #
+ # pre_spectrum is bounded away from zero.
+ pre_spectrum = linear_operator_test_util.random_uniform(
+ shape=self._shape_to_spectrum_shape(shape), minval=1., maxval=2.)
+ pre_spectrum_c = _to_complex(pre_spectrum)
+
+ # Real{IFFT[pre_spectrum]}
+ # = IFFT[EvenPartOf[pre_spectrum]]
+ # is the IFFT of something that is also bounded away from zero.
+ # Therefore, FFT[pre_h] would be a well-conditioned spectrum.
+ pre_h = math_ops.ifft(pre_spectrum_c)
+
+ # A spectrum is Hermitian iff it is the DFT of a real convolution kernel.
+ # So we will make spectrum = FFT[h], for real valued h.
+ h = math_ops.real(pre_h)
+ h_c = _to_complex(h)
+
+ spectrum = math_ops.fft(h_c)
+
+ if use_placeholder:
+ spectrum_ph = array_ops.placeholder(dtypes.complex64)
+ # Evaluate here because (i) you cannot feed a tensor, and (ii)
+ # it is random and we want the same value used for both mat and feed_dict.
+ spectrum = spectrum.eval()
+ operator = linalg.LinearOperatorCirculant(
+ spectrum_ph, input_output_dtype=dtype)
+ feed_dict = {spectrum_ph: spectrum}
+ else:
+ operator = linalg.LinearOperatorCirculant(
+ spectrum, input_output_dtype=dtype)
+ feed_dict = None
+
+ mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
+
+ return operator, mat, feed_dict
+
+ def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
+ with self.test_session():
+ spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
+ operator = linalg.LinearOperatorCirculant(
+ spectrum, input_output_dtype=dtypes.complex64)
+ matrix = operator.to_dense()
+ imag_matrix = math_ops.imag(matrix)
+ eps = np.finfo(np.float32).eps
+ np.testing.assert_allclose(0, imag_matrix.eval(), rtol=0, atol=eps * 3)
+
+
+class LinearOperatorCirculantTestNonHermitianSpectrum(
+ LinearOperatorCirculantBaseTest,
+ linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
+ """Test of LinearOperatorCirculant when the spectrum is not Hermitian.
+
+ Non-Hermitian spectrum <==> Complex valued operator.
+ We test only complex dtypes here.
+ """
+
+ @property
+ def _dtypes_to_test(self):
+ return [dtypes.complex64]
+
+ def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ shape = build_info.shape
+ # Will be well conditioned enough to get accurate solves.
+ spectrum = linear_operator_test_util.random_sign_uniform(
+ shape=self._shape_to_spectrum_shape(shape),
+ dtype=dtypes.complex64,
+ minval=1.,
+ maxval=2.)
+
+ if use_placeholder:
+ spectrum_ph = array_ops.placeholder(dtypes.complex64)
+ # Evaluate here because (i) you cannot feed a tensor, and (ii)
+ # it is random and we want the same value used for both mat and feed_dict.
+ spectrum = spectrum.eval()
+ operator = linalg.LinearOperatorCirculant(
+ spectrum_ph, input_output_dtype=dtype)
+ feed_dict = {spectrum_ph: spectrum}
+ else:
+ operator = linalg.LinearOperatorCirculant(
+ spectrum, input_output_dtype=dtype)
+ feed_dict = None
+
+ mat = self._spectrum_to_circulant_1d(spectrum, shape, dtype=dtype)
+
+ return operator, mat, feed_dict
+
+ def test_simple_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
+ with self.test_session():
+ spectrum = math_ops.cast([1., 1j, -1j], dtypes.complex64)
+ operator = linalg.LinearOperatorCirculant(
+ spectrum, input_output_dtype=dtypes.complex64)
+ matrix = operator.to_dense()
+ imag_matrix = math_ops.imag(matrix)
+ eps = np.finfo(np.float32).eps
+ np.testing.assert_allclose(0, imag_matrix.eval(), rtol=0, atol=eps * 3)
+
+ def test_simple_positive_real_spectrum_gives_self_adjoint_pos_def_oper(self):
+ with self.test_session() as sess:
+ spectrum = math_ops.cast([6., 4, 2], dtypes.complex64)
+ operator = linalg.LinearOperatorCirculant(
+ spectrum, input_output_dtype=dtypes.complex64)
+ matrix, matrix_h = sess.run(
+ [operator.to_dense(),
+ linalg.adjoint(operator.to_dense())])
+ self.assertAllClose(matrix, matrix_h)
+ operator.assert_positive_definite().run() # Should not fail
+ operator.assert_self_adjoint().run() # Should not fail
+
+ def test_defining_operator_using_real_convolution_kernel(self):
+ with self.test_session():
+ convolution_kernel = [1., 2., 1.]
+ spectrum = math_ops.fft(
+ math_ops.cast(convolution_kernel, dtypes.complex64))
+
+ # spectrum is shape [3] ==> operator is shape [3, 3]
+ # spectrum is Hermitian ==> operator is real.
+ operator = linalg.LinearOperatorCirculant(spectrum)
+
+ # Allow for complex output so we can make sure it has zero imag part.
+ self.assertEqual(operator.dtype, dtypes.complex64)
+
+ matrix = operator.to_dense().eval()
+ np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
+
+ def test_hermitian_spectrum_gives_operator_with_zero_imag_part(self):
+ with self.test_session():
+ # Make spectrum the FFT of a real convolution kernel h. This ensures that
+ # spectrum is Hermitian.
+ h = linear_operator_test_util.random_normal(shape=(3, 4))
+ spectrum = math_ops.fft(math_ops.cast(h, dtypes.complex64))
+ operator = linalg.LinearOperatorCirculant(
+ spectrum, input_output_dtype=dtypes.complex64)
+ matrix = operator.to_dense()
+ imag_matrix = math_ops.imag(matrix)
+ eps = np.finfo(np.float32).eps
+ np.testing.assert_allclose(
+ 0, imag_matrix.eval(), rtol=0, atol=eps * 3 * 4)
+
+ def test_convolution_kernel_same_as_first_row_of_to_dense(self):
+ spectrum = [[3., 2., 1.], [2., 1.5, 1.]]
+ with self.test_session():
+ operator = linalg.LinearOperatorCirculant(spectrum)
+ h = operator.convolution_kernel()
+ c = operator.to_dense()
+
+ self.assertAllEqual((2, 3), h.get_shape())
+ self.assertAllEqual((2, 3, 3), c.get_shape())
+ self.assertAllClose(h.eval(), c.eval()[:, :, 0])
+
+ def test_assert_non_singular_fails_for_singular_operator(self):
+ spectrum = math_ops.cast([0, 4, 2j + 2], dtypes.complex64)
+ operator = linalg.LinearOperatorCirculant(spectrum)
+ with self.test_session():
+ with self.assertRaisesOpError("Singular operator"):
+ operator.assert_non_singular().run()
+
+ def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
+ spectrum = math_ops.cast([-3j, 4, 2j + 2], dtypes.complex64)
+ operator = linalg.LinearOperatorCirculant(spectrum)
+ with self.test_session():
+ operator.assert_non_singular().run() # Should not fail
+
+ def test_assert_positive_definite_fails_for_non_positive_definite(self):
+ spectrum = math_ops.cast([6., 4, 2j], dtypes.complex64)
+ operator = linalg.LinearOperatorCirculant(spectrum)
+ with self.test_session():
+ with self.assertRaisesOpError("Not positive definite"):
+ operator.assert_positive_definite().run()
+
+ def test_assert_positive_definite_does_not_fail_when_pos_def(self):
+ spectrum = math_ops.cast([6., 4, 2j + 2], dtypes.complex64)
+ operator = linalg.LinearOperatorCirculant(spectrum)
+ with self.test_session():
+ operator.assert_positive_definite().run() # Should not fail
+
+ def test_real_spectrum_and_not_self_adjoint_hint_raises(self):
+ spectrum = [1., 2.]
+ with self.assertRaisesRegexp(ValueError, "real.*always.*self-adjoint"):
+ linalg.LinearOperatorCirculant(spectrum, is_self_adjoint=False)
+
+ def test_real_spectrum_auto_sets_is_self_adjoint_to_true(self):
+ spectrum = [1., 2.]
+ operator = linalg.LinearOperatorCirculant(spectrum)
+ self.assertTrue(operator.is_self_adjoint)
+
+
+class LinearOperatorCirculant2DBaseTest(object):
+ """Common class for 2D circulant tests."""
+
+ @contextlib.contextmanager
+ def test_session(self, *args, **kwargs):
+ with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ with spectral_ops_test_util.fft_kernel_label_map():
+ yield sess
+
+ @property
+ def _operator_build_infos(self):
+ build_info = linear_operator_test_util.OperatorBuildInfo
+ # non-batch operators (n, n) and batch operators.
+ return [
+ build_info((0, 0)),
+ build_info((1, 1)),
+ build_info((1, 6, 6)),
+ build_info((3, 4, 4)),
+ build_info((2, 1, 3, 3))
+ ]
+
+ def _shape_to_spectrum_shape(self, shape):
+ """Get a spectrum shape that will make an operator of desired shape."""
+ # This 2D block circulant operator takes a spectrum of shape
+ # batch_shape + [N0, N1],
+ # and creates and operator of shape
+ # batch_shape + [N0*N1, N0*N1]
+ if shape == (0, 0):
+ return (0, 0)
+ elif shape == (1, 1):
+ return (1, 1)
+ elif shape == (1, 6, 6):
+ return (1, 2, 3)
+ elif shape == (3, 4, 4):
+ return (3, 2, 2)
+ elif shape == (2, 1, 3, 3):
+ return (2, 1, 3, 1)
+ else:
+ raise ValueError("Unhandled shape: %s" % shape)
+
+ def _spectrum_to_circulant_2d(self, spectrum, shape, dtype):
+ """Creates a block circulant matrix from a spectrum.
+
+ Intentionally done in an explicit yet inefficient way. This provides a
+ cross check to the main code that uses fancy reshapes.
+
+ Args:
+ spectrum: Float or complex `Tensor`.
+ shape: Python list. Desired shape of returned matrix.
+ dtype: Type to cast the returned matrix to.
+
+ Returns:
+ Block circulant (batch) matrix of desired `dtype`.
+ """
+ spectrum = _to_complex(spectrum)
+ spectrum_shape = self._shape_to_spectrum_shape(shape)
+ domain_dimension = spectrum_shape[-1]
+ if not domain_dimension:
+ return array_ops.zeros(shape, dtype)
+
+ block_shape = spectrum_shape[-2:]
+
+ # Explicitly compute the action of spectrum on basis vectors.
+ matrix_rows = []
+ for n0 in range(block_shape[0]):
+ for n1 in range(block_shape[1]):
+ x = np.zeros(block_shape)
+ # x is a basis vector.
+ x[n0, n1] = 1.0
+ fft_x = math_ops.fft2d(x)
+ h_convolve_x = math_ops.ifft2d(spectrum * fft_x)
+ # We want the flat version of the action of the operator on a basis
+ # vector, not the block version.
+ h_convolve_x = array_ops.reshape(h_convolve_x, shape[:-1])
+ matrix_rows.append(h_convolve_x)
+ matrix = array_ops.stack(matrix_rows, axis=-1)
+ return math_ops.cast(matrix, dtype)
+
+
+class LinearOperatorCirculant2DTestHermitianSpectrum(
+ LinearOperatorCirculant2DBaseTest,
+ linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
+ """Test of LinearOperatorCirculant2D when the spectrum is Hermitian.
+
+ Hermitian spectrum <==> Real valued operator. We test both real and complex
+ dtypes here though. So in some cases the matrix will be complex but with
+ zero imaginary part.
+ """
+
+ @property
+ def _dtypes_to_test(self):
+ return [dtypes.float32, dtypes.complex64]
+
+ def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ shape = build_info.shape
+ # For this test class, we are creating Hermitian spectrums.
+ # We also want the spectrum to have eigenvalues bounded away from zero.
+ #
+ # pre_spectrum is bounded away from zero.
+ pre_spectrum = linear_operator_test_util.random_uniform(
+ shape=self._shape_to_spectrum_shape(shape), minval=1., maxval=2.)
+ pre_spectrum_c = _to_complex(pre_spectrum)
+
+ # Real{IFFT[pre_spectrum]}
+ # = IFFT[EvenPartOf[pre_spectrum]]
+ # is the IFFT of something that is also bounded away from zero.
+ # Therefore, FFT[pre_h] would be a well-conditioned spectrum.
+ pre_h = math_ops.ifft2d(pre_spectrum_c)
+
+ # A spectrum is Hermitian iff it is the DFT of a real convolution kernel.
+ # So we will make spectrum = FFT[h], for real valued h.
+ h = math_ops.real(pre_h)
+ h_c = _to_complex(h)
+
+ spectrum = math_ops.fft2d(h_c)
+
+ if use_placeholder:
+ spectrum_ph = array_ops.placeholder(dtypes.complex64)
+ # Evaluate here because (i) you cannot feed a tensor, and (ii)
+ # it is random and we want the same value used for both mat and feed_dict.
+ spectrum = spectrum.eval()
+ operator = linalg.LinearOperatorCirculant2D(
+ spectrum_ph, input_output_dtype=dtype)
+ feed_dict = {spectrum_ph: spectrum}
+ else:
+ operator = linalg.LinearOperatorCirculant2D(
+ spectrum, input_output_dtype=dtype)
+ feed_dict = None
+
+ mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
+
+ return operator, mat, feed_dict
+
+
+class LinearOperatorCirculant2DTestNonHermitianSpectrum(
+ LinearOperatorCirculant2DBaseTest,
+ linear_operator_test_util.SquareLinearOperatorDerivedClassTest):
+ """Test of LinearOperatorCirculant when the spectrum is not Hermitian.
+
+ Non-Hermitian spectrum <==> Complex valued operator.
+ We test only complex dtypes here.
+ """
+
+ @property
+ def _dtypes_to_test(self):
+ return [dtypes.complex64]
+
+ def _operator_and_mat_and_feed_dict(self, build_info, dtype, use_placeholder):
+ shape = build_info.shape
+ # Will be well conditioned enough to get accurate solves.
+ spectrum = linear_operator_test_util.random_sign_uniform(
+ shape=self._shape_to_spectrum_shape(shape),
+ dtype=dtype,
+ minval=1.,
+ maxval=2.)
+
+ if use_placeholder:
+ spectrum_ph = array_ops.placeholder(dtypes.complex64)
+ # Evaluate here because (i) you cannot feed a tensor, and (ii)
+ # it is random and we want the same value used for both mat and feed_dict.
+ spectrum = spectrum.eval()
+ operator = linalg.LinearOperatorCirculant2D(
+ spectrum_ph, input_output_dtype=dtype)
+ feed_dict = {spectrum_ph: spectrum}
+ else:
+ operator = linalg.LinearOperatorCirculant2D(
+ spectrum, input_output_dtype=dtype)
+ feed_dict = None
+
+ mat = self._spectrum_to_circulant_2d(spectrum, shape, dtype=dtype)
+
+ return operator, mat, feed_dict
+
+ def test_real_hermitian_spectrum_gives_real_symmetric_operator(self):
+ with self.test_session() as sess:
+ # This is a real and hermitian spectrum.
+ spectrum = [[1., 2., 2.], [3., 4., 4.], [3., 4., 4.]]
+ operator = linalg.LinearOperatorCirculant(spectrum)
+
+ matrix_tensor = operator.to_dense()
+ self.assertEqual(matrix_tensor.dtype,
+ linear_operator_circulant._DTYPE_COMPLEX)
+ matrix_t = array_ops.matrix_transpose(matrix_tensor)
+ imag_matrix = math_ops.imag(matrix_tensor)
+ matrix, matrix_transpose, imag_matrix = sess.run(
+ [matrix_tensor, matrix_t, imag_matrix])
+
+ np.testing.assert_allclose(0, imag_matrix, atol=1e-6)
+ self.assertAllClose(matrix, matrix_transpose, atol=0)
+
+ def test_real_spectrum_gives_self_adjoint_operator(self):
+ with self.test_session() as sess:
+ # This is a real and hermitian spectrum.
+ spectrum = linear_operator_test_util.random_normal(
+ shape=(3, 3), dtype=dtypes.float32)
+ operator = linalg.LinearOperatorCirculant2D(spectrum)
+
+ matrix_tensor = operator.to_dense()
+ self.assertEqual(matrix_tensor.dtype,
+ linear_operator_circulant._DTYPE_COMPLEX)
+ matrix_h = linalg.adjoint(matrix_tensor)
+ matrix, matrix_h = sess.run([matrix_tensor, matrix_h])
+ self.assertAllClose(matrix, matrix_h, atol=0)
+
+ def test_assert_non_singular_fails_for_singular_operator(self):
+ spectrum = math_ops.cast([[0, 4], [2j + 2, 3.]], dtypes.complex64)
+ operator = linalg.LinearOperatorCirculant2D(spectrum)
+ with self.test_session():
+ with self.assertRaisesOpError("Singular operator"):
+ operator.assert_non_singular().run()
+
+ def test_assert_non_singular_does_not_fail_for_non_singular_operator(self):
+ spectrum = math_ops.cast([[-3j, 4], [2j + 2, 3.]], dtypes.complex64)
+ operator = linalg.LinearOperatorCirculant2D(spectrum)
+ with self.test_session():
+ operator.assert_non_singular().run() # Should not fail
+
+ def test_assert_positive_definite_fails_for_non_positive_definite(self):
+ spectrum = math_ops.cast([[6., 4], [2j, 3.]], dtypes.complex64)
+ operator = linalg.LinearOperatorCirculant2D(spectrum)
+ with self.test_session():
+ with self.assertRaisesOpError("Not positive definite"):
+ operator.assert_positive_definite().run()
+
+ def test_assert_positive_definite_does_not_fail_when_pos_def(self):
+ spectrum = math_ops.cast([[6., 4], [2j + 2, 3.]], dtypes.complex64)
+ operator = linalg.LinearOperatorCirculant2D(spectrum)
+ with self.test_session():
+ operator.assert_positive_definite().run() # Should not fail
+
+ def test_real_spectrum_and_not_self_adjoint_hint_raises(self):
+ spectrum = [[1., 2.], [3., 4]]
+ with self.assertRaisesRegexp(ValueError, "real.*always.*self-adjoint"):
+ linalg.LinearOperatorCirculant2D(spectrum, is_self_adjoint=False)
+
+ def test_real_spectrum_auto_sets_is_self_adjoint_to_true(self):
+ spectrum = [[1., 2.], [3., 4]]
+ operator = linalg.LinearOperatorCirculant2D(spectrum)
+ self.assertTrue(operator.is_self_adjoint)
+
+ def test_invalid_dtype_raises(self):
+ spectrum = array_ops.constant(rng.rand(2, 2, 2))
+ with self.assertRaisesRegexp(TypeError, "must have dtype"):
+ linalg.LinearOperatorCirculant2D(spectrum)
+
+ def test_invalid_rank_raises(self):
+ spectrum = array_ops.constant(np.float32(rng.rand(2)))
+ with self.assertRaisesRegexp(ValueError, "must have at least 2 dimensions"):
+ linalg.LinearOperatorCirculant2D(spectrum)
+
+
+class LinearOperatorCirculant3DTest(test.TestCase):
+ """Simple test of the 3D case. See also the 1D and 2D tests."""
+
+ @contextlib.contextmanager
+ def test_session(self, *args, **kwargs):
+ with test.TestCase.test_session(self, *args, **kwargs) as sess:
+ with spectral_ops_test_util.fft_kernel_label_map():
+ yield sess
+
+ def test_real_spectrum_gives_self_adjoint_operator(self):
+ with self.test_session() as sess:
+ # This is a real and hermitian spectrum.
+ spectrum = linear_operator_test_util.random_normal(
+ shape=(2, 2, 3, 5), dtype=dtypes.float32)
+ operator = linalg.LinearOperatorCirculant3D(spectrum)
+ self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), operator.shape)
+
+ matrix_tensor = operator.to_dense()
+ self.assertEqual(matrix_tensor.dtype,
+ linear_operator_circulant._DTYPE_COMPLEX)
+ matrix_h = linalg.adjoint(matrix_tensor)
+
+ matrix, matrix_h = sess.run([matrix_tensor, matrix_h])
+ self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), matrix.shape)
+ self.assertAllClose(matrix, matrix_h)
+
+ def test_defining_operator_using_real_convolution_kernel(self):
+ with self.test_session():
+ convolution_kernel = linear_operator_test_util.random_normal(
+ shape=(2, 2, 3, 5), dtype=dtypes.float32)
+ # Convolution kernel is real ==> spectrum is Hermitian.
+ spectrum = math_ops.fft3d(
+ math_ops.cast(convolution_kernel, dtypes.complex64))
+
+ # spectrum is Hermitian ==> operator is real.
+ operator = linalg.LinearOperatorCirculant3D(spectrum)
+ self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), operator.shape)
+
+ # Allow for complex output so we can make sure it has zero imag part.
+ self.assertEqual(operator.dtype, dtypes.complex64)
+ matrix = operator.to_dense().eval()
+ self.assertAllEqual((2, 2 * 3 * 5, 2 * 3 * 5), matrix.shape)
+ np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
+
+ def test_defining_spd_operator_by_taking_real_part(self):
+ with self.test_session() as sess:
+ # S is real and positive.
+ s = linear_operator_test_util.random_uniform(
+ shape=(10, 2, 3, 4), dtype=dtypes.float32, minval=1., maxval=2.)
+
+ # Let S = S1 + S2, the Hermitian and anti-hermitian parts.
+ # S1 = 0.5 * (S + S^H), S2 = 0.5 * (S - S^H),
+ # where ^H is the Hermitian transpose of the function:
+ # f(n0, n1, n2)^H := ComplexConjugate[f(N0-n0, N1-n1, N2-n2)].
+ # We want to isolate S1, since
+ # S1 is Hermitian by construction
+ # S1 is real since S is
+ # S1 is positive since it is the sum of two positive kernels
+
+ # IDFT[S] = IDFT[S1] + IDFT[S2]
+ # = H1 + H2
+ # where H1 is real since it is Hermitian,
+ # and H2 is imaginary since it is anti-Hermitian.
+ ifft_s = math_ops.ifft3d(math_ops.cast(s, dtypes.complex64))
+
+ # Throw away H2, keep H1.
+ real_ifft_s = math_ops.real(ifft_s)
+
+ # This is the perfect spectrum!
+ # spectrum = DFT[H1]
+ # = S1,
+ fft_real_ifft_s = math_ops.fft3d(
+ math_ops.cast(real_ifft_s, dtypes.complex64))
+
+ # S1 is Hermitian ==> operator is real.
+ # S1 is real ==> operator is self-adjoint.
+ # S1 is positive ==> operator is positive-definite.
+ operator = linalg.LinearOperatorCirculant3D(fft_real_ifft_s)
+
+ # Allow for complex output so we can check operator has zero imag part.
+ self.assertEqual(operator.dtype, dtypes.complex64)
+ matrix, matrix_t = sess.run([
+ operator.to_dense(),
+ array_ops.matrix_transpose(operator.to_dense())
+ ])
+ operator.assert_positive_definite().run() # Should not fail.
+ np.testing.assert_allclose(0, np.imag(matrix), atol=1e-6)
+ self.assertAllClose(matrix, matrix_t)
+
+ # Just to test the theory, get S2 as well.
+ # This should create an imaginary operator.
+ # S2 is anti-Hermitian ==> operator is imaginary.
+ # S2 is real ==> operator is self-adjoint.
+ imag_ifft_s = math_ops.imag(ifft_s)
+ fft_imag_ifft_s = math_ops.fft3d(
+ 1j * math_ops.cast(imag_ifft_s, dtypes.complex64))
+ operator_imag = linalg.LinearOperatorCirculant3D(fft_imag_ifft_s)
+
+ matrix, matrix_h = sess.run([
+ operator_imag.to_dense(),
+ array_ops.matrix_transpose(math_ops.conj(operator_imag.to_dense()))
+ ])
+ self.assertAllClose(matrix, matrix_h)
+ np.testing.assert_allclose(0, np.real(matrix), atol=1e-7)
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/python/layers/layers.py b/tensorflow/python/layers/layers.py
index 13a8e8e39c..c5fa0d3aba 100644
--- a/tensorflow/python/layers/layers.py
+++ b/tensorflow/python/layers/layers.py
@@ -61,8 +61,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.util.all_util import remove_undocumented
-
# pylint: disable=g-bad-import-order,unused-import
# Base objects.
@@ -122,7 +120,3 @@ from tensorflow.python.layers.normalization import BatchNormalization
from tensorflow.python.layers.normalization import batch_normalization
# pylint: enable=g-bad-import-order,unused-import
-
-_allowed_symbols = []
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/lib/io/python_io.py b/tensorflow/python/lib/io/python_io.py
index b92cfe8f80..d4bc8afd1e 100644
--- a/tensorflow/python/lib/io/python_io.py
+++ b/tensorflow/python/lib/io/python_io.py
@@ -31,8 +31,3 @@ from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.python.lib.io.tf_record import *
# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = []
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/array_ops.py b/tensorflow/python/ops/array_ops.py
index bbffff0483..586eaa4d5e 100644
--- a/tensorflow/python/ops/array_ops.py
+++ b/tensorflow/python/ops/array_ops.py
@@ -935,9 +935,9 @@ def stack(values, axis=0, name="stack"):
except (TypeError, ValueError):
pass # Input list contains non-constant tensors
- value_shape = ops.convert_to_tensor(values[0], name=name).get_shape()
- if value_shape.ndims is not None:
- expanded_num_dims = value_shape.ndims + 1
+ value_shape = ops.convert_to_tensor(values[0], name=name)._shape_tuple() # pylint: disable=protected-access
+ if value_shape is not None:
+ expanded_num_dims = len(value_shape) + 1
if axis < -expanded_num_dims or axis >= expanded_num_dims:
raise ValueError("axis = %d not in [%d, %d)" % (axis, -expanded_num_dims,
expanded_num_dims))
diff --git a/tensorflow/python/ops/bitwise_ops.py b/tensorflow/python/ops/bitwise_ops.py
index e8e187e68f..123380cf04 100644
--- a/tensorflow/python/ops/bitwise_ops.py
+++ b/tensorflow/python/ops/bitwise_ops.py
@@ -32,7 +32,6 @@ from tensorflow.python.framework import ops
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_bitwise_ops import *
# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
ops.NotDifferentiable("BitwiseAnd")
ops.NotDifferentiable("BitwiseOr")
@@ -41,5 +40,3 @@ ops.NotDifferentiable("Invert")
ops.NotDifferentiable("PopulationCount")
ops.NotDifferentiable("LeftShift")
ops.NotDifferentiable("RightShift")
-
-remove_undocumented(__name__)
diff --git a/tensorflow/python/ops/control_flow_ops.py b/tensorflow/python/ops/control_flow_ops.py
index a1bfe450c8..f1e068d514 100644
--- a/tensorflow/python/ops/control_flow_ops.py
+++ b/tensorflow/python/ops/control_flow_ops.py
@@ -622,14 +622,16 @@ def _EnforceShapeInvariant(merge_var, next_var):
m_shape = merge_var.get_shape()
n_shape = next_var.get_shape()
if not _ShapeLessThanOrEqual(n_shape, m_shape):
- # TODO(skyewm): get original loop input that caused the shape error and
- # report its name instead of the merge node's.
+ enter = merge_var.op.inputs[0].op
+ assert util.IsLoopEnter(enter)
+ input_t = enter.inputs[0]
+ assert input_t.shape == m_shape
raise ValueError(
- "The shape for %s is not an invariant for the loop. It enters "
- "the loop with shape %s, but has shape %s after one iteration. "
- "Provide shape invariants using either the `shape_invariants` "
- "argument of tf.while_loop or set_shape() on the loop variables." %
- (merge_var.name, m_shape, n_shape))
+ "Input tensor '%s' enters the loop with shape %s, but has shape %s "
+ "after one iteration. To allow the shape to vary across iterations, "
+ "use the `shape_invariants` argument of tf.while_loop to specify a "
+ "less-specific shape." %
+ (input_t.name, input_t.shape, n_shape))
else:
if not isinstance(var, (ops.IndexedSlices, sparse_tensor.SparseTensor)):
raise TypeError("Type %s not supported" % type(var))
diff --git a/tensorflow/python/ops/cudnn_rnn_grad.py b/tensorflow/python/ops/cudnn_rnn_grad.py
index 97331bb5b5..c618c470f2 100644
--- a/tensorflow/python/ops/cudnn_rnn_grad.py
+++ b/tensorflow/python/ops/cudnn_rnn_grad.py
@@ -26,7 +26,7 @@ def _cudnn_rnn_backward(op, *grads):
"""Gradients for the CudnnRNN op."""
if not op.get_attr("is_training"):
raise ValueError(
- "CudnnRNN must set is_training to True to be used in gradients")
+ "To use CudnnRNN in gradients, is_training must be set to True.")
return gen_cudnn_rnn_ops.cudnn_rnn_backprop(
input=op.inputs[0],
input_h=op.inputs[1],
@@ -45,3 +45,29 @@ def _cudnn_rnn_backward(op, *grads):
rnn_mode=op.get_attr("rnn_mode"),
input_mode=op.get_attr("input_mode"),
direction=op.get_attr("direction"))
+
+
+@ops.RegisterGradient("CudnnRNNV2")
+def _cudnn_rnn_backward_v2(op, *grad):
+ if not op.get_attr("is_training"):
+ raise ValueError(
+ "To use CudnnRNNV2 in gradients, is_training must be set to True.")
+ return gen_cudnn_rnn_ops.cudnn_rnn_backprop_v2(
+ input=op.inputs[0],
+ input_h=op.inputs[1],
+ input_c=op.inputs[2],
+ params=op.inputs[3],
+ output=op.outputs[0],
+ output_h=op.outputs[1],
+ output_c=op.outputs[2],
+ output_backprop=grad[0],
+ output_h_backprop=grad[1],
+ output_c_backprop=grad[2],
+ reserve_space=op.outputs[3],
+ host_reserved=op.outputs[4],
+ dropout=op.get_attr("dropout"),
+ seed=op.get_attr("seed"),
+ seed2=op.get_attr("seed2"),
+ rnn_mode=op.get_attr("rnn_mode"),
+ input_mode=op.get_attr("input_mode"),
+ direction=op.get_attr("direction"))
diff --git a/tensorflow/python/ops/distributions/bijector.py b/tensorflow/python/ops/distributions/bijector.py
index 84bd0a20da..94a77a205a 100644
--- a/tensorflow/python/ops/distributions/bijector.py
+++ b/tensorflow/python/ops/distributions/bijector.py
@@ -23,8 +23,3 @@ from __future__ import print_function
from tensorflow.python.ops.distributions.bijector_impl import Bijector
# pylint: enable=wildcard-import,unused-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = ["Bijector"]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/distributions/distributions.py b/tensorflow/python/ops/distributions/distributions.py
index 7c4b8697d8..59ed455e43 100644
--- a/tensorflow/python/ops/distributions/distributions.py
+++ b/tensorflow/python/ops/distributions/distributions.py
@@ -35,29 +35,3 @@ from tensorflow.python.ops.distributions.student_t import StudentT
from tensorflow.python.ops.distributions.uniform import Uniform
# pylint: enable=wildcard-import,unused-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-
-_allowed_symbols = [
- "Bernoulli",
- "Beta",
- "Categorical",
- "DirichletMultinomial",
- "Dirichlet",
- "Distribution",
- "ReparameterizationType",
- "FULLY_REPARAMETERIZED",
- "NOT_REPARAMETERIZED",
- "Exponential",
- "Gamma",
- "RegisterKL",
- "kl_divergence",
- "Laplace",
- "Multinomial",
- "Normal",
- "StudentT",
- "Uniform",
-]
-
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/distributions/transformed_distribution.py b/tensorflow/python/ops/distributions/transformed_distribution.py
index 6aa6ec40d9..bc321900dc 100644
--- a/tensorflow/python/ops/distributions/transformed_distribution.py
+++ b/tensorflow/python/ops/distributions/transformed_distribution.py
@@ -19,8 +19,6 @@ from __future__ import print_function
import numpy as np
-# Bijectors must be directly imported because `remove_undocumented` prevents
-# individual file imports.
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
diff --git a/tensorflow/python/ops/embedding_ops.py b/tensorflow/python/ops/embedding_ops.py
index 9e46739bc1..6f2a34c731 100644
--- a/tensorflow/python/ops/embedding_ops.py
+++ b/tensorflow/python/ops/embedding_ops.py
@@ -331,8 +331,8 @@ def embedding_lookup_sparse(params,
representing sharded embedding tensors. Alternatively, a
`PartitionedVariable`, created by partitioning along dimension 0. Each
element must be appropriately sized for the given `partition_strategy`.
- sp_ids: N x M `SparseTensor` of int64 ids (typically from FeatureValueToId),
- where N is typically batch size and M is arbitrary.
+ sp_ids: N x M `SparseTensor` of int64 ids where N is typically batch size
+ and M is arbitrary.
sp_weights: either a `SparseTensor` of float / double weights, or `None` to
indicate all weights should be taken to be 1. If specified, `sp_weights`
must have exactly the same shape and indices as `sp_ids`.
diff --git a/tensorflow/python/ops/functional_ops.py b/tensorflow/python/ops/functional_ops.py
index 161f6f3659..765a2ef993 100644
--- a/tensorflow/python/ops/functional_ops.py
+++ b/tensorflow/python/ops/functional_ops.py
@@ -65,10 +65,20 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is fn(initializer, values[0]).shape`.
+ This method also allows multi-arity `elems` and output of `fn`. If `elems`
+ is a (possibly nested) list or tuple of tensors, then each of these tensors
+ must have a matching first (unpack) dimension. The signature of `fn` may
+ match the structure of `elems`. That is, if `elems` is
+ `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
+ `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
+
Args:
fn: The callable to be performed.
- elems: A tensor to be unpacked on dimension 0.
- initializer: (optional) The initial value for the accumulator.
+ elems: A tensor or (possibly nested) sequence of tensors, each of which
+ will be unpacked along their first dimension. The nested sequence
+ of the resulting slices will be the first argument to `fn`.
+ initializer: (optional) A tensor or (possibly nested) sequence of tensors,
+ as the initial value for the accumulator.
parallel_iterations: (optional) The number of iterations allowed to run
in parallel.
back_prop: (optional) True enables support for back propagation.
@@ -76,8 +86,9 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
name: (optional) Name prefix for the returned tensors.
Returns:
- A tensor resulting from applying `fn` consecutively to the list of tensors
- unpacked from `elems`, from first to last.
+ A tensor or (possibly nested) sequence of tensors, resulting from applying
+ `fn` consecutively to the list of tensors unpacked from `elems`, from first
+ to last.
Raises:
TypeError: if `fn` is not callable.
@@ -92,6 +103,11 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
if not callable(fn):
raise TypeError("fn must be callable.")
+ def create_ta(elem):
+ return tensor_array_ops.TensorArray(
+ dtype=elem.dtype, size=n, dynamic_size=False,
+ infer_shape=True).unstack(elem)
+
in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "foldl", [elems]):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
@@ -107,24 +123,26 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
varscope.set_caching_device(lambda op: op.device)
varscope_caching_device_was_none = True
- # Convert elems to tensor array.
- elems = ops.convert_to_tensor(elems, name="elems")
- n = array_ops.shape(elems)[0]
- elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
- dynamic_size=False,
- infer_shape=True)
- elems_ta = elems_ta.unstack(elems)
+ # Convert elems to tensor array. n may be known statically.
+ elems_flat = [
+ ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
+ ]
+ n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]
+
+ elems_ta = nest.map_structure(create_ta, elems)
if initializer is None:
- a = elems_ta.read(0)
+ a = nest.map_structure(lambda elem: elem.read(0), elems_ta)
i = constant_op.constant(1)
else:
- a = ops.convert_to_tensor(initializer)
+ a = initializer
i = constant_op.constant(0)
def compute(i, a):
- a = fn(a, elems_ta.read(i))
+ elem_i = nest.map_structure(lambda elem: elem.read(i), elems_ta)
+ a = fn(a, elem_i)
return [i + 1, a]
+
_, r_a = control_flow_ops.while_loop(
lambda i, a: i < n, compute, [i, a],
parallel_iterations=parallel_iterations,
@@ -135,6 +153,7 @@ def foldl(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
# supported in Eager
if in_graph_mode and varscope_caching_device_was_none:
varscope.set_caching_device(None)
+
return r_a
@@ -153,10 +172,20 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
Suppose that `elems` is unpacked into `values`, a list of tensors. The shape
of the result tensor is `fn(initializer, values[0]).shape`.
+ This method also allows multi-arity `elems` and output of `fn`. If `elems`
+ is a (possibly nested) list or tuple of tensors, then each of these tensors
+ must have a matching first (unpack) dimension. The signature of `fn` may
+ match the structure of `elems`. That is, if `elems` is
+ `(t1, [t2, t3, [t4, t5]])`, then an appropriate signature for `fn` is:
+ `fn = lambda (t1, [t2, t3, [t4, t5]]):`.
+
Args:
fn: The callable to be performed.
- elems: A tensor that is unpacked into a sequence of tensors to apply `fn`.
- initializer: (optional) The initial value for the accumulator.
+ elems: A tensor or (possibly nested) sequence of tensors, each of which
+ will be unpacked along their first dimension. The nested sequence
+ of the resulting slices will be the first argument to `fn`.
+ initializer: (optional) A tensor or (possibly nested) sequence of tensors,
+ as the initial value for the accumulator.
parallel_iterations: (optional) The number of iterations allowed to run
in parallel.
back_prop: (optional) True enables support for back propagation.
@@ -164,8 +193,9 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
name: (optional) Name prefix for the returned tensors.
Returns:
- A tensor resulting from applying `fn` consecutively to the list of tensors
- unpacked from `elems`, from last to first.
+ A tensor or (possibly nested) sequence of tensors, resulting from applying
+ `fn` consecutively to the list of tensors unpacked from `elems`, from last
+ to first.
Raises:
TypeError: if `fn` is not callable.
@@ -180,6 +210,11 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
if not callable(fn):
raise TypeError("fn must be callable.")
+ def create_ta(elem):
+ return tensor_array_ops.TensorArray(
+ dtype=elem.dtype, size=n, dynamic_size=False,
+ infer_shape=True).unstack(elem)
+
in_graph_mode = not context.executing_eagerly()
with ops.name_scope(name, "foldr", [elems]):
# TODO(akshayka): Remove the in_graph_mode check once caching devices are
@@ -195,26 +230,30 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
varscope.set_caching_device(lambda op: op.device)
varscope_caching_device_was_none = True
- # Convert elems to tensor array.
- elems = ops.convert_to_tensor(elems, name="elems")
- n = array_ops.shape(elems)[0]
- elems_ta = tensor_array_ops.TensorArray(dtype=elems.dtype, size=n,
- dynamic_size=False,
- infer_shape=True)
- elems_ta = elems_ta.unstack(elems)
+ # Convert elems to tensor array. n may be known statically.
+ elems_flat = [
+ ops.convert_to_tensor(elem, name="elem") for elem in nest.flatten(elems)
+ ]
+ n = elems_flat[0].shape[0].value or array_ops.shape(elems_flat[0])[0]
+
+ elems_ta = nest.map_structure(create_ta, elems)
if initializer is None:
i = n - 1
- a = elems_ta.read(i)
+ a = nest.map_structure(lambda elem: elem.read(i), elems_ta)
else:
i = n
- a = ops.convert_to_tensor(initializer)
+ a = initializer
+
def compute(i, a):
i -= 1
- a = fn(a, elems_ta.read(i))
- return [i, a]
+ elem = nest.map_structure(lambda elem: elem.read(i), elems_ta)
+ a_out = fn(a, elem)
+ return [i, a_out]
+
_, r_a = control_flow_ops.while_loop(
- lambda i, a: i > 0, compute, [i, a],
+ lambda i, a: i > 0,
+ compute, [i, a],
parallel_iterations=parallel_iterations,
back_prop=back_prop,
swap_memory=swap_memory)
@@ -223,6 +262,7 @@ def foldr(fn, elems, initializer=None, parallel_iterations=10, back_prop=True,
# supported in Eager
if in_graph_mode and varscope_caching_device_was_none:
varscope.set_caching_device(None)
+
return r_a
@@ -887,6 +927,9 @@ def For(start,
output_attr.list.i.extend(hostmem)
ret[0].op._set_attr("_output_hostmem", output_attr) # pylint: disable=protected-access
return ret
+# pylint: enable=invalid-name,protected-access
-# pylint: enable=invalid-name,protected-access
+def partitioned_call(args, f):
+ return gen_functional_ops.partitioned_call(
+ args=args, Tout=[o.type for o in f.definition.signature.output_arg], f=f)
diff --git a/tensorflow/python/ops/gradients.py b/tensorflow/python/ops/gradients.py
index 2668e8f60c..9fa8e27d5c 100644
--- a/tensorflow/python/ops/gradients.py
+++ b/tensorflow/python/ops/gradients.py
@@ -25,14 +25,4 @@ from tensorflow.python.ops.gradients_impl import AggregationMethod
from tensorflow.python.ops.gradients_impl import gradients
from tensorflow.python.ops.gradients_impl import hessians
# pylint: enable=unused-import
-from tensorflow.python.util.all_util import remove_undocumented
-_allowed_symbols = [
- # TODO(drpng): find a good place to reference this.
- "AggregationMethod",
- "GradientTape",
- "custom_gradient",
- "gradients", # tf.gradients.gradients.
- "hessians", # tf.gradients.hessians
-]
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/image_ops.py b/tensorflow/python/ops/image_ops.py
index 68be9ccdd6..3d40c39181 100644
--- a/tensorflow/python/ops/image_ops.py
+++ b/tensorflow/python/ops/image_ops.py
@@ -91,13 +91,3 @@ from tensorflow.python.ops.image_ops_impl import *
from tensorflow.python.ops.image_ops_impl import _Check3DImage
from tensorflow.python.ops.image_ops_impl import _ImageDimensions
# pylint: enable=unused-import
-
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = [
- # ResizeMethod is not documented, but is documented in functions
- # that use it.
- 'ResizeMethod',
-]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/linalg/linalg.py b/tensorflow/python/ops/linalg/linalg.py
index 14319025ff..d73c21cdc0 100644
--- a/tensorflow/python/ops/linalg/linalg.py
+++ b/tensorflow/python/ops/linalg/linalg.py
@@ -22,6 +22,7 @@ from __future__ import print_function
# pylint: disable=wildcard-import,unused-import
from tensorflow.python.ops.linalg.linalg_impl import *
from tensorflow.python.ops.linalg.linear_operator import *
+from tensorflow.python.ops.linalg.linear_operator_circulant import *
from tensorflow.python.ops.linalg.linear_operator_composition import *
from tensorflow.python.ops.linalg.linear_operator_diag import *
from tensorflow.python.ops.linalg.linear_operator_full_matrix import *
diff --git a/tensorflow/python/ops/linalg/linear_operator_circulant.py b/tensorflow/python/ops/linalg/linear_operator_circulant.py
new file mode 100644
index 0000000000..c367ed25ad
--- /dev/null
+++ b/tensorflow/python/ops/linalg/linear_operator_circulant.py
@@ -0,0 +1,1074 @@
+# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""`LinearOperator` coming from a [[nested] block] circulant matrix."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import check_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.ops.distributions import util as distribution_util
+from tensorflow.python.ops.linalg import linalg_impl as linalg
+from tensorflow.python.ops.linalg import linear_operator
+from tensorflow.python.ops.linalg import linear_operator_util
+from tensorflow.python.util.tf_export import tf_export
+
+__all__ = [
+ "LinearOperatorCirculant",
+ "LinearOperatorCirculant2D",
+ "LinearOperatorCirculant3D",
+]
+
+# Different FFT Ops will be used for different block depths.
+_FFT_OP = {1: math_ops.fft, 2: math_ops.fft2d, 3: math_ops.fft3d}
+_IFFT_OP = {1: math_ops.ifft, 2: math_ops.ifft2d, 3: math_ops.ifft3d}
+
+# This is the only dtype allowed with fft ops.
+# TODO(langmore) Add other types once available.
+_DTYPE_COMPLEX = dtypes.complex64
+
+
+# TODO(langmore) Add transformations that create common spectrums, e.g.
+# starting with the convolution kernel
+# start with half a spectrum, and create a Hermitian one.
+# common filters.
+# TODO(langmore) Support rectangular Toeplitz matrices.
+class _BaseLinearOperatorCirculant(linear_operator.LinearOperator):
+ """Base class for circulant operators. Not user facing.
+
+ `LinearOperator` acting like a [batch] [[nested] block] circulant matrix.
+ """
+
+ def __init__(self,
+ spectrum,
+ block_depth,
+ input_output_dtype=_DTYPE_COMPLEX,
+ is_non_singular=None,
+ is_self_adjoint=None,
+ is_positive_definite=None,
+ is_square=True,
+ name="LinearOperatorCirculant"):
+ r"""Initialize an `_BaseLinearOperatorCirculant`.
+
+ Args:
+ spectrum: Shape `[B1,...,Bb, N]` `Tensor`. Allowed dtypes are
+ `float32`, `complex64`. Type can be different than `input_output_dtype`
+ block_depth: Python integer, either 1, 2, or 3. Will be 1 for circulant,
+ 2 for block circulant, and 3 for nested block circulant.
+ input_output_dtype: `dtype` for input/output. Must be either
+ `float32` or `complex64`.
+ is_non_singular: Expect that this operator is non-singular.
+ is_self_adjoint: Expect that this operator is equal to its hermitian
+ transpose. If `spectrum` is real, this will always be true.
+ is_positive_definite: Expect that this operator is positive definite,
+ meaning the quadratic form `x^H A x` has positive real part for all
+ nonzero `x`. Note that we do not require the operator to be
+ self-adjoint to be positive-definite. See:
+ https://en.wikipedia.org/wiki/Positive-definite_matrix\
+ #Extension_for_non_symmetric_matrices
+ is_square: Expect that this operator acts like square [batch] matrices.
+ name: A name to prepend to all ops created by this class.
+
+ Raises:
+ ValueError: If `block_depth` is not an allowed value.
+ TypeError: If `spectrum` is not an allowed type.
+ """
+
+ allowed_block_depths = [1, 2, 3]
+
+ self._name = name
+
+ if block_depth not in allowed_block_depths:
+ raise ValueError("Expected block_depth to be in %s. Found: %s." %
+ (allowed_block_depths, block_depth))
+ self._block_depth = block_depth
+
+ with ops.name_scope(name, values=[spectrum]):
+ self._spectrum = self._check_spectrum_and_return_tensor(spectrum)
+
+ # Check and auto-set hints.
+ if not self.spectrum.dtype.is_complex:
+ if is_self_adjoint is False:
+ raise ValueError(
+ "A real spectrum always corresponds to a self-adjoint operator.")
+ is_self_adjoint = True
+
+ if is_square is False:
+ raise ValueError(
+ "A [[nested] block] circulant operator is always square.")
+ is_square = True
+
+ # If spectrum.shape = [s0, s1, s2], and block_depth = 2,
+ # block_shape = [s1, s2]
+ s_shape = array_ops.shape(self.spectrum)
+ self._block_shape_tensor = s_shape[-self.block_depth:]
+
+ # Add common variants of spectrum to the graph.
+ self._spectrum_complex = _to_complex(self.spectrum)
+ self._abs_spectrum = math_ops.abs(self.spectrum)
+ self._conj_spectrum = math_ops.conj(self._spectrum_complex)
+
+ super(_BaseLinearOperatorCirculant, self).__init__(
+ dtype=dtypes.as_dtype(input_output_dtype),
+ graph_parents=[self.spectrum],
+ is_non_singular=is_non_singular,
+ is_self_adjoint=is_self_adjoint,
+ is_positive_definite=is_positive_definite,
+ is_square=is_square,
+ name=name)
+
+ def _check_spectrum_and_return_tensor(self, spectrum):
+ """Static check of spectrum. Then return `Tensor` version."""
+ spectrum = ops.convert_to_tensor(spectrum, name="spectrum")
+
+ allowed_dtypes = [dtypes.float32, dtypes.complex64]
+ if spectrum.dtype not in allowed_dtypes:
+ raise TypeError("Argument spectrum must have dtype in %s. Found: %s" %
+ (allowed_dtypes, spectrum.dtype))
+ if spectrum.get_shape().ndims is not None:
+ if spectrum.get_shape().ndims < self.block_depth:
+ raise ValueError(
+ "Argument spectrum must have at least %d dimensions. Found: %s" %
+ (self.block_depth, spectrum))
+ return spectrum
+
+ @property
+ def block_depth(self):
+ """Depth of recursively defined circulant blocks defining this `Operator`.
+
+ With `A` the dense representation of this `Operator`,
+
+ `block_depth = 1` means `A` is symmetric circulant. For example,
+
+ ```
+ A = |x y z y|
+ |y x y z|
+ |z y x y|
+ |y z y x|
+ ```
+
+ `block_depth = 2` means `A` is block symmetric circulant with symemtric
+ circulant blocks. For example, with `X`, `Y`, `Z` symmetric circulant,
+
+ ```
+ A = |X Y Z Y|
+ |Y X Y Z|
+ |Z Y X Y|
+ |Y Z Y X|
+ ```
+
+ `block_depth = 3` means `A` is block symmetric circulant with block
+ symmetric circulant blocks.
+
+ Returns:
+ Python `integer`.
+ """
+ return self._block_depth
+
+ def block_shape_tensor(self):
+ """Shape of the block dimensions of `self.spectrum`."""
+ return self._block_shape_tensor
+
+ @property
+ def block_shape(self):
+ return self.spectrum.get_shape()[-self.block_depth:]
+
+ @property
+ def spectrum(self):
+ return self._spectrum
+
+ def _vectorize_then_blockify(self, matrix):
+ """Shape batch matrix to batch vector, then blockify trailing dimensions."""
+ # Suppose
+ # matrix.shape = [m0, m1, m2, m3],
+ # and matrix is a matrix because the final two dimensions are matrix dims.
+ # self.block_depth = 2,
+ # self.block_shape = [b0, b1] (note b0 * b1 = m2).
+ # We will reshape matrix to
+ # [m3, m0, m1, b0, b1].
+
+ # Vectorize: Reshape to batch vector.
+ # [m0, m1, m2, m3] --> [m3, m0, m1, m2]
+ # This is called "vectorize" because we have taken the final two matrix dims
+ # and turned this into a size m3 batch of vectors.
+ vec = distribution_util.rotate_transpose(matrix, shift=1)
+
+ # Blockify: Blockfy trailing dimensions.
+ # [m3, m0, m1, m2] --> [m3, m0, m1, b0, b1]
+ if (vec.get_shape().is_fully_defined() and
+ self.block_shape.is_fully_defined()):
+ # vec_leading_shape = [m3, m0, m1],
+ # the parts of vec that will not be blockified.
+ vec_leading_shape = vec.get_shape()[:-1]
+ final_shape = vec_leading_shape.concatenate(self.block_shape)
+ else:
+ vec_leading_shape = array_ops.shape(vec)[:-1]
+ final_shape = array_ops.concat(
+ (vec_leading_shape, self.block_shape_tensor()), 0)
+ return array_ops.reshape(vec, final_shape)
+
+ def _unblockify_then_matricize(self, vec):
+ """Flatten the block dimensions then reshape to a batch matrix."""
+ # Suppose
+ # vec.shape = [v0, v1, v2, v3],
+ # self.block_depth = 2.
+ # Then
+ # leading shape = [v0, v1]
+ # block shape = [v2, v3].
+ # We will reshape vec to
+ # [v1, v2*v3, v0].
+
+ # Un-blockify: Flatten block dimensions. Reshape
+ # [v0, v1, v2, v3] --> [v0, v1, v2*v3].
+ if vec.get_shape().is_fully_defined():
+ # vec_shape = [v0, v1, v2, v3]
+ vec_shape = vec.get_shape().as_list()
+ # vec_leading_shape = [v0, v1]
+ vec_leading_shape = vec_shape[:-self.block_depth]
+ # vec_block_shape = [v2, v3]
+ vec_block_shape = vec_shape[-self.block_depth:]
+ # flat_shape = [v0, v1, v2*v3]
+ flat_shape = vec_leading_shape + [np.prod(vec_block_shape)]
+ else:
+ vec_shape = array_ops.shape(vec)
+ vec_leading_shape = vec_shape[:-self.block_depth]
+ vec_block_shape = vec_shape[-self.block_depth:]
+ flat_shape = array_ops.concat(
+ (vec_leading_shape, [math_ops.reduce_prod(vec_block_shape)]), 0)
+ vec_flat = array_ops.reshape(vec, flat_shape)
+
+ # Matricize: Reshape to batch matrix.
+ # [v0, v1, v2*v3] --> [v1, v2*v3, v0],
+ # representing a shape [v1] batch of [v2*v3, v0] matrices.
+ matrix = distribution_util.rotate_transpose(vec_flat, shift=-1)
+ return matrix
+
+ def _fft(self, x):
+ """FFT along the last self.block_depth dimensions of x.
+
+ Args:
+ x: `Tensor` with floating or complex `dtype`.
+ Should be in the form returned by self._vectorize_then_blockify.
+
+ Returns:
+ `Tensor` with `dtype` `complex64`.
+ """
+ x_complex = _to_complex(x)
+ return _FFT_OP[self.block_depth](x_complex)
+
+ def _ifft(self, x):
+ """IFFT along the last self.block_depth dimensions of x.
+
+ Args:
+ x: `Tensor` with floating or complex dtype. Should be in the form
+ returned by self._vectorize_then_blockify.
+
+ Returns:
+ `Tensor` with `dtype` `complex64`.
+ """
+ x_complex = _to_complex(x)
+ return _IFFT_OP[self.block_depth](x_complex)
+
+ def convolution_kernel(self, name="convolution_kernel"):
+ """Convolution kernel corresponding to `self.spectrum`.
+
+ The `D` dimensional DFT of this kernel is the frequency domain spectrum of
+ this operator.
+
+ Args:
+ name: A name to give this `Op`.
+
+ Returns:
+ `Tensor` with `dtype` `self.dtype`.
+ """
+ with self._name_scope(name):
+ h = self._ifft(self._spectrum_complex)
+ return math_ops.cast(h, self.dtype)
+
+ def _shape(self):
+ s_shape = self._spectrum.get_shape()
+ # Suppose spectrum.shape = [a, b, c, d]
+ # block_depth = 2
+ # Then:
+ # batch_shape = [a, b]
+ # N = c*d
+ # and we want to return
+ # [a, b, c*d, c*d]
+ batch_shape = s_shape[:-self.block_depth]
+ # trailing_dims = [c, d]
+ trailing_dims = s_shape[-self.block_depth:]
+ if trailing_dims.is_fully_defined():
+ n = np.prod(trailing_dims.as_list())
+ else:
+ n = None
+ n_x_n = tensor_shape.TensorShape([n, n])
+ return batch_shape.concatenate(n_x_n)
+
+ def _shape_tensor(self):
+ # See self.shape for explanation of steps
+ s_shape = array_ops.shape(self._spectrum)
+ batch_shape = s_shape[:-self.block_depth]
+ trailing_dims = s_shape[-self.block_depth:]
+ n = math_ops.reduce_prod(trailing_dims)
+ n_x_n = [n, n]
+ return array_ops.concat((batch_shape, n_x_n), 0)
+
+ def assert_hermitian_spectrum(self, name="assert_hermitian_spectrum"):
+ """Returns an `Op` that asserts this operator has Hermitian spectrum.
+
+ This operator corresponds to a real-valued matrix if and only if its
+ spectrum is Hermitian.
+
+ Args:
+ name: A name to give this `Op`.
+
+ Returns:
+ An `Op` that asserts this operator has Hermitian spectrum.
+ """
+ eps = np.finfo(self.dtype.real_dtype.as_numpy_dtype).eps
+ with self._name_scope(name):
+ # Assume linear accumulation of error.
+ max_err = eps * self.domain_dimension_tensor()
+ imag_convolution_kernel = math_ops.imag(self.convolution_kernel())
+ return check_ops.assert_less(
+ math_ops.abs(imag_convolution_kernel),
+ max_err,
+ message="Spectrum was not Hermitian")
+
+ def _assert_non_singular(self):
+ return linear_operator_util.assert_no_entries_with_modulus_zero(
+ self.spectrum,
+ message="Singular operator: Spectrum contained zero values.")
+
+ def _assert_positive_definite(self):
+ # This operator has the action Ax = F^H D F x,
+ # where D is the diagonal matrix with self.spectrum on the diag. Therefore,
+ # <x, Ax> = <Fx, DFx>,
+ # Since F is bijective, the condition for positive definite is the same as
+ # for a diagonal matrix, i.e. real part of spectrum is positive.
+ message = (
+ "Not positive definite: Real part of spectrum was not all positive.")
+ return check_ops.assert_positive(
+ math_ops.real(self.spectrum), message=message)
+
+ def _assert_self_adjoint(self):
+ # Recall correspondence between symmetry and real transforms. See docstring
+ return linear_operator_util.assert_zero_imag_part(
+ self.spectrum,
+ message=(
+ "Not self-adjoint: The spectrum contained non-zero imaginary part."
+ ))
+
+ def _broadcast_batch_dims(self, x, spectrum):
+ """Broadcast batch dims of batch matrix `x` and spectrum."""
+ # spectrum.shape = batch_shape + block_shape
+ # First make spectrum a batch matrix with
+ # spectrum.shape = batch_shape + [prod(block_shape), 1]
+ spec_mat = array_ops.reshape(
+ spectrum, array_ops.concat(
+ (self.batch_shape_tensor(), [-1, 1]), axis=0))
+ # Second, broadcast, possibly requiring an addition of array of zeros.
+ x, spec_mat = linear_operator_util.broadcast_matrix_batch_dims((x,
+ spec_mat))
+ # Third, put the block shape back into spectrum.
+ batch_shape = array_ops.shape(x)[:-2]
+ spectrum = array_ops.reshape(
+ spec_mat,
+ array_ops.concat((batch_shape, self.block_shape_tensor()), axis=0))
+
+ return x, spectrum
+
+ def _matmul(self, x, adjoint=False, adjoint_arg=False):
+ x = linalg.adjoint(x) if adjoint_arg else x
+ # With F the matrix of a DFT, and F^{-1}, F^H the inverse and Hermitian
+ # transpose, one can show that F^{-1} = F^{H} is the IDFT matrix. Therefore
+ # matmul(x) = F^{-1} diag(spectrum) F x,
+ # = F^{H} diag(spectrum) F x,
+ # so that
+ # matmul(x, adjoint=True) = F^{H} diag(conj(spectrum)) F x.
+ spectrum = self._conj_spectrum if adjoint else self._spectrum_complex
+
+ x, spectrum = self._broadcast_batch_dims(x, spectrum)
+
+ x_vb = self._vectorize_then_blockify(x)
+ fft_x_vb = self._fft(x_vb)
+ block_vector_result = self._ifft(spectrum * fft_x_vb)
+ y = self._unblockify_then_matricize(block_vector_result)
+
+ return math_ops.cast(y, self.dtype)
+
+ def _determinant(self):
+ reduction_indices = [-(i + 1) for i in range(self.block_depth)]
+ det = math_ops.reduce_prod(
+ self.spectrum, reduction_indices=reduction_indices)
+ return math_ops.cast(det, self.dtype)
+
+ def _log_abs_determinant(self):
+ reduction_indices = [-(i + 1) for i in range(self.block_depth)]
+ lad = math_ops.reduce_sum(
+ math_ops.log(self._abs_spectrum), reduction_indices=reduction_indices)
+ return math_ops.cast(lad, self.dtype)
+
+ def _solve(self, rhs, adjoint=False, adjoint_arg=False):
+ rhs = linalg.adjoint(rhs) if adjoint_arg else rhs
+ spectrum = self._conj_spectrum if adjoint else self._spectrum_complex
+
+ rhs, spectrum = self._broadcast_batch_dims(rhs, spectrum)
+
+ rhs_vb = self._vectorize_then_blockify(rhs)
+ fft_rhs_vb = self._fft(rhs_vb)
+ solution_vb = self._ifft(fft_rhs_vb / spectrum)
+ x = self._unblockify_then_matricize(solution_vb)
+ return math_ops.cast(x, self.dtype)
+
+ def _diag_part(self):
+ # Get ones in shape of diag, which is [B1,...,Bb, N]
+ # Also get the size of the diag, "N".
+ if self.shape.is_fully_defined():
+ diag_shape = self.shape[:-1]
+ diag_size = self.domain_dimension.value
+ else:
+ diag_shape = self.shape_tensor()[:-1]
+ diag_size = self.domain_dimension_tensor()
+ ones_diag = array_ops.ones(diag_shape, dtype=self.dtype)
+
+ # As proved in comments in self._trace, the value on the diag is constant,
+ # repeated N times. This value is the trace divided by N.
+
+ # The handling of self.shape = (0, 0) is tricky, and is the reason we choose
+ # to compute trace and use that to compute diag_part, rather than computing
+ # the value on the diagonal ("diag_value") directly. Both result in a 0/0,
+ # but in different places, and the current method gives the right result in
+ # the end.
+
+ # Here, if self.shape = (0, 0), then self.trace() = 0., and then
+ # diag_value = 0. / 0. = NaN.
+ diag_value = self.trace() / math_ops.cast(diag_size, self.dtype)
+
+ # If self.shape = (0, 0), then ones_diag = [] (empty tensor), and then
+ # the following line is NaN * [] = [], as needed.
+ return diag_value[..., array_ops.newaxis] * ones_diag
+
+ def _trace(self):
+ # The diagonal of the [[nested] block] circulant operator is the mean of
+ # the spectrum.
+ # Proof: For the [0,...,0] element, this follows from the IDFT formula.
+ # Then the result follows since all diagonal elements are the same.
+
+ # Therefore, the trace is the sum of the spectrum.
+
+ # Get shape of diag along with the axis over which to reduce the spectrum.
+ # We will reduce the spectrum over all block indices.
+ if self.spectrum.get_shape().is_fully_defined():
+ spec_rank = self.spectrum.get_shape().ndims
+ axis = np.arange(spec_rank - self.block_depth, spec_rank, dtype=np.int32)
+ else:
+ spec_rank = array_ops.rank(self.spectrum)
+ axis = math_ops.range(spec_rank - self.block_depth, spec_rank)
+
+ # Real diag part "re_d".
+ # Suppose spectrum.shape = [B1,...,Bb, N1, N2]
+ # self.shape = [B1,...,Bb, N, N], with N1 * N2 = N.
+ # re_d_value.shape = [B1,...,Bb]
+ re_d_value = math_ops.reduce_sum(math_ops.real(self.spectrum), axis=axis)
+
+ if not self.dtype.is_complex:
+ return math_ops.cast(re_d_value, self.dtype)
+
+ # Imaginary part, "im_d".
+ if self.is_self_adjoint:
+ im_d_value = 0.
+ else:
+ im_d_value = math_ops.reduce_sum(math_ops.imag(self.spectrum), axis=axis)
+
+ return math_ops.cast(math_ops.complex(re_d_value, im_d_value), self.dtype)
+
+
+@tf_export("linalg.LinearOperatorCirculant")
+class LinearOperatorCirculant(_BaseLinearOperatorCirculant):
+ """`LinearOperator` acting like a circulant matrix.
+
+ This operator acts like a circulant matrix `A` with
+ shape `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
+ batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
+ an `N x N` matrix. This matrix `A` is not materialized, but for
+ purposes of broadcasting this shape will be relevant.
+
+ #### Description in terms of circulant matrices
+
+ Circulant means the entries of `A` are generated by a single vector, the
+ convolution kernel `h`: `A_{mn} := h_{m-n mod N}`. With `h = [w, x, y, z]`,
+
+ ```
+ A = |w z y x|
+ |x w z y|
+ |y x w z|
+ |z y x w|
+ ```
+
+ This means that the result of matrix multiplication `v = Au` has `Lth` column
+ given circular convolution between `h` with the `Lth` column of `u`.
+
+ See http://ee.stanford.edu/~gray/toeplitz.pdf
+
+ #### Description in terms of the frequency spectrum
+
+ There is an equivalent description in terms of the [batch] spectrum `H` and
+ Fourier transforms. Here we consider `A.shape = [N, N]` and ignore batch
+ dimensions. Define the discrete Fourier transform (DFT) and its inverse by
+
+ ```
+ DFT[ h[n] ] = H[k] := sum_{n = 0}^{N - 1} h_n e^{-i 2pi k n / N}
+ IDFT[ H[k] ] = h[n] = N^{-1} sum_{k = 0}^{N - 1} H_k e^{i 2pi k n / N}
+ ```
+
+ From these definitions, we see that
+
+ ```
+ H[0] = sum_{n = 0}^{N - 1} h_n
+ H[1] = "the first positive frequency"
+ H[N - 1] = "the first negative frequency"
+ ```
+
+ Loosely speaking, with `*` element-wise multiplication, matrix multiplication
+ is equal to the action of a Fourier multiplier: `A u = IDFT[ H * DFT[u] ]`.
+ Precisely speaking, given `[N, R]` matrix `u`, let `DFT[u]` be the `[N, R]`
+ matrix with `rth` column equal to the DFT of the `rth` column of `u`.
+ Define the `IDFT` similarly.
+ Matrix multiplication may be expressed columnwise:
+
+ ```(A u)_r = IDFT[ H * (DFT[u])_r ]```
+
+ #### Operator properties deduced from the spectrum.
+
+ Letting `U` be the `kth` Euclidean basis vector, and `U = IDFT[u]`.
+ The above formulas show that`A U = H_k * U`. We conclude that the elements
+ of `H` are the eigenvalues of this operator. Therefore
+
+ * This operator is positive definite if and only if `Real{H} > 0`.
+
+ A general property of Fourier transforms is the correspondence between
+ Hermitian functions and real valued transforms.
+
+ Suppose `H.shape = [B1,...,Bb, N]`. We say that `H` is a Hermitian spectrum
+ if, with `%` meaning modulus division,
+
+ ```H[..., n % N] = ComplexConjugate[ H[..., (-n) % N] ]```
+
+ * This operator corresponds to a real matrix if and only if `H` is Hermitian.
+ * This operator is self-adjoint if and only if `H` is real.
+
+ See e.g. "Discrete-Time Signal Processing", Oppenheim and Schafer.
+
+ #### Example of a self-adjoint positive definite operator
+
+ ```python
+ # spectrum is real ==> operator is self-adjoint
+ # spectrum is positive ==> operator is positive definite
+ spectrum = [6., 4, 2]
+
+ operator = LinearOperatorCirculant(spectrum)
+
+ # IFFT[spectrum]
+ operator.convolution_kernel()
+ ==> [4 + 0j, 1 + 0.58j, 1 - 0.58j]
+
+ operator.to_dense()
+ ==> [[4 + 0.0j, 1 - 0.6j, 1 + 0.6j],
+ [1 + 0.6j, 4 + 0.0j, 1 - 0.6j],
+ [1 - 0.6j, 1 + 0.6j, 4 + 0.0j]]
+ ```
+
+ #### Example of defining in terms of a real convolution kernel
+
+ ```python
+ # convolution_kernel is real ==> spectrum is Hermitian.
+ convolution_kernel = [1., 2., 1.]]
+ spectrum = tf.fft(tf.cast(convolution_kernel, tf.complex64))
+
+ # spectrum is Hermitian ==> operator is real.
+ # spectrum is shape [3] ==> operator is shape [3, 3]
+ # We force the input/output type to be real, which allows this to operate
+ # like a real matrix.
+ operator = LinearOperatorCirculant(spectrum, input_output_dtype=tf.float32)
+
+ operator.to_dense()
+ ==> [[ 1, 1, 2],
+ [ 2, 1, 1],
+ [ 1, 2, 1]]
+ ```
+
+ #### Example of Hermitian spectrum
+
+ ```python
+ # spectrum is shape [3] ==> operator is shape [3, 3]
+ # spectrum is Hermitian ==> operator is real.
+ spectrum = [1, 1j, -1j]
+
+ operator = LinearOperatorCirculant(spectrum)
+
+ operator.to_dense()
+ ==> [[ 0.33 + 0j, 0.91 + 0j, -0.24 + 0j],
+ [-0.24 + 0j, 0.33 + 0j, 0.91 + 0j],
+ [ 0.91 + 0j, -0.24 + 0j, 0.33 + 0j]
+ ```
+
+ #### Example of forcing real `dtype` when spectrum is Hermitian
+
+ ```python
+ # spectrum is shape [4] ==> operator is shape [4, 4]
+ # spectrum is real ==> operator is self-adjoint
+ # spectrum is Hermitian ==> operator is real
+ # spectrum has positive real part ==> operator is positive-definite.
+ spectrum = [6., 4, 2, 4]
+
+ # Force the input dtype to be float32.
+ # Cast the output to float32. This is fine because the operator will be
+ # real due to Hermitian spectrum.
+ operator = LinearOperatorCirculant(spectrum, input_output_dtype=tf.float32)
+
+ operator.shape
+ ==> [4, 4]
+
+ operator.to_dense()
+ ==> [[4, 1, 0, 1],
+ [1, 4, 1, 0],
+ [0, 1, 4, 1],
+ [1, 0, 1, 4]]
+
+ # convolution_kernel = tf.ifft(spectrum)
+ operator.convolution_kernel()
+ ==> [4, 1, 0, 1]
+ ```
+
+ #### Performance
+
+ Suppose `operator` is a `LinearOperatorCirculant` of shape `[N, N]`,
+ and `x.shape = [N, R]`. Then
+
+ * `operator.matmul(x)` is `O(R*N*Log[N])`
+ * `operator.solve(x)` is `O(R*N*Log[N])`
+ * `operator.determinant()` involves a size `N` `reduce_prod`.
+
+ If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
+ `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
+
+ #### Matrix property hints
+
+ This `LinearOperator` is initialized with boolean flags of the form `is_X`,
+ for `X = non_singular, self_adjoint, positive_definite, square`.
+ These have the following meaning:
+
+ * If `is_X == True`, callers should expect the operator to have the
+ property `X`. This is a promise that should be fulfilled, but is *not* a
+ runtime assert. For example, finite floating point precision may result
+ in these promises being violated.
+ * If `is_X == False`, callers should expect the operator to not have `X`.
+ * If `is_X == None` (the default), callers should have no expectation either
+ way.
+ """
+
+ def __init__(self,
+ spectrum,
+ input_output_dtype=_DTYPE_COMPLEX,
+ is_non_singular=None,
+ is_self_adjoint=None,
+ is_positive_definite=None,
+ is_square=True,
+ name="LinearOperatorCirculant"):
+ r"""Initialize an `LinearOperatorCirculant`.
+
+ This `LinearOperator` is initialized to have shape `[B1,...,Bb, N, N]`
+ by providing `spectrum`, a `[B1,...,Bb, N]` `Tensor`.
+
+ If `input_output_dtype = DTYPE`:
+
+ * Arguments to methods such as `matmul` or `solve` must be `DTYPE`.
+ * Values returned by all methods, such as `matmul` or `determinant` will be
+ cast to `DTYPE`.
+
+ Note that if the spectrum is not Hermitian, then this operator corresponds
+ to a complex matrix with non-zero imaginary part. In this case, setting
+ `input_output_dtype` to a real type will forcibly cast the output to be
+ real, resulting in incorrect results!
+
+ If on the other hand the spectrum is Hermitian, then this operator
+ corresponds to a real-valued matrix, and setting `input_output_dtype` to
+ a real type is fine.
+
+ Args:
+ spectrum: Shape `[B1,...,Bb, N]` `Tensor`. Allowed dtypes are
+ `float32`, `complex64`. Type can be different than `input_output_dtype`
+ input_output_dtype: `dtype` for input/output. Must be either
+ `float32` or `complex64`.
+ is_non_singular: Expect that this operator is non-singular.
+ is_self_adjoint: Expect that this operator is equal to its hermitian
+ transpose. If `spectrum` is real, this will always be true.
+ is_positive_definite: Expect that this operator is positive definite,
+ meaning the quadratic form `x^H A x` has positive real part for all
+ nonzero `x`. Note that we do not require the operator to be
+ self-adjoint to be positive-definite. See:
+ https://en.wikipedia.org/wiki/Positive-definite_matrix\
+ #Extension_for_non_symmetric_matrices
+ is_square: Expect that this operator acts like square [batch] matrices.
+ name: A name to prepend to all ops created by this class.
+ """
+ super(LinearOperatorCirculant, self).__init__(
+ spectrum,
+ block_depth=1,
+ input_output_dtype=input_output_dtype,
+ is_non_singular=is_non_singular,
+ is_self_adjoint=is_self_adjoint,
+ is_positive_definite=is_positive_definite,
+ is_square=is_square,
+ name=name)
+
+
+@tf_export("linalg.LinearOperatorCirculant2D")
+class LinearOperatorCirculant2D(_BaseLinearOperatorCirculant):
+ """`LinearOperator` acting like a block circulant matrix.
+
+ This operator acts like a block circulant matrix `A` with
+ shape `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
+ batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
+ an `N x N` matrix. This matrix `A` is not materialized, but for
+ purposes of broadcasting this shape will be relevant.
+
+ #### Description in terms of block circulant matrices
+
+ If `A` is block circulant, with block sizes `N0, N1` (`N0 * N1 = N`):
+ `A` has a block circulant structure, composed of `N0 x N0` blocks, with each
+ block an `N1 x N1` circulant matrix.
+
+ For example, with `W`, `X`, `Y`, `Z` each circulant,
+
+ ```
+ A = |W Z Y X|
+ |X W Z Y|
+ |Y X W Z|
+ |Z Y X W|
+ ```
+
+ Note that `A` itself will not in general be circulant.
+
+ #### Description in terms of the frequency spectrum
+
+ There is an equivalent description in terms of the [batch] spectrum `H` and
+ Fourier transforms. Here we consider `A.shape = [N, N]` and ignore batch
+ dimensions.
+
+ If `H.shape = [N0, N1]`, (`N0 * N1 = N`):
+ Loosely speaking, matrix multiplication is equal to the action of a
+ Fourier multiplier: `A u = IDFT2[ H DFT2[u] ]`.
+ Precisely speaking, given `[N, R]` matrix `u`, let `DFT2[u]` be the
+ `[N0, N1, R]` `Tensor` defined by re-shaping `u` to `[N0, N1, R]` and taking
+ a two dimensional DFT across the first two dimensions. Let `IDFT2` be the
+ inverse of `DFT2`. Matrix multiplication may be expressed columnwise:
+
+ ```(A u)_r = IDFT2[ H * (DFT2[u])_r ]```
+
+ #### Operator properties deduced from the spectrum.
+
+ * This operator is positive definite if and only if `Real{H} > 0`.
+
+ A general property of Fourier transforms is the correspondence between
+ Hermitian functions and real valued transforms.
+
+ Suppose `H.shape = [B1,...,Bb, N0, N1]`, we say that `H` is a Hermitian
+ spectrum if, with `%` indicating modulus division,
+
+ ```
+ H[..., n0 % N0, n1 % N1] = ComplexConjugate[ H[..., (-n0) % N0, (-n1) % N1 ].
+ ```
+
+ * This operator corresponds to a real matrix if and only if `H` is Hermitian.
+ * This operator is self-adjoint if and only if `H` is real.
+
+ See e.g. "Discrete-Time Signal Processing", Oppenheim and Schafer.
+
+ ### Example of a self-adjoint positive definite operator
+
+ ```python
+ # spectrum is real ==> operator is self-adjoint
+ # spectrum is positive ==> operator is positive definite
+ spectrum = [[1., 2., 3.],
+ [4., 5., 6.],
+ [7., 8., 9.]]
+
+ operator = LinearOperatorCirculant2D(spectrum)
+
+ # IFFT[spectrum]
+ operator.convolution_kernel()
+ ==> [[5.0+0.0j, -0.5-.3j, -0.5+.3j],
+ [-1.5-.9j, 0, 0],
+ [-1.5+.9j, 0, 0]]
+
+ operator.to_dense()
+ ==> Complex self adjoint 9 x 9 matrix.
+ ```
+
+ #### Example of defining in terms of a real convolution kernel,
+
+ ```python
+ # convolution_kernel is real ==> spectrum is Hermitian.
+ convolution_kernel = [[1., 2., 1.], [5., -1., 1.]]
+ spectrum = tf.fft2d(tf.cast(convolution_kernel, tf.complex64))
+
+ # spectrum is shape [2, 3] ==> operator is shape [6, 6]
+ # spectrum is Hermitian ==> operator is real.
+ operator = LinearOperatorCirculant2D(spectrum, input_output_dtype=tf.float32)
+ ```
+
+ #### Performance
+
+ Suppose `operator` is a `LinearOperatorCirculant` of shape `[N, N]`,
+ and `x.shape = [N, R]`. Then
+
+ * `operator.matmul(x)` is `O(R*N*Log[N])`
+ * `operator.solve(x)` is `O(R*N*Log[N])`
+ * `operator.determinant()` involves a size `N` `reduce_prod`.
+
+ If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
+ `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
+
+ #### Matrix property hints
+
+ This `LinearOperator` is initialized with boolean flags of the form `is_X`,
+ for `X = non_singular, self_adjoint, positive_definite, square`.
+ These have the following meaning
+ * If `is_X == True`, callers should expect the operator to have the
+ property `X`. This is a promise that should be fulfilled, but is *not* a
+ runtime assert. For example, finite floating point precision may result
+ in these promises being violated.
+ * If `is_X == False`, callers should expect the operator to not have `X`.
+ * If `is_X == None` (the default), callers should have no expectation either
+ way.
+ """
+
+ def __init__(self,
+ spectrum,
+ input_output_dtype=_DTYPE_COMPLEX,
+ is_non_singular=None,
+ is_self_adjoint=None,
+ is_positive_definite=None,
+ is_square=True,
+ name="LinearOperatorCirculant2D"):
+ r"""Initialize an `LinearOperatorCirculant2D`.
+
+ This `LinearOperator` is initialized to have shape `[B1,...,Bb, N, N]`
+ by providing `spectrum`, a `[B1,...,Bb, N0, N1]` `Tensor` with `N0*N1 = N`.
+
+ If `input_output_dtype = DTYPE`:
+
+ * Arguments to methods such as `matmul` or `solve` must be `DTYPE`.
+ * Values returned by all methods, such as `matmul` or `determinant` will be
+ cast to `DTYPE`.
+
+ Note that if the spectrum is not Hermitian, then this operator corresponds
+ to a complex matrix with non-zero imaginary part. In this case, setting
+ `input_output_dtype` to a real type will forcibly cast the output to be
+ real, resulting in incorrect results!
+
+ If on the other hand the spectrum is Hermitian, then this operator
+ corresponds to a real-valued matrix, and setting `input_output_dtype` to
+ a real type is fine.
+
+ Args:
+ spectrum: Shape `[B1,...,Bb, N]` `Tensor`. Allowed dtypes are
+ `float32`, `complex64`. Type can be different than `input_output_dtype`
+ input_output_dtype: `dtype` for input/output. Must be either
+ `float32` or `complex64`.
+ is_non_singular: Expect that this operator is non-singular.
+ is_self_adjoint: Expect that this operator is equal to its hermitian
+ transpose. If `spectrum` is real, this will always be true.
+ is_positive_definite: Expect that this operator is positive definite,
+ meaning the quadratic form `x^H A x` has positive real part for all
+ nonzero `x`. Note that we do not require the operator to be
+ self-adjoint to be positive-definite. See:
+ https://en.wikipedia.org/wiki/Positive-definite_matrix\
+ #Extension_for_non_symmetric_matrices
+ is_square: Expect that this operator acts like square [batch] matrices.
+ name: A name to prepend to all ops created by this class.
+ """
+ super(LinearOperatorCirculant2D, self).__init__(
+ spectrum,
+ block_depth=2,
+ input_output_dtype=input_output_dtype,
+ is_non_singular=is_non_singular,
+ is_self_adjoint=is_self_adjoint,
+ is_positive_definite=is_positive_definite,
+ is_square=is_square,
+ name=name)
+
+
+@tf_export("linalg.LinearOperatorCirculant3D")
+class LinearOperatorCirculant3D(_BaseLinearOperatorCirculant):
+ """`LinearOperator` acting like a nested block circulant matrix.
+
+ This operator acts like a block circulant matrix `A` with
+ shape `[B1,...,Bb, N, N]` for some `b >= 0`. The first `b` indices index a
+ batch member. For every batch index `(i1,...,ib)`, `A[i1,...,ib, : :]` is
+ an `N x N` matrix. This matrix `A` is not materialized, but for
+ purposes of broadcasting this shape will be relevant.
+
+ #### Description in terms of block circulant matrices
+
+ If `A` is nested block circulant, with block sizes `N0, N1, N2`
+ (`N0 * N1 * N2 = N`):
+ `A` has a block structure, composed of `N0 x N0` blocks, with each
+ block an `N1 x N1` block circulant matrix.
+
+ For example, with `W`, `X`, `Y`, `Z` each block circulant,
+
+ ```
+ A = |W Z Y X|
+ |X W Z Y|
+ |Y X W Z|
+ |Z Y X W|
+ ```
+
+ Note that `A` itself will not in general be circulant.
+
+ #### Description in terms of the frequency spectrum
+
+ There is an equivalent description in terms of the [batch] spectrum `H` and
+ Fourier transforms. Here we consider `A.shape = [N, N]` and ignore batch
+ dimensions.
+
+ If `H.shape = [N0, N1, N2]`, (`N0 * N1 * N2 = N`):
+ Loosely speaking, matrix multiplication is equal to the action of a
+ Fourier multiplier: `A u = IDFT3[ H DFT3[u] ]`.
+ Precisely speaking, given `[N, R]` matrix `u`, let `DFT3[u]` be the
+ `[N0, N1, N2, R]` `Tensor` defined by re-shaping `u` to `[N0, N1, N2, R]` and
+ taking a three dimensional DFT across the first three dimensions. Let `IDFT3`
+ be the inverse of `DFT3`. Matrix multiplication may be expressed columnwise:
+
+ ```(A u)_r = IDFT3[ H * (DFT3[u])_r ]```
+
+ #### Operator properties deduced from the spectrum.
+
+ * This operator is positive definite if and only if `Real{H} > 0`.
+
+ A general property of Fourier transforms is the correspondence between
+ Hermitian functions and real valued transforms.
+
+ Suppose `H.shape = [B1,...,Bb, N0, N1, N2]`, we say that `H` is a Hermitian
+ spectrum if, with `%` meaning modulus division,
+
+ ```
+ H[..., n0 % N0, n1 % N1, n2 % N2]
+ = ComplexConjugate[ H[..., (-n0) % N0, (-n1) % N1, (-n2) % N2] ].
+ ```
+
+ * This operator corresponds to a real matrix if and only if `H` is Hermitian.
+ * This operator is self-adjoint if and only if `H` is real.
+
+ See e.g. "Discrete-Time Signal Processing", Oppenheim and Schafer.
+
+ ### Examples
+
+ See `LinearOperatorCirculant` and `LinearOperatorCirculant2D` for examples.
+
+ #### Performance
+
+ Suppose `operator` is a `LinearOperatorCirculant` of shape `[N, N]`,
+ and `x.shape = [N, R]`. Then
+
+ * `operator.matmul(x)` is `O(R*N*Log[N])`
+ * `operator.solve(x)` is `O(R*N*Log[N])`
+ * `operator.determinant()` involves a size `N` `reduce_prod`.
+
+ If instead `operator` and `x` have shape `[B1,...,Bb, N, N]` and
+ `[B1,...,Bb, N, R]`, every operation increases in complexity by `B1*...*Bb`.
+
+ #### Matrix property hints
+
+ This `LinearOperator` is initialized with boolean flags of the form `is_X`,
+ for `X = non_singular, self_adjoint, positive_definite, square`.
+ These have the following meaning
+ * If `is_X == True`, callers should expect the operator to have the
+ property `X`. This is a promise that should be fulfilled, but is *not* a
+ runtime assert. For example, finite floating point precision may result
+ in these promises being violated.
+ * If `is_X == False`, callers should expect the operator to not have `X`.
+ * If `is_X == None` (the default), callers should have no expectation either
+ way.
+ """
+
+ def __init__(self,
+ spectrum,
+ input_output_dtype=_DTYPE_COMPLEX,
+ is_non_singular=None,
+ is_self_adjoint=None,
+ is_positive_definite=None,
+ is_square=True,
+ name="LinearOperatorCirculant3D"):
+ """Initialize an `LinearOperatorCirculant`.
+
+ This `LinearOperator` is initialized to have shape `[B1,...,Bb, N, N]`
+ by providing `spectrum`, a `[B1,...,Bb, N0, N1, N2]` `Tensor`
+ with `N0*N1*N2 = N`.
+
+ If `input_output_dtype = DTYPE`:
+
+ * Arguments to methods such as `matmul` or `solve` must be `DTYPE`.
+ * Values returned by all methods, such as `matmul` or `determinant` will be
+ cast to `DTYPE`.
+
+ Note that if the spectrum is not Hermitian, then this operator corresponds
+ to a complex matrix with non-zero imaginary part. In this case, setting
+ `input_output_dtype` to a real type will forcibly cast the output to be
+ real, resulting in incorrect results!
+
+ If on the other hand the spectrum is Hermitian, then this operator
+ corresponds to a real-valued matrix, and setting `input_output_dtype` to
+ a real type is fine.
+
+ Args:
+ spectrum: Shape `[B1,...,Bb, N]` `Tensor`. Allowed dtypes are
+ `float32`, `complex64`. Type can be different than `input_output_dtype`
+ input_output_dtype: `dtype` for input/output. Must be either
+ `float32` or `complex64`.
+ is_non_singular: Expect that this operator is non-singular.
+ is_self_adjoint: Expect that this operator is equal to its hermitian
+ transpose. If `spectrum` is real, this will always be true.
+ is_positive_definite: Expect that this operator is positive definite,
+ meaning the real part of all eigenvalues is positive. We do not require
+ the operator to be self-adjoint to be positive-definite. See:
+ https://en.wikipedia.org/wiki/Positive-definite_matrix
+ #Extension_for_non_symmetric_matrices
+ is_square: Expect that this operator acts like square [batch] matrices.
+ name: A name to prepend to all ops created by this class.
+ """
+ super(LinearOperatorCirculant3D, self).__init__(
+ spectrum,
+ block_depth=3,
+ input_output_dtype=input_output_dtype,
+ is_non_singular=is_non_singular,
+ is_self_adjoint=is_self_adjoint,
+ is_positive_definite=is_positive_definite,
+ is_square=is_square,
+ name=name)
+
+
+def _to_complex(x):
+ return math_ops.cast(x, _DTYPE_COMPLEX)
diff --git a/tensorflow/python/ops/lookup_ops.py b/tensorflow/python/ops/lookup_ops.py
index 6f043f60e6..0e547689cc 100644
--- a/tensorflow/python/ops/lookup_ops.py
+++ b/tensorflow/python/ops/lookup_ops.py
@@ -277,7 +277,27 @@ class HashTable(InitializableLookupTableBase):
name=scope)
super(HashTable, self).__init__(table_ref, default_value, initializer)
+ self._value_shape = self._default_value.get_shape()
+ def export(self, name=None):
+ """Returns tensors of all keys and values in the table.
+
+ Args:
+ name: A name for the operation (optional).
+
+ Returns:
+ A pair of tensors with the first tensor containing all keys and the
+ second tensors containing all values in the table.
+ """
+ with ops.name_scope(name, "%s_Export" % self._name,
+ [self._table_ref]) as name:
+ with ops.colocate_with(self._table_ref):
+ exported_keys, exported_values = gen_lookup_ops.lookup_table_export_v2(
+ self._table_ref, self._key_dtype, self._value_dtype, name=name)
+
+ exported_values.set_shape(exported_keys.get_shape().concatenate(
+ self._value_shape))
+ return exported_keys, exported_values
class TableInitializerBase(object):
"""Base class for lookup table initializers."""
diff --git a/tensorflow/python/ops/losses/losses.py b/tensorflow/python/ops/losses/losses.py
index 8532c19ad6..81ee01a41a 100644
--- a/tensorflow/python/ops/losses/losses.py
+++ b/tensorflow/python/ops/losses/losses.py
@@ -35,16 +35,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import sys
-
-from tensorflow.python.ops.losses import util
# pylint: disable=wildcard-import
from tensorflow.python.ops.losses.losses_impl import *
from tensorflow.python.ops.losses.util import *
# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = []
-
-remove_undocumented(__name__, _allowed_symbols,
- [sys.modules[__name__], util])
diff --git a/tensorflow/python/ops/manip_ops.py b/tensorflow/python/ops/manip_ops.py
index 6d335cdc21..373585395b 100644
--- a/tensorflow/python/ops/manip_ops.py
+++ b/tensorflow/python/ops/manip_ops.py
@@ -22,7 +22,6 @@ from __future__ import division
from __future__ import print_function
from tensorflow.python.ops import gen_manip_ops as _gen_manip_ops
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
@@ -34,7 +33,3 @@ def roll(input, shift, axis): # pylint: disable=redefined-builtin
roll.__doc__ = _gen_manip_ops.roll.__doc__
# pylint: enable=protected-access
-
-_allowed_symbols = ['roll']
-
-remove_undocumented(__name__, allowed_exception_list=_allowed_symbols)
diff --git a/tensorflow/python/ops/math_ops.py b/tensorflow/python/ops/math_ops.py
index 30ac001c25..b81b1e792e 100644
--- a/tensorflow/python/ops/math_ops.py
+++ b/tensorflow/python/ops/math_ops.py
@@ -973,7 +973,9 @@ def _OverrideBinaryOperatorHelper(func, op_name, clazz_object=ops.Tensor):
def binary_op_wrapper(x, y):
with ops.name_scope(None, op_name, [x, y]) as name:
- if not isinstance(y, sparse_tensor.SparseTensor):
+ if isinstance(x, ops.Tensor) and isinstance(y, ops.Tensor):
+ return func(x, y, name=name)
+ elif not isinstance(y, sparse_tensor.SparseTensor):
try:
y = ops.convert_to_tensor(y, dtype=x.dtype.base_dtype, name="y")
except TypeError:
diff --git a/tensorflow/python/ops/metrics.py b/tensorflow/python/ops/metrics.py
index 7e75542aec..d1a8249154 100644
--- a/tensorflow/python/ops/metrics.py
+++ b/tensorflow/python/ops/metrics.py
@@ -58,8 +58,3 @@ from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.python.ops.metrics_impl import *
# pylint: enable=wildcard-import
-
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = []
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/nn.py b/tensorflow/python/ops/nn.py
index 1d0d9a52a1..25e4add569 100644
--- a/tensorflow/python/ops/nn.py
+++ b/tensorflow/python/ops/nn.py
@@ -117,7 +117,6 @@ from tensorflow.python.ops import nn_ops as _nn_ops
from tensorflow.python.ops.math_ops import sigmoid
from tensorflow.python.ops.math_ops import tanh
# pylint: enable=unused-import
-from tensorflow.python.util.all_util import remove_undocumented
# Bring more nn-associated functionality into this package.
# go/tf-wildcard-import
@@ -128,22 +127,3 @@ from tensorflow.python.ops.nn_ops import *
from tensorflow.python.ops.candidate_sampling_ops import *
from tensorflow.python.ops.embedding_ops import *
# pylint: enable=wildcard-import,unused-import
-
-
-# TODO(cwhipkey): sigmoid and tanh should not be exposed from tf.nn.
-_allowed_symbols = [
- "zero_fraction", # documented in training.py
- # Modules whitelisted for reference through tf.nn.
- # TODO(cwhipkey): migrate callers to use the submodule directly.
- # Symbols whitelisted for export without documentation.
- # TODO(cwhipkey): review these and move to contrib or expose through
- # documentation.
- "all_candidate_sampler", # Excluded in gen_docs_combined.
- "lrn", # Excluded in gen_docs_combined.
- "relu_layer", # Excluded in gen_docs_combined.
- "xw_plus_b", # Excluded in gen_docs_combined.
- "rnn_cell", # rnn_cell is a submodule of tf.nn.
-]
-
-remove_undocumented(__name__, _allowed_symbols,
- [_sys.modules[__name__], _ctc_ops, _nn_ops, _nn_grad])
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 4d26b2f46e..1e953f658f 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -24,7 +24,6 @@ from tensorflow.core.framework import variable_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import tape
-from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
@@ -47,13 +46,11 @@ def get_resource_handle_data(graph_op):
assert ops._USE_C_SHAPES # pylint: disable=protected-access
assert type(graph_op) == ops.Tensor # pylint: disable=unidiomatic-typecheck
- with c_api_util.tf_buffer() as buf:
- pywrap_tensorflow.TFE_GetResourceHandleShapeAndType(
- graph_op.graph._c_graph, graph_op._as_tf_output(), buf) # pylint: disable=protected-access
- data = pywrap_tensorflow.TF_GetBuffer(buf)
+ handle_data = pywrap_tensorflow.GetResourceHandleShapeAndType(
+ graph_op.graph._c_graph, graph_op._as_tf_output()) # pylint: disable=protected-access
return cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData.FromString(
- compat.as_bytes(data))
+ compat.as_bytes(handle_data))
def _eager_safe_variable_handle(shape, dtype, shared_name, name, graph_mode):
diff --git a/tensorflow/python/ops/rnn_cell.py b/tensorflow/python/ops/rnn_cell.py
index c0dac8fb01..3d26ffb7ae 100644
--- a/tensorflow/python/ops/rnn_cell.py
+++ b/tensorflow/python/ops/rnn_cell.py
@@ -44,8 +44,3 @@ from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.python.ops.rnn_cell_impl import *
# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = []
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/sdca_ops.py b/tensorflow/python/ops/sdca_ops.py
index 8b7e5abbc2..24ea68892a 100644
--- a/tensorflow/python/ops/sdca_ops.py
+++ b/tensorflow/python/ops/sdca_ops.py
@@ -31,11 +31,6 @@ from tensorflow.python.framework import ops
from tensorflow.python.ops.gen_sdca_ops import *
# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
-
ops.NotDifferentiable("SdcaFprint")
ops.NotDifferentiable("SdcaOptimizer")
ops.NotDifferentiable("SdcaShrinkL1")
-
-
-remove_undocumented(__name__)
diff --git a/tensorflow/python/ops/sets.py b/tensorflow/python/ops/sets.py
index ea4677befe..54d6e1db41 100644
--- a/tensorflow/python/ops/sets.py
+++ b/tensorflow/python/ops/sets.py
@@ -28,8 +28,3 @@ from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.python.ops.sets_impl import *
# pylint: enable=wildcard-import
-
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = []
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/ops/spectral_ops.py b/tensorflow/python/ops/spectral_ops.py
index a579688276..4a4ca693dc 100644
--- a/tensorflow/python/ops/spectral_ops.py
+++ b/tensorflow/python/ops/spectral_ops.py
@@ -40,7 +40,6 @@ from tensorflow.python.framework import tensor_util as _tensor_util
from tensorflow.python.ops import array_ops as _array_ops
from tensorflow.python.ops import gen_spectral_ops
from tensorflow.python.ops import math_ops as _math_ops
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
@@ -249,5 +248,3 @@ def dct(input, type=2, n=None, axis=-1, norm=None, name=None): # pylint: disabl
dct2 *= weights
return dct2
-
-remove_undocumented(__name__)
diff --git a/tensorflow/python/ops/standard_ops.py b/tensorflow/python/ops/standard_ops.py
index f71f98aa12..a2d24711e2 100644
--- a/tensorflow/python/ops/standard_ops.py
+++ b/tensorflow/python/ops/standard_ops.py
@@ -33,7 +33,6 @@ from tensorflow.python.ops import sparse_grad
from tensorflow.python.ops import spectral_grad
from tensorflow.python.ops import state_grad
from tensorflow.python.ops import tensor_array_grad
-from tensorflow.python.util.all_util import remove_undocumented
# go/tf-wildcard-import
@@ -99,212 +98,3 @@ from tensorflow.python.ops.variables import *
# pylint: enable=wildcard-import
# pylint: enable=g-bad-import-order
-#### For use in remove_undocumented below:
-from tensorflow.python.framework import constant_op as _constant_op
-from tensorflow.python.ops import array_ops as _array_ops
-from tensorflow.python.ops import check_ops as _check_ops
-from tensorflow.python.ops import clip_ops as _clip_ops
-from tensorflow.python.ops import confusion_matrix as _confusion_matrix
-from tensorflow.python.ops import control_flow_ops as _control_flow_ops
-from tensorflow.python.ops import data_flow_ops as _data_flow_ops
-from tensorflow.python.ops import functional_ops as _functional_ops
-from tensorflow.python.ops import gradients as _gradients
-from tensorflow.python.ops import histogram_ops as _histogram_ops
-from tensorflow.python.ops import init_ops as _init_ops
-from tensorflow.python.ops import io_ops as _io_ops
-from tensorflow.python.ops import linalg_ops as _linalg_ops
-from tensorflow.python.ops import logging_ops as _logging_ops
-from tensorflow.python.ops import manip_ops as _manip_ops
-from tensorflow.python.ops import math_ops as _math_ops
-from tensorflow.python.ops import numerics as _numerics
-from tensorflow.python.ops import parsing_ops as _parsing_ops
-from tensorflow.python.ops import partitioned_variables as _partitioned_variables
-from tensorflow.python.ops import random_ops as _random_ops
-from tensorflow.python.ops import script_ops as _script_ops
-from tensorflow.python.ops import session_ops as _session_ops
-from tensorflow.python.ops import sparse_ops as _sparse_ops
-from tensorflow.python.ops import special_math_ops as _special_math_ops
-from tensorflow.python.ops import state_ops as _state_ops
-from tensorflow.python.ops import string_ops as _string_ops
-from tensorflow.python.ops import template as _template
-from tensorflow.python.ops import tensor_array_ops as _tensor_array_ops
-from tensorflow.python.ops import variable_scope as _variable_scope
-from tensorflow.python.ops import variables as _variables
-
-
-_allowed_symbols_math_ops = [
- # TODO(drpng): decide if we want to reference these in the documentation.
- "reduced_shape",
- "sparse_segment_mean_grad",
- "sparse_segment_sqrt_n_grad",
-
- # Legacy: will be removed.
- "arg_max",
- "arg_min",
- "lin_space",
- "sparse_matmul", # Use tf.matmul.
- # Deprecated (see versions.h):
- "batch_fft",
- "batch_fft2d",
- "batch_fft3d",
- "batch_ifft",
- "batch_ifft2d",
- "batch_ifft3d",
- "mul", # use tf.multiply instead.
- "neg", # use tf.negative instead.
- "sub", # use tf.subtract instead.
-
- # These are documented in nn.
- # We are not importing nn because it would create a circular dependency.
- "sigmoid",
- "log_sigmoid",
- "tanh",
-]
-
-_allowed_symbols_array_ops = [
- # TODO(drpng): make sure they are documented.
- # Scalars:
- "NEW_AXIS",
- "SHRINK_AXIS",
- "newaxis",
-
- # Documented in training.py.
- # I do not import train, to avoid circular dependencies.
- # TODO(drpng): this is defined in gen_array_ops, clearly not the right
- # place.
- "stop_gradient",
-
- # See gen_docs_combined for tf.copy documentation.
- "copy",
-
- ## TODO(drpng): make them inaccessible directly.
- ## TODO(drpng): Below, to-doc means that we need to find an appropriate
- ## documentation section to reference.
- ## For re-exporting to tf.*:
- "constant",
- "edit_distance", # to-doc
- # From gen_array_ops:
- "copy_host", # to-doc
- "immutable_const", # to-doc
- "invert_permutation", # to-doc
- "quantize_and_dequantize", # to-doc
-
- # TODO(drpng): legacy symbols to be removed.
- "batch_matrix_diag",
- "batch_matrix_band_part",
- "batch_matrix_diag_part",
- "batch_matrix_set_diag",
-]
-
-_allowed_symbols_partitioned_variables = [
- "PartitionedVariable", # Requires doc link.
- # Legacy.
- "create_partitioned_variables",
- "variable_axis_size_partitioner",
- "min_max_variable_partitioner",
- "fixed_size_partitioner",
-]
-
-_allowed_symbols_control_flow_ops = [
- # TODO(drpng): Find a place in the documentation to reference these or
- # remove.
- "control_trigger",
- "loop_cond",
- "merge",
- "switch",
-]
-
-_allowed_symbols_functional_ops = [
- "nest", # Used by legacy code.
-]
-
-_allowed_symbols_gradients = [
- # Documented in training.py:
- # Not importing training.py to avoid complex graph dependencies.
- "AggregationMethod",
- "GradientTape",
- "custom_gradient",
- "gradients", # tf.gradients = gradients.gradients
- "hessians",
-]
-
-_allowed_symbols_clip_ops = [
- # Documented in training.py:
- # Not importing training.py to avoid complex graph dependencies.
- "clip_by_average_norm",
- "clip_by_global_norm",
- "clip_by_norm",
- "clip_by_value",
- "global_norm",
-]
-
-_allowed_symbols_logging_ops = [
- # Documented in training.py.
- # We are not importing training.py to avoid complex dependencies.
- "audio_summary",
- "histogram_summary",
- "image_summary",
- "merge_all_summaries",
- "merge_summary",
- "scalar_summary",
-
- # TODO(drpng): link in training.py if it should be documented.
- "get_summary_op",
-]
-
-_allowed_symbols_variable_scope_ops = [
- "get_local_variable", # Documented in framework package.
-]
-
-_allowed_symbols_misc = [
- "deserialize_many_sparse",
- "parse_single_sequence_example",
- "serialize_many_sparse",
- "serialize_sparse",
- "confusion_matrix",
-]
-
-_allowed_symbols = (_allowed_symbols_array_ops +
- _allowed_symbols_clip_ops +
- _allowed_symbols_control_flow_ops +
- _allowed_symbols_functional_ops +
- _allowed_symbols_gradients +
- _allowed_symbols_logging_ops +
- _allowed_symbols_math_ops +
- _allowed_symbols_variable_scope_ops +
- _allowed_symbols_misc +
- _allowed_symbols_partitioned_variables)
-
-remove_undocumented(__name__, _allowed_symbols, [
- _sys.modules[__name__],
- _array_ops,
- _check_ops,
- _clip_ops,
- _confusion_matrix,
- _control_flow_ops,
- _constant_op,
- _data_flow_ops,
- _functional_ops,
- _gradients,
- _histogram_ops,
- _init_ops,
- _io_ops,
- _linalg_ops,
- _logging_ops,
- _manip_ops,
- _math_ops,
- _numerics,
- _parsing_ops,
- _partitioned_variables,
- _random_ops,
- _script_ops,
- _session_ops,
- _sparse_ops,
- _special_math_ops,
- _state_ops,
- _string_ops,
- _template,
- _tensor_array_ops,
- _variable_scope,
- _variables,
-])
diff --git a/tensorflow/python/ops/summary_ops_v2.py b/tensorflow/python/ops/summary_ops_v2.py
index 12f361c513..b80f84eb7c 100644
--- a/tensorflow/python/ops/summary_ops_v2.py
+++ b/tensorflow/python/ops/summary_ops_v2.py
@@ -74,10 +74,12 @@ def record_summaries_every_n_global_steps(n, global_step=None):
global_step = training_util.get_or_create_global_step()
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
- with ops.device("cpu:0"):
- collection_ref[:] = [math_ops.equal(global_step % n, 0)]
- yield
- collection_ref[:] = old
+ try:
+ with ops.device("cpu:0"):
+ collection_ref[:] = [math_ops.equal(global_step % n, 0)]
+ yield
+ finally:
+ collection_ref[:] = old
@tf_contextlib.contextmanager
@@ -85,9 +87,11 @@ def always_record_summaries():
"""Sets the should_record_summaries Tensor to always true."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
- collection_ref[:] = [True]
- yield
- collection_ref[:] = old
+ try:
+ collection_ref[:] = [True]
+ yield
+ finally:
+ collection_ref[:] = old
@tf_contextlib.contextmanager
@@ -95,9 +99,11 @@ def never_record_summaries():
"""Sets the should_record_summaries Tensor to always false."""
collection_ref = ops.get_collection_ref(_SHOULD_RECORD_SUMMARIES_NAME)
old = collection_ref[:]
- collection_ref[:] = [False]
- yield
- collection_ref[:] = old
+ try:
+ collection_ref[:] = [False]
+ yield
+ finally:
+ collection_ref[:] = old
class SummaryWriter(object):
@@ -127,12 +133,16 @@ class SummaryWriter(object):
yield self
else:
old = context.context().summary_writer_resource
- context.context().summary_writer_resource = self._resource
- yield self
- # Flushes the summary writer in eager mode or in graph functions, but not
- # in legacy graph mode (you're on your own there).
- self.flush()
- context.context().summary_writer_resource = old
+ try:
+ context.context().summary_writer_resource = self._resource
+ yield self
+ # Flushes the summary writer in eager mode or in graph functions, but
+ # not in legacy graph mode (you're on your own there).
+ with ops.device("cpu:0"):
+ gen_summary_ops.flush_summary_writer(self._resource)
+ finally:
+ context.context().summary_writer_resource = old
+
def init(self):
"""Operation to initialize the summary writer resource."""
diff --git a/tensorflow/python/platform/app.py b/tensorflow/python/platform/app.py
index cce64c0cca..4c91bc3652 100644
--- a/tensorflow/python/platform/app.py
+++ b/tensorflow/python/platform/app.py
@@ -22,7 +22,6 @@ import errno as _errno
import sys as _sys
from tensorflow.python.platform import flags
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
@@ -125,11 +124,3 @@ def run(main=None, argv=None):
# to the final program.
_sys.exit(main(argv))
-
-_allowed_symbols = [
- 'run',
- # Allowed submodule.
- 'flags',
-]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/platform/gfile.py b/tensorflow/python/platform/gfile.py
index 315889e9aa..fd697d70bf 100644
--- a/tensorflow/python/platform/gfile.py
+++ b/tensorflow/python/platform/gfile.py
@@ -33,7 +33,6 @@ from tensorflow.python.lib.io.file_io import rename as Rename
from tensorflow.python.lib.io.file_io import stat as Stat
from tensorflow.python.lib.io.file_io import walk as Walk
# pylint: enable=unused-import
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
@@ -56,24 +55,3 @@ class FastGFile(_FileIO):
# Does not alias to Open so that we use our version of GFile to strip
# 'b' mode.
Open = GFile
-
-# TODO(drpng): Find the right place to document these.
-_allowed_symbols = [
- 'Copy',
- 'DeleteRecursively',
- 'Exists',
- 'FastGFile',
- 'GFile',
- 'Glob',
- 'IsDirectory',
- 'ListDirectory',
- 'Open',
- 'MakeDirs',
- 'MkDir',
- 'Remove',
- 'Rename',
- 'Stat',
- 'Walk',
-]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/platform/resource_loader.py b/tensorflow/python/platform/resource_loader.py
index 8f7b12e2b2..650a1fd851 100644
--- a/tensorflow/python/platform/resource_loader.py
+++ b/tensorflow/python/platform/resource_loader.py
@@ -28,7 +28,6 @@ import os as _os
import sys as _sys
from tensorflow.python.util import tf_inspect as _inspect
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
@@ -129,7 +128,3 @@ def get_path_to_datafile(path):
def readahead_file_path(path, readahead='128M'): # pylint: disable=unused-argument
"""Readahead files not implemented; simply returns given path."""
return path
-
-
-_allowed_symbols = []
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/platform/sysconfig.py b/tensorflow/python/platform/sysconfig.py
index fdd2b903fc..56759d1b8e 100644
--- a/tensorflow/python/platform/sysconfig.py
+++ b/tensorflow/python/platform/sysconfig.py
@@ -28,7 +28,6 @@ import os.path as _os_path
from tensorflow.python.framework.versions import CXX11_ABI_FLAG as _CXX11_ABI_FLAG
from tensorflow.python.framework.versions import MONOLITHIC_BUILD as _MONOLITHIC_BUILD
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
@@ -84,6 +83,3 @@ def get_link_flags():
flags.append('-L%s' % get_lib())
flags.append('-ltensorflow_framework')
return flags
-
-_allowed_symbols = []
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/platform/test.py b/tensorflow/python/platform/test.py
index 1660791feb..0a0fe68be5 100644
--- a/tensorflow/python/platform/test.py
+++ b/tensorflow/python/platform/test.py
@@ -42,7 +42,6 @@ from __future__ import print_function
# pylint: disable=g-bad-import-order
from tensorflow.python.framework import test_util as _test_util
from tensorflow.python.platform import googletest as _googletest
-from tensorflow.python.util.all_util import remove_undocumented
# pylint: disable=unused-import
from tensorflow.python.framework.test_util import assert_equal_graph_def
@@ -108,13 +107,3 @@ def test_src_dir_path(relative_path):
def is_built_with_cuda():
"""Returns whether TensorFlow was built with CUDA (GPU) support."""
return _test_util.IsGoogleCudaEnabled()
-
-
-_allowed_symbols = [
- # We piggy-back googletest documentation.
- 'Benchmark',
- 'mock',
- 'StubOutForTesting',
-]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/platform/tf_logging.py b/tensorflow/python/platform/tf_logging.py
index 22aabfd712..5962d2f220 100644
--- a/tensorflow/python/platform/tf_logging.py
+++ b/tensorflow/python/platform/tf_logging.py
@@ -34,7 +34,6 @@ import threading
import six
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
@@ -287,35 +286,8 @@ def _get_thread_id():
_log_prefix = google2_log_prefix
-# Controls which methods from pyglib.logging are available within the project.
-# Do not add methods here without also adding to platform/tf_logging.py.
-_allowed_symbols = [
- 'DEBUG',
- 'ERROR',
- 'FATAL',
- 'INFO',
- 'TaskLevelStatusMessage',
- 'WARN',
- 'debug',
- 'error',
- 'fatal',
- 'flush',
- 'get_verbosity',
- 'info',
- 'log',
- 'log_if',
- 'log_every_n',
- 'log_first_n',
- 'set_verbosity',
- 'vlog',
- 'warn',
- 'warning',
-]
-
tf_export('logging.DEBUG').export_constant(__name__, 'DEBUG')
tf_export('logging.ERROR').export_constant(__name__, 'ERROR')
tf_export('logging.FATAL').export_constant(__name__, 'FATAL')
tf_export('logging.INFO').export_constant(__name__, 'INFO')
tf_export('logging.WARN').export_constant(__name__, 'WARN')
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/profiler/profiler.py b/tensorflow/python/profiler/profiler.py
index fa7f30b236..efbdd1ba68 100644
--- a/tensorflow/python/profiler/profiler.py
+++ b/tensorflow/python/profiler/profiler.py
@@ -30,7 +30,6 @@ from tensorflow.python.profiler.model_analyzer import Profiler
from tensorflow.python.profiler.option_builder import ProfileOptionBuilder
from tensorflow.python.profiler.tfprof_logger import write_op_log
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
@@ -54,11 +53,3 @@ tf_export('profiler.GraphNodeProto')(GraphNodeProto)
tf_export('profiler.MultiGraphNodeProto')(MultiGraphNodeProto)
tf_export('profiler.AdviceProto')(AdviceProto)
tf_export('profiler.OpLogProto')(OpLogProto)
-
-remove_undocumented(__name__, _allowed_symbols, [
- Profiler,
- profile,
- ProfileOptionBuilder,
- advise,
- write_op_log,
-])
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 0982a67dee..5ee55301df 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -59,8 +59,6 @@ limitations under the License.
%rename("%s") TFE_ContextOptionsSetAsync;
%rename("%s") TFE_DeleteContextOptions;
%rename("%s") TFE_Py_TensorShapeSlice;
-%rename("%s") TFE_GetResourceHandleShapeAndType;
-%rename("%s") TFE_SetResourceHandleShapeAndType;
%{
#include "tensorflow/python/eager/pywrap_tfe.h"
diff --git a/tensorflow/python/saved_model/builder.py b/tensorflow/python/saved_model/builder.py
index 766b0a3579..be49c70c60 100644
--- a/tensorflow/python/saved_model/builder.py
+++ b/tensorflow/python/saved_model/builder.py
@@ -26,10 +26,3 @@ from __future__ import print_function
# pylint: disable=unused-import
from tensorflow.python.saved_model.builder_impl import SavedModelBuilder
# pylint: enable=unused-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-
-_allowed_symbols = [
- "SavedModelBuilder",
-]
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/saved_model/constants.py b/tensorflow/python/saved_model/constants.py
index ec49a0539f..34206c6f6d 100644
--- a/tensorflow/python/saved_model/constants.py
+++ b/tensorflow/python/saved_model/constants.py
@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
# Subdirectory name containing the asset files.
@@ -66,17 +65,3 @@ tf_export("saved_model.constants.VARIABLES_DIRECTORY").export_constant(
VARIABLES_FILENAME = "variables"
tf_export("saved_model.constants.VARIABLES_FILENAME").export_constant(
__name__, "VARIABLES_FILENAME")
-
-
-_allowed_symbols = [
- "ASSETS_DIRECTORY",
- "ASSETS_KEY",
- "LEGACY_INIT_OP_KEY",
- "MAIN_OP_KEY",
- "SAVED_MODEL_SCHEMA_VERSION",
- "SAVED_MODEL_FILENAME_PB",
- "SAVED_MODEL_FILENAME_PBTXT",
- "VARIABLES_DIRECTORY",
- "VARIABLES_FILENAME",
-]
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/saved_model/loader.py b/tensorflow/python/saved_model/loader.py
index 0a7f516287..334298c232 100644
--- a/tensorflow/python/saved_model/loader.py
+++ b/tensorflow/python/saved_model/loader.py
@@ -67,11 +67,3 @@ from __future__ import print_function
from tensorflow.python.saved_model.loader_impl import load
from tensorflow.python.saved_model.loader_impl import maybe_saved_model_directory
# pylint: enable=unused-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-
-_allowed_symbols = [
- "load",
- "maybe_saved_model_directory",
-]
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/saved_model/main_op.py b/tensorflow/python/saved_model/main_op.py
index 04cadeab66..18d11b900c 100644
--- a/tensorflow/python/saved_model/main_op.py
+++ b/tensorflow/python/saved_model/main_op.py
@@ -26,10 +26,3 @@ from __future__ import print_function
from tensorflow.python.saved_model.main_op_impl import main_op
from tensorflow.python.saved_model.main_op_impl import main_op_with_restore
# pylint: enable=unused-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = [
- "main_op",
- "main_op_with_restore",
-]
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/saved_model/saved_model.py b/tensorflow/python/saved_model/saved_model.py
index caabd7bc30..6702c99607 100644
--- a/tensorflow/python/saved_model/saved_model.py
+++ b/tensorflow/python/saved_model/saved_model.py
@@ -34,18 +34,3 @@ from tensorflow.python.saved_model import utils
from tensorflow.python.saved_model.simple_save import *
# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-
-_allowed_symbols = [
- "builder",
- "constants",
- "loader",
- "main_op",
- "signature_constants",
- "signature_def_utils",
- "simple_save",
- "tag_constants",
- "utils",
-]
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/saved_model/signature_constants.py b/tensorflow/python/saved_model/signature_constants.py
index 6461fe8a7e..819f351291 100644
--- a/tensorflow/python/saved_model/signature_constants.py
+++ b/tensorflow/python/saved_model/signature_constants.py
@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
@@ -95,19 +94,3 @@ tf_export("saved_model.signature_constants.REGRESS_OUTPUTS").export_constant(
__name__, "REGRESS_OUTPUTS")
################################################################################
-
-
-_allowed_symbols = [
- "DEFAULT_SERVING_SIGNATURE_DEF_KEY",
- "CLASSIFY_INPUTS",
- "CLASSIFY_METHOD_NAME",
- "CLASSIFY_OUTPUT_CLASSES",
- "CLASSIFY_OUTPUT_SCORES",
- "PREDICT_INPUTS",
- "PREDICT_METHOD_NAME",
- "PREDICT_OUTPUTS",
- "REGRESS_INPUTS",
- "REGRESS_METHOD_NAME",
- "REGRESS_OUTPUTS",
-]
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/saved_model/tag_constants.py b/tensorflow/python/saved_model/tag_constants.py
index d164e2c23f..5a797da791 100644
--- a/tensorflow/python/saved_model/tag_constants.py
+++ b/tensorflow/python/saved_model/tag_constants.py
@@ -19,7 +19,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
@@ -40,11 +39,3 @@ tf_export("saved_model.tag_constants.GPU").export_constant(__name__, "GPU")
# Tag for the `tpu` graph.
TPU = "tpu"
tf_export("saved_model.tag_constants.TPU").export_constant(__name__, "TPU")
-
-_allowed_symbols = [
- "SERVING",
- "TRAINING",
- "GPU",
- "TPU"
-]
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/saved_model/utils.py b/tensorflow/python/saved_model/utils.py
index 8e750d8708..27c3554909 100644
--- a/tensorflow/python/saved_model/utils.py
+++ b/tensorflow/python/saved_model/utils.py
@@ -24,7 +24,3 @@ from __future__ import print_function
from tensorflow.python.saved_model.utils_impl import build_tensor_info
from tensorflow.python.saved_model.utils_impl import get_tensor_from_tensor_info
# pylint: enable=unused-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-_allowed_symbols = ["build_tensor_info", "get_tensor_from_tensor_info"]
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/summary/summary.py b/tensorflow/python/summary/summary.py
index 1286ed6703..969cbe7d35 100644
--- a/tensorflow/python/summary/summary.py
+++ b/tensorflow/python/summary/summary.py
@@ -74,7 +74,6 @@ from tensorflow.python.summary.writer.writer_cache import FileWriterCache
# pylint: enable=unused-import
from tensorflow.python.util import compat as _compat
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
@@ -361,10 +360,3 @@ def get_summary_description(node_def):
summary_description = SummaryDescription()
_json_format.Parse(description_str, summary_description)
return summary_description
-
-
-_allowed_symbols = [
- 'Summary', 'SummaryDescription', 'Event', 'TaggedRunMetadata', 'SessionLog',
-]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/tools/BUILD b/tensorflow/python/tools/BUILD
index 84d20f8e36..6c34b6aaf3 100644
--- a/tensorflow/python/tools/BUILD
+++ b/tensorflow/python/tools/BUILD
@@ -38,9 +38,9 @@ py_library(
deps = [
":saved_model_utils",
"//tensorflow/core:protos_all_py",
- "//tensorflow/python", # TODO(b/34059704): remove when fixed
"//tensorflow/python:client",
"//tensorflow/python:framework",
+ "//tensorflow/python:no_contrib", # TODO(b/34059704): remove when fixed
"//tensorflow/python:parsing_ops",
"//tensorflow/python:platform",
"//tensorflow/python:training",
diff --git a/tensorflow/python/training/basic_session_run_hooks.py b/tensorflow/python/training/basic_session_run_hooks.py
index 3651291bdf..47339e057f 100644
--- a/tensorflow/python/training/basic_session_run_hooks.py
+++ b/tensorflow/python/training/basic_session_run_hooks.py
@@ -434,23 +434,27 @@ class CheckpointSaverHook(session_run_hook.SessionRunHook):
for l in self._listeners:
l.begin()
- def before_run(self, run_context): # pylint: disable=unused-argument
- if self._timer.last_triggered_step() is None:
- # We do write graph and saver_def at the first call of before_run.
- # We cannot do this in begin, since we let other hooks to change graph and
- # add variables in begin. Graph is finalized after all begin calls.
- training_util.write_graph(
- ops.get_default_graph().as_graph_def(add_shapes=True),
- self._checkpoint_dir,
- "graph.pbtxt")
- saver_def = self._get_saver().saver_def if self._get_saver() else None
- graph = ops.get_default_graph()
- meta_graph_def = meta_graph.create_meta_graph_def(
- graph_def=graph.as_graph_def(add_shapes=True),
- saver_def=saver_def)
- self._summary_writer.add_graph(graph)
- self._summary_writer.add_meta_graph(meta_graph_def)
+ def after_create_session(self, session, coord):
+ global_step = session.run(self._global_step_tensor)
+ # We do write graph and saver_def at the first call of before_run.
+ # We cannot do this in begin, since we let other hooks to change graph and
+ # add variables in begin. Graph is finalized after all begin calls.
+ training_util.write_graph(
+ ops.get_default_graph().as_graph_def(add_shapes=True),
+ self._checkpoint_dir,
+ "graph.pbtxt")
+ saver_def = self._get_saver().saver_def if self._get_saver() else None
+ graph = ops.get_default_graph()
+ meta_graph_def = meta_graph.create_meta_graph_def(
+ graph_def=graph.as_graph_def(add_shapes=True),
+ saver_def=saver_def)
+ self._summary_writer.add_graph(graph)
+ self._summary_writer.add_meta_graph(meta_graph_def)
+ # The checkpoint saved here is the state at step "global_step".
+ self._save(session, global_step)
+ self._timer.update_last_triggered_step(global_step)
+ def before_run(self, run_context): # pylint: disable=unused-argument
return SessionRunArgs(self._global_step_tensor)
def after_run(self, run_context, run_values):
diff --git a/tensorflow/python/training/basic_session_run_hooks_test.py b/tensorflow/python/training/basic_session_run_hooks_test.py
index 25962f6bf7..31898562f8 100644
--- a/tensorflow/python/training/basic_session_run_hooks_test.py
+++ b/tensorflow/python/training/basic_session_run_hooks_test.py
@@ -466,8 +466,8 @@ class CheckpointSaverHookTest(test.TestCase):
self.assertEqual(2, global_step_val)
self.assertEqual({
'begin': 1,
- 'before_save': 2,
- 'after_save': 2,
+ 'before_save': 3,
+ 'after_save': 3,
'end': 1
}, listener_counts)
@@ -490,8 +490,8 @@ class CheckpointSaverHookTest(test.TestCase):
self.assertEqual(2, global_step_val)
self.assertEqual({
'begin': 1,
- 'before_save': 2,
- 'after_save': 2,
+ 'before_save': 3,
+ 'after_save': 3,
'end': 1
}, listener_counts)
@@ -523,8 +523,8 @@ class CheckpointSaverHookTest(test.TestCase):
self.assertEqual(2, global_step_val)
self.assertEqual({
'begin': 1,
- 'before_save': 2,
- 'after_save': 2,
+ 'before_save': 3,
+ 'after_save': 3,
'end': 1
}, listener1_counts)
self.assertEqual(listener1_counts, listener2_counts)
@@ -706,6 +706,7 @@ class CheckpointSaverHookTest(test.TestCase):
with session_lib.Session() as sess:
sess.run(self.scaffold.init_op)
mon_sess = monitored_session._HookedSession(sess, [hook])
+ hook.after_create_session(sess, None)
mon_sess.run(self.train_op)
summary_writer.assert_summaries(
test_case=self,
@@ -718,6 +719,31 @@ class CheckpointSaverHookTest(test.TestCase):
fake_summary_writer.FakeSummaryWriter.uninstall()
+ def test_save_checkpoint_before_first_train_step(self):
+ with self.graph.as_default():
+ hook = basic_session_run_hooks.CheckpointSaverHook(
+ self.model_dir, save_steps=2, scaffold=self.scaffold)
+ hook.begin()
+ self.scaffold.finalize()
+ with session_lib.Session() as sess:
+ mon_sess = monitored_session._HookedSession(sess, [hook])
+ sess.run(self.scaffold.init_op)
+ hook.after_create_session(sess, None)
+ # Verifies that checkpoint is saved at step 0.
+ self.assertEqual(0,
+ checkpoint_utils.load_variable(self.model_dir,
+ self.global_step.name))
+ # Verifies that no checkpoint is saved after one training step.
+ mon_sess.run(self.train_op)
+ self.assertEqual(0,
+ checkpoint_utils.load_variable(self.model_dir,
+ self.global_step.name))
+ # Verifies that checkpoint is saved after save_steps.
+ mon_sess.run(self.train_op)
+ self.assertEqual(2,
+ checkpoint_utils.load_variable(self.model_dir,
+ self.global_step.name))
+
class CheckpointSaverHookMultiStepTest(test.TestCase):
diff --git a/tensorflow/python/training/checkpointable.py b/tensorflow/python/training/checkpointable.py
index 0b8473742c..05afd37ccd 100644
--- a/tensorflow/python/training/checkpointable.py
+++ b/tensorflow/python/training/checkpointable.py
@@ -24,6 +24,7 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_io_ops as io_ops
+from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import nest
@@ -119,6 +120,7 @@ class _CheckpointPosition(object):
AssertionError: If another object is already bound to the `Object` proto.
"""
checkpoint = self.checkpoint
+ checkpoint.all_python_objects.add(checkpointable)
current_assignment = checkpoint.object_by_proto_id.get(self._proto_id, None)
if current_assignment is None:
checkpoint.object_by_proto_id[self._proto_id] = checkpointable
@@ -157,12 +159,12 @@ class _CheckpointPosition(object):
# consistent (if the dependency DAG is not a tree then there are
# multiple paths to the same object).
if current_assignment is not checkpointable:
- raise AssertionError(
- ("Unable to load the checkpoint into this object graph. Either "
- "the Checkpointable object references in the Python program "
- "have changed in an incompatible way, or the checkpoint was "
- "generated in an incompatible program.\n\nTwo checkpoint "
- "references resolved to different objects (%s and %s).")
+ logging.warning(
+ ("Inconsistent references when loading the checkpoint into this "
+ "object graph. Either the Checkpointable object references in the "
+ "Python program have changed in an incompatible way, or the "
+ "checkpoint was generated in an incompatible program.\n\nTwo "
+ "checkpoint references resolved to different objects (%s and %s).")
% (current_assignment, checkpointable))
return False # Not a new assignment
diff --git a/tensorflow/python/training/checkpointable_utils.py b/tensorflow/python/training/checkpointable_utils.py
index 4769e15120..9cdd53cbf9 100644
--- a/tensorflow/python/training/checkpointable_utils.py
+++ b/tensorflow/python/training/checkpointable_utils.py
@@ -84,6 +84,11 @@ class _CheckpointRestoreCoordinator(object):
# (as objects with deferred dependencies will generally have references to
# this object).
self.object_by_proto_id = weakref.WeakValueDictionary()
+ # A set of all Python objects we've seen as dependencies, even if we didn't
+ # use them (for example because of inconsistent references when
+ # loading). Used to make status assertions fail when loading checkpoints
+ # that don't quite match.
+ self.all_python_objects = weakref.WeakSet()
self.save_path = save_path
self.dtype_map = dtype_map
# When graph building, contains a list of ops to run to restore objects from
@@ -336,19 +341,19 @@ def _serialize_object_graph(root_checkpointable):
slot_variables=slot_variables)
-def gather_initializers(root_checkpointable):
- """Traverse the object graph and find initialization ops.
+def list_objects(root_checkpointable):
+ """Traverse the object graph and list all accessible objects.
Looks for `Checkpointable` objects which are dependencies of
- `root_checkpointable` and which have an `initializer` property. Includes
- initializers for slot variables only if the variable they are slotting for and
- the optimizer are dependencies of `root_checkpointable` (i.e. if they would be
- saved with a checkpoint).
+ `root_checkpointable`. Includes slot variables only if the variable they are
+ slotting for and the optimizer are dependencies of `root_checkpointable`
+ (i.e. if they would be saved with a checkpoint).
Args:
- root_checkpointable: A `Checkpointable` object to gather initializers for.
+ root_checkpointable: A `Checkpointable` object whose dependencies should be
+ flattened.
Returns:
- A list of initialization ops.
+ A flat list of objects.
"""
# TODO(allenl): Extract out gathering logic so the naming logic doesn't have
# to run.
@@ -363,6 +368,24 @@ def gather_initializers(root_checkpointable):
checkpointable_objects=checkpointable_objects,
node_ids=node_ids,
object_names=object_names)
+ return checkpointable_objects
+
+
+def gather_initializers(root_checkpointable):
+ """Traverse the object graph and find initialization ops.
+
+ Looks for `Checkpointable` objects which are dependencies of
+ `root_checkpointable` and which have an `initializer` property. Includes
+ initializers for slot variables only if the variable they are slotting for and
+ the optimizer are dependencies of `root_checkpointable` (i.e. if they would be
+ saved with a checkpoint).
+
+ Args:
+ root_checkpointable: A `Checkpointable` object to gather initializers for.
+ Returns:
+ A list of initialization ops.
+ """
+ checkpointable_objects = list_objects(root_checkpointable)
return [c.initializer for c in checkpointable_objects
if hasattr(c, "initializer") and c.initializer is not None]
@@ -414,9 +437,10 @@ class CheckpointLoadStatus(_LoadStatus):
See `Saver.restore` for usage examples.
"""
- def __init__(self, checkpoint, feed_dict):
+ def __init__(self, checkpoint, feed_dict, root_checkpointable):
self._checkpoint = checkpoint
self._feed_dict = feed_dict
+ self._root_checkpointable = root_checkpointable
def assert_consumed(self):
"""Asserts that all objects in the checkpoint have been created/matched.
@@ -446,6 +470,16 @@ class CheckpointLoadStatus(_LoadStatus):
("Unused attributes in these objects (the attributes exist in the "
"checkpoint but not in the objects): %s") % (
self._checkpoint.unused_attributes.items(),))
+ for checkpointable_object in list_objects(self._root_checkpointable):
+ self._checkpoint.all_python_objects.add(checkpointable_object)
+ unused_python_objects = (
+ set(self._checkpoint.all_python_objects)
+ - set(self._checkpoint.object_by_proto_id.values()))
+ if unused_python_objects:
+ raise AssertionError(
+ ("Some Python objects were not bound to checkpointed values, likely "
+ "due to changes in the Python program: %s")
+ % (unused_python_objects,))
return self
def run_restore_ops(self, session=None):
@@ -457,17 +491,35 @@ class CheckpointLoadStatus(_LoadStatus):
session.run(self._checkpoint.restore_ops, feed_dict=self._feed_dict)
def initialize_or_restore(self, session=None):
- """Alias for `run_restore_ops`.
+ """Run operations to initialize or restore objects in the dependency graph.
+
+ Any objects in the dependency graph which have initializers but are not in
+ the checkpoint will have those initializers run, unless those variables are
+ being restored by a later call to `tf.train.Checkpoint.restore()`.
This method has a sibling in `InitializationOnlyStatus` which instead
initializes variables. That type is returned if no checkpoint is specified
in `Saver.restore`.
Args:
- session: The session to run restore ops in. If `None`, uses the default
- session.
+ session: The session to run init/restore ops in. If `None`, uses the
+ default session.
"""
+ if context.executing_eagerly():
+ return # Initialization and restoration ops are run eagerly
+ if session is None:
+ session = ops.get_default_session()
+ all_objects = list_objects(self._root_checkpointable)
+ already_initialized_objects = set(
+ self._checkpoint.object_by_proto_id.values())
+ initializers_for_non_restored_variables = [
+ c.initializer for c in all_objects
+ if hasattr(c, "initializer")
+ and c not in already_initialized_objects
+ and (getattr(c, "_update_uid", self._checkpoint.restore_uid - 1)
+ < self._checkpoint.restore_uid)]
self.run_restore_ops(session=session)
+ session.run(initializers_for_non_restored_variables)
class InitializationOnlyStatus(_LoadStatus):
@@ -480,7 +532,8 @@ class InitializationOnlyStatus(_LoadStatus):
otherwise.
"""
- def __init__(self, root_checkpointable):
+ def __init__(self, root_checkpointable, restore_uid):
+ self._restore_uid = restore_uid
self._root_checkpointable = root_checkpointable
def assert_consumed(self):
@@ -504,8 +557,9 @@ class InitializationOnlyStatus(_LoadStatus):
def initialize_or_restore(self, session=None):
"""Runs initialization ops for variables.
- Only objects which would be saved by `Saver.save` will be initialized. See
- `gather_initializers` for details.
+ Objects which would be saved by `Saver.save` will be initialized, unless
+ those variables are being restored by a later call to
+ `tf.train.Checkpoint.restore()`.
This method does nothing when executing eagerly (initializers get run
eagerly).
@@ -518,7 +572,13 @@ class InitializationOnlyStatus(_LoadStatus):
return # run eagerly
if session is None:
session = ops.get_default_session()
- session.run(gather_initializers(self._root_checkpointable))
+ checkpointable_objects = list_objects(self._root_checkpointable)
+ initializers = [
+ c.initializer for c in checkpointable_objects
+ if hasattr(c, "initializer") and c.initializer is not None
+ and (getattr(c, "_update_uid", self._restore_uid - 1)
+ < self._restore_uid)]
+ session.run(initializers)
_DEPRECATED_RESTORE_INSTRUCTIONS = (
@@ -616,11 +676,10 @@ class CheckpointableSaver(object):
# Allow passing in a weak reference to avoid reference cycles when
# `Checkpointable` objects save themselves.
self._root_checkpointable_ref = root_checkpointable
- if not context.executing_eagerly():
- with ops.device("/cpu:0"):
- self._file_prefix_placeholder = constant_op.constant("model")
- else:
- self._file_prefix_placeholder = None
+ # The file prefix placeholder is created lazily when graph building (and not
+ # at all when executing eagerly) to avoid creating ops in the constructor
+ # (when they may never be necessary).
+ self._file_prefix_placeholder = None
# Op caching for save
self._object_graph_feed_tensor = None
@@ -775,9 +834,12 @@ class CheckpointableSaver(object):
object is returned which runs restore ops from a name-based saver.
"""
if save_path is None:
- return InitializationOnlyStatus(self._root_checkpointable)
+ return InitializationOnlyStatus(self._root_checkpointable, ops.uid())
in_graph_mode = not context.executing_eagerly()
if in_graph_mode:
+ if self._file_prefix_placeholder is None:
+ with ops.device("/cpu:0"):
+ self._file_prefix_placeholder = constant_op.constant("model")
file_prefix_tensor = self._file_prefix_placeholder
file_prefix_feed_dict = {self._file_prefix_placeholder: save_path}
else:
@@ -819,7 +881,9 @@ class CheckpointableSaver(object):
checkpointable_lib._CheckpointPosition( # pylint: disable=protected-access
checkpoint=checkpoint, proto_id=0).restore(self._root_checkpointable)
load_status = CheckpointLoadStatus(
- checkpoint, feed_dict=file_prefix_feed_dict)
+ checkpoint,
+ root_checkpointable=self._root_checkpointable,
+ feed_dict=file_prefix_feed_dict)
return load_status
diff --git a/tensorflow/python/training/checkpointable_utils_test.py b/tensorflow/python/training/checkpointable_utils_test.py
index 29fcdb70b4..40dfeb28d5 100644
--- a/tensorflow/python/training/checkpointable_utils_test.py
+++ b/tensorflow/python/training/checkpointable_utils_test.py
@@ -808,13 +808,16 @@ class CheckpointingTests(test.TestCase):
save_path = checkpointable_utils.CheckpointableSaver(save_root).save(
os.path.join(checkpoint_directory, "ckpt"))
load_root = checkpointable.Checkpointable()
- checkpointable_utils.CheckpointableSaver(load_root).restore(save_path)
+ status = checkpointable_utils.CheckpointableSaver(load_root).restore(
+ save_path)
load_root.dep_one = checkpointable.Checkpointable()
load_root.dep_two = checkpointable.Checkpointable()
load_root.dep_one.dep_three = checkpointable.Checkpointable()
- with self.assertRaisesRegexp(AssertionError,
- "resolved to different objects"):
- load_root.dep_two.dep_three = checkpointable.Checkpointable()
+ load_root.dep_two.dep_three = checkpointable.Checkpointable()
+ checkpointable_utils.add_variable(
+ load_root.dep_one.dep_three, name="var", initializer=0.)
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
@test_util.run_in_graph_and_eager_modes()
def testObjectsCombined(self):
@@ -1114,6 +1117,84 @@ class CheckpointingTests(test.TestCase):
self.assertAllEqual([1., 2., 3., 4., 5.],
self.evaluate(deferred_second_dense.bias))
+ @test_util.run_in_graph_and_eager_modes()
+ def test_initialize_if_not_restoring(self):
+ checkpoint_directory = self.get_temp_dir()
+ checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
+ optimizer_only_prefix = os.path.join(checkpoint_directory, "opt")
+ with ops.Graph().as_default(), self.test_session(
+ graph=ops.get_default_graph()), test_util.device(use_gpu=True):
+ model = MyModel()
+ optimizer = adam.AdamOptimizer(0.001)
+ root = checkpointable_utils.Checkpoint(
+ model=model, # Do not save the optimizer with the checkpoint.
+ global_step=training_util.get_or_create_global_step())
+ optimizer_checkpoint = checkpointable_utils.Checkpoint(
+ optimizer=optimizer)
+
+ checkpoint_path = saver_lib.latest_checkpoint(checkpoint_directory)
+ status = root.restore(save_path=checkpoint_path)
+ input_value = constant_op.constant([[3.]])
+ train_fn = functools.partial(
+ optimizer.minimize,
+ functools.partial(model, input_value),
+ global_step=root.global_step)
+ if not context.executing_eagerly():
+ train_fn = functools.partial(self.evaluate, train_fn())
+ status.initialize_or_restore()
+ self.evaluate([v.initializer for v in optimizer.variables()])
+ train_fn()
+ model_save_path = root.save(file_prefix=checkpoint_prefix)
+ self.evaluate(optimizer.variables()[0].assign(42.))
+ optimizer_save_path = optimizer_checkpoint.save(optimizer_only_prefix)
+
+ # Restore into a graph with the optimizer
+ with ops.Graph().as_default(), self.test_session(
+ graph=ops.get_default_graph()), test_util.device(use_gpu=True):
+ model = MyModel()
+ optimizer = adam.AdamOptimizer(0.001)
+ root = checkpointable_utils.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ status = root.restore(save_path=model_save_path)
+ input_value = constant_op.constant([[3.]])
+ train_fn = functools.partial(
+ optimizer.minimize,
+ functools.partial(model, input_value),
+ global_step=root.global_step)
+ if not context.executing_eagerly():
+ train_fn = functools.partial(self.evaluate, train_fn())
+ status.initialize_or_restore()
+ train_fn()
+ with self.assertRaises(AssertionError):
+ status.assert_consumed()
+
+ # Make sure initialization doesn't clobber later restores
+ with ops.Graph().as_default(), self.test_session(
+ graph=ops.get_default_graph()), test_util.device(use_gpu=True):
+ model = MyModel()
+ optimizer = adam.AdamOptimizer(0.001, beta1=1.0)
+ root = checkpointable_utils.Checkpoint(
+ optimizer=optimizer, model=model,
+ global_step=training_util.get_or_create_global_step())
+ opt_root = checkpointable_utils.Checkpoint(
+ optimizer=optimizer)
+ status = root.restore(save_path=model_save_path)
+ init_only_optimizer_status = opt_root.restore(save_path=None)
+ optimizer_status = opt_root.restore(save_path=optimizer_save_path)
+ input_value = constant_op.constant([[3.]])
+ train_fn = functools.partial(
+ optimizer.minimize,
+ functools.partial(model, input_value),
+ global_step=root.global_step)
+ if not context.executing_eagerly():
+ train_fn = functools.partial(self.evaluate, train_fn())
+ optimizer_status.run_restore_ops()
+ status.initialize_or_restore()
+ init_only_optimizer_status.initialize_or_restore()
+ train_fn()
+ self.assertEqual(42., self.evaluate(optimizer.variables()[0]))
+
class TemplateTests(test.TestCase):
@@ -1276,9 +1357,7 @@ class CheckpointCompatibilityTests(test.TestCase):
with save_graph.as_default(), self.test_session(
graph=save_graph) as session:
root = self._initialized_model()
- object_saver = checkpointable_utils.CheckpointableSaver(root)
- save_path = object_saver.save(
- session=session, file_prefix=checkpoint_prefix)
+ save_path = root.save(session=session, file_prefix=checkpoint_prefix)
with context.eager_mode():
root = self._initialized_model()
self._set_sentinels(root)
@@ -1290,8 +1369,7 @@ class CheckpointCompatibilityTests(test.TestCase):
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
with context.eager_mode():
root = self._initialized_model()
- object_saver = checkpointable_utils.CheckpointableSaver(root)
- save_path = object_saver.save(file_prefix=checkpoint_prefix)
+ save_path = root.save(file_prefix=checkpoint_prefix)
with context.graph_mode():
save_graph = ops.Graph()
with save_graph.as_default(), self.test_session(
diff --git a/tensorflow/python/training/optimizer.py b/tensorflow/python/training/optimizer.py
index f126d3847b..66914bacf3 100644
--- a/tensorflow/python/training/optimizer.py
+++ b/tensorflow/python/training/optimizer.py
@@ -170,19 +170,6 @@ class _DenseResourceVariableProcessor(_OptimizableVariable):
return update_op
-class _StreamingModelPortProcessor(_OptimizableVariable):
- """Processor for streaming ModelPorts."""
-
- def __init__(self, v):
- self._v = v
-
- def target(self):
- return self._v
-
- def update_op(self, optimizer, g):
- return g
-
-
class _TensorProcessor(_OptimizableVariable):
"""Processor for ordinary Tensors.
@@ -216,8 +203,6 @@ def _get_processor(v):
return _DenseResourceVariableProcessor(v)
if isinstance(v, variables.Variable):
return _RefVariableProcessor(v)
- if v.op.type == "SubmodelPort":
- return _StreamingModelPortProcessor(v)
if isinstance(v, ops.Tensor):
return _TensorProcessor(v)
raise NotImplementedError("Trying to optimize unsupported type ", v)
diff --git a/tensorflow/python/training/queue_runner.py b/tensorflow/python/training/queue_runner.py
index 42559d1e62..92207d97cd 100644
--- a/tensorflow/python/training/queue_runner.py
+++ b/tensorflow/python/training/queue_runner.py
@@ -22,13 +22,3 @@ from __future__ import print_function
# pylint: disable=wildcard-import
from tensorflow.python.training.queue_runner_impl import *
# pylint: enable=wildcard-import
-from tensorflow.python.util.all_util import remove_undocumented
-
-
-_allowed_symbols = [
- # Documented in training.py:
- "QueueRunner",
- "add_queue_runner",
- "start_queue_runners",
-]
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/training/training.py b/tensorflow/python/training/training.py
index d7e5078be7..4ae7f84510 100644
--- a/tensorflow/python/training/training.py
+++ b/tensorflow/python/training/training.py
@@ -105,13 +105,6 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
-import sys as _sys
-
-from tensorflow.python.ops import io_ops as _io_ops
-from tensorflow.python.ops import sdca_ops as _sdca_ops
-from tensorflow.python.ops import state_ops as _state_ops
-from tensorflow.python.util.all_util import remove_undocumented
-
# pylint: disable=g-bad-import-order,unused-import
from tensorflow.python.ops.sdca_ops import sdca_optimizer
from tensorflow.python.ops.sdca_ops import sdca_fprint
@@ -215,39 +208,6 @@ from tensorflow.core.protobuf.tensorflow_server_pb2 import ServerDef
from tensorflow.python.training.server_lib import ClusterSpec
from tensorflow.python.training.server_lib import Server
-# Symbols whitelisted for export without documentation.
-_allowed_symbols = [
- # TODO(cwhipkey): review these and move to contrib or expose through
- # documentation.
- "generate_checkpoint_state_proto", # Used internally by saver.
- "checkpoint_exists", # Only used in test?
- "get_checkpoint_mtimes", # Only used in test?
-
- # Legacy: remove.
- "do_quantize_training_on_graphdef", # At least use grah_def, not graphdef.
- # No uses within tensorflow.
- "queue_runner", # Use tf.train.start_queue_runner etc directly.
- # This is also imported internally.
-
- # TODO(drpng): document these. The reference in howtos/distributed does
- # not link.
- "SyncReplicasOptimizer",
- # Protobufs:
- "BytesList", # from example_pb2.
- "ClusterDef",
- "Example", # from example_pb2
- "Feature", # from example_pb2
- "Features", # from example_pb2
- "FeatureList", # from example_pb2
- "FeatureLists", # from example_pb2
- "FloatList", # from example_pb2.
- "Int64List", # from example_pb2.
- "JobDef",
- "SaverDef", # From saver_pb2.
- "SequenceExample", # from example_pb2.
- "ServerDef",
-]
-
# pylint: disable=undefined-variable
tf_export("train.BytesList")(BytesList)
tf_export("train.ClusterDef")(ClusterDef)
@@ -263,9 +223,3 @@ tf_export("train.SaverDef")(SaverDef)
tf_export("train.SequenceExample")(SequenceExample)
tf_export("train.ServerDef")(ServerDef)
# pylint: enable=undefined-variable
-
-# Include extra modules for docstrings because:
-# * Input methods in tf.train are documented in io_ops.
-# * Saver methods in tf.train are documented in state_ops.
-remove_undocumented(__name__, _allowed_symbols,
- [_sys.modules[__name__], _io_ops, _sdca_ops, _state_ops])
diff --git a/tensorflow/python/util/compat.py b/tensorflow/python/util/compat.py
index 3358ffe526..1aba7584d1 100644
--- a/tensorflow/python/util/compat.py
+++ b/tensorflow/python/util/compat.py
@@ -40,7 +40,6 @@ import numbers as _numbers
import numpy as _np
import six as _six
-from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export
@@ -142,13 +141,3 @@ tf_export('compat.complex_types').export_constant(__name__, 'complex_types')
bytes_or_text_types = (bytes, _six.text_type)
tf_export('compat.bytes_or_text_types').export_constant(__name__,
'bytes_or_text_types')
-
-_allowed_symbols = [
- 'as_str',
- 'bytes_or_text_types',
- 'complex_types',
- 'integral_types',
- 'real_types',
-]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index 5622431bc9..1104768ae8 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -36,7 +36,6 @@ import collections as _collections
import six as _six
from tensorflow.python import pywrap_tensorflow as _pywrap_tensorflow
-from tensorflow.python.util.all_util import remove_undocumented
def _sorted(dict_):
@@ -758,21 +757,3 @@ def flatten_with_joined_string_paths(structure, separator="/"):
_pywrap_tensorflow.RegisterSequenceClass(_collections.Sequence)
-
-
-_allowed_symbols = [
- "assert_same_structure",
- "is_sequence",
- "flatten",
- "flatten_dict_items",
- "pack_sequence_as",
- "map_structure",
- "assert_shallow_structure",
- "flatten_up_to",
- "map_structure_up_to",
- "get_traverse_shallow_structure",
- "yield_flat_paths",
- "flatten_with_joined_string_paths",
-]
-
-remove_undocumented(__name__, _allowed_symbols)
diff --git a/tensorflow/stream_executor/BUILD b/tensorflow/stream_executor/BUILD
index 80fc9ff292..c68cda0100 100644
--- a/tensorflow/stream_executor/BUILD
+++ b/tensorflow/stream_executor/BUILD
@@ -35,6 +35,7 @@ cc_library(
deps = [
"//tensorflow/compiler/xla:statusor",
"//tensorflow/core:lib",
+ "//tensorflow/core:ptr_util",
"@local_config_cuda//cuda:cuda_headers",
],
alwayslink = 1,
@@ -46,6 +47,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core:lib",
+ "//tensorflow/core:ptr_util",
"//tensorflow/compiler/xla:statusor",
"@local_config_cuda//cuda:cuda_headers",
] + if_static([":stream_executor_impl"]),
diff --git a/tensorflow/stream_executor/blas.h b/tensorflow/stream_executor/blas.h
index 6e62b85728..be0b0bf5fb 100644
--- a/tensorflow/stream_executor/blas.h
+++ b/tensorflow/stream_executor/blas.h
@@ -41,9 +41,10 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_BLAS_H_
#include <complex>
-#include "tensorflow/stream_executor/platform/port.h"
+#include "tensorflow/stream_executor/host_or_device_scalar.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
+#include "tensorflow/stream_executor/platform/port.h"
namespace Eigen {
struct half;
@@ -1032,43 +1033,49 @@ class BlasSupport {
// creating a new Stream for each attempt.
virtual bool DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
- const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int32> *c,
+ uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
+ const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,
+ int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int32> *c,
int ldc, ComputationType computation_type, AlgorithmType algorithm,
ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, const Eigen::half &alpha,
+ uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
const DeviceMemory<Eigen::half> &a, int lda,
- const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta,
- DeviceMemory<Eigen::half> *c, int ldc, ComputationType computation_type,
- AlgorithmType algorithm, ProfileResult *output_profile_result) = 0;
+ const DeviceMemory<Eigen::half> &b, int ldb,
+ const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
+ int ldc, ComputationType computation_type, AlgorithmType algorithm,
+ ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
- const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
+ const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
+ int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
int ldc, ComputationType computation_type, AlgorithmType algorithm,
ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
- const DeviceMemory<double> &b, int ldb, double beta,
- DeviceMemory<double> *c, int ldc, ComputationType computation_type,
- AlgorithmType algorithm, ProfileResult *output_profile_result) = 0;
+ uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
+ int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
+ int ldc, ComputationType computation_type, AlgorithmType algorithm,
+ ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, std::complex<float> alpha,
+ uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
- std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ const HostOrDeviceScalar<std::complex<float>> &beta,
+ DeviceMemory<std::complex<float>> *c, int ldc,
ComputationType computation_type, AlgorithmType algorithm,
ProfileResult *output_profile_result) = 0;
virtual bool DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, std::complex<double> alpha,
+ uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
- std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ const HostOrDeviceScalar<std::complex<double>> &beta,
+ DeviceMemory<std::complex<double>> *c, int ldc,
ComputationType computation_type, AlgorithmType algorithm,
ProfileResult *output_profile_result) = 0;
@@ -1886,49 +1893,57 @@ class BlasSupport {
override; \
bool DoBlasGemmWithAlgorithm( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
- uint64 m, uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, \
- int lda, const DeviceMemory<int8> &b, int ldb, int beta, \
- DeviceMemory<int> *c, int ldc, blas::ComputationType computation_type, \
+ uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha, \
+ const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, \
+ int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, \
+ int ldc, blas::ComputationType computation_type, \
blas::AlgorithmType algorithm, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmWithAlgorithm( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
- uint64 m, uint64 n, uint64 k, const Eigen::half &alpha, \
+ uint64 m, uint64 n, uint64 k, \
+ const HostOrDeviceScalar<Eigen::half> &alpha, \
const DeviceMemory<Eigen::half> &a, int lda, \
- const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta, \
+ const DeviceMemory<Eigen::half> &b, int ldb, \
+ const HostOrDeviceScalar<Eigen::half> &beta, \
DeviceMemory<Eigen::half> *c, int ldc, \
blas::ComputationType computation_type, blas::AlgorithmType algorithm, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmWithAlgorithm( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
- uint64 m, uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, \
- int lda, const DeviceMemory<float> &b, int ldb, float beta, \
- DeviceMemory<float> *c, int ldc, blas::ComputationType computation_type, \
+ uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha, \
+ const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b, \
+ int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c, \
+ int ldc, blas::ComputationType computation_type, \
blas::AlgorithmType algorithm, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmWithAlgorithm( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
- uint64 m, uint64 n, uint64 k, double alpha, \
+ uint64 m, uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha, \
const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b, \
- int ldb, double beta, DeviceMemory<double> *c, int ldc, \
+ int ldb, const HostOrDeviceScalar<double> &beta, \
+ DeviceMemory<double> *c, int ldc, \
blas::ComputationType computation_type, blas::AlgorithmType algorithm, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmWithAlgorithm( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
- uint64 m, uint64 n, uint64 k, std::complex<float> alpha, \
+ uint64 m, uint64 n, uint64 k, \
+ const HostOrDeviceScalar<std::complex<float>> &alpha, \
const DeviceMemory<std::complex<float>> &a, int lda, \
const DeviceMemory<std::complex<float>> &b, int ldb, \
- std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc, \
+ const HostOrDeviceScalar<std::complex<float>> &beta, \
+ DeviceMemory<std::complex<float>> *c, int ldc, \
blas::ComputationType computation_type, blas::AlgorithmType algorithm, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmWithAlgorithm( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
- uint64 m, uint64 n, uint64 k, std::complex<double> alpha, \
+ uint64 m, uint64 n, uint64 k, \
+ const HostOrDeviceScalar<std::complex<double>> &alpha, \
const DeviceMemory<std::complex<double>> &a, int lda, \
const DeviceMemory<std::complex<double>> &b, int ldb, \
- std::complex<double> beta, DeviceMemory<std::complex<double>> *c, \
- int ldc, blas::ComputationType computation_type, \
- blas::AlgorithmType algorithm, \
+ const HostOrDeviceScalar<std::complex<double>> &beta, \
+ DeviceMemory<std::complex<double>> *c, int ldc, \
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm, \
blas::ProfileResult *output_profile_result) override; \
bool DoBlasGemmBatched( \
Stream *stream, blas::Transpose transa, blas::Transpose transb, \
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.cc b/tensorflow/stream_executor/cuda/cuda_blas.cc
index 007c0f1c86..3c1353aee3 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.cc
+++ b/tensorflow/stream_executor/cuda/cuda_blas.cc
@@ -2156,10 +2156,11 @@ static bool TensorOpsAvailable(int cc_major) {
template <typename InT, typename OutT, typename CompT>
bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a, int lda,
- const DeviceMemory<InT> &b, int ldb, const CompT &beta,
- DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
- blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ uint64 n, uint64 k, const HostOrDeviceScalar<CompT> &alpha,
+ const DeviceMemory<InT> &a, int lda, const DeviceMemory<InT> &b, int ldb,
+ const HostOrDeviceScalar<CompT> &beta, DeviceMemory<OutT> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
// CUDA < version 8 and GPUs < sm_50 don't support cublasGemmEx.
#if CUDA_VERSION < 8000
return false;
@@ -2175,6 +2176,12 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
return false;
}
+ // Either both 'alpha' and 'beta' need to be pointers to device memory, or
+ // they need to be both host scalars.
+ if (alpha.is_pointer() != beta.is_pointer()) {
+ return false;
+ }
+
std::unique_ptr<CUDATimer, TimerDeleter> timer;
if (output_profile_result != nullptr) {
timer.reset(new CUDATimer(parent_));
@@ -2187,10 +2194,15 @@ bool CUDABlas::DoBlasGemmWithAlgorithmImpl(
// Since we are converting 'algorithm' to cublasGemmAlgo_t by static_cast,
// we do the following compile-time check on the default value:
static_assert(blas::kDefaultGemmAlgo == CUBLAS_GEMM_DFALT, "");
+ // If 'alpha' and 'beta' are host scalars and CompT is Eigen::half, we
+ // essentially reinterpet_cast to __half, which is safe because Eigen::half
+ // inherits from __half.
bool result = DoBlasInternalFailureOK(
- wrap::cublasGemmEx, stream, /* pointer_mode_host = */ true,
- CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k, &alpha,
- CUDAMemory(a), cuda_in_type, lda, CUDAMemory(b), cuda_in_type, ldb, &beta,
+ wrap::cublasGemmEx, stream, /* pointer_mode_host = */ !alpha.is_pointer(),
+ CUDABlasTranspose(transa), CUDABlasTranspose(transb), m, n, k,
+ alpha.is_pointer() ? CUDAMemory(alpha.pointer()) : &alpha.value(),
+ CUDAMemory(a), cuda_in_type, lda, CUDAMemory(b), cuda_in_type, ldb,
+ beta.is_pointer() ? CUDAMemory(beta.pointer()) : &beta.value(),
CUDAMemoryMutable(c), CUDADataType<OutT>::type, ldc,
CUDAComputationType(computation_type),
static_cast<cublasGemmAlgo_t>(algorithm));
@@ -2239,10 +2251,11 @@ bool CUDABlas::GetBlasGemmAlgorithms(
bool CUDABlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
- const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
- int ldc, blas::ComputationType computation_type,
- blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ uint64 n, uint64 k, const HostOrDeviceScalar<int> &alpha,
+ const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b, int ldb,
+ const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithAlgorithmImpl(
stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
computation_type, algorithm, output_profile_result);
@@ -2250,17 +2263,25 @@ bool CUDABlas::DoBlasGemmWithAlgorithm(
bool CUDABlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, const Eigen::half &alpha,
+ uint64 n, uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
const DeviceMemory<Eigen::half> &a, int lda,
- const DeviceMemory<Eigen::half> &b, int ldb, const Eigen::half &beta,
- DeviceMemory<Eigen::half> *c, int ldc,
- blas::ComputationType computation_type, blas::AlgorithmType algorithm,
- blas::ProfileResult *output_profile_result) {
+ const DeviceMemory<Eigen::half> &b, int ldb,
+ const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
+ int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
if (computation_type == blas::ComputationType::kF32) {
+ if (alpha.is_pointer() || beta.is_pointer()) {
+ // We cannot easily convert a pointer to f16 memory to a pointer to f32
+ // memory from here, so we don't support this for now.
+ // TODO(akuegel): Investigate whether we can do the conversion before
+ // calling DoBlasGemmWithAlgorithm.
+ return false;
+ }
+ HostOrDeviceScalar<float> float_alpha(static_cast<float>(alpha.value()));
+ HostOrDeviceScalar<float> float_beta(static_cast<float>(beta.value()));
return DoBlasGemmWithAlgorithmImpl(
- stream, transa, transb, m, n, k, static_cast<float>(alpha), a, lda, b,
- ldb, static_cast<float>(beta), c, ldc, computation_type, algorithm,
- output_profile_result);
+ stream, transa, transb, m, n, k, float_alpha, a, lda, b, ldb,
+ float_beta, c, ldc, computation_type, algorithm, output_profile_result);
}
CHECK_EQ(computation_type, blas::ComputationType::kF16);
@@ -2271,8 +2292,9 @@ bool CUDABlas::DoBlasGemmWithAlgorithm(
bool CUDABlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
- const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ uint64 n, uint64 k, const HostOrDeviceScalar<float> &alpha,
+ const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
+ int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithAlgorithmImpl(
@@ -2282,9 +2304,10 @@ bool CUDABlas::DoBlasGemmWithAlgorithm(
bool CUDABlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
- const DeviceMemory<double> &b, int ldb, double beta,
- DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
+ uint64 n, uint64 k, const HostOrDeviceScalar<double> &alpha,
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
+ int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
+ int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithAlgorithmImpl(
stream, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc,
@@ -2293,10 +2316,11 @@ bool CUDABlas::DoBlasGemmWithAlgorithm(
bool CUDABlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, std::complex<float> alpha,
+ uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
- std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ const HostOrDeviceScalar<std::complex<float>> &beta,
+ DeviceMemory<std::complex<float>> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithAlgorithmImpl(
@@ -2306,10 +2330,11 @@ bool CUDABlas::DoBlasGemmWithAlgorithm(
bool CUDABlas::DoBlasGemmWithAlgorithm(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, std::complex<double> alpha,
+ uint64 n, uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
- std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ const HostOrDeviceScalar<std::complex<double>> &beta,
+ DeviceMemory<std::complex<double>> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result) {
return DoBlasGemmWithAlgorithmImpl(
diff --git a/tensorflow/stream_executor/cuda/cuda_blas.h b/tensorflow/stream_executor/cuda/cuda_blas.h
index 55c414a1f9..12dc5e47fd 100644
--- a/tensorflow/stream_executor/cuda/cuda_blas.h
+++ b/tensorflow/stream_executor/cuda/cuda_blas.h
@@ -21,6 +21,7 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_CUDA_CUDA_BLAS_H_
#include "tensorflow/stream_executor/blas.h"
+#include "tensorflow/stream_executor/host_or_device_scalar.h"
#include "tensorflow/stream_executor/lib/stringpiece.h"
#include "tensorflow/stream_executor/platform/mutex.h"
#include "tensorflow/stream_executor/platform/port.h"
@@ -116,18 +117,13 @@ class CUDABlas : public blas::BlasSupport {
int batch_count, ScratchAllocator *scratch_allocator);
// Helper function for implementing DoBlasGemmWithAlgorithm.
- //
- // We take alpha and beta by const reference because T might be Eigen::half,
- // and we want to avoid pulling in a dependency on Eigen. When we pass the
- // references to cublas, we essentially reinterpret_cast to __half, which is
- // safe because Eigen::half inherits from __half.
template <typename InT, typename OutT, typename CompT>
bool DoBlasGemmWithAlgorithmImpl(
Stream *stream, blas::Transpose transa, blas::Transpose transb, uint64 m,
- uint64 n, uint64 k, const CompT &alpha, const DeviceMemory<InT> &a,
- int lda, const DeviceMemory<InT> &b, int ldb, const CompT &beta,
- DeviceMemory<OutT> *c, int ldc, blas::ComputationType computation_type,
- blas::AlgorithmType algorithm,
+ uint64 n, uint64 k, const HostOrDeviceScalar<CompT> &alpha,
+ const DeviceMemory<InT> &a, int lda, const DeviceMemory<InT> &b, int ldb,
+ const HostOrDeviceScalar<CompT> &beta, DeviceMemory<OutT> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);
// Helper function for implementing DoBlasGemmWithProfiling.
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.cc b/tensorflow/stream_executor/cuda/cuda_dnn.cc
index 102419a264..42a77aa3f8 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.cc
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/stream_executor/cuda/cuda_activation.h"
#include "tensorflow/stream_executor/cuda/cuda_diagnostics.h"
@@ -312,7 +313,10 @@ CUDNN_DNN_ROUTINE_EACH_R5_WITH_STREAM(
// clang-format off
#if CUDNN_VERSION >= 6000
#define CUDNN_DNN_ROUTINE_EACH_R6(__macro) \
- __macro(cudnnSetRNNDescriptor_v6)
+ __macro(cudnnSetRNNDescriptor_v6) \
+ __macro(cudnnCreatePersistentRNNPlan) \
+ __macro(cudnnDestroyPersistentRNNPlan) \
+ __macro(cudnnSetPersistentRNNPlan)
// clang-format on
CUDNN_DNN_ROUTINE_EACH_R6(STREAM_EXECUTOR_CUDNN_WRAP)
@@ -1195,7 +1199,7 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
public:
CudnnRnnDescriptor(CUDAExecutor* parent, cudnnHandle_t cudnn_handle,
int num_layers, int hidden_size, int input_size,
- cudnnRNNInputMode_t input_mode,
+ int batch_size, cudnnRNNInputMode_t input_mode,
cudnnDirectionMode_t direction_mode,
cudnnRNNMode_t rnn_mode, cudnnDataType_t data_type,
cudnnDataType_t compute_type,
@@ -1207,6 +1211,10 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
num_layers_(num_layers),
hidden_size_(hidden_size),
input_size_(input_size),
+ batch_size_(batch_size),
+#if CUDNN_VERSION >= 6000
+ rnn_plan_(nullptr),
+#endif
input_mode_(input_mode),
direction_mode_(direction_mode),
rnn_mode_(rnn_mode),
@@ -1226,12 +1234,26 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
CUDNN_RETURN_IF_FAIL(status, "Unable to create RNN descriptor");
#if CUDNN_VERSION >= 6000
// TODO: allow the user to choose an algorithm.
- cudnnRNNAlgo_t rnn_algo = ToCudnnRNNAlgo(algorithm_config_.algorithm());
+ rnn_algo_ = ToCudnnRNNAlgo(algorithm_config_.algorithm());
status = wrap::cudnnSetRNNDescriptor_v6(
- parent, cudnn_handle, rnn_desc_ /*rnnDesc*/, hidden_size /*hiddenSize*/,
- num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/,
- input_mode /*inputMode*/, direction_mode /*direction*/,
- rnn_mode /*mode*/, rnn_algo /*algo*/, compute_type /*dataType*/);
+ parent, cudnn_handle, /*rnnDesc=*/rnn_desc_, /*hiddenSize=*/hidden_size,
+ /*numLayers=*/num_layers, /*dropoutDesc=*/dropout_handle(),
+ /*inputMode=*/input_mode, /*direction=*/direction_mode,
+ /*mode=*/rnn_mode, /*algo=*/rnn_algo_, /*dataType=*/compute_type);
+ CUDNN_RETURN_IF_FAIL(status, ::tensorflow::strings::Printf(
+ "Unable to update RNN descriptor with "
+ "algo_id: %d and compute_type: %d",
+ static_cast<int>(rnn_algo_),
+ static_cast<int>(compute_type)));
+
+ if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC) {
+ CHECK_GE(batch_size_, 0);
+ status = wrap::cudnnCreatePersistentRNNPlan(
+ parent, rnn_desc_, batch_size_, data_type_, &rnn_plan_);
+ CUDNN_RETURN_IF_FAIL(status, "Unable to create persistent RNN plan.");
+ status = wrap::cudnnSetPersistentRNNPlan(parent, rnn_desc_, rnn_plan_);
+ CUDNN_RETURN_IF_FAIL(status, "Unable to update persistent RNN plan.");
+ }
#else
CHECK(algorithm_config_.is_default())
<< "Non-default algorithm not supported for CUDA version < 6.0";
@@ -1240,8 +1262,8 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
num_layers /*numLayers*/, dropout_handle() /*dropoutDesc*/,
input_mode /*inputMode*/, direction_mode /*direction*/,
rnn_mode /*mode*/, compute_type /*dataType*/);
-#endif
CUDNN_RETURN_IF_FAIL(status, "Unable to update RNN descriptor");
+#endif
// Create the params handle.
cudnn_params_desc_.reset(
@@ -1254,8 +1276,14 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
}
~CudnnRnnDescriptor() override {
if (rnn_desc_) {
- cudnnStatus_t status =
- wrap::cudnnDestroyRNNDescriptor(parent_, rnn_desc_);
+ cudnnStatus_t status;
+#if CUDNN_VERSION >= 6000
+ if (rnn_algo_ == CUDNN_RNN_ALGO_PERSIST_DYNAMIC && rnn_plan_) {
+ status = wrap::cudnnDestroyPersistentRNNPlan(parent_, rnn_plan_);
+ CUDNN_RETURN_IF_FAIL(status, "Unable to destroy persistent RNN plan.");
+ }
+#endif
+ status = wrap::cudnnDestroyRNNDescriptor(parent_, rnn_desc_);
CUDNN_RETURN_IF_FAIL(status, "Unable to destroy RNN descriptor");
}
}
@@ -1280,6 +1308,7 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
int num_layers() const { return num_layers_; }
int hidden_size() const { return hidden_size_; }
int input_size() const { return input_size_; }
+ int batch_size() const { return batch_size_; }
cudnnRNNInputMode_t input_mode() const { return input_mode_; }
cudnnDirectionMode_t direction_mode() const { return direction_mode_; }
cudnnRNNMode_t rnn_mode() const { return rnn_mode_; }
@@ -1314,6 +1343,13 @@ class CudnnRnnDescriptor : public CudnnDescriptorCommon<dnn::RnnDescriptor> {
int num_layers_;
int hidden_size_;
int input_size_;
+ // batch_size_ is set to -1 when not using CUDNN_RNN_ALGO_PERSIST_DYNAMIC
+ // algorithm.
+ int batch_size_;
+#if CUDNN_VERSION >= 6000
+ cudnnRNNAlgo_t rnn_algo_;
+ cudnnPersistentRNNPlan_t rnn_plan_;
+#endif
cudnnRNNInputMode_t input_mode_;
cudnnDirectionMode_t direction_mode_;
cudnnRNNMode_t rnn_mode_;
@@ -1970,22 +2006,20 @@ bool CudnnSupport::DoRnnBackwardImpl(
#endif // CUDNN_VERSION
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
-CudnnSupport::createRnnDescriptor(int num_layers, int hidden_size,
- int input_size, dnn::RnnInputMode input_mode,
- dnn::RnnDirectionMode direction_mode,
- dnn::RnnMode rnn_mode,
- dnn::DataType data_type,
- const dnn::AlgorithmConfig& algorithm_config,
- float dropout, uint64 seed,
- ScratchAllocator* state_allocator) {
+CudnnSupport::createRnnDescriptor(
+ int num_layers, int hidden_size, int input_size, int batch_size,
+ dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
+ dnn::RnnMode rnn_mode, dnn::DataType data_type,
+ const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
+ ScratchAllocator* state_allocator) {
#if CUDNN_VERSION >= 5000
mutex_lock lock{dnn_handle_mutex_};
std::unique_ptr<CudnnRnnDescriptor> rnn_desc(new CudnnRnnDescriptor(
parent_, ToHandle(dnn_handle_), num_layers, hidden_size, input_size,
- ToCudnnRnnInputMode(input_mode), ToCudnnRnnDirectionMode(direction_mode),
- ToCudnnRnnMode(rnn_mode), ToCudnnDataType(data_type),
- GetRnnComputeType(data_type), algorithm_config, dropout, seed,
- state_allocator));
+ batch_size, ToCudnnRnnInputMode(input_mode),
+ ToCudnnRnnDirectionMode(direction_mode), ToCudnnRnnMode(rnn_mode),
+ ToCudnnDataType(data_type), GetRnnComputeType(data_type),
+ algorithm_config, dropout, seed, state_allocator));
if (!rnn_desc->ok()) {
return rnn_desc->Status();
}
diff --git a/tensorflow/stream_executor/cuda/cuda_dnn.h b/tensorflow/stream_executor/cuda/cuda_dnn.h
index 5ded7cf154..7d53dbe4a5 100644
--- a/tensorflow/stream_executor/cuda/cuda_dnn.h
+++ b/tensorflow/stream_executor/cuda/cuda_dnn.h
@@ -48,7 +48,7 @@ class CudnnSupport : public dnn::DnnSupport {
port::StatusOr<perftools::gputools::dnn::VersionInfo> GetVersion() override;
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
- int num_layers, int hidden_size, int input_size,
+ int num_layers, int hidden_size, int input_size, int batch_size,
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
dnn::RnnMode rnn_mode, dnn::DataType data_type,
const dnn::AlgorithmConfig& algorithm_config, float dropout, uint64 seed,
diff --git a/tensorflow/stream_executor/cuda/cuda_platform.cc b/tensorflow/stream_executor/cuda/cuda_platform.cc
index 7a6ef5a248..649224a20e 100644
--- a/tensorflow/stream_executor/cuda/cuda_platform.cc
+++ b/tensorflow/stream_executor/cuda/cuda_platform.cc
@@ -168,8 +168,8 @@ port::StatusOr<StreamExecutor*> CudaPlatform::GetExecutor(
port::StatusOr<std::unique_ptr<StreamExecutor>>
CudaPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
- auto executor = port::MakeUnique<StreamExecutor>(
- this, port::MakeUnique<CUDAExecutor>(config.plugin_config));
+ auto executor = MakeUnique<StreamExecutor>(
+ this, MakeUnique<CUDAExecutor>(config.plugin_config));
auto init_status = executor->Init(config.ordinal, config.device_options);
if (!init_status.ok()) {
return port::Status{
diff --git a/tensorflow/stream_executor/dnn.cc b/tensorflow/stream_executor/dnn.cc
index 6edb572820..031c82d3f4 100644
--- a/tensorflow/stream_executor/dnn.cc
+++ b/tensorflow/stream_executor/dnn.cc
@@ -15,12 +15,17 @@ limitations under the License.
#include "tensorflow/stream_executor/dnn.h"
+#include "tensorflow/core/lib/hash/hash.h"
#include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/lib/stringprintf.h"
namespace stream_executor {
namespace dnn {
+uint64 AlgorithmDesc::hash() const {
+ return ::tensorflow::Hash64Combine(algo_, tensor_ops_enabled_);
+}
+
bool DnnSupport::GetConvolveAlgorithms(
bool with_winograd_nonfused, int cc_major, int cc_minor,
std::vector<AlgorithmDesc>* out_algorithms) {
diff --git a/tensorflow/stream_executor/dnn.h b/tensorflow/stream_executor/dnn.h
index 39f21d8b10..0c2e083b39 100644
--- a/tensorflow/stream_executor/dnn.h
+++ b/tensorflow/stream_executor/dnn.h
@@ -712,6 +712,7 @@ class AlgorithmDesc {
return this->algo_ == other.algo_ &&
this->tensor_ops_enabled_ == other.tensor_ops_enabled_;
}
+ uint64 hash() const;
private:
enum { kDefaultAlgorithm = -1 };
@@ -2023,7 +2024,7 @@ class DnnSupport {
// is no longer in use.
virtual port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
createRnnDescriptor(int num_layers, int hidden_size, int input_size,
- dnn::RnnInputMode input_mode,
+ int batch_size, dnn::RnnInputMode input_mode,
dnn::RnnDirectionMode direction_mode,
dnn::RnnMode rnn_mode, dnn::DataType data_type,
const dnn::AlgorithmConfig& algorithm_config,
diff --git a/tensorflow/stream_executor/host/host_platform.cc b/tensorflow/stream_executor/host/host_platform.cc
index 00a17a05ed..a652b08b4f 100644
--- a/tensorflow/stream_executor/host/host_platform.cc
+++ b/tensorflow/stream_executor/host/host_platform.cc
@@ -66,8 +66,8 @@ port::StatusOr<StreamExecutor*> HostPlatform::GetExecutor(
port::StatusOr<std::unique_ptr<StreamExecutor>>
HostPlatform::GetUncachedExecutor(const StreamExecutorConfig& config) {
- auto executor = port::MakeUnique<StreamExecutor>(
- this, port::MakeUnique<HostExecutor>(config.plugin_config));
+ auto executor = MakeUnique<StreamExecutor>(
+ this, MakeUnique<HostExecutor>(config.plugin_config));
auto init_status = executor->Init(config.ordinal, config.device_options);
if (!init_status.ok()) {
return port::Status{
diff --git a/tensorflow/stream_executor/host_or_device_scalar.h b/tensorflow/stream_executor/host_or_device_scalar.h
new file mode 100644
index 0000000000..c9e3e14778
--- /dev/null
+++ b/tensorflow/stream_executor/host_or_device_scalar.h
@@ -0,0 +1,56 @@
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
+#define TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
+
+#include "tensorflow/core/platform/logging.h"
+#include "tensorflow/stream_executor/device_memory.h"
+
+namespace stream_executor {
+
+// Allows to represent a value that is either a host scalar or a scalar stored
+// on the GPU device.
+template <typename ElemT>
+class HostOrDeviceScalar {
+ public:
+ // Not marked as explicit because when using this constructor, we usually want
+ // to set this to a compile-time constant.
+ HostOrDeviceScalar(ElemT value) : value_(value), is_pointer_(false) {}
+ explicit HostOrDeviceScalar(const DeviceMemory<ElemT>& pointer)
+ : pointer_(pointer), is_pointer_(true) {
+ CHECK_EQ(1, pointer.ElementCount());
+ }
+
+ bool is_pointer() const { return is_pointer_; }
+ const DeviceMemory<ElemT>& pointer() const {
+ CHECK(is_pointer());
+ return pointer_;
+ }
+ const ElemT& value() const {
+ CHECK(!is_pointer());
+ return value_;
+ }
+
+ private:
+ union {
+ ElemT value_;
+ DeviceMemory<ElemT> pointer_;
+ };
+ bool is_pointer_;
+};
+
+} // namespace stream_executor
+#endif // TENSORFLOW_STREAM_EXECUTOR_HOST_OR_DEVICE_SCALAR_H_
diff --git a/tensorflow/stream_executor/lib/ptr_util.h b/tensorflow/stream_executor/lib/ptr_util.h
index 3f89794688..8f9f420fec 100644
--- a/tensorflow/stream_executor/lib/ptr_util.h
+++ b/tensorflow/stream_executor/lib/ptr_util.h
@@ -17,47 +17,11 @@ limitations under the License.
#define TENSORFLOW_STREAM_EXECUTOR_LIB_PTR_UTIL_H_
#include <memory>
+#include "tensorflow/core/util/ptr_util.h"
namespace stream_executor {
-namespace port {
-
-// Trait to select overloads and return types for MakeUnique.
-template <typename T>
-struct MakeUniqueResult {
- using scalar = std::unique_ptr<T>;
-};
-template <typename T>
-struct MakeUniqueResult<T[]> {
- using array = std::unique_ptr<T[]>;
-};
-template <typename T, size_t N>
-struct MakeUniqueResult<T[N]> {
- using invalid = void;
-};
-
-// MakeUnique<T>(...) is an early implementation of C++14 std::make_unique.
-// It is designed to be 100% compatible with std::make_unique so that the
-// eventual switchover will be a simple renaming operation.
-template <typename T, typename... Args>
-typename MakeUniqueResult<T>::scalar MakeUnique(Args&&... args) { // NOLINT
- return std::unique_ptr<T>(
- new T(std::forward<Args>(args)...)); // NOLINT(build/c++11)
-}
-
-// Overload for array of unknown bound.
-// The allocation of arrays needs to use the array form of new,
-// and cannot take element constructor arguments.
-template <typename T>
-typename MakeUniqueResult<T>::array MakeUnique(size_t n) {
- return std::unique_ptr<T>(new typename std::remove_extent<T>::type[n]());
-}
-
-// Reject arrays of known bound.
-template <typename T, typename... Args>
-typename MakeUniqueResult<T>::invalid MakeUnique(Args&&... /* args */) =
- delete; // NOLINT
-
-} // namespace port
+using tensorflow::MakeUnique;
+using tensorflow::WrapUnique;
} // namespace stream_executor
namespace perftools {
diff --git a/tensorflow/stream_executor/multi_platform_manager.h b/tensorflow/stream_executor/multi_platform_manager.h
index 672855d5fb..7e316879ca 100644
--- a/tensorflow/stream_executor/multi_platform_manager.h
+++ b/tensorflow/stream_executor/multi_platform_manager.h
@@ -29,7 +29,7 @@ limitations under the License.
// interface. Sample API usage:
//
// port::StatusOr<Platform*> platform_status =
-// gpu::MultiPlatformManager::PlatformWithName("OpenCL");
+// se::MultiPlatformManager::PlatformWithName("OpenCL");
// if (!platform_status.ok()) { ... }
// Platform* platform = platform_status.ValueOrDie();
// LOG(INFO) << platform->VisibleDeviceCount() << " devices visible";
diff --git a/tensorflow/stream_executor/stream.cc b/tensorflow/stream_executor/stream.cc
index f59d9a13ac..093f0c9306 100644
--- a/tensorflow/stream_executor/stream.cc
+++ b/tensorflow/stream_executor/stream.cc
@@ -20,6 +20,7 @@ limitations under the License.
#include "third_party/eigen3/Eigen/Core"
#include "tensorflow/stream_executor/blas.h"
#include "tensorflow/stream_executor/host_buffer.h"
+#include "tensorflow/stream_executor/host_or_device_scalar.h"
#include "tensorflow/stream_executor/lib/stacktrace.h"
#include "tensorflow/stream_executor/lib/strcat.h"
#include "tensorflow/stream_executor/platform.h"
@@ -133,6 +134,14 @@ string ToVlogString(float f) { return port::StrCat(f); }
string ToVlogString(double d) { return port::StrCat(d); }
+template <typename T>
+string ToVlogString(const HostOrDeviceScalar<T> &memory_or_constant) {
+ if (memory_or_constant.is_pointer()) {
+ return ToVlogString(memory_or_constant.pointer());
+ }
+ return ToVlogString(memory_or_constant.value());
+}
+
template <class T>
string ToVlogString(port::ArraySlice<T> elements) {
string str = port::StrCat(
@@ -3882,22 +3891,23 @@ Stream &Stream::ThenBlasGemmWithProfiling(
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,
- int lda, const DeviceMemory<Eigen::half> &b, int ldb,
- const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
- blas::ComputationType computation_type, blas::AlgorithmType algorithm,
- blas::ProfileResult *output_profile_result) {
+ uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
+ const DeviceMemory<Eigen::half> &a, int lda,
+ const DeviceMemory<Eigen::half> &b, int ldb,
+ const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
+ int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
PARAM(algorithm));
- ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
- uint64, const Eigen::half &,
- const DeviceMemory<Eigen::half> &, int,
- const DeviceMemory<Eigen::half> &, int,
- const Eigen::half &, DeviceMemory<Eigen::half> *, int,
- blas::ComputationType, blas::AlgorithmType>
+ ThenBlasWithProfileImpl<
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ const HostOrDeviceScalar<Eigen::half> &,
+ const DeviceMemory<Eigen::half> &, int, const DeviceMemory<Eigen::half> &,
+ int, const HostOrDeviceScalar<Eigen::half> &, DeviceMemory<Eigen::half> *,
+ int, blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
@@ -3906,18 +3916,20 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, int alpha, const DeviceMemory<int8> &a, int lda,
- const DeviceMemory<int8> &b, int ldb, int beta, DeviceMemory<int> *c,
- int ldc, blas::ComputationType computation_type,
- blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
+ uint64 k, const HostOrDeviceScalar<int> &alpha, const DeviceMemory<int8> &a,
+ int lda, const DeviceMemory<int8> &b, int ldb,
+ const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c, int ldc,
+ blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
PARAM(algorithm));
ThenBlasWithProfileImpl<
- blas::Transpose, blas::Transpose, uint64, uint64, uint64, int,
- const DeviceMemory<int8> &, int, const DeviceMemory<int8> &, int, int,
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ const HostOrDeviceScalar<int> &, const DeviceMemory<int8> &, int,
+ const DeviceMemory<int8> &, int, const HostOrDeviceScalar<int> &,
DeviceMemory<int> *, int, blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
@@ -3927,8 +3939,9 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, float alpha, const DeviceMemory<float> &a, int lda,
- const DeviceMemory<float> &b, int ldb, float beta, DeviceMemory<float> *c,
+ uint64 k, const HostOrDeviceScalar<float> &alpha,
+ const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
+ int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
@@ -3937,8 +3950,9 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
PARAM(algorithm));
ThenBlasWithProfileImpl<
- blas::Transpose, blas::Transpose, uint64, uint64, uint64, float,
- const DeviceMemory<float> &, int, const DeviceMemory<float> &, int, float,
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ const HostOrDeviceScalar<float> &, const DeviceMemory<float> &, int,
+ const DeviceMemory<float> &, int, const HostOrDeviceScalar<float> &,
DeviceMemory<float> *, int, blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
@@ -3948,32 +3962,35 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
- const DeviceMemory<double> &b, int ldb, double beta,
- DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
+ uint64 k, const HostOrDeviceScalar<double> &alpha,
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
+ int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
+ int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm, blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
PARAM(algorithm));
- ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
- uint64, double, const DeviceMemory<double> &, int,
- const DeviceMemory<double> &, int, double,
- DeviceMemory<double> *, int, blas::ComputationType,
- blas::AlgorithmType>
+ ThenBlasWithProfileImpl<
+ blas::Transpose, blas::Transpose, uint64, uint64, uint64,
+ const HostOrDeviceScalar<double> &, const DeviceMemory<double> &, int,
+ const DeviceMemory<double> &, int, const HostOrDeviceScalar<double> &,
+ DeviceMemory<double> *, int, blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
- m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
+ m, n, k, HostOrDeviceScalar<double>(alpha), a, lda, b, ldb,
+ HostOrDeviceScalar<double>(beta), c, ldc, computation_type,
algorithm, output_profile_result);
}
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, std::complex<float> alpha,
+ uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
- std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ const HostOrDeviceScalar<std::complex<float>> &beta,
+ DeviceMemory<std::complex<float>> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
@@ -3981,12 +3998,14 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
PARAM(algorithm));
- ThenBlasWithProfileImpl<
- blas::Transpose, blas::Transpose, uint64, uint64, uint64,
- std::complex<float>, const DeviceMemory<std::complex<float>> &, int,
- const DeviceMemory<std::complex<float>> &, int, std::complex<float>,
- DeviceMemory<std::complex<float>> *, int, blas::ComputationType,
- blas::AlgorithmType>
+ ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
+ uint64,
+ const HostOrDeviceScalar<std::complex<float>> &,
+ const DeviceMemory<std::complex<float>> &, int,
+ const DeviceMemory<std::complex<float>> &, int,
+ const HostOrDeviceScalar<std::complex<float>> &,
+ DeviceMemory<std::complex<float>> *, int,
+ blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
@@ -3995,10 +4014,11 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
Stream &Stream::ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, std::complex<double> alpha,
+ uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
- std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ const HostOrDeviceScalar<std::complex<double>> &beta,
+ DeviceMemory<std::complex<double>> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
@@ -4006,12 +4026,14 @@ Stream &Stream::ThenBlasGemmWithAlgorithm(
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(computation_type),
PARAM(algorithm));
- ThenBlasWithProfileImpl<
- blas::Transpose, blas::Transpose, uint64, uint64, uint64,
- std::complex<double>, const DeviceMemory<std::complex<double>> &, int,
- const DeviceMemory<std::complex<double>> &, int, std::complex<double>,
- DeviceMemory<std::complex<double>> *, int, blas::ComputationType,
- blas::AlgorithmType>
+ ThenBlasWithProfileImpl<blas::Transpose, blas::Transpose, uint64, uint64,
+ uint64,
+ const HostOrDeviceScalar<std::complex<double>> &,
+ const DeviceMemory<std::complex<double>> &, int,
+ const DeviceMemory<std::complex<double>> &, int,
+ const HostOrDeviceScalar<std::complex<double>> &,
+ DeviceMemory<std::complex<double>> *, int,
+ blas::ComputationType, blas::AlgorithmType>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmWithAlgorithm, transa, transb,
m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, computation_type,
diff --git a/tensorflow/stream_executor/stream.h b/tensorflow/stream_executor/stream.h
index d4a81440e9..3d1b011c57 100644
--- a/tensorflow/stream_executor/stream.h
+++ b/tensorflow/stream_executor/stream.h
@@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/stream_executor/dnn.h"
#include "tensorflow/stream_executor/event.h"
#include "tensorflow/stream_executor/fft.h"
+#include "tensorflow/stream_executor/host_or_device_scalar.h"
#include "tensorflow/stream_executor/kernel.h"
#include "tensorflow/stream_executor/launch_dim.h"
#include "tensorflow/stream_executor/lib/array_slice.h"
@@ -1422,50 +1423,53 @@ class Stream {
// See BlasSupport::DoBlasGemmWithAlgorithm.
Stream &ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, const Eigen::half &alpha, const DeviceMemory<Eigen::half> &a,
- int lda, const DeviceMemory<Eigen::half> &b, int ldb,
- const Eigen::half &beta, DeviceMemory<Eigen::half> *c, int ldc,
- blas::ComputationType computation_type, blas::AlgorithmType algorithm,
+ uint64 k, const HostOrDeviceScalar<Eigen::half> &alpha,
+ const DeviceMemory<Eigen::half> &a, int lda,
+ const DeviceMemory<Eigen::half> &b, int ldb,
+ const HostOrDeviceScalar<Eigen::half> &beta, DeviceMemory<Eigen::half> *c,
+ int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);
- Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
- blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, int alpha,
- const DeviceMemory<int8> &a, int lda,
- const DeviceMemory<int8> &b, int ldb,
- int beta, DeviceMemory<int> *c, int ldc,
- blas::ComputationType computation_type,
- blas::AlgorithmType algorithm,
- blas::ProfileResult *output_profile_result);
- Stream &ThenBlasGemmWithAlgorithm(blas::Transpose transa,
- blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, float alpha,
- const DeviceMemory<float> &a, int lda,
- const DeviceMemory<float> &b, int ldb,
- float beta, DeviceMemory<float> *c, int ldc,
- blas::ComputationType computation_type,
- blas::AlgorithmType algorithm,
- blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, double alpha, const DeviceMemory<double> &a, int lda,
- const DeviceMemory<double> &b, int ldb, double beta,
- DeviceMemory<double> *c, int ldc, blas::ComputationType computation_type,
+ uint64 k, const HostOrDeviceScalar<int> &alpha,
+ const DeviceMemory<int8> &a, int lda, const DeviceMemory<int8> &b,
+ int ldb, const HostOrDeviceScalar<int> &beta, DeviceMemory<int> *c,
+ int ldc, blas::ComputationType computation_type,
blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, std::complex<float> alpha,
+ uint64 k, const HostOrDeviceScalar<float> &alpha,
+ const DeviceMemory<float> &a, int lda, const DeviceMemory<float> &b,
+ int ldb, const HostOrDeviceScalar<float> &beta, DeviceMemory<float> *c,
+ int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, const HostOrDeviceScalar<double> &alpha,
+ const DeviceMemory<double> &a, int lda, const DeviceMemory<double> &b,
+ int ldb, const HostOrDeviceScalar<double> &beta, DeviceMemory<double> *c,
+ int ldc, blas::ComputationType computation_type,
+ blas::AlgorithmType algorithm,
+ blas::ProfileResult *output_profile_result);
+ Stream &ThenBlasGemmWithAlgorithm(
+ blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
+ uint64 k, const HostOrDeviceScalar<std::complex<float>> &alpha,
const DeviceMemory<std::complex<float>> &a, int lda,
const DeviceMemory<std::complex<float>> &b, int ldb,
- std::complex<float> beta, DeviceMemory<std::complex<float>> *c, int ldc,
+ const HostOrDeviceScalar<std::complex<float>> &beta,
+ DeviceMemory<std::complex<float>> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);
Stream &ThenBlasGemmWithAlgorithm(
blas::Transpose transa, blas::Transpose transb, uint64 m, uint64 n,
- uint64 k, std::complex<double> alpha,
+ uint64 k, const HostOrDeviceScalar<std::complex<double>> &alpha,
const DeviceMemory<std::complex<double>> &a, int lda,
const DeviceMemory<std::complex<double>> &b, int ldb,
- std::complex<double> beta, DeviceMemory<std::complex<double>> *c, int ldc,
+ const HostOrDeviceScalar<std::complex<double>> &beta,
+ DeviceMemory<std::complex<double>> *c, int ldc,
blas::ComputationType computation_type, blas::AlgorithmType algorithm,
blas::ProfileResult *output_profile_result);
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.cc b/tensorflow/stream_executor/stream_executor_pimpl.cc
index 2e1adeb31e..20579790ef 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.cc
+++ b/tensorflow/stream_executor/stream_executor_pimpl.cc
@@ -350,7 +350,7 @@ bool StreamExecutor::GetBlasGemmAlgorithms(
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>>
StreamExecutor::createRnnDescriptor(
- int num_layers, int hidden_size, int input_size,
+ int num_layers, int hidden_size, int input_size, int batch_size,
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
dnn::RnnMode rnn_mode, dnn::DataType data_type,
const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
@@ -361,8 +361,9 @@ StreamExecutor::createRnnDescriptor(
"Fail to find the dnn implementation.");
}
return dnn_support->createRnnDescriptor(
- num_layers, hidden_size, input_size, input_mode, direction_mode, rnn_mode,
- data_type, algorithm_config, dropout, seed, state_allocator);
+ num_layers, hidden_size, input_size, batch_size, input_mode,
+ direction_mode, rnn_mode, data_type, algorithm_config, dropout, seed,
+ state_allocator);
}
port::StatusOr<std::unique_ptr<dnn::RnnSequenceTensorDescriptor>>
diff --git a/tensorflow/stream_executor/stream_executor_pimpl.h b/tensorflow/stream_executor/stream_executor_pimpl.h
index 39af7115d8..ab6b00f660 100644
--- a/tensorflow/stream_executor/stream_executor_pimpl.h
+++ b/tensorflow/stream_executor/stream_executor_pimpl.h
@@ -373,7 +373,7 @@ class StreamExecutor {
// Create an RNN descriptor based on model shapes and configurations.
// The caller retains the ownership of the descriptor.
port::StatusOr<std::unique_ptr<dnn::RnnDescriptor>> createRnnDescriptor(
- int num_layers, int hidden_size, int input_size,
+ int num_layers, int hidden_size, int input_size, int batch_size,
dnn::RnnInputMode input_mode, dnn::RnnDirectionMode direction_mode,
dnn::RnnMode rnn_mode, dnn::DataType data_type,
const dnn::AlgorithmConfig &algorithm_config, float dropout, uint64 seed,
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 51e856bed0..a9ddd4fc60 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -37,20 +37,25 @@ def src_to_test_name(src):
def full_path(relative_paths):
return [native.package_name() + "/" + relative for relative in relative_paths]
+def _add_tfcore_prefix(src):
+ if src.startswith("//"):
+ return src
+ return "//tensorflow/core:" + src
+
# List of proto files for android builds
def tf_android_core_proto_sources(core_proto_sources_relative):
return [
- "//tensorflow/core:" + p for p in core_proto_sources_relative
+ _add_tfcore_prefix(p) for p in core_proto_sources_relative
]
# Returns the list of pb.h and proto.h headers that are generated for
# tf_android_core_proto_sources().
def tf_android_core_proto_headers(core_proto_sources_relative):
return ([
- "//tensorflow/core/" + p.replace(".proto", ".pb.h")
+ _add_tfcore_prefix(p).replace(":", "/").replace(".proto", ".pb.h")
for p in core_proto_sources_relative
] + [
- "//tensorflow/core/" + p.replace(".proto", ".proto.h")
+ _add_tfcore_prefix(p).replace(":", "/").replace(".proto", ".proto.h")
for p in core_proto_sources_relative
])
@@ -1672,22 +1677,36 @@ def cuda_py_tests(name,
#
# Return a struct with fields (hdrs, srcs) containing the names of the
# generated files.
-def tf_generate_proto_text_sources(name, srcs_relative_dir, srcs):
+def tf_generate_proto_text_sources(name, srcs_relative_dir, srcs, protodeps=[], deps=[], visibility=None):
out_hdrs = (
[p.replace(".proto", ".pb_text.h")
for p in srcs] + [p.replace(".proto", ".pb_text-impl.h") for p in srcs])
out_srcs = [p.replace(".proto", ".pb_text.cc") for p in srcs]
native.genrule(
- name=name,
- srcs=srcs + [clean_dep("//tensorflow/tools/proto_text:placeholder.txt")],
+ name=name + "_srcs",
+ srcs=srcs + protodeps + [clean_dep("//tensorflow/tools/proto_text:placeholder.txt")],
outs=out_hdrs + out_srcs,
+ visibility=visibility,
cmd=
"$(location //tensorflow/tools/proto_text:gen_proto_text_functions) "
+ "$(@D) " + srcs_relative_dir + " $(SRCS)",
tools=[
clean_dep("//tensorflow/tools/proto_text:gen_proto_text_functions")
],)
- return struct(hdrs=out_hdrs, srcs=out_srcs)
+
+ native.filegroup(
+ name=name + "_hdrs",
+ srcs=out_hdrs,
+ visibility=visibility,
+ )
+
+ native.cc_library(
+ name=name,
+ srcs=out_srcs,
+ hdrs=out_hdrs,
+ visibility=visibility,
+ deps = deps,
+ )
def tf_genrule_cmd_append_to_srcs(to_append):
return ("cat $(SRCS) > $(@) && " + "echo >> $(@) && " + "echo " + to_append +
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
index cdf2da712f..cee76bdc1d 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-model.pbtxt
@@ -239,7 +239,7 @@ tf_class {
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
index 5c2c29e60f..02718cb5f9 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.-sequential.pbtxt
@@ -256,7 +256,7 @@ tf_class {
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
index b3f3f16922..dd78384005 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-model.pbtxt
@@ -239,7 +239,7 @@ tf_class {
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
index 4ac6811bac..9fcb03f47e 100644
--- a/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.keras.models.-sequential.pbtxt
@@ -256,7 +256,7 @@ tf_class {
}
member_method {
name: "save_weights"
- argspec: "args=[\'self\', \'filepath\', \'overwrite\'], varargs=None, keywords=None, defaults=[\'True\'], "
+ argspec: "args=[\'self\', \'filepath\', \'overwrite\', \'save_format\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], "
}
member_method {
name: "set_weights"
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt
new file mode 100644
index 0000000000..3b33f3da97
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.linalg.LinearOperatorCirculant.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant.pbtxt
new file mode 100644
index 0000000000..de917706d5
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant.pbtxt
@@ -0,0 +1,155 @@
+path: "tensorflow.linalg.LinearOperatorCirculant"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_circulant.LinearOperatorCirculant\'>"
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_circulant._BaseLinearOperatorCirculant\'>"
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "batch_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "block_depth"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "block_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "domain_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph_parents"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_non_singular"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_positive_definite"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_self_adjoint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_square"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "range_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "spectrum"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "tensor_rank"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'spectrum\', \'input_output_dtype\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\', \'None\', \'None\', \'True\', \'LinearOperatorCirculant\'], "
+ }
+ member_method {
+ name: "add_to_tensor"
+ argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
+ }
+ member_method {
+ name: "assert_hermitian_spectrum"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_hermitian_spectrum\'], "
+ }
+ member_method {
+ name: "assert_non_singular"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
+ }
+ member_method {
+ name: "assert_positive_definite"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], "
+ }
+ member_method {
+ name: "assert_self_adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], "
+ }
+ member_method {
+ name: "batch_shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], "
+ }
+ member_method {
+ name: "block_shape_tensor"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "convolution_kernel"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'convolution_kernel\'], "
+ }
+ member_method {
+ name: "determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], "
+ }
+ member_method {
+ name: "diag_part"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], "
+ }
+ member_method {
+ name: "domain_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], "
+ }
+ member_method {
+ name: "log_abs_determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], "
+ }
+ member_method {
+ name: "matmul"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], "
+ }
+ member_method {
+ name: "matvec"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], "
+ }
+ member_method {
+ name: "range_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], "
+ }
+ member_method {
+ name: "shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], "
+ }
+ member_method {
+ name: "solve"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], "
+ }
+ member_method {
+ name: "solvevec"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], "
+ }
+ member_method {
+ name: "tensor_rank_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], "
+ }
+ member_method {
+ name: "to_dense"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], "
+ }
+ member_method {
+ name: "trace"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt
new file mode 100644
index 0000000000..591bc9631a
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant2-d.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.linalg.LinearOperatorCirculant2D.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt
new file mode 100644
index 0000000000..c4e6a21c3a
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant2-d.pbtxt
@@ -0,0 +1,155 @@
+path: "tensorflow.linalg.LinearOperatorCirculant2D"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_circulant.LinearOperatorCirculant2D\'>"
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_circulant._BaseLinearOperatorCirculant\'>"
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "batch_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "block_depth"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "block_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "domain_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph_parents"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_non_singular"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_positive_definite"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_self_adjoint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_square"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "range_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "spectrum"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "tensor_rank"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'spectrum\', \'input_output_dtype\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\', \'None\', \'None\', \'True\', \'LinearOperatorCirculant2D\'], "
+ }
+ member_method {
+ name: "add_to_tensor"
+ argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
+ }
+ member_method {
+ name: "assert_hermitian_spectrum"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_hermitian_spectrum\'], "
+ }
+ member_method {
+ name: "assert_non_singular"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
+ }
+ member_method {
+ name: "assert_positive_definite"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], "
+ }
+ member_method {
+ name: "assert_self_adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], "
+ }
+ member_method {
+ name: "batch_shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], "
+ }
+ member_method {
+ name: "block_shape_tensor"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "convolution_kernel"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'convolution_kernel\'], "
+ }
+ member_method {
+ name: "determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], "
+ }
+ member_method {
+ name: "diag_part"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], "
+ }
+ member_method {
+ name: "domain_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], "
+ }
+ member_method {
+ name: "log_abs_determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], "
+ }
+ member_method {
+ name: "matmul"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], "
+ }
+ member_method {
+ name: "matvec"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], "
+ }
+ member_method {
+ name: "range_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], "
+ }
+ member_method {
+ name: "shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], "
+ }
+ member_method {
+ name: "solve"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], "
+ }
+ member_method {
+ name: "solvevec"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], "
+ }
+ member_method {
+ name: "tensor_rank_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], "
+ }
+ member_method {
+ name: "to_dense"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], "
+ }
+ member_method {
+ name: "trace"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt
new file mode 100644
index 0000000000..d643139a53
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant3-d.__metaclass__.pbtxt
@@ -0,0 +1,14 @@
+path: "tensorflow.linalg.LinearOperatorCirculant3D.__metaclass__"
+tf_class {
+ is_instance: "<class \'abc.ABCMeta\'>"
+ member_method {
+ name: "__init__"
+ }
+ member_method {
+ name: "mro"
+ }
+ member_method {
+ name: "register"
+ argspec: "args=[\'cls\', \'subclass\'], varargs=None, keywords=None, defaults=None"
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt
new file mode 100644
index 0000000000..2e085a8e28
--- /dev/null
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.-linear-operator-circulant3-d.pbtxt
@@ -0,0 +1,155 @@
+path: "tensorflow.linalg.LinearOperatorCirculant3D"
+tf_class {
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_circulant.LinearOperatorCirculant3D\'>"
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator_circulant._BaseLinearOperatorCirculant\'>"
+ is_instance: "<class \'tensorflow.python.ops.linalg.linear_operator.LinearOperator\'>"
+ is_instance: "<type \'object\'>"
+ member {
+ name: "batch_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "block_depth"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "block_shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "domain_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "dtype"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "graph_parents"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_non_singular"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_positive_definite"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_self_adjoint"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "is_square"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "name"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "range_dimension"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "shape"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "spectrum"
+ mtype: "<type \'property\'>"
+ }
+ member {
+ name: "tensor_rank"
+ mtype: "<type \'property\'>"
+ }
+ member_method {
+ name: "__init__"
+ argspec: "args=[\'self\', \'spectrum\', \'input_output_dtype\', \'is_non_singular\', \'is_self_adjoint\', \'is_positive_definite\', \'is_square\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'complex64\'>\", \'None\', \'None\', \'None\', \'True\', \'LinearOperatorCirculant3D\'], "
+ }
+ member_method {
+ name: "add_to_tensor"
+ argspec: "args=[\'self\', \'x\', \'name\'], varargs=None, keywords=None, defaults=[\'add_to_tensor\'], "
+ }
+ member_method {
+ name: "assert_hermitian_spectrum"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_hermitian_spectrum\'], "
+ }
+ member_method {
+ name: "assert_non_singular"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_non_singular\'], "
+ }
+ member_method {
+ name: "assert_positive_definite"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_positive_definite\'], "
+ }
+ member_method {
+ name: "assert_self_adjoint"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'assert_self_adjoint\'], "
+ }
+ member_method {
+ name: "batch_shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'batch_shape_tensor\'], "
+ }
+ member_method {
+ name: "block_shape_tensor"
+ argspec: "args=[\'self\'], varargs=None, keywords=None, defaults=None"
+ }
+ member_method {
+ name: "convolution_kernel"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'convolution_kernel\'], "
+ }
+ member_method {
+ name: "determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'det\'], "
+ }
+ member_method {
+ name: "diag_part"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'diag_part\'], "
+ }
+ member_method {
+ name: "domain_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'domain_dimension_tensor\'], "
+ }
+ member_method {
+ name: "log_abs_determinant"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'log_abs_det\'], "
+ }
+ member_method {
+ name: "matmul"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'matmul\'], "
+ }
+ member_method {
+ name: "matvec"
+ argspec: "args=[\'self\', \'x\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'matvec\'], "
+ }
+ member_method {
+ name: "range_dimension_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'range_dimension_tensor\'], "
+ }
+ member_method {
+ name: "shape_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'shape_tensor\'], "
+ }
+ member_method {
+ name: "solve"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'adjoint_arg\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'solve\'], "
+ }
+ member_method {
+ name: "solvevec"
+ argspec: "args=[\'self\', \'rhs\', \'adjoint\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'solve\'], "
+ }
+ member_method {
+ name: "tensor_rank_tensor"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'tensor_rank_tensor\'], "
+ }
+ member_method {
+ name: "to_dense"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'to_dense\'], "
+ }
+ member_method {
+ name: "trace"
+ argspec: "args=[\'self\', \'name\'], varargs=None, keywords=None, defaults=[\'trace\'], "
+ }
+}
diff --git a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
index 1d9c0c0f6d..7a5c533872 100644
--- a/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.linalg.pbtxt
@@ -5,6 +5,18 @@ tf_module {
mtype: "<class \'abc.ABCMeta\'>"
}
member {
+ name: "LinearOperatorCirculant"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "LinearOperatorCirculant2D"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
+ name: "LinearOperatorCirculant3D"
+ mtype: "<class \'abc.ABCMeta\'>"
+ }
+ member {
name: "LinearOperatorComposition"
mtype: "<class \'abc.ABCMeta\'>"
}
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index 9627475d84..8e8b2191e5 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -101,6 +101,7 @@ do_pylint() {
"^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\
"^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\
"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
+"^tensorflow/python/keras/_impl/keras/engine/base_layer.py.*\[E0203.*access-member-before-definition "\
"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned"
diff --git a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
index d654b433e7..582188fc00 100644
--- a/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
+++ b/tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh
@@ -140,6 +140,13 @@ function run_configure_for_gpu_build {
echo "" | ./configure
}
+function set_gcs_remote_cache_options {
+ echo "build --experimental_remote_spawn_cache" >> "${TMP_BAZELRC}"
+ echo "build --experimental_remote_platform_override='properties:{name:\"build\" value:\"windows-x64\"}'" >> "${TMP_BAZELRC}"
+ echo "build --remote_http_cache=https://storage.googleapis.com/$GCS_BUCKET_NAME" >> "${TMP_BAZELRC}"
+ echo "build --google_credentials=$GOOGLE_CLOUD_CREDENTIAL" >> "${TMP_BAZELRC}"
+}
+
function create_python_test_dir() {
rm -rf "$1"
mkdir -p "$1"
diff --git a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
index 5e9ae497e1..a2300811bb 100644
--- a/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
+++ b/tensorflow/tools/ci_build/windows/cpu/pip/build_tf_windows.sh
@@ -42,20 +42,36 @@ source "tensorflow/tools/ci_build/windows/bazel/common_env.sh" \
source "tensorflow/tools/ci_build/windows/bazel/bazel_test_lib.sh" \
|| { echo "Failed to source bazel_test_lib.sh" >&2; exit 1; }
+# Recreate an empty bazelrc file under source root
+export TMP_BAZELRC=.tmp.bazelrc
+rm -f "${TMP_BAZELRC}"
+touch "${TMP_BAZELRC}"
+
+function cleanup {
+ # Remove all options in .tmp.bazelrc
+ echo "" > "${TMP_BAZELRC}"
+}
+trap cleanup EXIT
+
skip_test=0
for ARG in "$@"; do
if [[ "$ARG" == --skip_test ]]; then
skip_test=1
+ elif [[ "$ARG" == --enable_gcs_remote_cache ]]; then
+ set_gcs_remote_cache_options
fi
done
-run_configure_for_cpu_build
-
# --define=override_eigen_strong_inline=true speeds up the compiling of conv_grad_ops_3d.cc and conv_ops_3d.cc
# by 20 minutes. See https://github.com/tensorflow/tensorflow/issues/10521
-BUILD_OPTS="--define=override_eigen_strong_inline=true"
-bazel build -c opt $BUILD_OPTS tensorflow/tools/pip_package:build_pip_package || exit $?
+echo "build --define=override_eigen_strong_inline=true" >> "${TMP_BAZELRC}"
+
+echo "import %workspace%/${TMP_BAZELRC}" >> .bazelrc
+
+run_configure_for_cpu_build
+
+bazel build --announce_rc -c opt tensorflow/tools/pip_package:build_pip_package || exit $?
if [[ "$skip_test" == 1 ]]; then
exit 0
@@ -71,12 +87,16 @@ create_python_test_dir "${PY_TEST_DIR}"
PIP_NAME=$(ls ${PY_TEST_DIR}/tensorflow-*.whl)
reinstall_tensorflow_pip ${PIP_NAME}
+# NUMBER_OF_PROCESSORS is predefined on Windows
+N_JOBS="${NUMBER_OF_PROCESSORS}"
+
# Define no_tensorflow_py_deps=true so that every py_test has no deps anymore,
# which will result testing system installed tensorflow
-bazel test -c opt $BUILD_OPTS -k --test_output=errors \
+bazel test -c opt -k --test_output=errors \
--define=no_tensorflow_py_deps=true --test_lang_filters=py \
--test_tag_filters=-no_pip,-no_windows,-no_oss \
--build_tag_filters=-no_pip,-no_windows,-no_oss --build_tests_only \
+ --jobs="${N_JOBS}" --test_timeout="300,450,1200,3600" \
--flaky_test_attempts=3 \
//${PY_TEST_DIR}/tensorflow/python/... \
//${PY_TEST_DIR}/tensorflow/contrib/...
diff --git a/tensorflow/tools/git/gen_git_source.py b/tensorflow/tools/git/gen_git_source.py
index a8d878cf16..73dee98bae 100755
--- a/tensorflow/tools/git/gen_git_source.py
+++ b/tensorflow/tools/git/gen_git_source.py
@@ -164,17 +164,14 @@ def get_git_version(git_base_path, git_tag_override):
"git", str("--git-dir=%s/.git" % git_base_path),
str("--work-tree=" + git_base_path), "describe", "--long", "--tags"
]).strip())
- if git_tag_override and val:
+ if git_tag_override:
split_val = val.split("-")
- if len(split_val) < 3:
+ if len(split_val) != 3:
raise Exception(
("Expected git version in format 'TAG-COMMITS AFTER TAG-HASH' "
"but got '%s'") % val)
- # There might be "-" in the tag name. But we can be sure that the final
- # two "-" are those inserted by the git describe command.
- abbrev_commit = split_val[-1]
- val = bytes(
- "-".join([git_tag_override, "0", abbrev_commit]))
+ split_val[0] = git_tag_override
+ val = bytes("-".join(split_val))
return val if val else unknown_label
except (subprocess.CalledProcessError, OSError):
return unknown_label
@@ -336,4 +333,4 @@ elif args.raw_generate is not None:
raw_generate(args.raw_generate, source_path, args.git_tag_override)
else:
raise RuntimeError("--configure or --generate or --raw_generate "
- "must be used") \ No newline at end of file
+ "must be used")
diff --git a/tensorflow/tools/pip_package/BUILD b/tensorflow/tools/pip_package/BUILD
index 7b508f87ab..677ea65edd 100644
--- a/tensorflow/tools/pip_package/BUILD
+++ b/tensorflow/tools/pip_package/BUILD
@@ -63,6 +63,7 @@ COMMON_PIP_DEPS = [
"//tensorflow/contrib/autograph/pyct/static_analysis:static_analysis",
"//tensorflow/contrib/boosted_trees:boosted_trees_pip",
"//tensorflow/contrib/cluster_resolver:cluster_resolver_pip",
+ "//tensorflow/contrib/constrained_optimization:constrained_optimization_pip",
"//tensorflow/contrib/data/python/kernel_tests:dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:contrib_op_loader",
"//tensorflow/contrib/eager/python/examples:examples_pip",
diff --git a/tensorflow/tools/proto_text/BUILD b/tensorflow/tools/proto_text/BUILD
index ef7bfdd3c9..31e8fb9120 100644
--- a/tensorflow/tools/proto_text/BUILD
+++ b/tensorflow/tools/proto_text/BUILD
@@ -75,9 +75,14 @@ tf_proto_library_cc(
)
tf_generate_proto_text_sources(
- name = "test_proto_text_srcs",
+ name = "test_proto_text",
srcs = ["test.proto"],
srcs_relative_dir = "tensorflow/tools/proto_text/",
+ deps = [
+ ":test_proto_cc",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
)
tf_cc_test(
diff --git a/tensorflow/tools/proto_text/gen_proto_text_functions.cc b/tensorflow/tools/proto_text/gen_proto_text_functions.cc
index f0bb59acf8..234afe879b 100644
--- a/tensorflow/tools/proto_text/gen_proto_text_functions.cc
+++ b/tensorflow/tools/proto_text/gen_proto_text_functions.cc
@@ -130,7 +130,11 @@ int MainImpl(int argc, char** argv) {
const string path = output_root + "/" + proto_path_no_suffix + suffix;
FILE* f = fopen(path.c_str(), "w");
- if (f == nullptr) return -1;
+ if (f == nullptr) {
+ // We don't expect this output to be generated. It was specified in the
+ // list of sources solely to satisfy a proto import dependency.
+ continue;
+ }
if (fwrite(data.c_str(), 1, data.size(), f) != data.size()) {
fclose(f);
return -1;