aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorGravatar Sourabh Bajaj <sourabhbajaj@google.com>2017-11-30 21:15:53 -0800
committerGravatar Sourabh Bajaj <sourabhbajaj@google.com>2017-11-30 21:15:53 -0800
commit2c4e8fcf05d3e22b0758a6f63a423b9319f9c19d (patch)
tree5d58a4760b28af0f7ca22b2620a9fb6fc940d335
parentc57796f366a0545a04424caeff1b27bbd629f8f0 (diff)
parent1ec61fafe13e5edce6e45d5a67e960efb9df618a (diff)
Fix merge conflicts
-rw-r--r--tensorflow/c/c_api.cc47
-rw-r--r--tensorflow/c/c_api_internal.h21
-rw-r--r--tensorflow/c/eager/tape.h2
-rw-r--r--tensorflow/c/python_api.cc10
-rw-r--r--tensorflow/compiler/aot/codegen.cc8
-rw-r--r--tensorflow/compiler/aot/codegen_test_h.golden8
-rw-r--r--tensorflow/compiler/aot/tests/tfcompile_test.cc4
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass.cc7
-rw-r--r--tensorflow/compiler/jit/mark_for_compilation_pass_test.cc27
-rw-r--r--tensorflow/compiler/tests/BUILD14
-rw-r--r--tensorflow/compiler/tests/categorical_op_test.py38
-rw-r--r--tensorflow/compiler/tests/scan_ops_test.py229
-rw-r--r--tensorflow/compiler/tf2xla/const_analysis.cc2
-rw-r--r--tensorflow/compiler/tf2xla/dump_graph.cc15
-rw-r--r--tensorflow/compiler/tf2xla/functionalize_control_flow.cc4
-rw-r--r--tensorflow/compiler/tf2xla/kernels/BUILD1
-rw-r--r--tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc74
-rw-r--r--tensorflow/compiler/tf2xla/kernels/conv_ops.cc252
-rw-r--r--tensorflow/compiler/tf2xla/kernels/matmul_op.cc9
-rw-r--r--tensorflow/compiler/tf2xla/kernels/scan_ops.cc141
-rw-r--r--tensorflow/compiler/tf2xla/kernels/shape_op.cc76
-rw-r--r--tensorflow/compiler/tf2xla/kernels/variable_ops.cc27
-rw-r--r--tensorflow/compiler/tf2xla/lib/util.cc3
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc15
-rw-r--r--tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h53
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.cc14
-rw-r--r--tensorflow/compiler/tf2xla/xla_context.h8
-rw-r--r--tensorflow/compiler/tf2xla/xla_helpers.cc7
-rw-r--r--tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc25
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.cc5
-rw-r--r--tensorflow/compiler/tf2xla/xla_op_kernel.h5
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.cc36
-rw-r--r--tensorflow/compiler/xla/client/computation_builder.h5
-rw-r--r--tensorflow/compiler/xla/literal_util.cc6
-rw-r--r--tensorflow/compiler/xla/literal_util.h6
-rw-r--r--tensorflow/compiler/xla/reference_util.cc18
-rw-r--r--tensorflow/compiler/xla/reference_util.h8
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier.cc17
-rw-r--r--tensorflow/compiler/xla/service/algebraic_simplifier_test.cc6
-rw-r--r--tensorflow/compiler/xla/service/batchnorm_rewriter.cc328
-rw-r--r--tensorflow/compiler/xla/service/buffer_assignment_test.cc16
-rw-r--r--tensorflow/compiler/xla/service/buffer_liveness_test.cc21
-rw-r--r--tensorflow/compiler/xla/service/cpu/BUILD34
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_compiler.cc17
-rw-r--r--tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc36
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc115
-rw-r--r--tensorflow/compiler/xla/service/cpu/dot_op_emitter.h16
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.cc277
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_emitter.h31
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.cc195
-rw-r--r--tensorflow/compiler/xla/service/cpu/ir_function.h109
-rw-r--r--tensorflow/compiler/xla/service/cpu/layout_assignment.cc6
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc76
-rw-r--r--tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h75
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_folding.cc55
-rw-r--r--tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc42
-rw-r--r--tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc5
-rw-r--r--tensorflow/compiler/xla/service/gpu/pad_insertion.cc16
-rw-r--r--tensorflow/compiler/xla/service/graphviz_example.cc5
-rw-r--r--tensorflow/compiler/xla/service/heap_simulator_test.cc24
-rw-r--r--tensorflow/compiler/xla/service/hlo.proto3
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_computation.h16
-rw-r--r--tensorflow/compiler/xla/service/hlo_cost_analysis.cc5
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_dce_test.cc20
-rw-r--r--tensorflow/compiler/xla/service/hlo_evaluator_test.cc21
-rw-r--r--tensorflow/compiler/xla/service/hlo_execution_profile.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.cc56
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction.h18
-rw-r--r--tensorflow/compiler/xla/service/hlo_instruction_test.cc12
-rw-r--r--tensorflow/compiler/xla/service/hlo_module.cc7
-rw-r--r--tensorflow/compiler/xla/service/hlo_profile_printer.h8
-rw-r--r--tensorflow/compiler/xla/service/hlo_rematerialization.cc9
-rw-r--r--tensorflow/compiler/xla/service/hlo_sharding.cc3
-rw-r--r--tensorflow/compiler/xla/service/hlo_verifier.cc16
-rw-r--r--tensorflow/compiler/xla/service/liveness_util_test.cc10
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc44
-rw-r--r--tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h32
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.cc32
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer.h13
-rw-r--r--tensorflow/compiler/xla/service/name_uniquer_test.cc26
-rw-r--r--tensorflow/compiler/xla/service/service.cc3
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.cc156
-rw-r--r--tensorflow/compiler/xla/service/shape_inference.h6
-rw-r--r--tensorflow/compiler/xla/service/shape_inference_test.cc177
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding.cc9
-rw-r--r--tensorflow/compiler/xla/service/transpose_folding_test.cc91
-rw-r--r--tensorflow/compiler/xla/service/user_computation.cc57
-rw-r--r--tensorflow/compiler/xla/service/user_computation.h4
-rw-r--r--tensorflow/compiler/xla/service/user_computation_test.cc45
-rw-r--r--tensorflow/compiler/xla/service/while_loop_simplifier.cc4
-rw-r--r--tensorflow/compiler/xla/shape_layout.cc8
-rw-r--r--tensorflow/compiler/xla/shape_layout.h13
-rw-r--r--tensorflow/compiler/xla/shape_util.cc4
-rw-r--r--tensorflow/compiler/xla/shape_util.h2
-rw-r--r--tensorflow/compiler/xla/statusor_test.cc14
-rw-r--r--tensorflow/compiler/xla/tests/bfloat16_test.cc9
-rw-r--r--tensorflow/compiler/xla/tests/client_library_test_base.h2
-rw-r--r--tensorflow/compiler/xla/tests/dot_operation_test.cc20
-rw-r--r--tensorflow/compiler/xla/tests/multioutput_fusion_test.cc12
-rw-r--r--tensorflow/compiler/xla/tests/reduce_window_test.cc112
-rw-r--r--tensorflow/compiler/xla/tests/slice_test.cc90
-rw-r--r--tensorflow/compiler/xla/tests/while_test.cc6
-rw-r--r--tensorflow/compiler/xla/util.cc6
-rw-r--r--tensorflow/compiler/xla/xla_data.proto23
-rw-r--r--tensorflow/contrib/android/README.md6
-rw-r--r--tensorflow/contrib/android/cmake/CMakeLists.txt2
-rw-r--r--tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h2
-rw-r--r--tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc1
-rw-r--r--tensorflow/contrib/batching/basic_batch_scheduler.h4
-rw-r--r--tensorflow/contrib/batching/basic_batch_scheduler_test.cc1
-rw-r--r--tensorflow/contrib/batching/batch_scheduler.h4
-rw-r--r--tensorflow/contrib/batching/shared_batch_scheduler.h5
-rw-r--r--tensorflow/contrib/batching/shared_batch_scheduler_test.cc1
-rw-r--r--tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc2
-rw-r--r--tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py3
-rw-r--r--tensorflow/contrib/cmake/external/nsync.cmake2
-rw-r--r--tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt9
-rw-r--r--tensorflow/contrib/cmake/tf_core_cpu.cmake2
-rw-r--r--tensorflow/contrib/cmake/tf_core_framework.cmake2
-rwxr-xr-xtensorflow/contrib/cmake/tf_python.cmake2
-rw-r--r--tensorflow/contrib/cmake/tf_tests.cmake2
-rw-r--r--tensorflow/contrib/copy_graph/python/util/copy_elements.py1
-rw-r--r--tensorflow/contrib/data/BUILD1
-rw-r--r--tensorflow/contrib/data/__init__.py1
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/BUILD6
-rw-r--r--tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py80
-rw-r--r--tensorflow/contrib/data/python/ops/BUILD32
-rw-r--r--tensorflow/contrib/data/python/ops/random_ops.py67
-rw-r--r--tensorflow/contrib/data/python/ops/shuffle_ops.py69
-rw-r--r--tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py342
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py28
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py4
-rw-r--r--tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py279
-rw-r--r--tensorflow/contrib/distributions/python/ops/cauchy.py22
-rw-r--r--tensorflow/contrib/distributions/python/ops/deterministic.py6
-rw-r--r--tensorflow/contrib/distributions/python/ops/gumbel.py8
-rw-r--r--tensorflow/contrib/distributions/python/ops/independent.py10
-rw-r--r--tensorflow/contrib/distributions/python/ops/inverse_gamma.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/logistic.py13
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture.py10
-rw-r--r--tensorflow/contrib/distributions/python/ops/mixture_same_family.py16
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag.py8
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py6
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py6
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py11
-rw-r--r--tensorflow/contrib/distributions/python/ops/mvn_tril.py13
-rw-r--r--tensorflow/contrib/distributions/python/ops/poisson_lognormal.py5
-rw-r--r--tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py11
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py7
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py11
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py8
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py11
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py2
-rw-r--r--tensorflow/contrib/distributions/python/ops/vector_student_t.py6
-rw-r--r--tensorflow/contrib/eager/python/BUILD1
-rw-r--r--tensorflow/contrib/eager/python/network.py24
-rw-r--r--tensorflow/contrib/eager/python/network_test.py29
-rw-r--r--tensorflow/contrib/estimator/BUILD3
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py98
-rw-r--r--tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py130
-rw-r--r--tensorflow/contrib/ffmpeg/BUILD5
-rw-r--r--tensorflow/contrib/ffmpeg/__init__.py1
-rw-r--r--tensorflow/contrib/ffmpeg/decode_video_op_test.py13
-rw-r--r--tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc5
-rw-r--r--tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc1
-rw-r--r--tensorflow/contrib/ffmpeg/ffmpeg_ops.py1
-rw-r--r--tensorflow/contrib/framework/python/framework/graph_util.py2
-rw-r--r--tensorflow/contrib/framework/python/framework/graph_util_test.py5
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc2
-rw-r--r--tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h9
-rw-r--r--tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc6
-rw-r--r--tensorflow/contrib/graph_editor/transform.py3
-rw-r--r--tensorflow/contrib/kfac/python/ops/fisher_factors.py6
-rw-r--r--tensorflow/contrib/layers/python/layers/feature_column.py4
-rw-r--r--tensorflow/contrib/layers/python/layers/layers.py5
-rw-r--r--tensorflow/contrib/layers/python/layers/layers_test.py17
-rw-r--r--tensorflow/contrib/learn/BUILD1
-rwxr-xr-xtensorflow/contrib/lite/build_ios_universal_lib.sh15
-rwxr-xr-xtensorflow/contrib/lite/download_dependencies.sh2
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h4
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm18
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h4
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm40
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h6
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm35
-rw-r--r--tensorflow/contrib/lite/examples/ios/simple/main.mm2
-rw-r--r--tensorflow/contrib/lite/ios_makefile.inc78
-rw-r--r--tensorflow/contrib/lite/java/demo/README.md8
-rw-r--r--tensorflow/contrib/lite/models/speech_asr_am_model_test.cc (renamed from tensorflow/contrib/lite/models/speech_terse_am_model_test.cc)10
-rw-r--r--tensorflow/contrib/lite/models/speech_asr_lm_model_test.cc (renamed from tensorflow/contrib/lite/models/speech_terse_lm_model_test.cc)8
-rw-r--r--tensorflow/contrib/lite/models/testdata/g3doc/README.md27
-rw-r--r--tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h6
-rw-r--r--tensorflow/contrib/lite/python/lite.py2
-rw-r--r--tensorflow/contrib/lite/tools/benchmark_model.cc7
-rw-r--r--tensorflow/contrib/lite/tools/mutable_op_resolver.h17
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager.cc32
-rw-r--r--tensorflow/contrib/nccl/kernels/nccl_manager_test.cc2
-rw-r--r--tensorflow/contrib/summary/summary.py1
-rw-r--r--tensorflow/contrib/summary/summary_ops.py21
-rw-r--r--tensorflow/contrib/summary/summary_ops_test.py27
-rw-r--r--tensorflow/contrib/summary/summary_test_util.py2
-rw-r--r--tensorflow/contrib/tpu/BUILD15
-rw-r--r--tensorflow/contrib/tpu/ops/cross_replica_ops.cc2
-rw-r--r--tensorflow/contrib/tpu/python/tpu/test_util.py296
-rw-r--r--tensorflow/contrib/tpu/python/tpu/tpu_estimator.py1
-rw-r--r--tensorflow/core/BUILD10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt12
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt43
-rw-r--r--tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt10
-rw-r--r--tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt18
-rw-r--r--tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt69
-rw-r--r--tensorflow/core/framework/bfloat16_test.cc12
-rw-r--r--tensorflow/core/framework/common_shape_fns.cc30
-rw-r--r--tensorflow/core/framework/common_shape_fns_test.cc106
-rw-r--r--tensorflow/core/framework/numeric_types.h49
-rw-r--r--tensorflow/core/framework/op_def_builder_test.cc15
-rw-r--r--tensorflow/core/framework/types.cc24
-rw-r--r--tensorflow/core/grappler/costs/BUILD1
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.cc157
-rw-r--r--tensorflow/core/grappler/costs/graph_properties.h37
-rw-r--r--tensorflow/core/grappler/costs/graph_properties_test.cc91
-rw-r--r--tensorflow/core/grappler/costs/virtual_scheduler.cc2
-rw-r--r--tensorflow/core/grappler/grappler_item_builder.cc4
-rw-r--r--tensorflow/core/grappler/optimizers/BUILD6
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc107
-rw-r--r--tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc106
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.cc367
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding.h21
-rw-r--r--tensorflow/core/grappler/optimizers/constant_folding_test.cc159
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer.cc321
-rw-r--r--tensorflow/core/grappler/optimizers/layout_optimizer_test.cc203
-rw-r--r--tensorflow/core/grappler/optimizers/memory_optimizer.cc2
-rw-r--r--tensorflow/core/grappler/optimizers/static_schedule.cc4
-rw-r--r--tensorflow/core/kernels/BUILD21
-rw-r--r--tensorflow/core/kernels/conv_grad_filter_ops.cc93
-rw-r--r--tensorflow/core/kernels/conv_grad_input_ops.cc97
-rw-r--r--tensorflow/core/kernels/conv_grad_ops.h16
-rw-r--r--tensorflow/core/kernels/conv_grad_ops_3d.cc4
-rw-r--r--tensorflow/core/kernels/conv_ops.cc113
-rw-r--r--tensorflow/core/kernels/conv_ops.h10
-rw-r--r--tensorflow/core/kernels/conv_ops_3d.cc3
-rw-r--r--tensorflow/core/kernels/conv_ops_gpu.h12
-rw-r--r--tensorflow/core/kernels/conv_ops_test.cc4
-rw-r--r--tensorflow/core/kernels/cwise_op_asinh.cc2
-rw-r--r--tensorflow/core/kernels/dataset.cc140
-rw-r--r--tensorflow/core/kernels/dataset.h155
-rw-r--r--tensorflow/core/kernels/dataset_utils.cc2
-rw-r--r--tensorflow/core/kernels/depthwise_conv_op.cc5
-rw-r--r--tensorflow/core/kernels/filter_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/flat_map_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/group_by_window_dataset_op.cc6
-rw-r--r--tensorflow/core/kernels/interleave_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/map_and_batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/map_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/mkl_batch_matmul_op.cc13
-rw-r--r--tensorflow/core/kernels/multinomial_op.cc51
-rw-r--r--tensorflow/core/kernels/multinomial_op.h2
-rw-r--r--tensorflow/core/kernels/multinomial_op_gpu.cu.cc30
-rw-r--r--tensorflow/core/kernels/nn_ops_test.cc2
-rw-r--r--tensorflow/core/kernels/padded_batch_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/parallel_map_dataset_op.cc4
-rw-r--r--tensorflow/core/kernels/quantized_conv_ops.cc13
-rw-r--r--tensorflow/core/kernels/random_dataset_op.cc154
-rw-r--r--tensorflow/core/kernels/reduction_ops_min.cc1
-rw-r--r--tensorflow/core/kernels/reduction_ops_test.cc5
-rw-r--r--tensorflow/core/kernels/scan_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/scatter_nd_op.cc63
-rw-r--r--tensorflow/core/kernels/serialize_sparse_op.cc177
-rw-r--r--tensorflow/core/kernels/softmax_op_functor.h33
-rw-r--r--tensorflow/core/kernels/strided_slice_op.cc1
-rw-r--r--tensorflow/core/kernels/strided_slice_op_gpu.cu.cc1
-rw-r--r--tensorflow/core/kernels/tensor_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/tensor_slice_dataset_op.cc2
-rw-r--r--tensorflow/core/kernels/zip_dataset_op.cc2
-rw-r--r--tensorflow/core/lib/core/arena.cc18
-rw-r--r--tensorflow/core/lib/math/math_util.h17
-rw-r--r--tensorflow/core/lib/math/math_util_test.cc29
-rw-r--r--tensorflow/core/lib/monitoring/collected_metrics.h1
-rw-r--r--tensorflow/core/lib/monitoring/collection_registry.h6
-rw-r--r--tensorflow/core/lib/monitoring/gauge.h33
-rw-r--r--tensorflow/core/lib/monitoring/gauge_test.cc22
-rw-r--r--tensorflow/core/lib/monitoring/metric_def.h13
-rw-r--r--tensorflow/core/ops/array_ops.cc32
-rw-r--r--tensorflow/core/ops/dataset_ops.cc18
-rw-r--r--tensorflow/core/ops/math_ops.cc125
-rw-r--r--tensorflow/core/ops/nn_ops.cc146
-rw-r--r--tensorflow/core/ops/random_ops.cc11
-rw-r--r--tensorflow/core/ops/resource_variable_ops.cc5
-rw-r--r--tensorflow/core/ops/sparse_ops.cc42
-rw-r--r--tensorflow/core/ops/state_ops.cc56
-rw-r--r--tensorflow/core/platform/cloud/curl_http_request_test.cc2
-rw-r--r--tensorflow/core/platform/cloud/file_block_cache.cc4
-rw-r--r--tensorflow/core/platform/cloud/gcs_file_system.cc10
-rw-r--r--tensorflow/core/profiler/g3doc/options.md11
-rw-r--r--tensorflow/core/profiler/internal/tfprof_node.cc28
-rw-r--r--tensorflow/core/profiler/internal/tfprof_node.h10
-rw-r--r--tensorflow/core/profiler/internal/tfprof_show_test.cc37
-rw-r--r--tensorflow/core/profiler/internal/tfprof_stats_test.cc105
-rw-r--r--tensorflow/core/profiler/tfprof_log.proto5
-rw-r--r--tensorflow/docs_src/api_guides/python/reading_data.md32
-rw-r--r--tensorflow/docs_src/get_started/custom_estimators.md576
-rw-r--r--tensorflow/docs_src/get_started/feature_columns.md570
-rw-r--r--tensorflow/docs_src/performance/xla/operation_semantics.md81
-rw-r--r--tensorflow/docs_src/programmers_guide/datasets.md14
-rw-r--r--tensorflow/examples/android/README.md8
-rw-r--r--tensorflow/examples/how_tos/reading_data/convert_to_records.py15
-rw-r--r--tensorflow/examples/speech_commands/train.py3
-rw-r--r--tensorflow/go/graph.go64
-rw-r--r--tensorflow/go/op/op_test.go73
-rw-r--r--tensorflow/go/tensor.go2
-rw-r--r--tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java3
-rw-r--r--tensorflow/java/src/main/native/operation_builder_jni.cc6
-rw-r--r--tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java31
-rw-r--r--tensorflow/python/BUILD13
-rw-r--r--tensorflow/python/client/session_test.py133
-rw-r--r--tensorflow/python/client/tf_session.i43
-rw-r--r--tensorflow/python/client/tf_session_helper.cc19
-rw-r--r--tensorflow/python/client/tf_session_helper.h14
-rw-r--r--tensorflow/python/data/ops/BUILD1
-rw-r--r--tensorflow/python/data/ops/dataset_ops.py35
-rw-r--r--tensorflow/python/data/util/nest.py6
-rw-r--r--tensorflow/python/data/util/nest_test.py5
-rw-r--r--tensorflow/python/eager/backprop.py2
-rw-r--r--tensorflow/python/eager/context.py15
-rw-r--r--tensorflow/python/eager/function.py197
-rw-r--r--tensorflow/python/eager/graph_callable.py19
-rw-r--r--tensorflow/python/eager/graph_callable_test.py1
-rw-r--r--tensorflow/python/eager/pywrap_tfe_src.cc7
-rw-r--r--tensorflow/python/estimator/BUILD2
-rw-r--r--tensorflow/python/estimator/estimator.py49
-rw-r--r--tensorflow/python/estimator/estimator_test.py74
-rw-r--r--tensorflow/python/framework/function.py5
-rw-r--r--tensorflow/python/framework/function_test.py32
-rw-r--r--tensorflow/python/framework/importer.py151
-rw-r--r--tensorflow/python/framework/importer_test.py137
-rw-r--r--tensorflow/python/framework/ops.py109
-rw-r--r--tensorflow/python/framework/ops_test.py30
-rw-r--r--tensorflow/python/framework/test_ops.cc18
-rw-r--r--tensorflow/python/grappler/item.i2
-rw-r--r--tensorflow/python/grappler/layout_optimizer_test.py2
-rw-r--r--tensorflow/python/grappler/model_analyzer.cc2
-rw-r--r--tensorflow/python/keras/_impl/keras/callbacks_test.py19
-rw-r--r--tensorflow/python/keras/_impl/keras/utils/io_utils.py2
-rw-r--r--tensorflow/python/kernel_tests/BUILD1
-rw-r--r--tensorflow/python/kernel_tests/constant_op_eager_test.py33
-rw-r--r--tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py54
-rw-r--r--tensorflow/python/kernel_tests/conv_ops_test.py407
-rw-r--r--tensorflow/python/kernel_tests/decode_bmp_op_test.py124
-rw-r--r--tensorflow/python/kernel_tests/prefetch_dataset_op_test.py5
-rw-r--r--tensorflow/python/kernel_tests/random/multinomial_op_test.py14
-rw-r--r--tensorflow/python/kernel_tests/scatter_nd_ops_test.py15
-rw-r--r--tensorflow/python/kernel_tests/segment_reduction_ops_test.py21
-rw-r--r--tensorflow/python/kernel_tests/sparse_serialization_ops_test.py4
-rw-r--r--tensorflow/python/kernel_tests/template_test.py37
-rw-r--r--tensorflow/python/layers/base.py85
-rw-r--r--tensorflow/python/layers/base_test.py5
-rw-r--r--tensorflow/python/layers/convolutional.py3
-rw-r--r--tensorflow/python/lib/core/py_func.cc54
-rw-r--r--tensorflow/python/lib/core/py_seq_tensor.cc18
-rw-r--r--tensorflow/python/lib/core/py_util.cc70
-rw-r--r--tensorflow/python/lib/core/py_util.h27
-rw-r--r--tensorflow/python/lib/core/safe_ptr.cc16
-rw-r--r--tensorflow/python/lib/core/safe_ptr.h42
-rw-r--r--tensorflow/python/ops/nn_grad.py90
-rw-r--r--tensorflow/python/ops/nn_ops.py28
-rw-r--r--tensorflow/python/ops/random_ops.py5
-rw-r--r--tensorflow/python/ops/resource_variable_ops.py21
-rw-r--r--tensorflow/python/ops/sparse_ops.py45
-rw-r--r--tensorflow/python/ops/state_ops.py70
-rw-r--r--tensorflow/python/ops/template.py37
-rw-r--r--tensorflow/python/platform/flags.py48
-rw-r--r--tensorflow/python/platform/flags_test.py41
-rw-r--r--tensorflow/python/profiler/model_analyzer_test.py40
-rw-r--r--tensorflow/python/pywrap_tfe.i3
-rw-r--r--tensorflow/python/training/momentum_test.py41
-rw-r--r--tensorflow/python/training/monitored_session.py1
-rw-r--r--tensorflow/python/training/saver_test.py20
-rw-r--r--tensorflow/python/training/supervisor.py6
-rw-r--r--tensorflow/python/util/nest.py6
-rw-r--r--tensorflow/python/util/nest_test.py8
-rw-r--r--tensorflow/tensorflow.bzl22
-rw-r--r--tensorflow/tools/api/golden/tensorflow.nn.pbtxt18
-rw-r--r--tensorflow/tools/api/golden/tensorflow.pbtxt2
-rwxr-xr-xtensorflow/tools/ci_build/ci_sanity.sh3
-rw-r--r--tensorflow/tools/dist_test/python/census_widendeep.py3
-rw-r--r--tensorflow/tools/pip_package/setup.py13
-rw-r--r--tensorflow/workspace.bzl67
-rw-r--r--third_party/aws.BUILD16
-rw-r--r--third_party/curl.BUILD46
-rw-r--r--third_party/gif.BUILD2
-rw-r--r--third_party/jemalloc.BUILD10
-rw-r--r--third_party/jpeg/jpeg.BUILD2
-rw-r--r--third_party/mkl/build_defs.bzl1
-rw-r--r--third_party/nccl.BUILD8
-rw-r--r--third_party/snappy.BUILD4
407 files changed, 10880 insertions, 3826 deletions
diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc
index bb41f92306..c8b4bfffd4 100644
--- a/tensorflow/c/c_api.cc
+++ b/tensorflow/c/c_api.cc
@@ -383,12 +383,11 @@ void TF_Reset_Helper(const TF_SessionOptions* opt, const char** containers,
// be less than the total node count.
Status ValidateNoCycles(const Graph& g) {
// TODO(nolivia): check this on a subset of the graph instead of all of it.
- int total_num_nodes = g.num_node_ids();
// A node is ready when all of its inputs have been visited.
std::vector<const Node*> ready;
- std::vector<int> pending_count(total_num_nodes, 0);
+ std::vector<int> pending_count(g.num_node_ids(), 0);
- for (int i = 0; i < total_num_nodes; ++i) {
+ for (int i = 0; i < g.num_node_ids(); ++i) {
const Node* n = g.FindNodeId(i);
if (n == nullptr) continue;
pending_count[i] = n->in_edges().size();
@@ -421,7 +420,7 @@ Status ValidateNoCycles(const Graph& g) {
}
}
- if (processed < total_num_nodes) {
+ if (processed < g.num_nodes()) {
std::vector<string> nodes_in_cycle;
for (int i = 0; i < pending_count.size() && nodes_in_cycle.size() < 3;
++i) {
@@ -430,7 +429,7 @@ Status ValidateNoCycles(const Graph& g) {
}
}
return errors::InvalidArgument(
- "Graph is invalid, contains a cycle with ", total_num_nodes - processed,
+ "Graph is invalid, contains a cycle with ", g.num_nodes() - processed,
" nodes, including: ", str_util::Join(nodes_in_cycle, ", "));
}
return Status::OK();
@@ -625,6 +624,23 @@ Status MessageToBuffer(const tensorflow::protobuf::Message& in,
return Status::OK();
}
+void RecordMutation(TF_Graph* graph, const TF_Operation& op,
+ const char* mutation_type)
+ EXCLUSIVE_LOCKS_REQUIRED(graph->mu) {
+ // If any session has already run this node_id, mark this session as
+ // unrunnable.
+ for (auto it : graph->sessions) {
+ if (it.first->last_num_graph_nodes > op.node.id()) {
+ it.second = FailedPrecondition(
+ "Operation '", op.node.DebugString(), "' was changed by ",
+ mutation_type,
+ " after it was run by a session. Nodes can be mutated "
+ "only before they are executed by a session. Either don't modify "
+ "nodes after running them or create a new session.");
+ }
+ }
+}
+
// Helpers for loading a TensorFlow plugin (a .so file).
Status LoadLibrary(const char* library_filename, void** result,
const void** buf, size_t* len);
@@ -1745,7 +1761,6 @@ void TF_OperationToNodeDef(TF_Operation* oper, TF_Buffer* output_node_def,
TF_Graph::TF_Graph()
: graph(tensorflow::OpRegistry::Global()),
refiner(graph.versions().producer(), graph.op_registry()),
- num_sessions(0),
delete_requested(false),
parent(nullptr),
parent_inputs(nullptr) {}
@@ -1755,7 +1770,7 @@ TF_Graph* TF_NewGraph() { return new TF_Graph; }
void TF_DeleteGraph(TF_Graph* g) {
g->mu.lock();
g->delete_requested = true;
- const bool del = g->num_sessions == 0;
+ const bool del = g->sessions.empty();
g->mu.unlock();
if (del) delete g;
}
@@ -2325,11 +2340,12 @@ TF_Session* TF_NewSession(TF_Graph* graph, const TF_SessionOptions* opt,
Session* session;
status->status = NewSession(opt->options, &session);
if (status->status.ok()) {
+ TF_Session* new_session = new TF_Session(session, graph);
if (graph != nullptr) {
mutex_lock l(graph->mu);
- graph->num_sessions += 1;
+ graph->sessions[new_session] = Status::OK();
}
- return new TF_Session(session, graph);
+ return new_session;
} else {
DCHECK_EQ(nullptr, session);
return nullptr;
@@ -2393,7 +2409,7 @@ TF_Session* TF_LoadSessionFromSavedModel(
TF_Session* session = new TF_Session(bundle.session.release(), graph);
- graph->num_sessions += 1;
+ graph->sessions[session] = Status::OK();
session->last_num_graph_nodes = graph->graph.num_node_ids();
return session;
#endif // __ANDROID__
@@ -2408,8 +2424,8 @@ void TF_DeleteSession(TF_Session* s, TF_Status* status) {
TF_Graph* const graph = s->graph;
if (graph != nullptr) {
graph->mu.lock();
- graph->num_sessions -= 1;
- const bool del = graph->delete_requested && graph->num_sessions == 0;
+ graph->sessions.erase(s);
+ const bool del = graph->delete_requested && graph->sessions.empty();
graph->mu.unlock();
if (del) delete graph;
}
@@ -2425,6 +2441,13 @@ static bool ExtendSessionGraphHelper(TF_Session* session, TF_Status* status) {
mutex_lock session_lock(session->mu);
session->graph->mu.lock();
const Graph& graph = session->graph->graph;
+
+ status->status = session->graph->sessions[session];
+ if (!status->status.ok()) {
+ session->graph->mu.unlock();
+ return false;
+ }
+
const auto num_nodes = graph.num_node_ids();
if (session->last_num_graph_nodes < num_nodes) {
status->status = tensorflow::ValidateNoCycles(session->graph->graph);
diff --git a/tensorflow/c/c_api_internal.h b/tensorflow/c/c_api_internal.h
index bb04e01bee..aac333d9e2 100644
--- a/tensorflow/c/c_api_internal.h
+++ b/tensorflow/c/c_api_internal.h
@@ -81,12 +81,20 @@ struct TF_Graph {
std::unordered_map<tensorflow::string, tensorflow::Node*> name_map
GUARDED_BY(mu);
- // TF_Graph may only / must be deleted when
- // num_sessions == 0 && delete_requested == true
-
- // num_sessions incremented by TF_NewSession, and decremented by
+ // The keys of this map are all the active sessions using this graph.
+ // Each value is the current "runnability" status of the corresponding
+ // session. Under normal conditions all statuses are Status::OK(), but
+ // if some operation is mutated after it was run by a session (this
+ // is detected in RecordMutation function), that session is no longer
+ // safe to run. Its status will contain the error that will be returned
+ // to the user, should she try running this session.
+ //
+ // Sessions are added to this map in TF_NewSession, and removed in
// TF_DeleteSession.
- int num_sessions GUARDED_BY(mu);
+ // TF_Graph may only / must be deleted when
+ // sessions.size() == 0 && delete_requested == true
+ tensorflow::gtl::FlatMap<TF_Session*, tensorflow::Status> sessions
+ GUARDED_BY(mu);
bool delete_requested GUARDED_BY(mu); // set true by TF_DeleteGraph
// Used to link graphs contained in TF_WhileParams to the parent graph that
@@ -167,6 +175,9 @@ TF_Tensor* TF_TensorFromTensor(const Tensor& src, TF_Status* status);
Status MessageToBuffer(const tensorflow::protobuf::Message& in, TF_Buffer* out);
+void RecordMutation(TF_Graph* graph, const TF_Operation& op,
+ const char* mutation_type);
+
} // end namespace tensorflow
#endif // TENSORFLOW_C_C_API_INTERNAL_H_
diff --git a/tensorflow/c/eager/tape.h b/tensorflow/c/eager/tape.h
index f52248e7d5..191e9c3413 100644
--- a/tensorflow/c/eager/tape.h
+++ b/tensorflow/c/eager/tape.h
@@ -161,7 +161,7 @@ class GradientTape {
// the tape refer to it); to aid in tape garbage collection.
std::unordered_map<int64, int64> tensor_usage_;
- // If true, all activations are deleted in the first call to ComputeGradient.
+ // If false, all activations are deleted in the first call to ComputeGradient.
// Else, only when this is destructed.
bool persistent_;
};
diff --git a/tensorflow/c/python_api.cc b/tensorflow/c/python_api.cc
index ba5a9268b4..37629a74ba 100644
--- a/tensorflow/c/python_api.cc
+++ b/tensorflow/c/python_api.cc
@@ -22,6 +22,7 @@ namespace tensorflow {
void AddControlInput(TF_Graph* graph, TF_Operation* op, TF_Operation* input) {
mutex_lock l(graph->mu);
graph->graph.AddControlEdge(&input->node, &op->node);
+ RecordMutation(graph, *op, "adding control input");
}
void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
@@ -36,11 +37,13 @@ void SetAttr(TF_Graph* graph, TF_Operation* op, const char* attr_name,
mutex_lock l(graph->mu);
op->node.AddAttr(attr_name, attr_val);
+ RecordMutation(graph, *op, "setting attribute");
}
void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
mutex_lock l(graph->mu);
op->node.set_requested_device(device);
+ RecordMutation(graph, *op, "setting device");
}
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
@@ -75,6 +78,13 @@ void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
}
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);
+
+ if (status->status.ok()) {
+ // This modification only updates the destination node for
+ // the purposes of running this graph in a session. Thus, we don't
+ // record the source node as being modified.
+ RecordMutation(graph, *dst.oper, "updating input tensor");
+ }
}
} // namespace tensorflow
diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc
index ae22f7edc4..28ac40df18 100644
--- a/tensorflow/compiler/aot/codegen.cc
+++ b/tensorflow/compiler/aot/codegen.cc
@@ -418,7 +418,7 @@ namespace xla { class ExecutableRunOptions; }
// (Implementation detail) Entry point to the function in the object file.
extern "C" void {{ENTRY}}(
void* result, const xla::ExecutableRunOptions* run_options,
- const void** args, void** temps);
+ const void** args, void** temps, tensorflow::int64* profile_counters);
{{NS_START}}
// {{CLASS}} represents a computation previously specified in a
@@ -483,7 +483,7 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
return *kStaticData;
}
- {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS)
+ {{CLASS}}(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS)
: XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
{{CLASS}}(const {{CLASS}}&) = delete;
@@ -496,8 +496,8 @@ class {{CLASS}} : public tensorflow::XlaCompiledCpuFunction {
// void set_argN_data(void* data)
// Sets the buffer of type T for positional argument N. May be called in
// any AllocMode. Must be called before Run to have an affect. Must be
- // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument,
- // to set the argument buffers.
+ // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
+ // argument, to set the argument buffers.
//
// T* argN_data()
// Returns the buffer of type T for positional argument N.
diff --git a/tensorflow/compiler/aot/codegen_test_h.golden b/tensorflow/compiler/aot/codegen_test_h.golden
index 65f342ce27..cf01bee325 100644
--- a/tensorflow/compiler/aot/codegen_test_h.golden
+++ b/tensorflow/compiler/aot/codegen_test_h.golden
@@ -19,7 +19,7 @@ namespace xla { class ExecutableRunOptions; }
// (Implementation detail) Entry point to the function in the object file.
extern "C" void entry_point(
void* result, const xla::ExecutableRunOptions* run_options,
- const void** args, void** temps);
+ const void** args, void** temps, tensorflow::int64* profile_counters);
namespace foo {
namespace bar {
@@ -86,7 +86,7 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
return *kStaticData;
}
- MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS)
+ MyClass(AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS)
: XlaCompiledCpuFunction(StaticData(), alloc_mode) {}
MyClass(const MyClass&) = delete;
@@ -99,8 +99,8 @@ class MyClass : public tensorflow::XlaCompiledCpuFunction {
// void set_argN_data(void* data)
// Sets the buffer of type T for positional argument N. May be called in
// any AllocMode. Must be called before Run to have an affect. Must be
- // called in AllocMode::RESULTS_AND_TEMPS_ONLY for each positional argument,
- // to set the argument buffers.
+ // called in AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY for each positional
+ // argument, to set the argument buffers.
//
// T* argN_data()
// Returns the buffer of type T for positional argument N.
diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc
index 6b037f276a..413efd9cea 100644
--- a/tensorflow/compiler/aot/tests/tfcompile_test.cc
+++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc
@@ -70,7 +70,7 @@ TEST(TFCompileTest, Add) {
// Run tests that use set_argN_data separately, to avoid accidentally re-using
// non-existent buffers.
TEST(TFCompileTest, Add_SetArg) {
- AddComp add(AddComp::AllocMode::RESULTS_AND_TEMPS_ONLY);
+ AddComp add(AddComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
int32 arg_x = 10;
int32 arg_y = 32;
@@ -258,7 +258,7 @@ TEST(TFCompileTest, MatMul2_SetArg) {
Eigen::ThreadPoolDevice device(&tp, tp.NumThreads());
foo::bar::MatMulComp matmul(
- foo::bar::MatMulComp::AllocMode::RESULTS_AND_TEMPS_ONLY);
+ foo::bar::MatMulComp::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY);
matmul.set_thread_pool(&device);
// Test using the set_argN_data() methods.
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
index 74c9791f5e..aceedeb823 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc
@@ -210,6 +210,13 @@ Status FindCompilationCandidates(
!IsCompilableWhile(*node, jit_device_type, 0, lib_runtime)) {
continue;
}
+ // _Retval nodes in a top-level function represent fetches.
+ // Do not compile them.
+ if (node->type_string() == "_Retval") {
+ VLOG(2) << "Compilation rejected node: return value " << node->name()
+ << ": " << node->type_string();
+ continue;
+ }
candidates->insert(node);
}
return Status::OK();
diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
index b3d258aea1..454f0aeae9 100644
--- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
+++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc
@@ -525,5 +525,32 @@ TEST(XlaCompilationTest, IllegalCycle_UsefulErrorMessage) {
"+-- c\n"));
}
+TEST(XlaCompilationTest, Retval) {
+ std::unique_ptr<Graph> graph(new Graph(OpRegistry::Global()));
+ GraphDef graphdef;
+ {
+ GraphDefBuilder builder(GraphDefBuilder::kFailImmediately);
+ Node* a = ops::SourceOp("Const", builder.opts()
+ .WithName("A")
+ .WithAttr("dtype", DT_FLOAT)
+ .WithAttr("value", Tensor()));
+ Node* b = ops::UnaryOp("Relu", a, builder.opts().WithName("B"));
+ ops::UnaryOp("_Retval", b,
+ builder.opts()
+ .WithName("R")
+ .WithAttr("T", DT_FLOAT)
+ .WithAttr("index", 0));
+
+ TF_EXPECT_OK(builder.ToGraph(graph.get()));
+ }
+
+ TF_ASSERT_OK(MarkForCompilation(&graph));
+ auto clusters = GetClusters(*graph);
+
+ EXPECT_EQ(2, clusters.size());
+ EXPECT_TRUE(clusters.find("R") == clusters.cend());
+ EXPECT_EQ(clusters["A"], clusters["B"]);
+}
+
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD
index 6cad2b0824..fff1a7f57b 100644
--- a/tensorflow/compiler/tests/BUILD
+++ b/tensorflow/compiler/tests/BUILD
@@ -417,6 +417,20 @@ tf_xla_py_test(
)
tf_xla_py_test(
+ name = "scan_ops_test",
+ size = "small",
+ srcs = ["scan_ops_test.py"],
+ tags = ["optonly"],
+ deps = [
+ ":xla_test",
+ "//tensorflow/python:array_ops",
+ "//tensorflow/python:framework_for_generated_wrappers",
+ "//tensorflow/python:math_ops",
+ "//tensorflow/python:platform_test",
+ ],
+)
+
+tf_xla_py_test(
name = "segment_reduction_ops_test",
size = "medium",
srcs = ["segment_reduction_ops_test.py"],
diff --git a/tensorflow/compiler/tests/categorical_op_test.py b/tensorflow/compiler/tests/categorical_op_test.py
index 5e06f9a724..035cdea178 100644
--- a/tensorflow/compiler/tests/categorical_op_test.py
+++ b/tensorflow/compiler/tests/categorical_op_test.py
@@ -35,6 +35,9 @@ from tensorflow.python.platform import googletest
class CategoricalTest(XLATestCase):
"""Test cases for random-number generating operators."""
+ def output_dtypes(self):
+ return set(self.int_types).intersection([np.int32, np.int64])
+
def _chi2(self, expected, actual):
"""Returns Chi2 GOF statistic."""
actual = np.asarray(actual)
@@ -55,7 +58,8 @@ class CategoricalTest(XLATestCase):
"""
with self.test_session() as sess, self.test_scope():
random_seed.set_random_seed(1618)
- op = random_ops.multinomial(logits, num_samples)
+ op = random_ops.multinomial(logits, num_samples,
+ output_dtype=dtypes.int32)
d = sess.run(op)
batch_size, num_classes = logits.shape
@@ -73,11 +77,11 @@ class CategoricalTest(XLATestCase):
return freqs_mat
- def _testRngIsNotConstant(self, rng, dtype):
+ def _testRngIsNotConstant(self, rng, dtype, output_dtype):
# Tests that 'rng' does not always return the same value.
with self.test_session() as sess:
with self.test_scope():
- x = rng(dtype)
+ x = rng(dtype, output_dtype)
# The random-number generator, if working correctly, should produce the
# same output multiple times with low probability.
@@ -92,21 +96,25 @@ class CategoricalTest(XLATestCase):
(not np.array_equal(y, w)))
def testCategoricalIsNotConstant(self):
- def rng(unused_dtype):
- return random_ops.multinomial([[1., 1., 1.]], 10)
+ def rng(dtype, output_dtype):
+ return random_ops.multinomial(np.array([[1., 1., 1.]], dtype=dtype), 10,
+ output_dtype=output_dtype)
- dtype = dtypes.float32
- self._testRngIsNotConstant(rng, dtype)
+ dtype = np.float32
+ for output_dtype in self.output_dtypes():
+ self._testRngIsNotConstant(rng, dtype, output_dtype)
def testCategoricalIsInRange(self):
- for dtype in [dtypes.float32, dtypes.float64]:
- with self.test_session() as sess:
- with self.test_scope():
- x = random_ops.multinomial(
- array_ops.ones(shape=[1, 20], dtype=dtype), 1000)
- y = sess.run(x)
- self.assertTrue((y >= 0).sum() == 1000)
- self.assertTrue((y < 20).sum() == 1000)
+ for dtype in self.float_types:
+ for output_dtype in self.output_dtypes():
+ with self.test_session() as sess:
+ with self.test_scope():
+ x = random_ops.multinomial(
+ array_ops.ones(shape=[1, 20], dtype=dtype), 1000,
+ output_dtype=output_dtype)
+ y = sess.run(x)
+ self.assertTrue((y >= 0).sum() == 1000)
+ self.assertTrue((y < 20).sum() == 1000)
def testSamplingCorrectness(self):
np.random.seed(1618) # Make it reproducible.
diff --git a/tensorflow/compiler/tests/scan_ops_test.py b/tensorflow/compiler/tests/scan_ops_test.py
new file mode 100644
index 0000000000..3260e63b23
--- /dev/null
+++ b/tensorflow/compiler/tests/scan_ops_test.py
@@ -0,0 +1,229 @@
+# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Functional tests for scan ops."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import numpy as np
+
+from tensorflow.compiler.tests.xla_test import XLATestCase
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import errors_impl
+from tensorflow.python.framework import ops
+from tensorflow.python.ops import array_ops
+from tensorflow.python.ops import math_ops
+from tensorflow.python.platform import test
+
+
+def numpy_reverse(x, axis):
+ length = len(x.shape)
+ if axis < 0:
+ axis = length + axis
+
+ ix = [
+ slice(None, None, -1) if i == axis else slice(None) for i in range(length)
+ ]
+ return x[ix]
+
+
+def handle_options(func, x, axis, exclusive, reverse):
+ """Adds tf options to numpy scan ops."""
+ length = len(x.shape)
+ if axis < 0:
+ axis = length + axis
+
+ if reverse:
+ x = numpy_reverse(x, axis)
+
+ if exclusive:
+ ix_head = [slice(0, 1) if i == axis else slice(None) for i in range(length)]
+ ix_init = [
+ slice(0, -1) if i == axis else slice(None) for i in range(length)
+ ]
+ if func == np.cumsum:
+ init = np.zeros_like(x[ix_head])
+ elif func == np.cumprod:
+ init = np.ones_like(x[ix_head])
+ else:
+ raise ValueError("Unknown scan function.")
+ x = np.concatenate([init, func(x[ix_init], axis)], axis=axis)
+ else:
+ x = func(x, axis=axis)
+
+ if reverse:
+ x = numpy_reverse(x, axis)
+ return x
+
+
+class CumsumTest(XLATestCase):
+
+ valid_dtypes = [np.float32]
+
+ def axis_dtypes(self):
+ return set(self.int_types).intersection([np.int32, np.int64])
+
+ def _compare(self, x, axis, exclusive, reverse):
+ np_out = handle_options(np.cumsum, x, axis, exclusive, reverse)
+ with self.test_session(), self.test_scope():
+ p = array_ops.placeholder(x.dtype)
+ tf_out = math_ops.cumsum(p, axis, exclusive, reverse).eval(
+ feed_dict={p: x})
+
+ self.assertAllClose(np_out, tf_out)
+
+ def _compareAll(self, x, axis):
+ for exclusive in [True, False]:
+ for reverse in [True, False]:
+ self._compare(x, axis, exclusive, reverse)
+
+ def testEmpty(self):
+ for dtype in self.valid_dtypes:
+ x = np.zeros([0]).astype(dtype)
+ for axis in (-1, 0):
+ self._compareAll(x, axis)
+
+ def testAxisType(self):
+ for dtype in self.valid_dtypes:
+ x = np.arange(1, 6).reshape([5]).astype(dtype)
+ for axis_dtype in self.axis_dtypes():
+ with self.test_session(), self.test_scope():
+ p = array_ops.placeholder(x.dtype)
+ axis = constant_op.constant(0, axis_dtype)
+ math_ops.cumsum(p, axis).eval(feed_dict={p: x})
+
+ def test1D(self):
+ for dtype in self.valid_dtypes:
+ x = np.arange(1, 6).reshape([5]).astype(dtype)
+ for axis in (-1, 0):
+ self._compareAll(x, axis)
+
+ def test2D(self):
+ for dtype in self.valid_dtypes:
+ x = np.arange(0, 10).reshape([2, 5]).astype(dtype)
+ for axis in (-2, -1, 0, 1):
+ self._compareAll(x, axis)
+
+ def test3D(self):
+ for dtype in self.valid_dtypes:
+ x = np.arange(0, 20).reshape([2, 2, 5]).astype(dtype)
+ for axis in (-3, -2, -1, 0, 1, 2):
+ self._compareAll(x, axis)
+
+ def test6D(self):
+ for dtype in self.valid_dtypes:
+ x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype)
+ for axis in range(-6, 6, 3):
+ self._compareAll(x, axis)
+
+ def testInvalidAxis(self):
+ x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
+ with self.test_session(), self.test_scope():
+ input_tensor = ops.convert_to_tensor(x)
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError,
+ lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
+ math_ops.cumsum(input_tensor, -3).eval()
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError,
+ lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
+ math_ops.cumsum(input_tensor, 2).eval()
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError,
+ lambda e: "axis must be a scalar" in str(e)):
+ math_ops.cumsum(input_tensor, [0]).eval()
+
+
+class CumprodTest(XLATestCase):
+
+ valid_dtypes = [np.float32]
+
+ def axis_dtypes(self):
+ return set(self.int_types).intersection([np.int32, np.int64])
+
+ def _compare(self, x, axis, exclusive, reverse):
+ np_out = handle_options(np.cumprod, x, axis, exclusive, reverse)
+ with self.test_session(), self.test_scope():
+ p = array_ops.placeholder(x.dtype)
+ prod = math_ops.cumprod(p, axis, exclusive, reverse)
+ tf_out = prod.eval(feed_dict={p: x})
+
+ self.assertAllClose(np_out, tf_out)
+
+ def _compareAll(self, x, axis):
+ for exclusive in [True, False]:
+ for reverse in [True, False]:
+ self._compare(x, axis, exclusive, reverse)
+
+ def testEmpty(self):
+ for dtype in self.valid_dtypes:
+ x = np.zeros([0]).astype(dtype)
+ for axis in (-1, 0):
+ self._compareAll(x, axis)
+
+ def testAxisType(self):
+ for dtype in self.valid_dtypes:
+ x = np.arange(1, 6).reshape([5]).astype(dtype)
+ for axis_dtype in self.axis_dtypes():
+ with self.test_session(), self.test_scope():
+ p = array_ops.placeholder(x.dtype)
+ axis = constant_op.constant(0, axis_dtype)
+ math_ops.cumprod(x, axis).eval(feed_dict={p: x})
+
+ def test1D(self):
+ for dtype in self.valid_dtypes:
+ x = np.arange(1, 6).reshape([5]).astype(dtype)
+ for axis in (-1, 0):
+ self._compareAll(x, axis)
+
+ def test2D(self):
+ for dtype in self.valid_dtypes:
+ x = np.arange(1, 11).reshape([2, 5]).astype(dtype)
+ for axis in (-2, -1, 0, 1):
+ self._compareAll(x, axis)
+
+ def test3D(self):
+ for dtype in self.valid_dtypes:
+ x = np.arange(1, 21).reshape([2, 2, 5]).astype(dtype)
+ for axis in (-3, -2, -1, 0, 1, 2):
+ self._compareAll(x, axis)
+
+ def test6D(self):
+ for dtype in self.valid_dtypes:
+ x = np.arange(1, 145).reshape([2, 2, 3, 3, 2, 2]).astype(dtype)
+ for axis in range(-6, 6, 3):
+ self._compareAll(x, axis)
+
+ def testInvalidAxis(self):
+ x = np.arange(0, 10).reshape([2, 5]).astype(np.float32)
+ with self.test_session(), self.test_scope():
+ input_tensor = ops.convert_to_tensor(x)
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError,
+ lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
+ math_ops.cumprod(input_tensor, -3).eval()
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError,
+ lambda e: "Expected scan axis in the range [-2, 2)" in str(e)):
+ math_ops.cumprod(input_tensor, 2).eval()
+ with self.assertRaisesWithPredicateMatch(
+ errors_impl.InvalidArgumentError,
+ lambda e: "axis must be a scalar" in str(e)):
+ math_ops.cumprod(input_tensor, [0]).eval()
+
+
+if __name__ == "__main__":
+ test.main()
diff --git a/tensorflow/compiler/tf2xla/const_analysis.cc b/tensorflow/compiler/tf2xla/const_analysis.cc
index d57273d844..6a1a5467e0 100644
--- a/tensorflow/compiler/tf2xla/const_analysis.cc
+++ b/tensorflow/compiler/tf2xla/const_analysis.cc
@@ -52,6 +52,8 @@ Status BackwardsConstAnalysis(const Graph& g,
{"Conv2DBackpropInput", "input_sizes"},
{"Conv3DBackpropFilterV2", "filter_sizes"},
{"Conv3DBackpropInputV2", "input_sizes"},
+ {"Cumprod", "axis"},
+ {"Cumsum", "axis"},
{"DepthwiseConv2dNativeBackpropFilter", "filter_sizes"},
{"DepthwiseConv2dNativeBackpropInput", "input_sizes"},
{"DynamicStitch", "indices"},
diff --git a/tensorflow/compiler/tf2xla/dump_graph.cc b/tensorflow/compiler/tf2xla/dump_graph.cc
index ddd912b873..03603ee9ba 100644
--- a/tensorflow/compiler/tf2xla/dump_graph.cc
+++ b/tensorflow/compiler/tf2xla/dump_graph.cc
@@ -63,7 +63,12 @@ string MakeUniquePath(string name) {
string DumpGraphDefToFile(const string& name, GraphDef const& graph_def) {
string path = MakeUniquePath(name);
- TF_CHECK_OK(WriteTextProto(Env::Default(), path, graph_def));
+ Status status = WriteTextProto(Env::Default(), path, graph_def);
+ if (!status.ok()) {
+ VLOG(1) << "Failed to dump GraphDef to file: " << path << " : " << status;
+ path.clear();
+ path = "(unavailable)";
+ }
return path;
}
@@ -79,7 +84,13 @@ string DumpGraphToFile(const string& name, Graph const& graph,
string DumpFunctionDefToFile(const string& name, FunctionDef const& fdef) {
string path = MakeUniquePath(name);
- TF_CHECK_OK(WriteTextProto(Env::Default(), path, fdef));
+ Status status = WriteTextProto(Env::Default(), path, fdef);
+ if (!status.ok()) {
+ VLOG(1) << "Failed to dump FunctionDef to file: " << path << " : "
+ << status;
+ path.clear();
+ path = "(unavailable)";
+ }
return path;
}
diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
index 5726d8294a..267268298c 100644
--- a/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
+++ b/tensorflow/compiler/tf2xla/functionalize_control_flow.cc
@@ -1067,6 +1067,10 @@ FunctionalizeCond::CreateCorrespondingMergeCluster(Cluster* switch_cluster) {
enqueue_or_update_merge(out);
}
}
+ // Return if there are no merge nodes.
+ if (merges.empty()) {
+ return gtl::nullopt;
+ }
auto it = merges.begin();
Cluster* merge_cluster = *it;
for (++it; it != merges.end(); ++it) {
diff --git a/tensorflow/compiler/tf2xla/kernels/BUILD b/tensorflow/compiler/tf2xla/kernels/BUILD
index 6302fece1f..a1720ff919 100644
--- a/tensorflow/compiler/tf2xla/kernels/BUILD
+++ b/tensorflow/compiler/tf2xla/kernels/BUILD
@@ -54,6 +54,7 @@ tf_kernel_library(
"reshape_op.cc",
"retval_op.cc",
"reverse_op.cc",
+ "scan_ops.cc",
"segment_reduction_ops.cc",
"select_op.cc",
"sendrecv_ops.cc",
diff --git a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
index 248e9d111e..468af34aab 100644
--- a/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/batch_norm_op.cc
@@ -14,7 +14,7 @@ limitations under the License.
==============================================================================*/
// XLA implementation of BatchNorm operations.
-#include "tensorflow/compiler/tf2xla/literal_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
@@ -42,27 +42,44 @@ class FusedBatchNormOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
+ xla::PrimitiveType input_type;
+ OP_REQUIRES_OK(ctx,
+ DataTypeToPrimitiveType(ctx->input_type(0), &input_type));
+ xla::PrimitiveType stats_type;
+ OP_REQUIRES_OK(ctx,
+ DataTypeToPrimitiveType(ctx->input_type(1), &stats_type));
+
+ xla::ComputationBuilder* builder = ctx->builder();
+
+ xla::ComputationDataHandle input = ctx->Input(0);
+
+ // TODO(b/69928690): support mixed precision in the XLA batch normalization
+ // operators. As a workaround, cast everything to the statistics type (which
+ // may be more precise than the input type).
+ input = builder->ConvertElementType(input, stats_type);
+
if (is_training_) {
- xla::ComputationDataHandle output = ctx->builder()->BatchNormTraining(
- ctx->Input(0), ctx->Input(1), ctx->Input(2), epsilon_,
- feature_index_);
+ xla::ComputationDataHandle output = builder->BatchNormTraining(
+ input, ctx->Input(1), ctx->Input(2), epsilon_, feature_index_);
// In training mode, outputs the normalized value as well as the
// calculated mean and variance.
- for (int i = 0; i < 3; i++) {
- ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i));
- }
+ ctx->SetOutput(0, builder->ConvertElementType(
+ builder->GetTupleElement(output, 0), input_type));
+ ctx->SetOutput(1, builder->GetTupleElement(output, 1));
+ ctx->SetOutput(2, builder->GetTupleElement(output, 2));
+
// Output 3 and 4 for "FusedBatchNorm" are currently marked as "reserved
// space 1 & 2". They are used to pass the per-batch mean and
// variance to the gradient. Here we maintain the same behavior by setting
// them to the mean and variance calculated by BatchNormTraining.
- ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1));
- ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2));
+ ctx->SetOutput(3, builder->GetTupleElement(output, 1));
+ ctx->SetOutput(4, builder->GetTupleElement(output, 2));
} else {
- xla::ComputationDataHandle output = ctx->builder()->BatchNormInference(
- ctx->Input(0), ctx->Input(1), ctx->Input(2), ctx->Input(3),
- ctx->Input(4), epsilon_, feature_index_);
- ctx->SetOutput(0, output);
+ xla::ComputationDataHandle output = builder->BatchNormInference(
+ input, ctx->Input(1), ctx->Input(2), ctx->Input(3), ctx->Input(4),
+ epsilon_, feature_index_);
+ ctx->SetOutput(0, builder->ConvertElementType(output, input_type));
// Directly send input to output as mean and variance in inference mode.
ctx->SetOutput(1, ctx->Input(3));
ctx->SetOutput(2, ctx->Input(4));
@@ -78,6 +95,7 @@ class FusedBatchNormOp : public XlaOpKernel {
};
REGISTER_XLA_OP(Name("FusedBatchNorm"), FusedBatchNormOp);
+REGISTER_XLA_OP(Name("FusedBatchNormV2"), FusedBatchNormOp);
class FusedBatchNormGradOp : public XlaOpKernel {
public:
@@ -101,19 +119,36 @@ class FusedBatchNormGradOp : public XlaOpKernel {
}
void Compile(XlaOpKernelContext* ctx) override {
+ xla::ComputationBuilder* builder = ctx->builder();
+
auto grad_output = ctx->Input(0);
auto activation = ctx->Input(1);
auto scale = ctx->Input(2);
auto mean = ctx->Input(3);
auto var = ctx->Input(4);
- xla::ComputationDataHandle output = ctx->builder()->BatchNormGrad(
+
+ xla::PrimitiveType input_type;
+ OP_REQUIRES_OK(ctx,
+ DataTypeToPrimitiveType(ctx->input_type(0), &input_type));
+ xla::PrimitiveType stats_type;
+ OP_REQUIRES_OK(ctx,
+ DataTypeToPrimitiveType(ctx->input_type(3), &stats_type));
+
+ // TODO(b/69928690): support mixed precision in the XLA batch normalization
+ // operators. As a workaround, cast everything to the statistics type (which
+ // may be more precise than the input type).
+ grad_output = builder->ConvertElementType(grad_output, stats_type);
+ activation = builder->ConvertElementType(activation, stats_type);
+
+ xla::ComputationDataHandle output = builder->BatchNormGrad(
activation, scale, mean, var, grad_output, epsilon_, feature_index_);
- for (int i = 0; i < 3; i++) {
- ctx->SetOutput(i, ctx->builder()->GetTupleElement(output, i));
- }
- ctx->SetOutput(3, ctx->builder()->GetTupleElement(output, 1));
- ctx->SetOutput(4, ctx->builder()->GetTupleElement(output, 2));
+ ctx->SetOutput(0, builder->ConvertElementType(
+ builder->GetTupleElement(output, 0), input_type));
+ ctx->SetOutput(1, builder->GetTupleElement(output, 1));
+ ctx->SetOutput(2, builder->GetTupleElement(output, 2));
+ ctx->SetOutput(3, builder->GetTupleElement(output, 1));
+ ctx->SetOutput(4, builder->GetTupleElement(output, 2));
}
private:
@@ -122,6 +157,7 @@ class FusedBatchNormGradOp : public XlaOpKernel {
};
REGISTER_XLA_OP(Name("FusedBatchNormGrad"), FusedBatchNormGradOp);
+REGISTER_XLA_OP(Name("FusedBatchNormGradV2"), FusedBatchNormGradOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
index c5017704e2..aaddbe811c 100644
--- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc
@@ -46,72 +46,130 @@ TensorShape ExpandedFilterShapeForDepthwiseConvolution(
return expanded_shape;
}
+// Broadcast zeros to ExpandedFilterShapeForDepthwiseConvolution.
+xla::ComputationDataHandle CreateExpandedZero(
+ const TensorShape& filter_shape, DataType dtype,
+ xla::ComputationBuilder* builder) {
+ TensorShape expanded_filter_shape =
+ ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+ return builder->Broadcast(XlaHelpers::Zero(builder, dtype),
+ expanded_filter_shape.dim_sizes());
+}
+
+// Create a mask for depthwise convolution that will make a normal convolution
+// produce the same results as a depthwise convolution. For a [2, 2, 3, 2]
+// depthwise filter this returns a [2, 2, 3, 6] tesnsor
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// 1 1 0 0 0 0 1 1 0 0 0 0
+// 0 0 1 1 0 0 0 0 1 1 0 0
+// 0 0 0 0 1 1 0 0 0 0 1 1
+//
+// The first step is to create a one tensor, A, that is [3]
+// 0 1 2
+//
+// and another tensor, B, that is [3 * 2]
+// 0 1 2 3 4 5
+//
+// and divide B it by 2 to get
+// 0 0 1 1 2 2
+//
+// then we broadcast the B to [2, 2, 3, 3 * 2]
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+// 0 0 1 1 2 2 0 0 1 1 2 2
+//
+// Finally compare A and broadcasted B in dimension 2 amd return the result at
+// the beginning of the comment.
+xla::ComputationDataHandle CreateExpandedFilterMask(
+ const TensorShape& filter_shape, xla::ComputationBuilder* builder) {
+ TensorShape expanded_filter_shape =
+ ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+ int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
+ int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
+
+ // Create a M sized linspace and an M*N sized linspace that will be
+ // broadcasted into perpendicular dimensions and compared.
+ xla::ComputationDataHandle input_feature_iota;
+ // DT_INT32 Iota will always return status::OK().
+ TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32, input_feature,
+ &input_feature_iota));
+ xla::ComputationDataHandle expanded_feature_iota;
+ TF_CHECK_OK(XlaHelpers::Iota(builder, DataType::DT_INT32,
+ input_feature * depthwise_multiplier,
+ &expanded_feature_iota));
+
+ // Divide the M*N sized linspace by the depthwise_multiplier to create
+ // [0 0 1 1 2 2] in the example in the function comment.
+ expanded_feature_iota =
+ builder->Div(expanded_feature_iota,
+ XlaHelpers::IntegerLiteral(builder, DataType::DT_INT32,
+ depthwise_multiplier));
+
+ // Broadcast the N*M linspace to [H, W, ..., M, M*N].
+ auto expanded_feature_broadcast_dims = expanded_filter_shape.dim_sizes();
+ expanded_feature_broadcast_dims.pop_back();
+ auto broadcasted_expanded_feature_iota = builder->Broadcast(
+ expanded_feature_iota, expanded_feature_broadcast_dims);
+
+ // Compare the broadcasted linspace to the input feature linspace in the
+ // input feature dimension to create a diagonal predicate.
+ return builder->Eq(broadcasted_expanded_feature_iota, input_feature_iota,
+ {expanded_filter_shape.dims() - 2});
+}
+
// Expands a filter of shape [H, W, ..., M, N] to [H, W, ..., M, M*N] by adding
// zeros for the cross-depth filters. Used to build a depthwise convolution.
xla::ComputationDataHandle ExpandFilterForDepthwiseConvolution(
const TensorShape& filter_shape, DataType dtype,
const xla::ComputationDataHandle& filter,
xla::ComputationBuilder* builder) {
- // Filter has shape [H, W, ..., M, N]
- // Dilate to [H, W, ..., M*M, N] using M inter-element padding, and then
- // reshape to [H, W, ..., M, M*N].
- int num_spatial_dims = filter_shape.dims() - 2;
- const int64 in_depth = filter_shape.dim_size(num_spatial_dims);
- xla::PaddingConfig padding = xla::MakeNoPaddingConfig(filter_shape.dims());
- padding.mutable_dimensions(num_spatial_dims)->set_interior_padding(in_depth);
- auto dilated_filter =
- builder->Pad(filter, XlaHelpers::Zero(builder, dtype), padding);
-
+ int64 depthwise_multiplier = filter_shape.dim_size(filter_shape.dims() - 1);
+ int64 input_feature = filter_shape.dim_size(filter_shape.dims() - 2);
TensorShape expanded_filter_shape =
ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
- return builder->Reshape(dilated_filter, expanded_filter_shape.dim_sizes());
+
+ // Create a [H, W, ..., 1, N*M] reshape of the filter.
+ TensorShape implicit_broadcast_filter_shape = expanded_filter_shape;
+ implicit_broadcast_filter_shape.set_dim(
+ implicit_broadcast_filter_shape.dims() - 2, 1);
+ implicit_broadcast_filter_shape.set_dim(
+ implicit_broadcast_filter_shape.dims() - 1,
+ depthwise_multiplier * input_feature);
+ auto implicit_broadcast_filter =
+ builder->Reshape(filter, implicit_broadcast_filter_shape.dim_sizes());
+
+ // Broadcast the filter to [H, W, ..., M, M*N].
+ auto expanded_zero = CreateExpandedZero(filter_shape, dtype, builder);
+ auto expanded_filter = builder->Add(implicit_broadcast_filter, expanded_zero);
+
+ // If the filter mask is set, choose the broadcasted filter, othwerwise,
+ // choose zero.
+ return builder->Select(CreateExpandedFilterMask(filter_shape, builder),
+ expanded_filter, expanded_zero);
}
// Inverse of ExpandFilterForDepthwiseConvolution.
xla::ComputationDataHandle ContractFilterForDepthwiseBackprop(
- const TensorShape& filter_shape, DataType dtype,
+ XlaOpKernelContext* ctx, const TensorShape& filter_shape, DataType dtype,
const xla::ComputationDataHandle& filter_backprop,
xla::ComputationBuilder* builder) {
- int num_spatial_dims = filter_shape.dims() - 2;
-
- // Reshape to [H, W, ..., M*M, N]
- TensorShape shape = filter_shape;
- int64 in_depth = filter_shape.dim_size(num_spatial_dims);
- shape.set_dim(num_spatial_dims, in_depth * in_depth);
- auto reshaped = builder->Reshape(filter_backprop, shape.dim_sizes());
-
- std::vector<int64> zeros(filter_shape.dims());
- std::vector<int64> strides(filter_shape.dims(), 1LL);
- strides[num_spatial_dims] = in_depth + 1;
- return builder->Slice(reshaped, zeros, shape.dim_sizes(), strides);
-
- // Alternate implementation for backends without strided Slice() support.
- // TODO(phawkins): Remove when all backends support strided slice.
- // // Pad [..., M * (M + 1), N]
- // xla::PaddingConfig config =
- // xla::MakeNoPaddingConfig(filter_shape.dims());
- // config.mutable_dimensions(num_spatial_dims)
- // ->set_edge_padding_high(in_depth);
- // auto zero = XlaHelpers::Zero(builder, dtype);
- // auto padded = builder->Pad(reshaped, zero, config);
- //
- // // Reshape to [..., M, M + 1, N]
- // shape = filter_shape;
- // shape.set_dim(num_spatial_dims, in_depth);
- // shape.set_dim(num_spatial_dims + 1, in_depth + 1);
- // int64 out_depth = filter_shape.dim_size(num_spatial_dims + 1);
- // shape.AddDim(out_depth);
- // reshaped = builder->Reshape(padded, shape.dim_sizes());
- //
- // // Slice to [..., M, 1, N]
- // std::vector<int64> zeros(shape.dims());
- // std::vector<int64> strides(shape.dims(), 1LL);
- // shape.set_dim(num_spatial_dims + 1, 1);
- // auto sliced = builder->Slice(reshaped, zeros, shape.dim_sizes(),
- // strides);
- //
- // // Reshape to [..., M, N]
- // return builder->Reshape(sliced, filter_shape.dim_sizes());
+ TensorShape expanded_filter_shape =
+ ExpandedFilterShapeForDepthwiseConvolution(filter_shape);
+ auto masked_expanded_filter = builder->Select(
+ CreateExpandedFilterMask(filter_shape, builder), filter_backprop,
+ CreateExpandedZero(filter_shape, dtype, builder));
+ return builder->Reshape(
+ builder->Reduce(masked_expanded_filter, XlaHelpers::Zero(builder, dtype),
+ *ctx->GetOrCreateAdd(dtype),
+ {expanded_filter_shape.dims() - 2}),
+ filter_shape.dim_sizes());
}
class ConvOp : public XlaOpKernel {
@@ -121,6 +179,7 @@ class ConvOp : public XlaOpKernel {
: XlaOpKernel(ctx),
num_spatial_dims_(num_spatial_dims),
depthwise_(depthwise) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
@@ -144,6 +203,23 @@ class ConvOp : public XlaOpKernel {
errors::Unimplemented("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(ctx, dilations_.size() == num_dims(),
+ errors::InvalidArgument("Dilations field must "
+ "specify ",
+ num_dims(), " dimensions"));
+ OP_REQUIRES(
+ ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
+ errors::Unimplemented("Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ for (int i = 0; i < num_spatial_dims_; ++i) {
+ int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
+ OP_REQUIRES(
+ ctx, dilations_[input_dim] == 1,
+ errors::Unimplemented("Current implementation does not yet support "
+ "dilations in the ",
+ i, "th spatial dimension."));
+ }
+
const TensorShape input_shape = ctx->InputShape(0);
// Input filter is of the following dimensions:
// [ filter_rows, filter_cols, ..., in_depth, out_depth]
@@ -184,7 +260,7 @@ class ConvOp : public XlaOpKernel {
dims.set_input_feature_dimension(feature_dim);
dims.set_output_feature_dimension(feature_dim);
for (int i = 0; i < num_spatial_dims_; ++i) {
- int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
+ const int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
dims.add_input_spatial_dimensions(dim);
dims.add_kernel_spatial_dimensions(i);
dims.add_output_spatial_dimensions(dim);
@@ -204,6 +280,7 @@ class ConvOp : public XlaOpKernel {
protected:
const int num_spatial_dims_;
const bool depthwise_;
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_ = FORMAT_NHWC;
@@ -241,6 +318,7 @@ class ConvBackpropInputOp : public XlaOpKernel {
: XlaOpKernel(ctx),
num_spatial_dims_(num_spatial_dims),
depthwise_(depthwise) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
string data_format;
@@ -263,6 +341,23 @@ class ConvBackpropInputOp : public XlaOpKernel {
errors::Unimplemented("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(ctx, dilations_.size() == num_dims(),
+ errors::InvalidArgument("Dilations field must "
+ "specify ",
+ num_dims(), " dimensions"));
+ OP_REQUIRES(
+ ctx, dilations_[batch_dim] == 1 && dilations_[feature_dim] == 1,
+ errors::Unimplemented("Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ for (int i = 0; i < num_spatial_dims_; ++i) {
+ int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
+ OP_REQUIRES(
+ ctx, dilations_[input_dim] == 1,
+ errors::Unimplemented("Current implementation does not yet support "
+ "dilations in the ",
+ i, "th spatial dimension."));
+ }
+
TensorShape input_shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &input_shape));
@@ -336,6 +431,7 @@ class ConvBackpropInputOp : public XlaOpKernel {
protected:
const int num_spatial_dims_;
const bool depthwise_;
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_ = FORMAT_NHWC;
@@ -373,6 +469,7 @@ class ConvBackpropFilterOp : public XlaOpKernel {
: XlaOpKernel(ctx),
num_spatial_dims_(num_spatial_dims),
depthwise_(depthwise) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("dilations", &dilations_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("strides", &strides_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("padding", &padding_));
string data_format;
@@ -392,6 +489,23 @@ class ConvBackpropFilterOp : public XlaOpKernel {
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(ctx, dilations_.size() == num_dims(),
+ errors::InvalidArgument("Dilations field must "
+ "specify ",
+ num_dims(), " dimensions"));
+ OP_REQUIRES(
+ ctx, dilations_[n_dim] == 1 && dilations_[c_dim] == 1,
+ errors::Unimplemented("Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ for (int i = 0; i < num_spatial_dims_; ++i) {
+ int input_dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
+ OP_REQUIRES(
+ ctx, dilations_[input_dim] == 1,
+ errors::Unimplemented("Current implementation does not yet support "
+ "dilations in the ",
+ i, "th spatial dimension."));
+ }
+
const TensorShape activations_shape = ctx->InputShape(0);
TensorShape filter_shape;
OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(1, &filter_shape));
@@ -426,9 +540,7 @@ class ConvBackpropFilterOp : public XlaOpKernel {
// Swap n_dim and c_dim in the activations.
dnums.set_input_batch_dimension(c_dim);
- dnums.set_output_batch_dimension(c_dim);
dnums.set_input_feature_dimension(n_dim);
- dnums.set_output_feature_dimension(n_dim);
// The gradients become the RHS of the convolution.
// The gradients have shape [batch, out_rows, out_cols, ..., out_depth]
@@ -440,11 +552,17 @@ class ConvBackpropFilterOp : public XlaOpKernel {
std::vector<int64> rhs_dilation(num_spatial_dims_);
std::vector<int64> ones(num_spatial_dims_, 1);
+ // Tensorflow filter shape is [ H, W, ..., inC, outC ].
+ for (int i = 0; i < num_spatial_dims_; ++i) {
+ dnums.add_output_spatial_dimensions(i);
+ }
+ dnums.set_output_batch_dimension(num_spatial_dims_);
+ dnums.set_output_feature_dimension(num_spatial_dims_ + 1);
+
for (int i = 0; i < num_spatial_dims_; ++i) {
int64 dim = GetTensorSpatialDimIndex(num_dims(), data_format_, i);
dnums.add_input_spatial_dimensions(dim);
dnums.add_kernel_spatial_dimensions(dim);
- dnums.add_output_spatial_dimensions(dim);
// We will also need to pad the input with zeros such that after the
// convolution, we get the right size for the filter.
@@ -501,31 +619,17 @@ class ConvBackpropFilterOp : public XlaOpKernel {
/*window_strides=*/ones, padding,
/*lhs_dilation=*/ones, rhs_dilation, dnums);
- // The layout of filter_backprop will match the layout of
- // padded_activations
- // and so will have layout: [out_feature, h, w, ..., in_feature]
- // Tensorflow filter shape is [ H, W, ..., inC, outC ], so we transpose the
- // output.
- std::vector<int64> transpose_dims;
- transpose_dims.reserve(num_dims());
- for (int i = 0; i < num_spatial_dims_; ++i) {
- transpose_dims.push_back(dnums.output_spatial_dimensions(i));
- }
- transpose_dims.push_back(c_dim);
- transpose_dims.push_back(n_dim);
- xla::ComputationDataHandle filter_backprop_reshaped =
- b->Transpose(filter_backprop, transpose_dims);
-
if (depthwise_) {
- filter_backprop_reshaped = ContractFilterForDepthwiseBackprop(
- filter_shape, ctx->input_type(0), filter_backprop_reshaped, b);
+ filter_backprop = ContractFilterForDepthwiseBackprop(
+ ctx, filter_shape, ctx->input_type(0), filter_backprop, b);
}
- ctx->SetOutput(0, filter_backprop_reshaped);
+ ctx->SetOutput(0, filter_backprop);
}
protected:
const int num_spatial_dims_;
const bool depthwise_;
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_ = FORMAT_NHWC;
diff --git a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
index fcef497e58..644abd5905 100644
--- a/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/matmul_op.cc
@@ -23,8 +23,8 @@ limitations under the License.
namespace tensorflow {
namespace {
-constexpr std::array<DataType, 4> kMatmulTypes = {
- {DT_HALF, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}};
+constexpr std::array<DataType, 5> kMatmulTypes = {
+ {DT_HALF, DT_BFLOAT16, DT_FLOAT, DT_DOUBLE, DT_COMPLEX64}};
class MatMulOp : public XlaOpKernel {
public:
@@ -85,10 +85,7 @@ class SparseMatMulOp : public MatMulOp {
~SparseMatMulOp() override = default;
};
-REGISTER_XLA_OP(Name("SparseMatMul")
- .TypeConstraint("Ta", kFloatTypes)
- .TypeConstraint("Tb", kFloatTypes),
- SparseMatMulOp);
+REGISTER_XLA_OP(Name("SparseMatMul"), SparseMatMulOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/scan_ops.cc b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
new file mode 100644
index 0000000000..650f8c7dc8
--- /dev/null
+++ b/tensorflow/compiler/tf2xla/kernels/scan_ops.cc
@@ -0,0 +1,141 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <vector>
+
+#include "tensorflow/compiler/tf2xla/shape_util.h"
+#include "tensorflow/compiler/tf2xla/type_util.h"
+#include "tensorflow/compiler/tf2xla/xla_helpers.h"
+#include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
+#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
+#include "tensorflow/compiler/xla/literal_util.h"
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/register_types.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/framework/tensor_types.h"
+#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
+#include "tensorflow/core/kernels/concat_lib.h"
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+namespace {
+
+// TODO(phawkins): implement double-sized windowed reductions in XLA and remove
+// the type constraint.
+constexpr std::array<DataType, 3> kScanOpTypes = {
+ {DT_HALF, DT_BFLOAT16, DT_FLOAT}};
+
+class ScanOp : public XlaOpKernel {
+ public:
+ ScanOp(OpKernelConstruction* ctx, bool sum) : XlaOpKernel(ctx), sum_(sum) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("reverse", &reverse_));
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("exclusive", &exclusive_));
+ }
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ const TensorShape input_shape = ctx->InputShape(0);
+ const TensorShape tensor_axis_shape = ctx->InputShape(1);
+
+ OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(tensor_axis_shape),
+ errors::InvalidArgument("ScanOp: axis must be a scalar, not ",
+ tensor_axis_shape.DebugString()));
+
+ int64 axis;
+ OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(1, &axis));
+ if (axis < 0) {
+ axis += input_shape.dims();
+ }
+ OP_REQUIRES(
+ ctx, FastBoundsCheck(axis, input_shape.dims()),
+ errors::InvalidArgument("ScanOp: Expected scan axis in the range [",
+ -input_shape.dims(), ", ", input_shape.dims(),
+ "), but got ", axis));
+
+ DataType dtype = ctx->input_type(0);
+
+ if (input_shape.num_elements() == 0) {
+ // Exit early if there is nothing to compute.
+ ctx->SetOutput(0, ctx->Input(0));
+ return;
+ }
+
+ xla::ComputationBuilder* builder = ctx->builder();
+
+ std::vector<int64> window_strides(input_shape.dims(), 1);
+ std::vector<int64> window_dims(input_shape.dims(), 1);
+ window_dims[axis] = input_shape.dim_size(axis);
+
+ std::vector<std::pair<int64, int64>> padding(input_shape.dims(), {0, 0});
+ padding[axis].first = input_shape.dim_size(axis) - 1;
+ // In exclusive mode, add an extra padding element so there is a complete
+ // window of padding before the data starts.
+ if (exclusive_) {
+ ++padding[axis].first;
+ }
+ if (reverse_) {
+ std::swap(padding[axis].first, padding[axis].second);
+ }
+
+ xla::ComputationDataHandle input = ctx->Input(0);
+ xla::ComputationDataHandle init;
+ const xla::Computation* reducer;
+ if (sum_) {
+ init = XlaHelpers::Zero(builder, dtype);
+ reducer = ctx->GetOrCreateAdd(dtype);
+ } else {
+ init = XlaHelpers::One(builder, dtype);
+ reducer = ctx->GetOrCreateMul(dtype);
+ }
+ auto output = builder->ReduceWindowWithGeneralPadding(
+ ctx->Input(0), init, *reducer, window_dims, window_strides, padding);
+
+ // In exclusive mode, we have computed an extra element containing the sum
+ // of all the input elements. Slice off this extra "last" element.
+ if (exclusive_) {
+ if (reverse_) {
+ output = builder->SliceInDim(output, 1, input_shape.dim_size(axis) + 1,
+ 1, axis);
+
+ } else {
+ output =
+ builder->SliceInDim(output, 0, input_shape.dim_size(axis), 1, axis);
+ }
+ }
+ ctx->SetOutput(0, output);
+ }
+
+ private:
+ const bool sum_; // True=cumulative sum. False=cumulative product.
+ bool reverse_;
+ bool exclusive_;
+};
+
+class CumsumOp : public ScanOp {
+ public:
+ explicit CumsumOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/true) {}
+};
+REGISTER_XLA_OP(Name("Cumsum").TypeConstraint("T", kScanOpTypes), CumsumOp);
+
+class CumprodOp : public ScanOp {
+ public:
+ explicit CumprodOp(OpKernelConstruction* ctx) : ScanOp(ctx, /*sum=*/false) {}
+};
+REGISTER_XLA_OP(Name("Cumprod").TypeConstraint("T", kScanOpTypes), CumprodOp);
+
+} // anonymous namespace
+} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
index 24a99f253d..06838d1625 100644
--- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc
+++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc
@@ -25,58 +25,72 @@ limitations under the License.
namespace tensorflow {
namespace {
+// Converts a TensorShape to a constant Tensor.
+//
+// The input TensorShape input_shape is used to populate the elements of
+// shape_constant, which is modified in place.
+Status TensorShapeToConstant(const TensorShape& input_shape,
+ Tensor* shape_constant) {
+ const int dims = input_shape.dims();
+ if (shape_constant->dtype() == DT_INT32) {
+ auto vec = shape_constant->vec<int32>();
+ for (int i = 0; i < dims; ++i) {
+ int64 dim_size = input_shape.dim_size(i);
+ if (!FastBoundsCheck(dim_size, std::numeric_limits<int32>::max())) {
+ return errors::InvalidArgument(
+ "Shape with out_type=int32 does not support tensors > int32max",
+ " but dim ", i, " is ", dim_size);
+ }
+ vec(i) = static_cast<int32>(dim_size);
+ }
+ } else {
+ auto vec = shape_constant->vec<int64>();
+ for (int i = 0; i < dims; ++i) {
+ int64 dim_size = input_shape.dim_size(i);
+ vec(i) = dim_size;
+ }
+ }
+ return Status::OK();
+}
+
class ShapeOp : public XlaOpKernel {
public:
- explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ explicit ShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
+ }
void Compile(XlaOpKernelContext* ctx) override {
const TensorShape input_shape = ctx->InputShape(0);
- const int rank = input_shape.dims();
- Tensor shape_constant(DT_INT32, TensorShape({rank}));
- auto vec = shape_constant.vec<int32>();
- // TODO(dga): support int64. b/28119922.
- for (int i = 0; i < rank; ++i) {
- int64 dim_size = input_shape.dim_size(i);
- OP_REQUIRES(
- ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
- errors::InvalidArgument("Shape does not support tensors > int32max",
- " but dim ", i, " is ", dim_size));
- vec(i) = static_cast<int32>(dim_size);
- }
-
+ Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
+ OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
ctx->SetConstantOutput(0, shape_constant);
}
+
+ private:
+ DataType out_dtype_;
};
REGISTER_XLA_OP(Name("Shape"), ShapeOp);
class ShapeNOp : public XlaOpKernel {
public:
- explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+ explicit ShapeNOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
+ OP_REQUIRES_OK(ctx, ctx->GetAttr("out_type", &out_dtype_));
+ }
void Compile(XlaOpKernelContext* ctx) override {
for (int i = 0; i < ctx->num_inputs(); ++i) {
- const TensorShape shape = ctx->InputShape(i);
- const int dims = shape.dims();
- Tensor shape_constant(DT_INT32, TensorShape({dims}));
- auto vec = shape_constant.vec<int32>();
-
- // TODO(dga): support int64. b/28119922.
- for (int j = 0; j < dims; ++j) {
- int64 dim_size = shape.dim_size(j);
- OP_REQUIRES(
- ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
- errors::InvalidArgument("Shape does not support tensors > int32max",
- " but shape ", i, " dim ", j, " is ",
- dim_size));
- vec(j) = static_cast<int32>(dim_size);
- }
-
+ const TensorShape input_shape = ctx->InputShape(i);
+ Tensor shape_constant(out_dtype_, TensorShape({input_shape.dims()}));
+ OP_REQUIRES_OK(ctx, TensorShapeToConstant(input_shape, &shape_constant));
ctx->SetConstantOutput(i, shape_constant);
}
}
bool IsExpensive() override { return false; }
+
+ private:
+ DataType out_dtype_;
};
REGISTER_XLA_OP(Name("ShapeN"), ShapeNOp);
diff --git a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
index b19ea22f50..2346c62ad1 100644
--- a/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
+++ b/tensorflow/compiler/tf2xla/kernels/variable_ops.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
+#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/no_op.h"
namespace tensorflow {
@@ -121,5 +122,31 @@ class ResourceGatherOp : public XlaOpKernel {
REGISTER_XLA_OP(Name("ResourceGather").TypeConstraint("dtype", kNumericTypes),
ResourceGatherOp);
+class VariableShapeOp : public XlaOpKernel {
+ public:
+ explicit VariableShapeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
+
+ void Compile(XlaOpKernelContext* ctx) override {
+ DataType dtype;
+ TensorShape shape;
+ OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &dtype, &shape));
+ const int rank = shape.dims();
+ Tensor shape_constant(DT_INT32, TensorShape({rank}));
+ auto vec = shape_constant.vec<int32>();
+ // TODO(dga): support int64. b/28119922.
+ for (int i = 0; i < rank; ++i) {
+ int64 dim_size = shape.dim_size(i);
+ OP_REQUIRES(
+ ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
+ errors::InvalidArgument("Shape does not support tensors > int32max",
+ " but dim ", i, " is ", dim_size));
+ vec(i) = static_cast<int32>(dim_size);
+ }
+
+ ctx->SetConstantOutput(0, shape_constant);
+ }
+};
+
+REGISTER_XLA_OP(Name("VariableShape"), VariableShapeOp);
} // namespace
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/lib/util.cc b/tensorflow/compiler/tf2xla/lib/util.cc
index 7ffe0aa6df..943248aedb 100644
--- a/tensorflow/compiler/tf2xla/lib/util.cc
+++ b/tensorflow/compiler/tf2xla/lib/util.cc
@@ -40,6 +40,9 @@ xla::ComputationDataHandle FloatLiteral(xla::ComputationBuilder* builder,
case xla::F16:
return builder->ConstantR0<xla::half>(static_cast<xla::half>(value));
break;
+ case xla::BF16:
+ return builder->ConstantR0<bfloat16>(static_cast<bfloat16>(value));
+ break;
case xla::F32:
return builder->ConstantR0<float>(static_cast<float>(value));
break;
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
index b5c17c5273..43d0e17c2c 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc
@@ -28,9 +28,10 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
temps_(new void*[static_data.num_temps]),
arg_names_(static_data.arg_names),
result_names_(static_data.result_names),
- program_shape_(static_data.program_shape) {
+ program_shape_(static_data.program_shape),
+ hlo_profile_printer_(static_data.hlo_profile_printer) {
// Allocate arg and temp buffers.
- if (alloc_mode == AllocMode::ARGS_RESULTS_AND_TEMPS) {
+ if (alloc_mode == AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS) {
alloc_args_ = tensorflow::tfcompile::runtime::MallocContiguousBuffers(
static_data.arg_sizes, static_data.num_args, args_,
/*annotate_initialized=*/false);
@@ -43,6 +44,15 @@ XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
if (static_data.requires_runtime_context) {
args_[static_data.num_args - 1] = &context_;
}
+
+ // If Hlo profiling is enabled the generated code expects an appropriately
+ // sized buffer to be passed in as the last argument. If Hlo profiling is
+ // disabled the last function argument is still present in the function
+ // signature, but it is ignored by the generated code and we pass in null for
+ // it.
+ if (hlo_profiling_enabled()) {
+ profile_counters_ = new int64[static_data.profile_counters_size]();
+ }
}
XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
@@ -50,6 +60,7 @@ XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
tensorflow::tfcompile::runtime::FreeContiguous(alloc_temps_);
delete[] args_;
delete[] temps_;
+ delete[] profile_counters_;
}
namespace {
diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
index f49a788922..3c4314d498 100644
--- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
+++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h
@@ -16,7 +16,7 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
-#include <functional>
+#include <cassert>
#include <string>
#include "tensorflow/compiler/tf2xla/xla_local_runtime_context.h"
@@ -27,6 +27,7 @@ limitations under the License.
// never use this functionality.
namespace xla {
class ProgramShape;
+class HloProfilePrinter;
}
namespace tensorflow {
@@ -48,12 +49,10 @@ namespace tensorflow {
class XlaCompiledCpuFunction {
public:
// Type of the raw function, produced by either JIT or AOT.
- //
- // TODO(toddw): Add support for hlo profiling, and replace std::function with
- // a raw function pointer, for some codesize savings.
- using RawFunction = std::function<void(
- void* result, const xla::ExecutableRunOptions* run_options,
- const void** args, void** temps)>;
+ using RawFunction = void (*)(void* result,
+ const xla::ExecutableRunOptions* run_options,
+ const void** args, void** temps,
+ int64* profile_counters);
// StaticData represents the state necessary to run an XLA-compiled
// function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for
@@ -81,21 +80,29 @@ class XlaCompiledCpuFunction {
// [Optional] Arg and result shapes.
const xla::ProgramShape* program_shape = nullptr;
+
+ // [Optional] Profile printer. Null if profiling is disabled.
+ const xla::HloProfilePrinter* hlo_profile_printer = nullptr;
+
+ // [Optional] The number of profile counters expected in the profile counter
+ // buffer by the generated code and hlo_profile_printer. 0 if profiling is
+ // disabled.
+ int64 profile_counters_size = 0;
};
// AllocMode controls the buffer allocation mode.
enum class AllocMode {
- // Allocate all buffers - args, results and temps.
- ARGS_RESULTS_AND_TEMPS,
+ // Allocate all buffers - args, results, profile and temps.
+ ARGS_RESULTS_PROFILES_AND_TEMPS,
- // Only allocate result and temp buffers.
+ // Only allocate result, profile and temp buffers.
// Use set_arg_data to set argument buffers before Run is called.
- RESULTS_AND_TEMPS_ONLY,
+ RESULTS_PROFILES_AND_TEMPS_ONLY,
};
XlaCompiledCpuFunction(
const StaticData& static_data,
- AllocMode alloc_mode = AllocMode::ARGS_RESULTS_AND_TEMPS);
+ AllocMode alloc_mode = AllocMode::ARGS_RESULTS_PROFILES_AND_TEMPS);
virtual ~XlaCompiledCpuFunction();
XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete;
@@ -113,7 +120,7 @@ class XlaCompiledCpuFunction {
context_.error = false;
context_.error_msg.clear();
raw_function_(temps_[result_index_], &run_options_,
- const_cast<const void**>(args_), temps_);
+ const_cast<const void**>(args_), temps_, profile_counters_);
return !context_.error;
}
@@ -162,6 +169,16 @@ class XlaCompiledCpuFunction {
return static_cast<const void* const*>(temps_[result_index_]);
}
+ // Profile counters for this XLA computation.
+ //
+ // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in
+ // this case) these counters are non-null and are automatically populated by
+ // `Run`. The counters can then be pretty-printed using
+ // `hlo_profile_printer()`.
+ //
+ // When Hlo profiling is disabled, this accessor returns null.
+ const int64* profile_counters() const { return profile_counters_; }
+
// Returns the buffer for the positional result at the given `index`.
void* result_data(size_t index) { return results()[index]; }
const void* result_data(size_t index) const { return results()[index]; }
@@ -195,6 +212,12 @@ class XlaCompiledCpuFunction {
// program shape isn't available.
const xla::ProgramShape* ProgramShape() const { return program_shape_; }
+ bool hlo_profiling_enabled() const { return hlo_profile_printer_ != nullptr; }
+ const xla::HloProfilePrinter& hlo_profile_printer() const {
+ assert(hlo_profiling_enabled());
+ return *hlo_profile_printer_;
+ }
+
private:
const RawFunction raw_function_;
const size_t result_index_;
@@ -208,6 +231,9 @@ class XlaCompiledCpuFunction {
void* alloc_args_ = nullptr;
void* alloc_temps_ = nullptr;
+ // Backing memory for profiling counters.
+ int64* profile_counters_ = nullptr;
+
// Options and context passed to the compiled function.
xla::ExecutableRunOptions run_options_;
tensorflow::XlaLocalRuntimeContext context_;
@@ -216,6 +242,7 @@ class XlaCompiledCpuFunction {
const char** arg_names_ = nullptr;
const char** result_names_ = nullptr;
const xla::ProgramShape* program_shape_ = nullptr;
+ const xla::HloProfilePrinter* hlo_profile_printer_ = nullptr;
};
} // namespace tensorflow
diff --git a/tensorflow/compiler/tf2xla/xla_context.cc b/tensorflow/compiler/tf2xla/xla_context.cc
index 651bafd6c5..78e770c62b 100644
--- a/tensorflow/compiler/tf2xla/xla_context.cc
+++ b/tensorflow/compiler/tf2xla/xla_context.cc
@@ -178,6 +178,20 @@ const xla::Computation* XlaContext::GetOrCreateAdd(const DataType type) {
});
}
+const xla::Computation* XlaContext::GetOrCreateMul(const DataType type) {
+ return LookupOrCreate(type, &mul_func_, [this, type] {
+ const string type_string = DataTypeString(type);
+ VLOG(1) << "Building Mul() for " << type_string;
+ xla::ComputationBuilder b(builder()->client(), "mul<" + type_string + ">");
+ xla::PrimitiveType xla_type;
+ TF_CHECK_OK(DataTypeToPrimitiveType(type, &xla_type));
+ auto x = b.Parameter(0, xla::ShapeUtil::MakeShape(xla_type, {}), "x");
+ auto y = b.Parameter(1, xla::ShapeUtil::MakeShape(xla_type, {}), "y");
+ b.Mul(x, y);
+ return b.Build().ConsumeValueOrDie();
+ });
+}
+
const xla::Computation* XlaContext::LookupOrCreate(
DataType type, ComputationMap* out,
const std::function<xla::Computation()>& create) {
diff --git a/tensorflow/compiler/tf2xla/xla_context.h b/tensorflow/compiler/tf2xla/xla_context.h
index de8aafa362..55d2995987 100644
--- a/tensorflow/compiler/tf2xla/xla_context.h
+++ b/tensorflow/compiler/tf2xla/xla_context.h
@@ -102,6 +102,11 @@ class XlaContext : public ResourceBase {
// separate specialization of the computation for each DataType.
const xla::Computation* GetOrCreateAdd(const DataType type);
+ // Get an XLA lambda to compute Mul. This is cached in the
+ // XlaContext since it may be used by multiple Ops. There is a
+ // separate specialization of the computation for each DataType.
+ const xla::Computation* GetOrCreateMul(const DataType type);
+
// The name of the XlaContext resource during symbolic graph execution.
static const char kXlaContextResourceName[];
@@ -155,6 +160,9 @@ class XlaContext : public ResourceBase {
// Cached computation to compute Sum of two elements, specialized by type.
ComputationMap add_func_;
+ // Cached computation to compute Mul of two elements, specialized by type.
+ ComputationMap mul_func_;
+
// Cached computation to compute Sigmoid of an element, specialized by type.
ComputationMap sigmoid_func_;
diff --git a/tensorflow/compiler/tf2xla/xla_helpers.cc b/tensorflow/compiler/tf2xla/xla_helpers.cc
index 9c3e15d2fa..ec9e535b70 100644
--- a/tensorflow/compiler/tf2xla/xla_helpers.cc
+++ b/tensorflow/compiler/tf2xla/xla_helpers.cc
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// This file defines helper routines for Tla JIT compilation.
+// This file defines helper routines for XLA compilation.
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/lib/util.h"
@@ -121,6 +121,8 @@ xla::ComputationDataHandle XlaHelpers::One(xla::ComputationBuilder* b,
xla::ComputationDataHandle XlaHelpers::Epsilon(xla::ComputationBuilder* b,
DataType data_type) {
switch (data_type) {
+ case DT_BFLOAT16:
+ return b->ConstantR0<bfloat16>(bfloat16::epsilon());
case DT_FLOAT:
return b->ConstantR0<float>(std::numeric_limits<float>::epsilon());
case DT_DOUBLE:
@@ -169,6 +171,9 @@ xla::ComputationDataHandle XlaHelpers::IntegerLiteral(
case xla::S16:
case xla::U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
+ case xla::BF16:
+ literal = *xla::Literal::CreateR0<bfloat16>(static_cast<bfloat16>(value));
+ break;
case xla::F16:
literal =
*xla::Literal::CreateR0<xla::half>(static_cast<xla::half>(value));
diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
index 1dd454ea8d..f727f20464 100644
--- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
+++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc
@@ -90,21 +90,6 @@ xla::StatusOr<size_t> ComputeResultIndex(
return result_slice.index();
}
-// Adapt ComputeFunctionType, which includes a final profile_counters arg, to
-// RawFunction, which doesn't include that final arg.
-//
-// TODO(toddw): Change RawFunction and AOT to also pass the final
-// profile_counters arg, and remove this adapter.
-XlaCompiledCpuFunction::RawFunction RawFunctionAdapter(
- xla::cpu::CpuExecutable::ComputeFunctionType compute_function) {
- return [compute_function](void* result,
- const xla::ExecutableRunOptions* run_options,
- const void** args, void** temps) {
- return compute_function(result, run_options, args, temps,
- /*profile_counters=*/nullptr);
- };
-}
-
// Collect names from `entries`, where T is one of tf2xla::{Feed,Fetch}. We hold
// the actual strings in nonempty_names, and hold arrays of pointers in
// name_ptrs, terminated by a nullptr entry.
@@ -177,7 +162,7 @@ XlaJitCompiledCpuFunction::Compile(
const xla::cpu::CpuExecutable* cpu_executable =
static_cast<xla::cpu::CpuExecutable*>(executable->executable());
XlaCompiledCpuFunction::RawFunction raw_function =
- RawFunctionAdapter(cpu_executable->compute_function());
+ cpu_executable->compute_function();
const xla::BufferAssignment& buffer_assignment =
cpu_executable->buffer_assignment();
@@ -211,6 +196,14 @@ XlaJitCompiledCpuFunction::Compile(
jit->static_data_.arg_names = jit->arg_names_.data();
jit->static_data_.result_names = jit->result_names_.data();
jit->static_data_.program_shape = jit->program_shape_.get();
+
+ if (cpu_executable->hlo_profiling_enabled()) {
+ jit->static_data_.hlo_profile_printer =
+ &cpu_executable->hlo_profile_printer();
+ jit->static_data_.profile_counters_size =
+ cpu_executable->hlo_profile_printer().profile_counters_size();
+ }
+
return std::move(jit_unique_ptr);
}
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
index 2b4cc9ba2d..79d501b511 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc
@@ -417,6 +417,11 @@ const xla::Computation* XlaOpKernelContext::GetOrCreateAdd(
return XlaContext::Get(context_).GetOrCreateAdd(type);
}
+const xla::Computation* XlaOpKernelContext::GetOrCreateMul(
+ const DataType type) {
+ return XlaContext::Get(context_).GetOrCreateMul(type);
+}
+
XlaOpKernel::XlaOpKernel(OpKernelConstruction* context) : OpKernel(context) {}
void XlaOpKernel::Compute(OpKernelContext* context) {
diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.h b/tensorflow/compiler/tf2xla/xla_op_kernel.h
index 76bcf594e6..06845a674e 100644
--- a/tensorflow/compiler/tf2xla/xla_op_kernel.h
+++ b/tensorflow/compiler/tf2xla/xla_op_kernel.h
@@ -210,6 +210,11 @@ class XlaOpKernelContext {
// separate specialization of the computation for each DataType.
const xla::Computation* GetOrCreateAdd(const DataType type);
+ // Gets an XLA lambda to compute Mul. This is cached in the
+ // XlaContext since it may be used by multiple Ops. There is a
+ // separate specialization of the computation for each DataType.
+ const xla::Computation* GetOrCreateMul(const DataType type);
+
private:
OpKernelContext* const context_;
};
diff --git a/tensorflow/compiler/xla/client/computation_builder.cc b/tensorflow/compiler/xla/client/computation_builder.cc
index cce9310003..9febea8dcf 100644
--- a/tensorflow/compiler/xla/client/computation_builder.cc
+++ b/tensorflow/compiler/xla/client/computation_builder.cc
@@ -625,7 +625,41 @@ ComputationDataHandle ComputationBuilder::Lt(
ComputationDataHandle ComputationBuilder::Dot(
const ComputationDataHandle& lhs, const ComputationDataHandle& rhs) {
- return BinaryOp(BINOP_DOT, lhs, rhs, /*broadcast_dimensions=*/{});
+ StatusOr<std::unique_ptr<Shape>> lhs_shape_or_status = GetShape(lhs);
+ if (!lhs_shape_or_status.ok()) {
+ NoteError(lhs_shape_or_status.status());
+ return ComputationDataHandle();
+ }
+ std::unique_ptr<Shape> lhs_shape = lhs_shape_or_status.ConsumeValueOrDie();
+
+ DotDimensionNumbers dimension_numbers;
+ dimension_numbers.add_lhs_contracting_dimensions(
+ lhs_shape->dimensions_size() == 1 ? 0 : 1);
+ dimension_numbers.add_rhs_contracting_dimensions(0);
+ return DotGeneral(lhs, rhs, dimension_numbers);
+}
+
+ComputationDataHandle ComputationBuilder::DotGeneral(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ const DotDimensionNumbers& dimension_numbers) {
+ if (!first_error_.ok() || !PrepareComputation().ok()) {
+ return ComputationDataHandle();
+ }
+
+ DotRequest request;
+ *request.mutable_lhs() = lhs;
+ *request.mutable_rhs() = rhs;
+ *request.mutable_dimension_numbers() = dimension_numbers;
+
+ OpRequest op_request;
+ *op_request.mutable_computation() = computation_.handle();
+ *op_request.mutable_dot_request() = request;
+ AddCommonFieldsToOpRequest(&op_request);
+ OpResponse response;
+
+ VLOG(2) << "making Dot request";
+ Status s = client_->stub()->Op(&op_request, &response);
+ return ParseOpResponse(s, &response);
}
ComputationDataHandle ComputationBuilder::Conv(
diff --git a/tensorflow/compiler/xla/client/computation_builder.h b/tensorflow/compiler/xla/client/computation_builder.h
index d2dbbbbebb..531b98cfb9 100644
--- a/tensorflow/compiler/xla/client/computation_builder.h
+++ b/tensorflow/compiler/xla/client/computation_builder.h
@@ -393,6 +393,11 @@ class ComputationBuilder {
ComputationDataHandle Dot(const ComputationDataHandle& lhs,
const ComputationDataHandle& rhs);
+ // Enqueues a general dot instruction onto the computation.
+ ComputationDataHandle DotGeneral(
+ const ComputationDataHandle& lhs, const ComputationDataHandle& rhs,
+ const DotDimensionNumbers& dimension_numbers);
+
// Default dimension numbers used for a 2D convolution.
static constexpr int64 kConvBatchDimension = 0;
static constexpr int64 kConvFeatureDimension = 1;
diff --git a/tensorflow/compiler/xla/literal_util.cc b/tensorflow/compiler/xla/literal_util.cc
index 93d3cd425f..250df5f4d5 100644
--- a/tensorflow/compiler/xla/literal_util.cc
+++ b/tensorflow/compiler/xla/literal_util.cc
@@ -252,6 +252,10 @@ Status Literal::Copy(const Literal& src_literal,
return *Literal::CreateR0<int32>(1);
case S64:
return *Literal::CreateR0<int64>(1);
+ case F16:
+ return *Literal::CreateR0<half>(static_cast<half>(1.0f));
+ case BF16:
+ return *Literal::CreateR0<bfloat16>(static_cast<bfloat16>(1.0f));
case F32:
return *Literal::CreateR0<float>(1);
case F64:
@@ -263,8 +267,6 @@ Status Literal::Copy(const Literal& src_literal,
case S16:
case U16:
LOG(FATAL) << "u16/s16 literals not yet implemented";
- case F16:
- return *Literal::CreateR0<half>(static_cast<half>(1.0f));
case TUPLE:
LOG(FATAL) << "tuple element type cannot take on value of 1";
case OPAQUE:
diff --git a/tensorflow/compiler/xla/literal_util.h b/tensorflow/compiler/xla/literal_util.h
index f37e529caf..069d1b33ca 100644
--- a/tensorflow/compiler/xla/literal_util.h
+++ b/tensorflow/compiler/xla/literal_util.h
@@ -285,11 +285,11 @@ class Literal {
std::unique_ptr<Literal> Relayout(const Layout& new_layout,
const ShapeIndex& shape_index = {}) const;
- // Creates a new literal by reshaping this literal to have 'shape'. Both the
- // original shape and 'shape' must contain the same number of elements. The
+ // Creates a new literal by reshaping this literal to have the given
+ // dimensions. The total number of elements must not change; The
// implementation currently only supports monotonic dim0-major layouts.
StatusOr<std::unique_ptr<Literal>> Reshape(
- tensorflow::gtl::ArraySlice<int64> shape) const;
+ tensorflow::gtl::ArraySlice<int64> dimensions) const;
// Creates a new literal by reordering the dimensions of this literal.
// The given `permutation` must be a permutation of the dimension numbers
diff --git a/tensorflow/compiler/xla/reference_util.cc b/tensorflow/compiler/xla/reference_util.cc
index 5bb81b80dd..bdf92eaed1 100644
--- a/tensorflow/compiler/xla/reference_util.cc
+++ b/tensorflow/compiler/xla/reference_util.cc
@@ -195,14 +195,26 @@ ReferenceUtil::ReduceWindow1DGeneric(
const tensorflow::gtl::ArraySlice<int64>& window,
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding) {
std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
- auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
+ return ReduceWindow1DGeneric(
+ operand, init, reduce_func, window, stride,
+ xla::MakePadding(dim_lengths, window, stride, padding));
+}
+/* static */ std::unique_ptr<std::vector<float>>
+ReferenceUtil::ReduceWindow1DGeneric(
+ const tensorflow::gtl::ArraySlice<float>& operand, float init,
+ const std::function<float(float, float)>& reduce_func,
+ const tensorflow::gtl::ArraySlice<int64>& window,
+ const tensorflow::gtl::ArraySlice<int64>& stride,
+ const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding) {
+ std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
std::vector<int64> window_counts(window.size(), 0);
std::vector<int64> pad_low(window.size(), 0);
for (int64 i = 0; i < window.size(); ++i) {
+ int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
window_counts[i] =
- WindowCount(dim_lengths[i], window[i], stride[i], padding);
- pad_low[i] = padding_both[i].first;
+ window_util::StridedBound(padded_width, window[i], stride[i]);
+ pad_low[i] = padding[i].first;
}
auto result = MakeUnique<std::vector<float>>(window_counts[0]);
diff --git a/tensorflow/compiler/xla/reference_util.h b/tensorflow/compiler/xla/reference_util.h
index 62d455d71a..58e1a84461 100644
--- a/tensorflow/compiler/xla/reference_util.h
+++ b/tensorflow/compiler/xla/reference_util.h
@@ -70,7 +70,7 @@ class ReferenceUtil {
// dilation factors.
static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensionsDilated(
const Array4D<float>& lhs, const Array4D<float>& rhs,
- std::pair<int64, int64> stride, Padding padding,
+ std::pair<int64, int64> kernel_stride, Padding padding,
std::pair<int64, int64> lhs_dilation,
std::pair<int64, int64> rhs_dilation, ConvolutionDimensionNumbers dnums);
@@ -184,6 +184,12 @@ class ReferenceUtil {
const std::function<float(float, float)>& reduce_func,
const tensorflow::gtl::ArraySlice<int64>& window,
const tensorflow::gtl::ArraySlice<int64>& stride, Padding padding);
+ static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
+ const tensorflow::gtl::ArraySlice<float>& operand, float init,
+ const std::function<float(float, float)>& reduce_func,
+ const tensorflow::gtl::ArraySlice<int64>& window,
+ const tensorflow::gtl::ArraySlice<int64>& stride,
+ const tensorflow::gtl::ArraySlice<std::pair<int64, int64>>& padding);
static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
const Array4D<float>& operand, float init,
const std::function<float(float, float)>& reduce_func,
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier.cc b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
index 71491218aa..b1d0345e70 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier.cc
@@ -597,9 +597,13 @@ Status AlgebraicSimplifierVisitor::HandleDot(HloInstruction* dot) {
// Simplify dot(transpose(a), transpose(b)) to transpose(dot(b,a)).
if (lhs->IsRank2Transpose() && rhs->IsRank2Transpose()) {
- auto new_dot = computation_->AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::PermuteDimensions({1, 0}, dot->shape()), HloOpcode::kDot,
- rhs->mutable_operand(0), lhs->mutable_operand(0)));
+ DotDimensionNumbers dot_dimension_numbers;
+ dot_dimension_numbers.add_lhs_contracting_dimensions(1);
+ dot_dimension_numbers.add_rhs_contracting_dimensions(0);
+ auto new_dot = computation_->AddInstruction(HloInstruction::CreateDot(
+ ShapeUtil::PermuteDimensions({1, 0}, dot->shape()),
+ rhs->mutable_operand(0), lhs->mutable_operand(0),
+ dot_dimension_numbers));
return ReplaceWithNewInstruction(
dot, HloInstruction::CreateTranspose(dot->shape(), new_dot, {1, 0}));
}
@@ -1616,8 +1620,11 @@ Status AlgebraicSimplifierVisitor::HandleConvolution(
auto new_lhs = add_bitcast(new_input_shape, lhs);
auto new_rhs = add_bitcast(new_filter_shape, rhs);
- auto dot = computation_->AddInstruction(HloInstruction::CreateBinary(
- dot_output_shape, HloOpcode::kDot, new_lhs, new_rhs));
+ DotDimensionNumbers dot_dimension_numbers;
+ dot_dimension_numbers.add_lhs_contracting_dimensions(1);
+ dot_dimension_numbers.add_rhs_contracting_dimensions(0);
+ auto dot = computation_->AddInstruction(HloInstruction::CreateDot(
+ dot_output_shape, new_lhs, new_rhs, dot_dimension_numbers));
return ReplaceInstruction(convolution, add_bitcast(convolution_shape, dot));
}
diff --git a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
index 56dfb1cf0b..3d70505f6e 100644
--- a/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
+++ b/tensorflow/compiler/xla/service/algebraic_simplifier_test.cc
@@ -2138,8 +2138,10 @@ TEST_F(AlgebraicSimplifierTest, IteratorInvalidation) {
builder.AddInstruction(HloInstruction::CreateParameter(0, r1f32, "x"));
HloInstruction* y =
builder.AddInstruction(HloInstruction::CreateParameter(1, r1f32, "y"));
- builder.AddInstruction(
- HloInstruction::CreateBinary(r1f32, HloOpcode::kDot, x, y));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ builder.AddInstruction(HloInstruction::CreateDot(r1f32, x, y, dot_dnums));
std::unique_ptr<HloComputation> dot_computation(builder.Build());
HloComputation::Builder call_builder(TestName() + ".Call");
diff --git a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc
index c6193b3fbb..2bbae25aee 100644
--- a/tensorflow/compiler/xla/service/batchnorm_rewriter.cc
+++ b/tensorflow/compiler/xla/service/batchnorm_rewriter.cc
@@ -149,6 +149,15 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
if (!rewrite_training_op_) {
return Status::OK();
}
+
+ std::vector<HloInstruction*> added_instructions;
+ auto add = [&](std::unique_ptr<HloInstruction> inst) {
+ HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
+ added_instructions.push_back(added_inst);
+ return added_inst;
+ };
+ int64 instruction_count_before = computation_->instruction_count();
+
// Expand batch norm training into smaller HLO ops.
HloInstruction* operand = batch_norm->mutable_operand(0);
const Shape operand_shape = operand->shape();
@@ -160,7 +169,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
Literal::CreateR0<float>(size_in_elements / feature_count);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
- auto elements_per_feature = computation_->AddInstruction(
+ auto elements_per_feature = add(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
HloInstruction* scale = batch_norm->mutable_operand(1);
@@ -169,14 +178,12 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
auto zero_literal = Literal::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
- auto zero = computation_->AddInstruction(
- HloInstruction::CreateConstant(std::move(zero_literal)));
+ auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
- auto epsilon = computation_->AddInstruction(
- HloInstruction::CreateConstant(std::move(epsilon_literal)));
-
+ auto epsilon =
+ add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
std::vector<int64> dimensions_without_feature;
for (int64 i = 0; i < ShapeUtil::Rank(operand_shape); ++i) {
@@ -185,105 +192,110 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
}
}
- auto scale_broadcasted = computation_->AddInstruction(
+ auto scale_broadcasted = add(
HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
- auto offset_broadcasted = computation_->AddInstruction(
+ auto offset_broadcasted = add(
HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
HloComputation* add_reduce_computation =
GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
// X^2.
- auto operand_squared =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kMultiply, operand, operand));
+ auto operand_squared = add(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kMultiply, operand, operand));
// Sum[X].
- auto sum = computation_->AddInstruction(HloInstruction::CreateReduce(
- feature_shape, operand, zero, dimensions_without_feature,
- add_reduce_computation));
+ auto sum = add(HloInstruction::CreateReduce(feature_shape, operand, zero,
+ dimensions_without_feature,
+ add_reduce_computation));
// Sum[X^2].
- auto squared_sum = computation_->AddInstruction(HloInstruction::CreateReduce(
+ auto squared_sum = add(HloInstruction::CreateReduce(
feature_shape, operand_squared, zero, dimensions_without_feature,
add_reduce_computation));
// Fuse two parallel reduces together to improve performance.
- if (use_fusion_) {
- auto tuple = computation_->AddInstruction(
- HloInstruction::CreateTuple({sum, squared_sum}));
+ if (use_fusion_ && !batch_norm->has_sharding()) {
+ auto tuple = add(HloInstruction::CreateTuple({sum, squared_sum}));
auto fused = computation_->CreateFusionInstruction(
{tuple, sum, squared_sum, operand_squared},
HloInstruction::FusionKind::kInput);
- sum = computation_->AddInstruction(
- HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
+ sum = add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
- squared_sum = computation_->AddInstruction(
- HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
+ squared_sum =
+ add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
}
// E[X].
- auto mean = computation_->AddInstruction(HloInstruction::CreateBinary(
+ auto mean = add(HloInstruction::CreateBinary(
feature_shape, HloOpcode::kDivide, sum, elements_per_feature));
- auto mean_broadcasted = computation_->AddInstruction(
+ auto mean_broadcasted = add(
HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
// E[X^2].
- auto square_mean = computation_->AddInstruction(HloInstruction::CreateBinary(
+ auto square_mean = add(HloInstruction::CreateBinary(
feature_shape, HloOpcode::kDivide, squared_sum, elements_per_feature));
// E^2[X].
- auto mean_square = computation_->AddInstruction(HloInstruction::CreateBinary(
+ auto mean_square = add(HloInstruction::CreateBinary(
feature_shape, HloOpcode::kMultiply, mean, mean));
// Var[X].
- auto var = computation_->AddInstruction(HloInstruction::CreateBinary(
+ auto var = add(HloInstruction::CreateBinary(
feature_shape, HloOpcode::kSubtract, square_mean, mean_square));
- auto var_broadcasted = computation_->AddInstruction(
- HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
+ auto var_broadcasted =
+ add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
// Var[X] + epsilon.
- auto var_add_epsilon =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
+ auto var_add_epsilon = add(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
- auto neg_half = computation_->AddInstruction(
- HloInstruction::CreateConstant(std::move(neg_half_literal)));
+ auto neg_half =
+ add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
// 1 / Sqrt[Var[X] + epsilon].
- auto rsqrt_var_add_epsilon =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
+ auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
// X - E[X].
- auto operand_minus_mean =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
+ auto operand_minus_mean = add(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
// (X - E[X]) / Sqrt[Var[X] + epsilon].
- auto normalized = computation_->AddInstruction(
+ auto normalized = add(
HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
operand_minus_mean, rsqrt_var_add_epsilon));
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
- auto scaled_normalized =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
+ auto scaled_normalized = add(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
- auto shifted_normalized = computation_->AddInstruction(
- HloInstruction::CreateBinary(operand_shape, HloOpcode::kAdd,
- scaled_normalized, offset_broadcasted));
-
- TF_CHECK_OK(ReplaceWithNewInstruction(
- batch_norm,
- HloInstruction::CreateTuple({shifted_normalized, mean, var})));
+ auto shifted_normalized = add(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted));
+
+ auto tuple = HloInstruction::CreateTuple({shifted_normalized, mean, var});
+
+ if (batch_norm->has_sharding()) {
+ int64 instruction_count_after = computation_->instruction_count();
+ CHECK_EQ(instruction_count_after,
+ instruction_count_before + added_instructions.size());
+ for (HloInstruction* inst : added_instructions) {
+ if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
+ inst->set_sharding(batch_norm->sharding());
+ } else {
+ inst->set_sharding(HloSharding::Replicate());
+ }
+ }
+ tuple->set_sharding(batch_norm->sharding());
+ }
+ TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
return Status::OK();
}
@@ -317,52 +329,69 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference(
}
}
- auto scale_broadcasted = computation_->AddInstruction(
+ std::vector<HloInstruction*> added_instructions;
+ auto add = [&](std::unique_ptr<HloInstruction> inst) {
+ HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
+ added_instructions.push_back(added_inst);
+ return added_inst;
+ };
+ int64 instruction_count_before = computation_->instruction_count();
+
+ auto scale_broadcasted = add(
HloInstruction::CreateBroadcast(operand_shape, scale, {feature_index}));
- auto offset_broadcasted = computation_->AddInstruction(
+ auto offset_broadcasted = add(
HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));
- auto mean_broadcasted = computation_->AddInstruction(
+ auto mean_broadcasted = add(
HloInstruction::CreateBroadcast(operand_shape, mean, {feature_index}));
- auto var_broadcasted = computation_->AddInstruction(
- HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
+ auto var_broadcasted =
+ add(HloInstruction::CreateBroadcast(operand_shape, var, {feature_index}));
// Var[X] + epsilon.
- auto var_add_epsilon =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
+ auto var_add_epsilon = add(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));
auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
- auto neg_half = computation_->AddInstruction(
- HloInstruction::CreateConstant(std::move(neg_half_literal)));
+ auto neg_half =
+ add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
// 1 / Sqrt[Var[X] + epsilon].
- auto rsqrt_var_add_epsilon =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
+ auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kPower, var_add_epsilon, neg_half));
// X - E[X].
- auto operand_minus_mean =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
+ auto operand_minus_mean = add(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kSubtract, operand, mean_broadcasted));
// (X - E[X]) / Sqrt[Var[X] + epsilon].
- auto normalized = computation_->AddInstruction(
+ auto normalized = add(
HloInstruction::CreateBinary(operand_shape, HloOpcode::kMultiply,
operand_minus_mean, rsqrt_var_add_epsilon));
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale.
- auto scaled_normalized =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
+ auto scaled_normalized = add(HloInstruction::CreateBinary(
+ operand_shape, HloOpcode::kMultiply, normalized, scale_broadcasted));
// (X - E[X]) / Sqrt[Var[X] + epsilon] * scale + offset.
auto shifted_normalized = HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, scaled_normalized, offset_broadcasted);
+ int64 instruction_count_after = computation_->instruction_count();
+ CHECK_EQ(instruction_count_after,
+ instruction_count_before + added_instructions.size());
+ if (batch_norm->has_sharding()) {
+ for (HloInstruction* inst : added_instructions) {
+ if (ShapeUtil::Equal(inst->shape(), operand_shape)) {
+ inst->set_sharding(batch_norm->sharding());
+ } else {
+ inst->set_sharding(HloSharding::Replicate());
+ }
+ }
+ shifted_normalized->set_sharding(batch_norm->sharding());
+ }
TF_CHECK_OK(
ReplaceWithNewInstruction(batch_norm, std::move(shifted_normalized)));
return Status::OK();
@@ -385,6 +414,13 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
if (!rewrite_grad_op_) {
return Status::OK();
}
+ std::vector<HloInstruction*> added_instructions;
+ auto add = [&](std::unique_ptr<HloInstruction> inst) {
+ HloInstruction* added_inst = computation_->AddInstruction(std::move(inst));
+ added_instructions.push_back(added_inst);
+ return added_inst;
+ };
+ int64 instruction_count_before = computation_->instruction_count();
HloInstruction* activation = batch_norm->mutable_operand(0);
const Shape activation_shape = activation->shape();
@@ -403,23 +439,22 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
Literal::CreateR0<float>(size_in_elements / feature_count);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
- auto elements_per_feature = computation_->AddInstruction(
+ auto elements_per_feature = add(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));
auto zero_literal = Literal::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
- auto zero = computation_->AddInstruction(
- HloInstruction::CreateConstant(std::move(zero_literal)));
+ auto zero = add(HloInstruction::CreateConstant(std::move(zero_literal)));
auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
- auto neg_half = computation_->AddInstruction(
- HloInstruction::CreateConstant(std::move(neg_half_literal)));
+ auto neg_half =
+ add(HloInstruction::CreateConstant(std::move(neg_half_literal)));
auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
- auto epsilon = computation_->AddInstruction(
- HloInstruction::CreateConstant(std::move(epsilon_literal)));
+ auto epsilon =
+ add(HloInstruction::CreateConstant(std::move(epsilon_literal)));
std::vector<int64> dimensions_without_feature;
@@ -429,126 +464,131 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
}
}
- auto scale_broadcasted =
- computation_->AddInstruction(HloInstruction::CreateBroadcast(
- activation_shape, scale, {feature_index}));
- auto variance_broadcasted =
- computation_->AddInstruction(HloInstruction::CreateBroadcast(
- activation_shape, variance, {feature_index}));
+ auto scale_broadcasted = add(HloInstruction::CreateBroadcast(
+ activation_shape, scale, {feature_index}));
+ auto variance_broadcasted = add(HloInstruction::CreateBroadcast(
+ activation_shape, variance, {feature_index}));
// E[X].
- auto mean_broadcasted = computation_->AddInstruction(
+ auto mean_broadcasted = add(
HloInstruction::CreateBroadcast(activation_shape, mean, {feature_index}));
// rsqrt[Var[X] + epsilon].
- auto rsqrt_var_add_epsilon_broadcasted =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- activation_shape, HloOpcode::kPower,
- computation_->AddInstruction(
- HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd,
- variance_broadcasted, epsilon)),
- neg_half));
-
- auto rsqrt_var_add_epsilon =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- feature_shape, HloOpcode::kPower,
- computation_->AddInstruction(HloInstruction::CreateBinary(
- feature_shape, HloOpcode::kAdd, variance, epsilon)),
- neg_half));
+ auto rsqrt_var_add_epsilon_broadcasted = add(HloInstruction::CreateBinary(
+ activation_shape, HloOpcode::kPower,
+ add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd,
+ variance_broadcasted, epsilon)),
+ neg_half));
+
+ auto rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
+ feature_shape, HloOpcode::kPower,
+ add(HloInstruction::CreateBinary(feature_shape, HloOpcode::kAdd, variance,
+ epsilon)),
+ neg_half));
// X - E[X].
- auto activation_minus_mean = computation_->AddInstruction(
- HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract,
- activation, mean_broadcasted));
+ auto activation_minus_mean = add(HloInstruction::CreateBinary(
+ activation_shape, HloOpcode::kSubtract, activation, mean_broadcasted));
// Grad[Y] * (X - E[X]).
- auto grad_output_times_activiation_minus_mean = computation_->AddInstruction(
- HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
- grad_output, activation_minus_mean));
+ auto grad_output_times_activiation_minus_mean =
+ add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
+ grad_output, activation_minus_mean));
HloComputation* add_reduce_computation =
GetScalarBinaryComputation(ptype, HloOpcode::kAdd);
// sum(Grad[Y] * (X - E[X])).
auto sum_grad_output_times_activiation_minus_mean =
- computation_->AddInstruction(HloInstruction::CreateReduce(
+ add(HloInstruction::CreateReduce(
feature_shape, grad_output_times_activiation_minus_mean, zero,
dimensions_without_feature, add_reduce_computation));
// Grad[beta] = Sum(Grad[Y]).
- auto grad_beta = computation_->AddInstruction(HloInstruction::CreateReduce(
+ auto grad_beta = add(HloInstruction::CreateReduce(
feature_shape, grad_output, zero, dimensions_without_feature,
add_reduce_computation));
- if (use_fusion_) {
- auto tuple = computation_->AddInstruction(HloInstruction::CreateTuple(
+ if (use_fusion_ && !batch_norm->has_sharding()) {
+ auto tuple = add(HloInstruction::CreateTuple(
{sum_grad_output_times_activiation_minus_mean, grad_beta}));
auto fused = computation_->CreateFusionInstruction(
{tuple, sum_grad_output_times_activiation_minus_mean, grad_beta},
HloInstruction::FusionKind::kInput);
- sum_grad_output_times_activiation_minus_mean = computation_->AddInstruction(
- HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
+ sum_grad_output_times_activiation_minus_mean =
+ add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 0));
- grad_beta = computation_->AddInstruction(
- HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
+ grad_beta =
+ add(HloInstruction::CreateGetTupleElement(feature_shape, fused, 1));
}
// Grad[scale] = Sum(Grad[Y] * (X - E[X]) * rsqrt[Var[X] + epsilon]).
- auto grad_scale = computation_->AddInstruction(HloInstruction::CreateBinary(
+ auto grad_scale = add(HloInstruction::CreateBinary(
feature_shape, HloOpcode::kMultiply,
sum_grad_output_times_activiation_minus_mean, rsqrt_var_add_epsilon));
// I2 = Sum(Grad[Y])
- auto I2 = computation_->AddInstruction(HloInstruction::CreateBroadcast(
- activation_shape, grad_beta, {feature_index}));
+ auto i2 = add(HloInstruction::CreateBroadcast(activation_shape, grad_beta,
+ {feature_index}));
// I3 = Sum(Grad[Y] * (X - E[X]))
- auto I3 = computation_->AddInstruction(HloInstruction::CreateBroadcast(
+ auto i3 = add(HloInstruction::CreateBroadcast(
activation_shape, sum_grad_output_times_activiation_minus_mean,
{feature_index}));
// I4 = (X - E[X]) * I3
- auto I4 = computation_->AddInstruction(HloInstruction::CreateBinary(
- activation_shape, HloOpcode::kMultiply, I3, activation_minus_mean));
+ auto i4 = add(HloInstruction::CreateBinary(
+ activation_shape, HloOpcode::kMultiply, i3, activation_minus_mean));
// I5 = I4 / (Var[X] + epsilon)
- auto I5 = computation_->AddInstruction(HloInstruction::CreateBinary(
- activation_shape, HloOpcode::kDivide, I4,
- computation_->AddInstruction(HloInstruction::CreateBinary(
- activation_shape, HloOpcode::kAdd, variance_broadcasted, epsilon))));
+ auto i5 = add(HloInstruction::CreateBinary(
+ activation_shape, HloOpcode::kDivide, i4,
+ add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kAdd,
+ variance_broadcasted, epsilon))));
// scale * rsqrt[Var[X] + epsilon] * 1/N
- auto scale_times_rsqrt_var_add_epsilon =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- activation_shape, HloOpcode::kMultiply, scale_broadcasted,
- rsqrt_var_add_epsilon_broadcasted));
+ auto scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
+ activation_shape, HloOpcode::kMultiply, scale_broadcasted,
+ rsqrt_var_add_epsilon_broadcasted));
- scale_times_rsqrt_var_add_epsilon =
- computation_->AddInstruction(HloInstruction::CreateBinary(
- activation_shape, HloOpcode::kDivide,
- scale_times_rsqrt_var_add_epsilon, elements_per_feature));
+ scale_times_rsqrt_var_add_epsilon = add(HloInstruction::CreateBinary(
+ activation_shape, HloOpcode::kDivide, scale_times_rsqrt_var_add_epsilon,
+ elements_per_feature));
- auto I1 = computation_->AddInstruction(
- HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
- grad_output, elements_per_feature));
+ auto i1 =
+ add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
+ grad_output, elements_per_feature));
// I6 = I1 - I2 - I5
- auto I6 = computation_->AddInstruction(HloInstruction::CreateBinary(
+ auto i6 = add(HloInstruction::CreateBinary(
activation_shape, HloOpcode::kSubtract,
- computation_->AddInstruction(HloInstruction::CreateBinary(
- activation_shape, HloOpcode::kSubtract, I1, I2)),
- I5));
+ add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kSubtract,
+ i1, i2)),
+ i5));
// Grad[X] = scale * rsqrt[Var[X] + epsilon] * 1/N * I6.
- auto grad_activation = computation_->AddInstruction(
- HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
- scale_times_rsqrt_var_add_epsilon, I6));
+ auto grad_activation =
+ add(HloInstruction::CreateBinary(activation_shape, HloOpcode::kMultiply,
+ scale_times_rsqrt_var_add_epsilon, i6));
+ auto tuple =
+ HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta});
+ if (batch_norm->has_sharding()) {
+ int64 instruction_count_after = computation_->instruction_count();
+ CHECK_EQ(instruction_count_after,
+ instruction_count_before + added_instructions.size());
+ for (HloInstruction* inst : added_instructions) {
+ if (ShapeUtil::Equal(inst->shape(), activation_shape)) {
+ inst->set_sharding(batch_norm->sharding());
+ } else {
+ inst->set_sharding(HloSharding::Replicate());
+ }
+ }
+ tuple->set_sharding(batch_norm->sharding());
+ }
- TF_CHECK_OK(ReplaceWithNewInstruction(
- batch_norm,
- HloInstruction::CreateTuple({grad_activation, grad_scale, grad_beta})));
+ TF_CHECK_OK(ReplaceWithNewInstruction(batch_norm, std::move(tuple)));
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/buffer_assignment_test.cc b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
index 8fba8ef5e5..09681b34e7 100644
--- a/tensorflow/compiler/xla/service/buffer_assignment_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_assignment_test.cc
@@ -1360,10 +1360,13 @@ TEST_F(BufferAssignmentTest, OneTempAllocation) {
HloInstruction::CreateParameter(1, shape_3x4, "param_b"));
auto param_c = builder.AddInstruction(
HloInstruction::CreateParameter(2, shape_4x4, "param_c"));
- auto dot_ab = builder.AddInstruction(HloInstruction::CreateBinary(
- shape_2x4, HloOpcode::kDot, param_a, param_b));
- auto dot_bc = builder.AddInstruction(HloInstruction::CreateBinary(
- shape_3x4, HloOpcode::kDot, param_b, param_c));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto dot_ab = builder.AddInstruction(
+ HloInstruction::CreateDot(shape_2x4, param_a, param_b, dot_dnums));
+ auto dot_bc = builder.AddInstruction(
+ HloInstruction::CreateDot(shape_3x4, param_b, param_c, dot_dnums));
builder.AddInstruction(
HloInstruction::CreateConcatenate(shape_5x4, {dot_ab, dot_bc}, 1));
@@ -1708,9 +1711,8 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
BufferAssigner::Run(
module.get(),
xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence),
- ByteSizeOf,
- [](LogicalBuffer::Color) { return 1; })
- .ConsumeValueOrDie();
+ ByteSizeOf, [](LogicalBuffer::Color) { return 1; })
+ .ConsumeValueOrDie();
EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
}
diff --git a/tensorflow/compiler/xla/service/buffer_liveness_test.cc b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
index bbb42d494b..13825fe05b 100644
--- a/tensorflow/compiler/xla/service/buffer_liveness_test.cc
+++ b/tensorflow/compiler/xla/service/buffer_liveness_test.cc
@@ -167,11 +167,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) {
SequentialHloOrdering::HloModuleSequence sequence;
sequence.insert({entry, {param0, negate, param1, exp, add}});
- auto liveness = BufferLiveness::Run(
- module.get(),
- xla::MakeUnique<SequentialHloOrdering>(
- module.get(), sequence))
- .ConsumeValueOrDie();
+ auto liveness =
+ BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
+ module.get(), sequence))
+ .ConsumeValueOrDie();
// Entry parameters interfere as if they are defined simultaneously at
// the very beginning.
@@ -296,7 +295,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
module_sequence.emplace(computation, order);
auto liveness =
BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
- module.get(), module_sequence))
+ module.get(), module_sequence))
.ConsumeValueOrDie();
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
@@ -625,9 +624,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
// Run BufferLiveness on 'module'.
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(
- module.get()))
+ BufferLiveness::Run(
+ module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
@@ -738,9 +736,8 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
module->AddEmbeddedComputation(builder.Build());
// Run BufferLiveness on 'module'.
auto liveness =
- BufferLiveness::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(
- module.get()))
+ BufferLiveness::Run(
+ module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD
index e1eed498f6..bf41d5ce07 100644
--- a/tensorflow/compiler/xla/service/cpu/BUILD
+++ b/tensorflow/compiler/xla/service/cpu/BUILD
@@ -250,6 +250,8 @@ cc_library(
":dot_op_emitter",
":external_constant_pool",
":ir_emission_utils",
+ ":ir_function",
+ ":parallel_loop_emitter",
":shape_partition",
":simple_orc_jit",
"//tensorflow/compiler/xla:shape_util",
@@ -281,6 +283,38 @@ cc_library(
)
cc_library(
+ name = "ir_function",
+ srcs = ["ir_function.cc"],
+ hdrs = ["ir_function.h"],
+ deps = [
+ "//tensorflow/compiler/xla:status_macros",
+ "//tensorflow/compiler/xla:statusor",
+ "//tensorflow/compiler/xla:types",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/compiler/xla/service/llvm_ir:vector_support_library",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
+ name = "parallel_loop_emitter",
+ srcs = ["parallel_loop_emitter.cc"],
+ hdrs = ["parallel_loop_emitter.h"],
+ deps = [
+ "//tensorflow/compiler/xla:shape_util",
+ "//tensorflow/compiler/xla:xla_data_proto",
+ "//tensorflow/compiler/xla/service/llvm_ir:ir_array",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_loop",
+ "//tensorflow/compiler/xla/service/llvm_ir:llvm_util",
+ "//tensorflow/compiler/xla/service/llvm_ir:loop_emitter",
+ "//tensorflow/core:lib",
+ "@llvm//:core",
+ ],
+)
+
+cc_library(
name = "dot_op_emitter",
srcs = ["dot_op_emitter.cc"],
hdrs = ["dot_op_emitter.h"],
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
index addd7284c5..988f632748 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
@@ -528,9 +528,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// uses data dependencies for determining order.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
- BufferAssigner::Run(module.get(),
- xla::MakeUnique<DependencyHloOrdering>(module.get()),
- BufferSizeBytesFunction(), memory_alignment));
+ BufferAssigner::Run(
+ module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()),
+ BufferSizeBytesFunction(), memory_alignment));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@@ -642,10 +642,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
- BufferAssigner::Run(
- module.get(),
- xla::MakeUnique<SequentialHloOrdering>(module.get(), module_sequence),
- BufferSizeBytesFunction(), memory_alignment));
+ BufferAssigner::Run(module.get(),
+ xla::MakeUnique<SequentialHloOrdering>(
+ module.get(), module_sequence),
+ BufferSizeBytesFunction(), memory_alignment));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@@ -824,7 +824,8 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(
- module, xla::MakeUnique<SequentialHloOrdering>(module, module_sequence),
+ module,
+ xla::MakeUnique<SequentialHloOrdering>(module, module_sequence),
BufferSizeBytesFunction(), memory_alignment));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
diff --git a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
index b9e4d006d7..1c04c9835e 100644
--- a/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
+++ b/tensorflow/compiler/xla/service/cpu/cpu_instruction_fusion_test.cc
@@ -31,6 +31,14 @@ namespace {
using InstructionFusionTest = HloTestBase;
+std::unique_ptr<HloInstruction> MakeDot(const Shape& shape, HloInstruction* lhs,
+ HloInstruction* rhs) {
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ return HloInstruction::CreateDot(shape, lhs, rhs, dot_dnums);
+}
+
TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
HloComputation::Builder builder(TestName());
HloInstruction* arg0 = builder.AddInstruction(HloInstruction::CreateParameter(
@@ -40,8 +48,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_0) {
HloInstruction* exp0 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kExp, arg0));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, exp0, arg1));
+ HloInstruction* dot = builder.AddInstruction(
+ MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), exp0, arg1));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
@@ -59,8 +67,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Basic_1) {
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, exp1));
+ HloInstruction* dot = builder.AddInstruction(
+ MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, exp1));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
@@ -80,8 +88,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Bitcast) {
ShapeUtil::MakeShape(S32, {2, 512, 2, 128}), HloOpcode::kExp, arg0));
HloInstruction* bitcast0 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {1024, 256}), HloOpcode::kBitcast, exp0));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, bitcast0, arg1));
+ HloInstruction* dot = builder.AddInstruction(
+ MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), bitcast0, arg1));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
@@ -102,8 +110,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_Reshape) {
HloInstruction* reshape0 =
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(S32, {1024, 256}), exp0));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {1024, 1}), HloOpcode::kDot, reshape0, arg1));
+ HloInstruction* dot = builder.AddInstruction(
+ MakeDot(ShapeUtil::MakeShape(F32, {1024, 1}), reshape0, arg1));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
@@ -121,8 +129,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TooLarge) {
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {256, 32 * 1024}), HloOpcode::kExp, arg1));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {1, 32 * 1024}), HloOpcode::kDot, arg0, exp1));
+ HloInstruction* dot = builder.AddInstruction(
+ MakeDot(ShapeUtil::MakeShape(F32, {1, 32 * 1024}), arg0, exp1));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
@@ -140,8 +148,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_ElementReuse) {
HloInstruction* exp1 = builder.AddInstruction(HloInstruction::CreateUnary(
ShapeUtil::MakeShape(S32, {256, 1024}), HloOpcode::kExp, arg1));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {2, 1024}), HloOpcode::kDot, arg0, exp1));
+ HloInstruction* dot = builder.AddInstruction(
+ MakeDot(ShapeUtil::MakeShape(F32, {2, 1024}), arg0, exp1));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
@@ -162,8 +170,8 @@ TEST_F(InstructionFusionTest, DotOperationFusion_TransposeFusion) {
HloInstruction* transpose1 =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(S32, {256, 1024}), exp1, {1, 0}));
- builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {1, 1024}), HloOpcode::kDot, arg0, transpose1));
+ builder.AddInstruction(
+ MakeDot(ShapeUtil::MakeShape(F32, {1, 1024}), arg0, transpose1));
auto module = CreateNewModule();
auto computation = module->AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
index 4c40dae512..4ccff756a3 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc
@@ -518,14 +518,14 @@ DotOpEmitter::DotOpEmitter(const HloInstruction& dot, bool transpose_lhs,
bool DotOpEmitter::ShapesAreLegalForRuntimeDot() const { return true; }
bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
- if (dot_.shape().dimensions_size() != 2 ||
- ProfitableToImplementDotInUntiledLlvmIr(dot_) ==
- DotInLlvmIrProfitable::kYes) {
+ if (dot_.shape().dimensions_size() != 2) {
return false;
}
- if (!primitive_util::IsFloatingPointType(dot_.shape().element_type()) &&
- !primitive_util::IsIntegralType(dot_.shape().element_type())) {
+ PrimitiveType primitive_type = dot_.shape().element_type();
+
+ if (!primitive_util::IsFloatingPointType(primitive_type) &&
+ !primitive_util::IsIntegralType(primitive_type)) {
return false;
}
@@ -575,30 +575,50 @@ bool DotOpEmitter::EmitLlvmIrDotIfProfitable() {
int64 tiling_factor = GetGemvTilingFactor();
CHECK_GT(tiling_factor, 0);
+ llvm::Value* result_op = target_array_.GetBasePointer();
+ llvm::Value* lhs_op =
+ swap_operands ? rhs_array_.GetBasePointer() : lhs_array_.GetBasePointer();
+ llvm::Value* rhs_op =
+ swap_operands ? lhs_array_.GetBasePointer() : rhs_array_.GetBasePointer();
+
if (is_column_major_matrix_vector) {
VLOG(2) << "Emitting column major matrix-vector multiply with m = " << m
<< " and k = " << k;
- ColumnMajorMatrixVectorProductEmitter emitter(
- dot_.shape().element_type(), /*tile_rows=*/8,
- /*tile_cols=*/tiling_factor, m, k,
- swap_operands ? rhs_array_.GetBasePointer()
- : lhs_array_.GetBasePointer(),
- swap_operands ? lhs_array_.GetBasePointer()
- : rhs_array_.GetBasePointer(),
- target_array_.GetBasePointer(), ir_builder_);
- emitter.Emit();
+ int64 tile_rows = 8;
+ int64 tile_cols = tiling_factor;
+
+ string kernel_name = tensorflow::strings::StrCat(
+ "col_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
+ "_", tile_cols, "_", m, "_", k);
+
+ KernelSupportLibrary::EmitAndCallOutlinedKernel(
+ ir_builder_, kernel_name, lhs_op, rhs_op, result_op,
+ [this, tile_rows, tile_cols, m, k, primitive_type](
+ llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* result_op) {
+ ColumnMajorMatrixVectorProductEmitter emitter(
+ primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
+ result_op, ir_builder_);
+ emitter.Emit();
+ });
} else {
VLOG(2) << "Emitting row major matrix-vector multiply with m = " << m
<< " and k = " << k;
- RowMajorMatrixVectorProductEmitter emitter(
- dot_.shape().element_type(), /*tile_rows=*/tiling_factor,
- /*tile_cols=*/8, m, k,
- swap_operands ? rhs_array_.GetBasePointer()
- : lhs_array_.GetBasePointer(),
- swap_operands ? lhs_array_.GetBasePointer()
- : rhs_array_.GetBasePointer(),
- target_array_.GetBasePointer(), ir_builder_);
- emitter.Emit();
+ int64 tile_rows = tiling_factor;
+ int64 tile_cols = 8;
+
+ string kernel_name = tensorflow::strings::StrCat(
+ "row_major_gemv_", PrimitiveType_Name(primitive_type), "_", tile_rows,
+ "_", tile_cols, "_", m, "_", k);
+
+ KernelSupportLibrary::EmitAndCallOutlinedKernel(
+ ir_builder_, kernel_name, lhs_op, rhs_op, result_op,
+ [this, tile_rows, tile_cols, m, k, primitive_type](
+ llvm::Value* lhs_op, llvm::Value* rhs_op, llvm::Value* result_op) {
+ RowMajorMatrixVectorProductEmitter emitter(
+ primitive_type, tile_rows, tile_cols, m, k, lhs_op, rhs_op,
+ result_op, ir_builder_);
+ emitter.Emit();
+ });
}
return true;
@@ -977,9 +997,7 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
return false;
}
- if (ProfitableToImplementDotInUntiledLlvmIr(hlo) ==
- DotInLlvmIrProfitable::kYes ||
- ProfitableToImplementDotInTiledLlvmIr(hlo)) {
+ if (ProfitableToImplementDotInTiledLlvmIr(hlo)) {
return false;
}
@@ -1010,46 +1028,11 @@ bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo) {
return false;
}
-DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
- const HloInstruction& dot) {
- if (dot.opcode() == HloOpcode::kDot && dot.shape().dimensions_size() == 2) {
- const Shape& result_shape = dot.shape();
- // kReductionDimensionThresholdBytes was chosen to be 1/4 of a typical L1
- // cache line size, so that we can have the reduction dimension of both the
- // LHS and RHS matrices and still have some space "left over". This needs
- // to be tuned further.
- const int64 kReductionDimensionThresholdBytes = 8 * 1024;
- const bool single_threaded_eigen =
- !dot.GetModule()->config().debug_options().xla_cpu_multi_thread_eigen();
-
- // This is the point at which it is better to call into Eigen and shard the
- // dot across multiple worker threads. This is a rough estimate by running
- // a matmult benchmark on my local machine, and it can be tuned further.
- const int64 kMaxSingleThreadedFlops = 16 * 1024;
-
- const int64 M = result_shape.dimensions(0);
- const int64 N = result_shape.dimensions(1);
- const int64 K = dot.operand(1)->shape().dimensions(0);
- const int64 primitive_type_size =
- ShapeUtil::ByteSizeOfPrimitiveType(result_shape.element_type());
- if (M == 1 &&
- K * primitive_type_size <= kReductionDimensionThresholdBytes &&
- (single_threaded_eigen || M * K * N <= kMaxSingleThreadedFlops)) {
- // Heuristics:
- //
- // - Look for a configuration where we will likely be able to keep LHS in
- // L1 and do a cache-optimal traversal of RHS.
- //
- // - Bail out on matrices that are large enough that Eigen can profitably
- // shard the computation across multiple cores. This only applies when
- // multi-threading is enabled.
- return LayoutUtil::IsMonotonicWithDim0Major(
- dot.operand(1)->shape().layout())
- ? DotInLlvmIrProfitable::kWithColumnMajorRhs
- : DotInLlvmIrProfitable::kYes;
- }
- }
- return DotInLlvmIrProfitable::kNo;
+// For vector-matrix dot products, it is always profitable to make the Rhs
+// column major.
+bool ProfitableToMakeDotRhsColumnMajor(const HloInstruction& hlo) {
+ return hlo.opcode() == HloOpcode::kDot &&
+ hlo.shape().dimensions_size() == 2 && hlo.shape().dimensions(0) == 1;
}
bool ProfitableToImplementDotInTiledLlvmIr(const HloInstruction& dot) {
diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
index c9168ccc0f..2badb26f90 100644
--- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.h
@@ -32,19 +32,9 @@ namespace cpu {
bool PotentiallyImplementedAsEigenDot(const HloInstruction& hlo);
-enum class DotInLlvmIrProfitable { kYes, kNo, kWithColumnMajorRhs };
-
-// Returns a value to indicate if (and under what conditions) will lowering
-// |dot| as a untiled LLVM IR dot operation be profitable over calling into
-// Eigen or emitting a tiled LLVM IR implementation. Possible return values
-// are:
-//
-// * DotInLlvmIrProfitable::kYes - always profitable.
-// * DotInLlvmIrProfitable::kNo - never profitable.
-// * DotInLlvmIrProfitable::kWithColumnMajorRhs - only if we can manage to make
-// the Rhs layout column major.
-DotInLlvmIrProfitable ProfitableToImplementDotInUntiledLlvmIr(
- const HloInstruction& dot);
+// Returns true to indicate that |hlo| is a dot, and that it is profitable to
+// switch the layout of the |hlo|'s RHS operand to column major.
+bool ProfitableToMakeDotRhsColumnMajor(const HloInstruction& hlo);
// Returns true to indicate that we can generate a tiled LLVM IR implementation
// for |dot|.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
index 502dd2e738..bb75d3f49e 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.cc
@@ -24,6 +24,7 @@ limitations under the License.
#include <utility>
#include <vector>
+#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
#include "llvm/CodeGen/TargetRegisterInfo.h"
@@ -42,6 +43,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_function.h"
+#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
@@ -124,131 +127,27 @@ StatusOr<llvm::Function*> IrEmitter::EmitComputation(
} else {
TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order));
}
- InsertOrDie(&emitted_functions_, computation, compute_function_);
-
- return compute_function_;
-}
-
-static llvm::Argument* GetArg(llvm::Function* f, int idx) {
- llvm::Function::arg_iterator arg_iter = f->arg_begin();
- std::advance(arg_iter, idx);
- return &*arg_iter;
+ llvm::Function* ir_function = compute_function_->function();
+ InsertOrDie(&emitted_functions_, computation, ir_function);
+ // Delete 'compute_function', finalizing 'ir_function' and restoring caller
+ // IR insert point.
+ compute_function_.reset();
+ return ir_function;
}
void IrEmitter::InitializeIrFunction(const string& function_name) {
- // The function signature is:
- // void function(i8* retval, i8* run_options, i8** params, i8** temps,
- // i64* dynamic_loop_bounds, i64* prof_counters)
- //
- // retval: points to the returned value.
- // params: address of an array with pointers to parameters.
- // temps: address of an array with pointers to temporary buffers.
- //
- // Therefore, the generated function's signature (FunctionType) is statically
- // determined - parameter unpacking is done in code generated into the
- // function, rather than by a prologue dictated by the platform ABI.
- //
- // /--------------\
- // retval ----------> | return value |
- // \--------------/
- //
- // /-------------------------------\
- // run_options -----> | xla::ExecutableRunOptions |
- // \-------------------------------/
- //
- // /---------------------------------------------\
- // params --------> | param 0 | param 1 | ..... | param N-1 |
- // | addr | addr | | addr |
- // \---------------------------------------------/
- // | | |
- // | | |
- // V V V
- // /---------\ /---------\ /-----------\
- // | param 0 | | param 1 | | param N-1 |
- // \---------/ \---------/ \-----------/
- //
- // /---------------------------------------------\
- // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 |
- // | addr | addr | | addr |
- // \---------------------------------------------/
- // | | |
- // | | |
- // V V V
- // /---------\ /---------\ /-----------\
- // | temp 0 | | temp 1 | | temp N-1 |
- // \---------/ \---------/ \-----------/
- //
- // /--------------------------------------------\
- // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....|
- // (elided for aot) \--------------------------------------------/
- //
- // /---------------------------------------------\
- // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 |
- // (elided for aot) \---------------------------------------------/
-
- // Even though the type of params and temps is void** in the host's view, in
- // LLVM IR this is represented by i8*, similarly to void*. It's up to the code
- // to use GEPs to unravel the indirection layers.
- llvm::FunctionType* compute_function_type = llvm::FunctionType::get(
- /*Result=*/llvm::Type::getVoidTy(module_->getContext()),
- /*Params=*/GetComputeFunctionParams(),
- /*isVarArg=*/false);
-
// Functions with local linkage get an inlining bonus. Because we know
// a-priori that embedded functions (non-entry functions) will not have its
// name resolved, give it local linkage.
llvm::Function::LinkageTypes linkage =
is_top_level_computation_ ? llvm::GlobalValue::ExternalLinkage
: llvm::GlobalValue::InternalLinkage;
- compute_function_ =
- llvm::Function::Create(/*Ty=*/compute_function_type,
- /*Linkage=*/linkage,
- /*Name=*/AsStringRef(function_name),
- /*Module=*/module_);
- compute_function_->setCallingConv(llvm::CallingConv::C);
-
- // Set meaningful names for the function's arguments: useful for debugging.
- llvm::Function::arg_iterator arg_iter = compute_function_->arg_begin();
- arg_iter->setName("retval");
- (++arg_iter)->setName("run_options");
- (++arg_iter)->setName("params");
- (++arg_iter)->setName("temps");
- if (num_dynamic_loop_bounds_ > 0) {
- (++arg_iter)->setName("dynamic_loop_bounds");
- }
- (++arg_iter)->setName("prof_counters");
-
- // We know a-priori that the function arguments are guaranteed to point to
- // disjoint objects.
- llvm::Argument* retval = GetResultArgument();
- for (llvm::Argument& argument : compute_function_->args()) {
- // However, the return buffer aliases the temporaries and thus cannot be
- // marked noalias.
- if (&argument == retval) {
- continue;
- }
- compute_function_->addAttribute(argument.getArgNo() + 1,
- llvm::Attribute::NoAlias);
- }
-
- // Add the optize attribute to the function if optimizing for size. This
- // controls internal behavior of some optimization passes (e.g. loop
- // unrolling).
- if (options::OptimizeForSizeRequested(hlo_module_config_)) {
- compute_function_->addFnAttr(llvm::Attribute::OptimizeForSize);
- }
-
- if (hlo_module_config_.debug_options().xla_enable_fast_math()) {
- compute_function_->addFnAttr("unsafe-fp-math", "true");
- compute_function_->addFnAttr("no-infs-fp-math", "true");
- compute_function_->addFnAttr("no-nans-fp-math", "true");
- compute_function_->addFnAttr("no-signed-zeros-fp-math", "true");
- }
-
- ir_builder_.SetInsertPoint(llvm::BasicBlock::Create(
- /*Context=*/module_->getContext(),
- /*Name=*/"entry",
- /*Parent=*/compute_function_));
+ // Create and initialize new IrFunction.
+ compute_function_.reset(
+ new IrFunction(function_name, linkage,
+ options::OptimizeForSizeRequested(hlo_module_config_),
+ hlo_module_config_.debug_options().xla_enable_fast_math(),
+ module_, &ir_builder_, num_dynamic_loop_bounds_));
}
IrEmitter::~IrEmitter() {}
@@ -898,6 +797,11 @@ Status IrEmitter::HandleDot(HloInstruction* dot) {
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*dot, /*operands=*/{lhs, rhs},
/*supported_types=*/{F32, F64, C64}));
+ const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
+ if (dnums.lhs_batch_dimensions_size() > 0 ||
+ dnums.rhs_batch_dimensions_size() > 0) {
+ return Unimplemented("Dot with batch dimensions not implemented.");
+ }
llvm_ir::IrArray lhs_array(GetIrArrayFor(lhs));
llvm_ir::IrArray rhs_array(GetIrArrayFor(rhs));
@@ -1452,7 +1356,7 @@ Status IrEmitter::HandleParameter(HloInstruction* parameter) {
//
// Where Param is the actual element type of the underlying buffer (for
// example, float for an XLA F32 element type).
- llvm::Argument* params = GetArg(compute_function_, 2);
+ llvm::Argument* params = compute_function_->parameters_arg();
llvm::Value* param_address_offset =
llvm_ir::EmitBufferIndexingGEP(params, param_number, &ir_builder_);
llvm::LoadInst* param_address_untyped =
@@ -1590,7 +1494,7 @@ IrEmitter::ShardedVectorType IrEmitter::CreateShardedVectorType(
// Here we assume that the largest register is a vector register.
int max_vector_register_size_in_bytes =
target_machine_features_.largest_register_size_in_bytes(
- compute_function_);
+ compute_function_->function());
int vector_register_size_in_elements =
max_vector_register_size_in_bytes /
@@ -1748,19 +1652,6 @@ void IrEmitter::EmitShardedVectorStore(
}
}
-namespace {
-// TODO(sanjoy): This is duplicated in tensorflow/core/lib/core/arena.cc.
-// Extract out a common implementation to tensorflow/core/lib/math/math_util.h
-uint32 GCD(uint32 x, uint32 y) {
- while (y != 0) {
- uint32 r = x % y;
- x = y;
- y = r;
- }
- return x;
-}
-} // namespace
-
StatusOr<bool> IrEmitter::EmitVectorizedReduce(
HloInstruction* reduce, HloInstruction* arg, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions, HloComputation* function,
@@ -1783,9 +1674,9 @@ StatusOr<bool> IrEmitter::EmitVectorizedReduce(
std::find(dimensions.begin(), dimensions.end(),
arg->shape().layout().minor_to_major(0)) != dimensions.end();
- unsigned element_alignment =
- GCD(ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
- MinimumAlignmentForPrimitiveType(reduce->shape().element_type()));
+ unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
+ ShapeUtil::ByteSizeOfPrimitiveType(reduce->shape().element_type()),
+ MinimumAlignmentForPrimitiveType(reduce->shape().element_type()));
if (is_reduction_over_minor_dimension) {
// TODO(sanjoy): Implement vectorized reduction over the minor dimension.
@@ -1995,7 +1886,7 @@ Status IrEmitter::HandleSlice(HloInstruction* slice) {
VLOG(2) << "HandleSlice: " << slice->ToString();
auto operand = slice->operand(0);
// The code below emits a sequential loop nest. For the parallel backend, use
- // EmitParallelTargetElementLoop() which respects dynamic loop bounds.
+ // ParallelLoopEmitter which respects dynamic loop bounds.
if (ShouldEmitParallelLoopFor(*slice)) {
return DefaultAction(slice);
}
@@ -2410,7 +2301,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
// Terminates the current block with a branch to a while header.
llvm::BasicBlock* header_bb = llvm::BasicBlock::Create(
module_->getContext(), AsStringRef(IrName(xla_while, "header")),
- compute_function_);
+ compute_function_->function());
ir_builder_.CreateBr(header_bb);
ir_builder_.SetInsertPoint(header_bb);
@@ -2427,7 +2318,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
// Branches to the body or to the while exit depending on the condition.
llvm::BasicBlock* body_bb = llvm::BasicBlock::Create(
module_->getContext(), AsStringRef(IrName(xla_while, "body")),
- compute_function_);
+ compute_function_->function());
llvm::BasicBlock* exit_bb = llvm::BasicBlock::Create(
module_->getContext(), AsStringRef(IrName(xla_while, "exit")));
ir_builder_.CreateCondBr(while_predicate, body_bb, exit_bb);
@@ -2442,7 +2333,7 @@ Status IrEmitter::HandleWhile(HloInstruction* xla_while) {
ir_builder_.CreateBr(header_bb);
// Adds the exit block to the function and sets the insert point there.
- compute_function_->getBasicBlockList().push_back(exit_bb);
+ compute_function_->function()->getBasicBlockList().push_back(exit_bb);
ir_builder_.SetInsertPoint(exit_bb);
return Status::OK();
@@ -2560,7 +2451,7 @@ void IrEmitter::EmitTransferElements(llvm::Value* target, llvm::Value* source,
const llvm_ir::IrArray& source_array) {
unsigned primitive_type_size =
ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
- unsigned element_alignment = GCD(
+ unsigned element_alignment = tensorflow::MathUtil::GCD<unsigned>(
primitive_type_size, MinimumAlignmentForPrimitiveType(primitive_type));
llvm::Type* primitive_ptr_type = llvm::PointerType::getUnqual(
llvm_ir::PrimitiveTypeToIrType(primitive_type, module_));
@@ -2642,7 +2533,6 @@ Status IrEmitter::FinishVisit(HloInstruction* root) {
if (prof_counter) {
profiling_state_.RecordCompleteComputation(&ir_builder_, prof_counter);
}
- ir_builder_.CreateRetVoid();
return Status::OK();
}
@@ -2783,43 +2673,16 @@ llvm::Type* IrEmitter::IrShapeType(const Shape& shape) {
return llvm_ir::ShapeToIrType(shape, module_);
}
-std::vector<llvm::Type*> IrEmitter::GetComputeFunctionParams() {
- llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
- llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
- llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext());
- std::vector<llvm::Type*> compute_function_params(
- {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type});
- if (num_dynamic_loop_bounds_ > 0) {
- compute_function_params.push_back(i64_ptr_type);
- }
- compute_function_params.push_back(i64_ptr_type);
- return compute_function_params;
-}
-
-llvm::Argument* IrEmitter::GetResultArgument() {
- return GetArg(compute_function_, 0);
-}
-
llvm::Argument* IrEmitter::GetProfileCountersArgument() {
- const int64 arg_index = num_dynamic_loop_bounds_ > 0 ? 5 : 4;
- return GetArg(compute_function_, arg_index);
+ return compute_function_->profile_counters_arg();
}
llvm::Value* IrEmitter::GetTempBuffersArgument() {
- return GetArg(compute_function_, 3);
-}
-
-llvm::Value* IrEmitter::GetDynamicLoopBound(const int64 offset) {
- CHECK_GT(num_dynamic_loop_bounds_, 0);
- CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
- llvm::Argument* loop_bounds_arg = GetArg(compute_function_, 4);
- string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset);
- return ir_builder_.CreateLoad(ir_builder_.CreateGEP(
- loop_bounds_arg, ir_builder_.getInt64(offset), AsStringRef(name)));
+ return compute_function_->temp_buffers_arg();
}
llvm::Value* IrEmitter::GetExecutableRunOptionsArgument() {
- return GetArg(compute_function_, 1);
+ return compute_function_->exec_run_options_arg();
}
llvm::Value* IrEmitter::EmitTempBufferPointer(
@@ -2965,7 +2828,8 @@ Status IrEmitter::EmitParallelForkJoin(
HloInstruction* root = computation->root_instruction();
// Build ParallelForkJoin function type.
- std::vector<llvm::Type*> compute_function_params = GetComputeFunctionParams();
+ std::vector<llvm::Type*> compute_function_params =
+ compute_function_->GetComputeFunctionParams();
// Number of parallel compute functions.
compute_function_params.push_back(ir_builder_.getInt32Ty());
// Array of partitions. There is an array element for each
@@ -3066,7 +2930,7 @@ Status IrEmitter::EmitTargetAddressForOp(const HloInstruction* op) {
if (op == op->parent()->root_instruction()) {
// For the root node, we write directly to the output buffer of the
// function.
- llvm::Argument* retval = GetResultArgument();
+ llvm::Argument* retval = compute_function_->result_arg();
if (!ShapeUtil::IsNil(target_shape)) {
llvm::AttrBuilder attr_builder;
attr_builder.addAlignmentAttr(MinimumAlignmentForShape(target_shape));
@@ -3127,8 +2991,19 @@ Status IrEmitter::EmitTargetElementLoop(
} else {
if (ShouldEmitParallelLoopFor(*target_op)) {
- TF_RETURN_IF_ERROR(EmitParallelTargetElementLoop(
- target_shape, element_generator, IrName(target_op), &target_array));
+ // Emit code to read dynamic loop bounds from compute function argument.
+ ParallelLoopEmitter::LoopBounds dynamic_loop_bounds(
+ num_dynamic_loop_bounds_);
+ for (int i = 0; i < num_dynamic_loop_bounds_; ++i) {
+ dynamic_loop_bounds[i].first =
+ compute_function_->GetDynamicLoopBound(i * 2 + 0);
+ dynamic_loop_bounds[i].second =
+ compute_function_->GetDynamicLoopBound(i * 2 + 1);
+ }
+ // Emit parallel loop with dynamic loop bounds for most-major dimensions.
+ TF_RETURN_IF_ERROR(ParallelLoopEmitter(element_generator, target_array,
+ &dynamic_loop_bounds, &ir_builder_)
+ .EmitLoop(IrName(target_op)));
} else {
TF_RETURN_IF_ERROR(
llvm_ir::LoopEmitter(element_generator, target_array, &ir_builder_)
@@ -3138,60 +3013,6 @@ Status IrEmitter::EmitTargetElementLoop(
return Status::OK();
}
-Status IrEmitter::EmitParallelTargetElementLoop(
- const Shape& target_shape,
- const llvm_ir::ElementGenerator& element_generator,
- tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array) {
- CHECK(!ShapeUtil::IsTuple(target_shape));
- CHECK(!ShapeUtil::IsScalar(target_shape));
-
- // Emit code to read dynamic loop bounds from function argument 4.
- std::vector<llvm::Value*> dynamic_loop_bounds(2 * num_dynamic_loop_bounds_);
- for (int i = 0; i < 2 * num_dynamic_loop_bounds_; ++i) {
- dynamic_loop_bounds[i] = GetDynamicLoopBound(i);
- }
-
- llvm_ir::ForLoopNest loop_nest(loop_name, &ir_builder_);
- const int64 num_dims = target_shape.dimensions_size();
- llvm_ir::IrArray::Index array_index(num_dims);
-
- // Add loops from outer-most to inner-most dimensions.
- for (int i = target_shape.layout().minor_to_major_size() - 1; i >= 0; --i) {
- const int64 dimension = target_shape.layout().minor_to_major(i);
- const int bounds_index = num_dims - 1 - i;
- if (bounds_index < num_dynamic_loop_bounds_) {
- // Emit dynamic loop bounds for this dimension. Dynamic loop bounds
- // are read from ir function dynamic loop bounds argument.
- llvm::Value* start_index = dynamic_loop_bounds[bounds_index * 2 + 0];
- llvm::Value* end_index = dynamic_loop_bounds[bounds_index * 2 + 1];
-
- std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
- /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension),
- start_index, end_index);
- array_index[dimension] = loop->GetIndVarValue();
- } else {
- // Emit static loop bounds for this dimension.
- std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
- /*start_index=*/0,
- /*end_index=*/target_shape.dimensions(dimension),
- /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension));
- array_index[dimension] = loop->GetIndVarValue();
- }
- }
- // Point IR builder at inner loop BB.
- SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(), &ir_builder_);
-
- // Emit loop body.
- TF_ASSIGN_OR_RETURN(llvm::Value * target_element,
- element_generator(array_index));
- target_array->EmitWriteArrayElement(array_index, target_element,
- &ir_builder_);
- // Point IR builder at outer loop exit BB.
- SetToFirstInsertPoint(loop_nest.GetOuterLoopExitBasicBlock(), &ir_builder_);
-
- return Status::OK();
-}
-
Status IrEmitter::EmitMemcpy(const HloInstruction& source,
const HloInstruction& destination) {
llvm::Value* source_value = GetEmittedValueFor(&source);
diff --git a/tensorflow/compiler/xla/service/cpu/ir_emitter.h b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
index 351c95278c..6b576d16bb 100644
--- a/tensorflow/compiler/xla/service/cpu/ir_emitter.h
+++ b/tensorflow/compiler/xla/service/cpu/ir_emitter.h
@@ -18,6 +18,7 @@ limitations under the License.
#include <stddef.h>
#include <map>
+#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
@@ -30,6 +31,7 @@ limitations under the License.
#include "llvm/Target/TargetMachine.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/external_constant_pool.h"
+#include "tensorflow/compiler/xla/service/cpu/ir_function.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@@ -233,13 +235,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// Convenience function to get the IR type matching the given shape.
llvm::Type* IrShapeType(const Shape& shape);
- // Returns an array of compute function parameter types.
- std::vector<llvm::Type*> GetComputeFunctionParams();
-
- // Get the llvm::Value* that represents the "retval" argument of the
- // computation function being emitted by this emitter.
- llvm::Argument* GetResultArgument();
-
// Get the llvm::Value* that represents the "prof_counters" argument of the
// computation function being emitted by this emitter.
llvm::Argument* GetProfileCountersArgument();
@@ -252,11 +247,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
// computation function being emitted by this emitter.
llvm::Value* GetTempBuffersArgument();
- // Emit ir to read and return the ir value for the dynamic loop bound at
- // 'offset' from the "dynamic_loop_bounds" argument of the computation
- // function being emitted by this emitter.
- llvm::Value* GetDynamicLoopBound(const int64 offset);
-
// Emits code that computes the address of the given temporary buffer to the
// function. target_shape is the shape of this temporary buffer.
// The returned Value's type is a pointer to element_type.
@@ -346,15 +336,6 @@ class IrEmitter : public DfsHloVisitorWithDefault {
HloInstruction* target_op, tensorflow::StringPiece desc,
const llvm_ir::ElementGenerator& element_generator);
- // Emit IR to perform a computation for every element in a partition/slice of
- // 'target_shape'. The loop bounds for the outer-dimension partitions are
- // passed into the compute function as a runtime argument (accessible from
- // GetDynamicLoopBound).
- Status EmitParallelTargetElementLoop(
- const Shape& target_shape,
- const llvm_ir::ElementGenerator& element_generator,
- tensorflow::StringPiece loop_name, llvm_ir::IrArray* target_array);
-
// Emits a memcpy from the source instruction's result value to the
// destination's. Both source and destination must have an entry in the
// emitted_value_ table.
@@ -476,8 +457,10 @@ class IrEmitter : public DfsHloVisitorWithDefault {
thread_local_buffers_;
// The following fields track the IR emission state. According to LLVM memory
- // management rules, their memory is owned by the module.
- llvm::Function* compute_function_;
+ // management rules, their memory is owned by the module (Note that IrFunction
+ // creates the encapsulated llvm::Function s.t. it is added to the llvm
+ // module's function list).
+ std::unique_ptr<IrFunction> compute_function_;
llvm::IRBuilder<> ir_builder_;
// Maps HLOs to their index into the profile counter array.
@@ -490,7 +473,7 @@ class IrEmitter : public DfsHloVisitorWithDefault {
llvm_ir::AliasAnalysis alias_analysis_;
// The number of root instruction outer dimensions used in parallel loop
- // emission (EmitParallelTargetElementLoop).
+ // emission (ParallelLoopEmitter).
int64 num_dynamic_loop_bounds_ = 0;
// Returns whether the given instruction should be emitted as a parallel loop.
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.cc b/tensorflow/compiler/xla/service/cpu/ir_function.cc
new file mode 100644
index 0000000000..701bce2cbf
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.cc
@@ -0,0 +1,195 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include <iterator>
+
+#include "tensorflow/compiler/xla/service/cpu/ir_function.h"
+
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/compiler/xla/status_macros.h"
+
+namespace xla {
+
+namespace {
+using llvm_ir::AsStringRef;
+} // namespace
+
+namespace cpu {
+
+IrFunction::IrFunction(const string& function_name,
+ llvm::Function::LinkageTypes linkage,
+ const bool optimize_for_size_requested,
+ const bool enable_fast_math, llvm::Module* llvm_module,
+ llvm::IRBuilder<>* ir_builder,
+ int64 num_dynamic_loop_bounds)
+ : ir_builder_(ir_builder),
+ llvm_module_(llvm_module),
+ caller_insert_point_guard_(*ir_builder),
+ num_dynamic_loop_bounds_(num_dynamic_loop_bounds) {
+ Initialize(function_name, linkage, optimize_for_size_requested,
+ enable_fast_math);
+}
+
+IrFunction::~IrFunction() {
+ // Emit function return value.
+ ir_builder_->CreateRetVoid();
+}
+
+void IrFunction::Initialize(const string& function_name,
+ llvm::Function::LinkageTypes linkage,
+ const bool optimize_for_size_requested,
+ const bool enable_fast_math) {
+ // The function signature is:
+ // void function(i8* retval, i8* run_options, i8** params, i8** temps,
+ // i64* dynamic_loop_bounds, i64* prof_counters)
+ //
+ // retval: points to the returned value.
+ // params: address of an array with pointers to parameters.
+ // temps: address of an array with pointers to temporary buffers.
+ //
+ // Therefore, the generated function's signature (FunctionType) is statically
+ // determined - parameter unpacking is done in code generated into the
+ // function, rather than by a prologue dictated by the platform ABI.
+ //
+ // /--------------\
+ // retval ----------> | return value |
+ // \--------------/
+ //
+ // /-------------------------------\
+ // run_options -----> | xla::ExecutableRunOptions |
+ // \-------------------------------/
+ //
+ // /---------------------------------------------\
+ // params --------> | param 0 | param 1 | ..... | param N-1 |
+ // | addr | addr | | addr |
+ // \---------------------------------------------/
+ // | | |
+ // | | |
+ // V V V
+ // /---------\ /---------\ /-----------\
+ // | param 0 | | param 1 | | param N-1 |
+ // \---------/ \---------/ \-----------/
+ //
+ // /---------------------------------------------\
+ // temps ---------> | temp 0 | temp 1 | ..... | temp N-1 |
+ // | addr | addr | | addr |
+ // \---------------------------------------------/
+ // | | |
+ // | | |
+ // V V V
+ // /---------\ /---------\ /-----------\
+ // | temp 0 | | temp 1 | | temp N-1 |
+ // \---------/ \---------/ \-----------/
+ //
+ // /--------------------------------------------\
+ // dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....|
+ // (elided for aot) \--------------------------------------------/
+ //
+ // /---------------------------------------------\
+ // prof counters -> | counter 0 | counter 1 | ..... | counter N-1 |
+ // \---------------------------------------------/
+
+ // Even though the type of params and temps is void** in the host's view, in
+ // LLVM IR this is represented by i8*, similarly to void*. It's up to the code
+ // to use GEPs to unravel the indirection layers.
+ llvm::FunctionType* function_type = llvm::FunctionType::get(
+ /*Result=*/llvm::Type::getVoidTy(llvm_module_->getContext()),
+ /*Params=*/GetComputeFunctionParams(),
+ /*isVarArg=*/false);
+
+ // Functions with local linkage get an inlining bonus. Because we know
+ // a-priori that embedded functions (non-entry functions) will not have its
+ // name resolved, give it local linkage.
+ function_ = llvm::Function::Create(/*Ty=*/function_type,
+ /*Linkage=*/linkage,
+ /*N=*/AsStringRef(function_name),
+ /*M=*/llvm_module_);
+ function_->setCallingConv(llvm::CallingConv::C);
+
+ // Set meaningful names for the function's arguments: useful for debugging.
+ llvm::Function::arg_iterator arg_iter = function_->arg_begin();
+ arg_iter->setName("retval");
+ result_arg_ = &*arg_iter;
+ (++arg_iter)->setName("run_options");
+ exec_run_options_arg_ = &*arg_iter;
+ (++arg_iter)->setName("params");
+ parameters_arg_ = &*arg_iter;
+ (++arg_iter)->setName("temps");
+ temp_buffers_arg_ = &*arg_iter;
+ if (num_dynamic_loop_bounds_ > 0) {
+ (++arg_iter)->setName("dynamic_loop_bounds");
+ dynamic_loop_bounds_arg_ = &*arg_iter;
+ }
+ (++arg_iter)->setName("prof_counters");
+ profile_counters_arg_ = &*arg_iter;
+
+ // We know a-priori that the function arguments are guaranteed to point to
+ // disjoint objects.
+ llvm::Argument* retval = result_arg();
+ for (llvm::Argument& argument : function_->args()) {
+ // However, the return buffer aliases the temporaries and thus cannot be
+ // marked noalias.
+ if (&argument == retval) {
+ continue;
+ }
+ function_->addAttribute(argument.getArgNo() + 1, llvm::Attribute::NoAlias);
+ }
+
+ // Add the optize attribute to the function if optimizing for size. This
+ // controls internal behavior of some optimization passes (e.g. loop
+ // unrolling).
+ if (optimize_for_size_requested) {
+ function_->addFnAttr(llvm::Attribute::OptimizeForSize);
+ }
+
+ if (enable_fast_math) {
+ function_->addFnAttr("unsafe-fp-math", "true");
+ function_->addFnAttr("no-infs-fp-math", "true");
+ function_->addFnAttr("no-nans-fp-math", "true");
+ function_->addFnAttr("no-signed-zeros-fp-math", "true");
+ }
+
+ ir_builder_->SetInsertPoint(llvm::BasicBlock::Create(
+ /*Context=*/llvm_module_->getContext(),
+ /*Name=*/"entry",
+ /*Parent=*/function_));
+}
+
+std::vector<llvm::Type*> IrFunction::GetComputeFunctionParams() {
+ llvm::Type* i8_ptr_type =
+ llvm::Type::getInt8PtrTy(llvm_module_->getContext());
+ llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
+ llvm::Type* i64_ptr_type =
+ llvm::Type::getInt64PtrTy(llvm_module_->getContext());
+ std::vector<llvm::Type*> compute_function_params(
+ {i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type});
+ if (num_dynamic_loop_bounds_ > 0) {
+ compute_function_params.push_back(i64_ptr_type);
+ }
+ compute_function_params.push_back(i64_ptr_type);
+ return compute_function_params;
+}
+
+llvm::Value* IrFunction::GetDynamicLoopBound(const int64 offset) {
+ CHECK_GT(num_dynamic_loop_bounds_, 0);
+ CHECK_LT(offset, num_dynamic_loop_bounds_ * 2);
+ string name = tensorflow::strings::StrCat("dynamic_loop_bound_", offset);
+ return ir_builder_->CreateLoad(
+ ir_builder_->CreateGEP(CHECK_NOTNULL(dynamic_loop_bounds_arg_),
+ ir_builder_->getInt64(offset), AsStringRef(name)));
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/ir_function.h b/tensorflow/compiler/xla/service/cpu/ir_function.h
new file mode 100644
index 0000000000..b7516b403e
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/ir_function.h
@@ -0,0 +1,109 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_
+
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/statusor.h"
+#include "tensorflow/compiler/xla/types.h"
+
+namespace xla {
+namespace cpu {
+
+// IrFunction creates and encapsulates an llvm::Function, exposing methods to
+// emitters for function and function argument access.
+// The llvm::Function is created with the standard function signature
+// used in the XLA CPU backend (see ir_function.cc for argument details).
+// In addtion IrFunction saves the callers IR insert point during contruction,
+// and restores it after desctruction.
+//
+// Example usage:
+//
+// // Create and initialize new IrFunction.
+// std::unique_ptr<IrFunction> compute_function(new IrFunction(...));
+// // Emit IR for function body using IrFunction helper methods.
+// ...
+// // Store reference to llvm::Function for future invocation.
+// ir_functions.push_back(compute_function.function());
+// // Delete IrFunction (finalizes IR function and restores caller insertion
+// // point).
+// compute_function.reset();
+//
+
+class IrFunction {
+ public:
+ IrFunction(const string& function_name, llvm::Function::LinkageTypes linkage,
+ const bool optimize_for_size_requested,
+ const bool enable_fast_math, llvm::Module* llvm_module,
+ llvm::IRBuilder<>* ir_builder, int64 num_dynamic_loop_bounds);
+ ~IrFunction();
+
+ // Returns an array of compute function parameter types.
+ std::vector<llvm::Type*> GetComputeFunctionParams();
+
+ // Emit ir to read and return the ir value for the dynamic loop bound at
+ // 'offset' from the "dynamic_loop_bounds" argument of this function.
+ llvm::Value* GetDynamicLoopBound(int64 offset);
+
+ // Returns the encapculated llvm::Function.
+ llvm::Function* function() { return function_; }
+
+ // Get the llvm::Value* that represents this functions "retval" argument.
+ llvm::Argument* result_arg() { return result_arg_; }
+
+ // Get the xla::ExecutableRunOptions that represents this functions
+ // "run_options" argument.
+ llvm::Value* exec_run_options_arg() { return exec_run_options_arg_; }
+
+ // Get the llvm::Argument that represents this functions parameters argument.
+ llvm::Argument* parameters_arg() { return parameters_arg_; }
+
+ // Get the llvm::Value* that represents this functions "temps" argument.
+ llvm::Value* temp_buffers_arg() { return temp_buffers_arg_; }
+
+ // Get the llvm::Value* that represents this functions "prof_counters"
+ // argument.
+ llvm::Argument* profile_counters_arg() { return profile_counters_arg_; }
+
+ private:
+ // Initialize an llvm::Function with standard signature based on arguments.
+ void Initialize(const string& function_name,
+ llvm::Function::LinkageTypes linkage,
+ bool optimize_for_size_requested, bool enable_fast_math);
+
+ llvm::IRBuilder<>* ir_builder_;
+ llvm::Module* llvm_module_;
+ llvm::IRBuilder<>::InsertPointGuard caller_insert_point_guard_;
+
+ int64 num_dynamic_loop_bounds_ = 0;
+ // Encapsulated llvm::Function.
+ llvm::Function* function_;
+ // Function argument IR values.
+ llvm::Argument* result_arg_;
+ llvm::Value* exec_run_options_arg_;
+ llvm::Argument* parameters_arg_;
+ llvm::Value* temp_buffers_arg_;
+ llvm::Argument* dynamic_loop_bounds_arg_ = nullptr;
+ llvm::Argument* profile_counters_arg_;
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_IR_FUNCTION_H_
diff --git a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc
index 3f2d101959..69466fd32e 100644
--- a/tensorflow/compiler/xla/service/cpu/layout_assignment.cc
+++ b/tensorflow/compiler/xla/service/cpu/layout_assignment.cc
@@ -52,8 +52,7 @@ Status CpuLayoutAssignment::AddBackendConstraints(
tensorflow::gtl::FlatMap<const HloInstruction*, bool>
should_make_rhs_col_major_cache;
auto should_make_rhs_col_major = [&](const HloInstruction& instruction) {
- if (ProfitableToImplementDotInUntiledLlvmIr(instruction) !=
- DotInLlvmIrProfitable::kWithColumnMajorRhs) {
+ if (!ProfitableToMakeDotRhsColumnMajor(instruction)) {
return false;
}
@@ -69,8 +68,7 @@ Status CpuLayoutAssignment::AddBackendConstraints(
bool result = std::all_of(
rhs->users().begin(), rhs->users().end(), [&](HloInstruction* user) {
- return ProfitableToImplementDotInUntiledLlvmIr(*user) ==
- DotInLlvmIrProfitable::kWithColumnMajorRhs &&
+ return ProfitableToMakeDotRhsColumnMajor(*user) &&
user->operand(0) != rhs;
});
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
new file mode 100644
index 0000000000..91e704e3d0
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.cc
@@ -0,0 +1,76 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h"
+
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
+
+namespace xla {
+namespace cpu {
+
+ParallelLoopEmitter::ParallelLoopEmitter(
+ const llvm_ir::ElementGenerator& target_element_generator,
+ const llvm_ir::IrArray& target_array, const LoopBounds* dynamic_loop_bounds,
+ llvm::IRBuilder<>* ir_builder)
+ : LoopEmitter(target_element_generator, target_array, ir_builder),
+ dynamic_loop_bounds_(dynamic_loop_bounds) {}
+
+llvm_ir::IrArray::Index ParallelLoopEmitter::EmitIndexAndSetExitBasicBlock(
+ tensorflow::StringPiece loop_name) {
+ CHECK(!ShapeUtil::IsTuple(shape_));
+ CHECK(!ShapeUtil::IsScalar(shape_));
+
+ llvm_ir::ForLoopNest loop_nest(loop_name, ir_builder_);
+ const int64 num_dims = shape_.dimensions_size();
+ llvm_ir::IrArray::Index array_index(num_dims);
+
+ // Add loops from outer-most to inner-most dimensions.
+ for (int i = shape_.layout().minor_to_major_size() - 1; i >= 0; --i) {
+ const int64 dimension = shape_.layout().minor_to_major(i);
+ const int bounds_index = num_dims - 1 - i;
+ if (bounds_index < dynamic_loop_bounds_->size()) {
+ // Emit dynamic loop bounds for this dimension. Dynamic loop bounds
+ // are read from ir function dynamic loop bounds argument.
+ llvm::Value* start_index = (*dynamic_loop_bounds_)[bounds_index].first;
+ llvm::Value* end_index = (*dynamic_loop_bounds_)[bounds_index].second;
+
+ std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
+ /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension),
+ start_index, end_index);
+ array_index[dimension] = loop->GetIndVarValue();
+ } else {
+ // Emit static loop bounds for this dimension.
+ std::unique_ptr<llvm_ir::ForLoop> loop = loop_nest.AddLoop(
+ /*start_index=*/0,
+ /*end_index=*/shape_.dimensions(dimension),
+ /*suffix=*/tensorflow::strings::Printf("dim.%lld", dimension));
+ array_index[dimension] = loop->GetIndVarValue();
+ }
+ }
+ // Point IR builder at inner loop BB.
+ llvm_ir::SetToFirstInsertPoint(loop_nest.GetInnerLoopBodyBasicBlock(),
+ ir_builder_);
+
+ // Set exit_bb_ to the exit block of the loop nest.
+ exit_bb_ = loop_nest.GetOuterLoopExitBasicBlock();
+ CHECK(exit_bb_ != nullptr);
+
+ return array_index;
+}
+
+} // namespace cpu
+} // namespace xla
diff --git a/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
new file mode 100644
index 0000000000..492d5953c4
--- /dev/null
+++ b/tensorflow/compiler/xla/service/cpu/parallel_loop_emitter.h
@@ -0,0 +1,75 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_
+#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_
+
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/ir_array.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/loop_emitter.h"
+
+namespace xla {
+namespace cpu {
+
+// ParallelLoopEmitter emits a loop nest for the target array shape.
+// The outer loop bounds of the loop nest are passed as ir values at runtime
+// (specified in 'dynamic_loop_bounds'), and the inner loop bounds are static.
+// Dynamic loop bounds are specified as an array of dimension index
+// [start, limit) pairs of ir values (one for each partitioned outer dimension).
+//
+// EX: Let 'shape' = [8, 16, 32], with the loop bounds of the two-most major
+// dimensions dynamic.
+// Then 'dynamic_loop_bounds' will contain the following ir values for
+// the two most-major dimenions:
+// [dim0_index_start_ir_value, dim0_index_limit_ir_value]
+// [dim1_index_start_ir_value, dim1_index_limit_ir_value]
+//
+// Code emitted by ParallelLoopEmitter will be called in a multi-threaded
+// context where each thread will be assigned a different set of outer dimension
+// partitions, and where all threads will collectively iterate over the
+// entire target array shape.
+//
+// Outer dimension partitions can be generated using the ShapePartitionAssigner
+// and ShapePartitionIterator utility classes from shape_partition.cc.
+//
+class ParallelLoopEmitter : public llvm_ir::LoopEmitter {
+ public:
+ using LoopBounds = std::vector<std::pair<llvm::Value*, llvm::Value*>>;
+
+ // Constructs a ParallelLoopEmitter which uses 'target_element_generator' to
+ // generate elements, 'dynamic_loop_bounds' to set the loop bounds of the
+ // most-major dimensions, and 'target_array.' shape to set the static loop
+ // bounds for the most-minor dimensions.
+ ParallelLoopEmitter(const llvm_ir::ElementGenerator& target_element_generator,
+ const llvm_ir::IrArray& target_array,
+ const LoopBounds* dynamic_loop_bounds,
+ llvm::IRBuilder<>* ir_builder);
+
+ ParallelLoopEmitter(const ParallelLoopEmitter&) = delete;
+ ParallelLoopEmitter& operator=(const ParallelLoopEmitter&) = delete;
+ ~ParallelLoopEmitter() override = default;
+
+ llvm_ir::IrArray::Index EmitIndexAndSetExitBasicBlock(
+ tensorflow::StringPiece loop_name) override;
+
+ private:
+ const LoopBounds* dynamic_loop_bounds_;
+};
+
+} // namespace cpu
+} // namespace xla
+
+#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_LOOP_EMITTER_H_
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc
index 828ae675d7..f198c4c08e 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_folding.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_folding.cc
@@ -55,19 +55,7 @@ MatchBackwardFilter(HloInstruction* conv) {
// v v
// Convolution
// conv
- // |
- // v
- // Transpose (optional if identity transposition)
CHECK_EQ(HloOpcode::kConvolution, conv->opcode());
- // If the forward convolution is followed by a transpose, we can fuse the
- // transpose into the backward convolution as well.
- HloInstruction* transpose = nullptr;
- if (conv->user_count() == 1) {
- HloInstruction* single_user = *conv->users().begin();
- if (single_user->opcode() == HloOpcode::kTranspose) {
- transpose = single_user;
- }
- }
// Step 2: match paddings and dimension numbers of the forward convolution.
const ConvolutionDimensionNumbers& conv_dnums =
@@ -75,6 +63,9 @@ MatchBackwardFilter(HloInstruction* conv) {
auto input_batch_dim = conv_dnums.input_batch_dimension();
auto input_feature_dim = conv_dnums.input_feature_dimension();
auto input_spatial_dims = conv_dnums.input_spatial_dimensions();
+ auto kernel_input_feature_dim = conv_dnums.kernel_input_feature_dimension();
+ auto kernel_output_feature_dim = conv_dnums.kernel_output_feature_dimension();
+ auto kernel_spatial_dims = conv_dnums.kernel_spatial_dimensions();
auto output_batch_dim = conv_dnums.output_batch_dimension();
auto output_feature_dim = conv_dnums.output_feature_dimension();
auto output_spatial_dims = conv_dnums.output_spatial_dimensions();
@@ -98,7 +89,8 @@ MatchBackwardFilter(HloInstruction* conv) {
}
// Padding high will be checked in Step 3.
}
- if (transpose == nullptr && !window_util::HasWindowDilation(conv->window())) {
+ if (input_batch_dim == output_batch_dim &&
+ !window_util::HasWindowDilation(conv->window())) {
VLOG(1) << conv->ToString()
<< " is a regular forward convolution. No need "
"to fold it to a backward filter convolution.";
@@ -169,53 +161,32 @@ MatchBackwardFilter(HloInstruction* conv) {
}
}
- // To make future HLO passes easier, we canonicalize the fused expression by
- // adding an identity transposition if it's omitted in the pattern.
- if (transpose == nullptr) {
- // Create an identity transposition with the same rank as the forward
- // convolution.
- HloComputation* parent_computation = conv->parent();
- std::vector<int64> transpose_dimensions(ShapeUtil::Rank(conv->shape()));
- std::iota(transpose_dimensions.begin(), transpose_dimensions.end(), 0);
- transpose =
- parent_computation->AddInstruction(HloInstruction::CreateTranspose(
- conv->shape(), conv, transpose_dimensions));
- TF_CHECK_OK(conv->ReplaceAllUsesWith(transpose));
- }
-
// Restore the dimension numbers of the backward convolution from the forward
// convolution. The two activation dimensions are reversed (batch and
// feature).
ConvolutionDimensionNumbers backward_conv_dnums;
backward_conv_dnums.set_input_batch_dimension(input_feature_dim);
backward_conv_dnums.set_input_feature_dimension(input_batch_dim);
- backward_conv_dnums.set_output_batch_dimension(output_feature_dim);
- backward_conv_dnums.set_output_feature_dimension(output_batch_dim);
for (int i = 0; i < input_spatial_dims.size(); ++i) {
backward_conv_dnums.add_input_spatial_dimensions(input_spatial_dims[i]);
}
- for (int i = 0; i < output_spatial_dims.size(); ++i) {
- backward_conv_dnums.add_output_spatial_dimensions(output_spatial_dims[i]);
+ backward_conv_dnums.set_output_batch_dimension(kernel_input_feature_dim);
+ backward_conv_dnums.set_output_feature_dimension(kernel_output_feature_dim);
+ for (int i = 0; i < kernel_spatial_dims.size(); ++i) {
+ backward_conv_dnums.add_output_spatial_dimensions(kernel_spatial_dims[i]);
}
// The dimension numbering of the output of the forward convolution (before
// transposition) is the same as that of the activations (according to the
// semantics of kConvolution). The batch dimension of the activations should
// be treated as the input feature dimension, and the feature dimension should
// be treated as the output feature.
- //
- // The output of the forward convolution needs to be transposed to fit into
- // the dimension numbering of the weight gradients. This transposition maps
- // dimension i to PositionInContainer(transpose->dimensions(), i).
- backward_conv_dnums.set_kernel_input_feature_dimension(
- PositionInContainer(transpose->dimensions(), output_batch_dim));
- backward_conv_dnums.set_kernel_output_feature_dimension(
- PositionInContainer(transpose->dimensions(), output_feature_dim));
+ backward_conv_dnums.set_kernel_input_feature_dimension(output_batch_dim);
+ backward_conv_dnums.set_kernel_output_feature_dimension(output_feature_dim);
for (int i = 0; i < output_spatial_dims.size(); ++i) {
- backward_conv_dnums.add_kernel_spatial_dimensions(
- PositionInContainer(transpose->dimensions(), output_spatial_dims[i]));
+ backward_conv_dnums.add_kernel_spatial_dimensions(output_spatial_dims[i]);
}
- return std::make_tuple(true, std::vector<HloInstruction*>({transpose, conv}),
+ return std::make_tuple(true, std::vector<HloInstruction*>({conv}),
backward_conv_window, backward_conv_dnums);
}
diff --git a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc
index 112c496e1f..34e6bdb117 100644
--- a/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc
+++ b/tensorflow/compiler/xla/service/gpu/convolution_folding_test.cc
@@ -46,18 +46,18 @@ class ConvolutionFoldingTest : public HloTestBase {
//
// TODO(jingyue): Add more tests on NCHW input order which TF also supports.
tf_default_dnums_for_backward_filter_.set_input_batch_dimension(3);
- tf_default_dnums_for_backward_filter_.set_output_batch_dimension(3);
tf_default_dnums_for_backward_filter_.set_input_feature_dimension(0);
- tf_default_dnums_for_backward_filter_.set_output_feature_dimension(0);
tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(1);
- tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1);
tf_default_dnums_for_backward_filter_.add_input_spatial_dimensions(2);
- tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(2);
tf_default_dnums_for_backward_filter_.set_kernel_input_feature_dimension(0);
tf_default_dnums_for_backward_filter_.set_kernel_output_feature_dimension(
3);
tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(1);
tf_default_dnums_for_backward_filter_.add_kernel_spatial_dimensions(2);
+ tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(0);
+ tf_default_dnums_for_backward_filter_.add_output_spatial_dimensions(1);
+ tf_default_dnums_for_backward_filter_.set_output_batch_dimension(2);
+ tf_default_dnums_for_backward_filter_.set_output_feature_dimension(3);
tf_default_dnums_for_backward_input_.set_input_batch_dimension(0);
tf_default_dnums_for_backward_input_.set_output_batch_dimension(0);
@@ -86,7 +86,7 @@ class ConvolutionFoldingTest : public HloTestBase {
ConvolutionDimensionNumbers tf_default_dnums_for_backward_input_;
};
-TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithoutTranspose) {
+TEST_F(ConvolutionFoldingTest, BackwardFilterConvolve) {
HloComputation::Builder builder(TestName());
HloInstruction* activations =
builder.AddInstruction(HloInstruction::CreateParameter(
@@ -136,7 +136,7 @@ TEST_F(ConvolutionFoldingTest,
auto module = CreateNewModule();
module->AddEntryComputation(builder.Build());
- EXPECT_FALSE(FoldConvolution(module.get()));
+ EXPECT_TRUE(FoldConvolution(module.get()));
}
// Extracted from block35 training.
@@ -155,13 +155,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedActivations) {
conv_window.mutable_dimensions(i)->set_padding_low(1);
conv_window.mutable_dimensions(i)->set_padding_high(1);
}
- HloInstruction* convolution =
- builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
-
- builder.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(F32, {3, 3, 32, 32}), convolution, {1, 2, 3, 0}));
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {32, 3, 3, 32}), activations, gradients,
+ conv_window, tf_default_dnums_for_backward_filter_));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -189,13 +185,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithPaddedGradients) {
conv_window.mutable_dimensions(i)->set_padding_high(-1);
conv_window.mutable_dimensions(i)->set_window_dilation(2);
}
- HloInstruction* convolution =
- builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
-
- builder.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(F32, {3, 3, 192, 320}), convolution, {1, 2, 3, 0}));
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {320, 3, 3, 192}), activations, gradients,
+ conv_window, tf_default_dnums_for_backward_filter_));
auto module = CreateNewModule();
HloComputation* entry_computation =
@@ -222,13 +214,9 @@ TEST_F(ConvolutionFoldingTest, BackwardFilterConvolveWithUnevenPadding) {
// Uneven padding: padding_low=0, padding_high=1
conv_window.mutable_dimensions(i)->set_padding_high(1);
}
- HloInstruction* convolution =
- builder.AddInstruction(HloInstruction::CreateConvolve(
- ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
- conv_window, tf_default_dnums_for_backward_filter_));
-
- builder.AddInstruction(HloInstruction::CreateTranspose(
- ShapeUtil::MakeShape(F32, {2, 2, 32, 32}), convolution, {1, 2, 3, 0}));
+ builder.AddInstruction(HloInstruction::CreateConvolve(
+ ShapeUtil::MakeShape(F32, {32, 2, 2, 32}), activations, gradients,
+ conv_window, tf_default_dnums_for_backward_filter_));
auto module = CreateNewModule();
HloComputation* entry_computation =
diff --git a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
index 1b863c9e3c..abc739d181 100644
--- a/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
+++ b/tensorflow/compiler/xla/service/gpu/ir_emitter_unnested.cc
@@ -246,6 +246,11 @@ Status IrEmitterUnnested::DefaultAction(HloInstruction* hlo) {
}
Status IrEmitterUnnested::HandleDot(HloInstruction* dot) {
+ const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
+ if (dnums.lhs_batch_dimensions_size() > 0 ||
+ dnums.rhs_batch_dimensions_size() > 0) {
+ return Unimplemented("Dot with batch dimensions not implemented.");
+ }
if (ImplementedAsGemm(*dot)) {
thunk_sequence_->emplace_back(BuildGemmThunk(dot));
return Status::OK();
diff --git a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
index 11290eda4f..c29fee0879 100644
--- a/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
+++ b/tensorflow/compiler/xla/service/gpu/pad_insertion.cc
@@ -202,8 +202,7 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
// ABCD0 = Pad(ABCD, padding_high=1)
// BackwardFilterConv(ABCD0, xyz, padding_low=pading_high=1)
// We choose the lesser of padding_low and padding_high as the new padding.
- HloInstruction* transpose = backward_conv->fused_expression_root();
- HloInstruction* forward_conv = transpose->mutable_operand(0);
+ HloInstruction* forward_conv = backward_conv->fused_expression_root();
HloInstruction* input = backward_conv->mutable_operand(0);
Window new_forward_conv_window = forward_conv->window();
Window new_backward_conv_window = backward_conv->window();
@@ -269,19 +268,10 @@ bool PadInsertion::CanonicalizeBackwardFilterConvolution(
.ConsumeValueOrDie(),
padded_input, output, new_forward_conv_window, forward_conv_dnums));
- HloInstruction* new_transpose =
- computation->AddInstruction(HloInstruction::CreateTranspose(
- ShapeInference::InferTransposeShape(new_forward_conv->shape(),
- transpose->dimensions())
- .ConsumeValueOrDie(),
- new_forward_conv, transpose->dimensions()));
-
- // Fuse the new forward convolution and the new transpose to the new backward
- // convolution.
+ // Fuse the new forward convolution to the new backward convolution.
HloInstruction* new_backward_conv =
computation->CreateFusionInstructionForBackwardConvolution(
- {new_transpose, new_forward_conv},
- HloInstruction::FusionKind::kConvBackwardFilter,
+ {new_forward_conv}, HloInstruction::FusionKind::kConvBackwardFilter,
new_backward_conv_window, backward_conv_dnums);
VLOG(1) << "Canonicalizing backward filter conv";
diff --git a/tensorflow/compiler/xla/service/graphviz_example.cc b/tensorflow/compiler/xla/service/graphviz_example.cc
index 049e8d80d8..05017008e2 100644
--- a/tensorflow/compiler/xla/service/graphviz_example.cc
+++ b/tensorflow/compiler/xla/service/graphviz_example.cc
@@ -108,8 +108,11 @@ std::unique_ptr<HloModule> MakeBigGraph() {
HloInstruction::CreateUnary(vshape, HloOpcode::kCopy, param_v0));
auto clamp = builder.AddInstruction(HloInstruction::CreateTernary(
vshape, HloOpcode::kClamp, copy, param_v1, param_v2));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
auto dot = builder.AddInstruction(
- HloInstruction::CreateBinary(vshape, HloOpcode::kDot, clamp, param_v0));
+ HloInstruction::CreateDot(vshape, clamp, param_v0, dot_dnums));
auto tuple = builder.AddInstruction(
HloInstruction::CreateTuple({dot, param_s, clamp}));
auto scalar = builder.AddInstruction(
diff --git a/tensorflow/compiler/xla/service/heap_simulator_test.cc b/tensorflow/compiler/xla/service/heap_simulator_test.cc
index 17b926c874..387b649a73 100644
--- a/tensorflow/compiler/xla/service/heap_simulator_test.cc
+++ b/tensorflow/compiler/xla/service/heap_simulator_test.cc
@@ -259,8 +259,11 @@ TEST_F(HeapSimulatorTest, MultiplyDot) {
HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec4_, HloOpcode::kMultiply, paramA, paramX));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
auto dot = builder.AddInstruction(
- HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY));
+ HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
// The buffer for dot is the output, and it cannot be shared with the buffer
// for mul, since dot isn't elementwise.
@@ -292,8 +295,11 @@ TEST_F(HeapSimulatorTest, MultiplyDotAdd) {
HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec4_, HloOpcode::kMultiply, paramA, paramX));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
auto dot = builder.AddInstruction(
- HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY));
+ HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
auto add = builder.AddInstruction(
HloInstruction::CreateBinary(f32vec4_, HloOpcode::kAdd, dot, paramA));
@@ -327,10 +333,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDot) {
HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec4_, HloOpcode::kMultiply, paramA, paramX));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
auto dot0 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY));
+ HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
auto dot1 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY));
+ HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
// The buffer for dot1 is the output. No buffers can be shared. The buffer
// for mul is freed before the end, since it's no longer used after dot0
@@ -365,10 +374,13 @@ TEST_F(HeapSimulatorTest, MultiplyDotDotTuple) {
HloInstruction::CreateParameter(2, f32scalar_, "paramY"));
auto mul = builder.AddInstruction(HloInstruction::CreateBinary(
f32vec4_, HloOpcode::kMultiply, paramA, paramX));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
auto dot0 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, mul, paramY));
+ HloInstruction::CreateDot(f32vec4_, mul, paramY, dot_dnums));
auto dot1 = builder.AddInstruction(
- HloInstruction::CreateBinary(f32vec4_, HloOpcode::kDot, dot0, paramY));
+ HloInstruction::CreateDot(f32vec4_, dot0, paramY, dot_dnums));
auto tuple =
builder.AddInstruction(HloInstruction::CreateTuple({dot0, dot1}));
diff --git a/tensorflow/compiler/xla/service/hlo.proto b/tensorflow/compiler/xla/service/hlo.proto
index e984bdb5f7..5d0cfba1fc 100644
--- a/tensorflow/compiler/xla/service/hlo.proto
+++ b/tensorflow/compiler/xla/service/hlo.proto
@@ -118,6 +118,9 @@ message HloInstructionProto {
// Shape of outfeed request.
xla.Shape outfeed_shape = 29;
+
+ // Describes the dimension numbers used for a dot operation
+ xla.DotDimensionNumbers dot_dimension_numbers = 30;
}
// Serialization of HloComputation.
diff --git a/tensorflow/compiler/xla/service/hlo_computation.cc b/tensorflow/compiler/xla/service/hlo_computation.cc
index c215cc48d6..014a851c96 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.cc
+++ b/tensorflow/compiler/xla/service/hlo_computation.cc
@@ -176,10 +176,6 @@ bool HloComputation::IsRemovable(const HloInstruction* instruction) {
return false;
}
- if (instruction->HasSideEffect()) {
- return false;
- }
-
return true;
}
@@ -207,7 +203,8 @@ Status HloComputation::RemoveInstructionAndUnusedOperands(
worklist.pop();
if (removed.count(item) != 0 || item->user_count() != 0 ||
- item == root_instruction() || !IsRemovable(item)) {
+ item == root_instruction() || !IsRemovable(item) ||
+ item->HasSideEffect()) {
continue;
}
for (int i = 0; i < item->operand_count(); ++i) {
diff --git a/tensorflow/compiler/xla/service/hlo_computation.h b/tensorflow/compiler/xla/service/hlo_computation.h
index 353b30bc69..ccedda2a03 100644
--- a/tensorflow/compiler/xla/service/hlo_computation.h
+++ b/tensorflow/compiler/xla/service/hlo_computation.h
@@ -313,11 +313,17 @@ class HloComputation {
replacements,
HloModule* module = nullptr, const string& suffix = "clone");
- // Returns true if the given instruction can be removed from the
- // computation. Instructions such as parameters and send/receive instructions
- // cannot be removed without violating invariants of the HLO computation or
- // module with the exception of fusion computation. A parameter instruction
- // is removable for a fusion computation.
+ // Returns true if the given instruction can be removed from the computation.
+ // Parameter instructions cannot be removed without violating invariants of
+ // the HLO computation with the exception of fusion computation. A parameter
+ // instruction is removable for a fusion computation.
+ //
+ // Note that IsRemovable() is a necessariy condition to remove an instruction
+ // rather than a sufficient condition. For example, instructions with
+ // side-effect (e.g., Send, Infeed) may be removed from a computation, but the
+ // transformation must guarantee the invariants relevant to the instructions
+ // still hold (e.g., Send and Recv must be removed together to make each
+ // channel complete).
bool IsRemovable(const HloInstruction* instruction);
// Returns true if this computation has a side effect. A computation has a
diff --git a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
index 6fcc01dd64..0ed64e6779 100644
--- a/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
+++ b/tensorflow/compiler/xla/service/hlo_cost_analysis.cc
@@ -201,10 +201,11 @@ Status HloCostAnalysis::HandleCopy(const HloInstruction*) {
Status HloCostAnalysis::HandleDot(const HloInstruction* dot) {
const Shape& lhs_shape = dot->operand(0)->shape();
const Shape& rhs_shape = dot->operand(1)->shape();
+ const DotDimensionNumbers& dnums = dot->dot_dimension_numbers();
// Count of elements along the reduction dimension (last dimension for the
// rhs).
- int64 reduction_width = lhs_shape.dimensions(ShapeUtil::Rank(lhs_shape) - 1);
-
+ int64 reduction_width =
+ lhs_shape.dimensions(dnums.lhs_contracting_dimensions(0));
// First divide by reduction width before multiplying by rhs elements to avoid
// overflow.
int64 fma_count;
diff --git a/tensorflow/compiler/xla/service/hlo_dce.cc b/tensorflow/compiler/xla/service/hlo_dce.cc
index 40e67c8780..1e5f0f797a 100644
--- a/tensorflow/compiler/xla/service/hlo_dce.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce.cc
@@ -55,7 +55,8 @@ StatusOr<bool> HloDCE::Run(HloModule* module) {
for (auto* instruction : computation->instructions()) {
if (instruction->user_count() == 0 &&
live_instructions.count(instruction) == 0 &&
- computation->IsRemovable(instruction)) {
+ computation->IsRemovable(instruction) &&
+ !instruction->HasSideEffect()) {
dead_roots.push_back(instruction);
}
}
diff --git a/tensorflow/compiler/xla/service/hlo_dce_test.cc b/tensorflow/compiler/xla/service/hlo_dce_test.cc
index d54b9a2708..5a56607a66 100644
--- a/tensorflow/compiler/xla/service/hlo_dce_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_dce_test.cc
@@ -70,6 +70,26 @@ TEST_F(HloDceTest, NoDeadCode) {
EXPECT_EQ(3, computation->instruction_count());
}
+TEST_F(HloDceTest, InstructionsWithSideEffect) {
+ // Verify that side-effect instructions (Send in this test) are not removed.
+ auto builder = HloComputation::Builder(TestName());
+ auto constant = builder.AddInstruction(
+ HloInstruction::CreateConstant(Literal::CreateR0<float>(42.0f)));
+ builder.AddInstruction(
+ HloInstruction::CreateSend(constant, /*channel_id=*/0));
+ builder.AddInstruction(HloInstruction::CreateTuple({}));
+
+ auto module = CreateNewModule();
+ auto computation = module->AddEntryComputation(builder.Build());
+
+ EXPECT_EQ(3, computation->instruction_count());
+
+ HloDCE dce;
+ EXPECT_FALSE(dce.Run(module.get()).ValueOrDie());
+
+ EXPECT_EQ(3, computation->instruction_count());
+}
+
TEST_F(HloDceTest, DeadParameters) {
// Verify that dead parameters are not removed, but use of the dead parameters
// are.
diff --git a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
index b2c4351896..a5d39fe086 100644
--- a/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_evaluator_test.cc
@@ -621,8 +621,11 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank1) {
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
- b.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, lhs_instruction, rhs_instruction));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
+ rhs_instruction, dot_dnums));
auto computation = module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result =
@@ -664,8 +667,11 @@ TEST_F(HloEvaluatorTest, DotRank1AndRank2) {
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
Shape shape = ShapeUtil::MakeShape(F32, {2});
- b.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, lhs_instruction, rhs_instruction));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
+ rhs_instruction, dot_dnums));
auto computation = module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result =
@@ -705,8 +711,11 @@ TEST_F(HloEvaluatorTest, DotRank2AndRank2) {
b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
Shape shape = ShapeUtil::MakeShape(F32, {4, 2});
- b.AddInstruction(HloInstruction::CreateBinary(
- shape, HloOpcode::kDot, lhs_instruction, rhs_instruction));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ b.AddInstruction(HloInstruction::CreateDot(shape, lhs_instruction,
+ rhs_instruction, dot_dnums));
auto computation = module().AddEntryComputation(b.Build());
std::unique_ptr<Literal> result =
diff --git a/tensorflow/compiler/xla/service/hlo_execution_profile.cc b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
index ba75e2ef1b..0809fe780d 100644
--- a/tensorflow/compiler/xla/service/hlo_execution_profile.cc
+++ b/tensorflow/compiler/xla/service/hlo_execution_profile.cc
@@ -109,7 +109,8 @@ std::unique_ptr<HloProfilePrinter> CreateHloProfilePrinter(
};
return MakeUnique<HloProfilePrinter>(
- computation_infos, hlo_profile_index_map.computation_count(), deleter);
+ computation_infos, hlo_profile_index_map.computation_count(),
+ /*profile_counters_size=*/max_profile_index, deleter);
}
HloExecutionProfile::HloExecutionProfile(
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.cc b/tensorflow/compiler/xla/service/hlo_instruction.cc
index c30c432654..b4bac18bcd 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction.cc
@@ -118,6 +118,10 @@ StatusOr<std::unique_ptr<HloInstruction>> HloInstruction::CreateFromProto(
MakeUnique<ConvolutionDimensionNumbers>(
proto.convolution_dimension_numbers());
}
+ if (proto.has_dot_dimension_numbers()) {
+ instruction->dot_dimension_numbers_ =
+ MakeUnique<DotDimensionNumbers>(proto.dot_dimension_numbers());
+ }
for (const HloInstructionProto::SliceDimensions& slice_dimensions :
proto.slice_dimensions()) {
instruction->slice_starts_.push_back(slice_dimensions.start());
@@ -332,6 +336,17 @@ HloInstruction::CreateGetTupleElement(const Shape& shape,
return instruction;
}
+/* static */ std::unique_ptr<HloInstruction> HloInstruction::CreateDot(
+ const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dimension_numbers) {
+ auto instruction = WrapUnique(new HloInstruction(HloOpcode::kDot, shape));
+ instruction->AppendOperand(lhs);
+ instruction->AppendOperand(rhs);
+ instruction->dot_dimension_numbers_ =
+ MakeUnique<DotDimensionNumbers>(dimension_numbers);
+ return instruction;
+}
+
/* static */ std::unique_ptr<HloInstruction>
HloInstruction::CreateReducePrecision(const Shape& shape,
HloInstruction* operand,
@@ -1086,7 +1101,6 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
case HloOpcode::kLe:
case HloOpcode::kLt:
case HloOpcode::kNe:
- case HloOpcode::kDot:
case HloOpcode::kMaximum:
case HloOpcode::kMinimum:
case HloOpcode::kPower:
@@ -1138,6 +1152,11 @@ std::unique_ptr<HloInstruction> HloInstruction::CloneWithNewOperands(
clone = CreateConvolve(shape, new_operands[0], new_operands[1], *window_,
*convolution_dimension_numbers_);
break;
+ case HloOpcode::kDot:
+ CHECK_EQ(new_operands.size(), 2);
+ clone = CreateDot(shape, new_operands[0], new_operands[1],
+ *dot_dimension_numbers_);
+ break;
case HloOpcode::kCrossReplicaSum:
CHECK_EQ(new_operands.size(), 1);
clone = CreateCrossReplicaSum(shape, new_operands[0]);
@@ -1509,7 +1528,6 @@ bool HloInstruction::IdenticalSlowPath(
case HloOpcode::kCos:
case HloOpcode::kCrossReplicaSum:
case HloOpcode::kDivide:
- case HloOpcode::kDot:
case HloOpcode::kEq:
case HloOpcode::kExp:
case HloOpcode::kFloor:
@@ -1582,6 +1600,10 @@ bool HloInstruction::IdenticalSlowPath(
protobuf_util::ProtobufEquals(
convolution_dimension_numbers(),
other.convolution_dimension_numbers());
+ // Check dot dimension numbers.
+ case HloOpcode::kDot:
+ return protobuf_util::ProtobufEquals(dot_dimension_numbers(),
+ other.dot_dimension_numbers());
// Reduction results are determined by the reduction dimension and the
// reduction computation.
@@ -1990,6 +2012,9 @@ std::vector<string> HloInstruction::ExtraAttributesToString() const {
if (convolution_dimension_numbers_ != nullptr) {
extra.push_back(ConvolutionDimensionNumbersToString());
}
+ if (dot_dimension_numbers_ != nullptr) {
+ extra.push_back(DotDimensionNumbersToString());
+ }
if (opcode() == HloOpcode::kWhile) {
extra.push_back(StrCat("condition=%", while_condition()->name()));
@@ -2086,6 +2111,9 @@ HloInstructionProto HloInstruction::ToProto() const {
*proto.mutable_convolution_dimension_numbers() =
*convolution_dimension_numbers_;
}
+ if (dot_dimension_numbers_ != nullptr) {
+ *proto.mutable_dot_dimension_numbers() = *dot_dimension_numbers_;
+ }
for (int i = 0; i < slice_starts_.size(); ++i) {
auto* slice_dimension = proto.add_slice_dimensions();
slice_dimension->set_start(slice_starts_[i]);
@@ -3051,6 +3079,30 @@ string HloInstruction::ConvolutionDimensionNumbersToString() const {
return result;
}
+string HloInstruction::DotDimensionNumbersToString() const {
+ string result;
+ if (dot_dimension_numbers_ == nullptr) {
+ return result;
+ }
+ const DotDimensionNumbers& dnums = *dot_dimension_numbers_;
+ if (!dnums.lhs_batch_dimensions().empty()) {
+ result += "lhs_batch_dims=";
+ StrAppend(&result, Join(dnums.lhs_batch_dimensions(), ","));
+ }
+ result += "lhs_contracting_dims=";
+ StrAppend(&result, Join(dnums.lhs_contracting_dimensions(), ","));
+
+ result += ",";
+ if (!dnums.rhs_batch_dimensions().empty()) {
+ result += "rhs_batch_dims=";
+ StrAppend(&result, Join(dnums.rhs_batch_dimensions(), ","));
+ }
+ result += "rhs_contracting_dims=";
+ StrAppend(&result, Join(dnums.rhs_contracting_dimensions(), ","));
+
+ return result;
+}
+
bool HloInstruction::CouldBeBitcast() const {
switch (opcode_) {
case HloOpcode::kTranspose:
diff --git a/tensorflow/compiler/xla/service/hlo_instruction.h b/tensorflow/compiler/xla/service/hlo_instruction.h
index cda8b07c61..768c027a42 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction.h
+++ b/tensorflow/compiler/xla/service/hlo_instruction.h
@@ -160,6 +160,12 @@ class HloInstruction {
const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers);
+ // Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
+ // dimensions specified in 'dimension_numbers'.
+ static std::unique_ptr<HloInstruction> CreateDot(
+ const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
+ const DotDimensionNumbers& dimension_numbers);
+
// Creates a reduce-precision op, where operand is the data to reduce in
// precision, and exponent_bits and mantissa_bits describe the precision to
// reduce it to.
@@ -915,6 +921,15 @@ class HloInstruction {
// Returns the dump string of the convolution dimension numbers.
string ConvolutionDimensionNumbersToString() const;
+ // Returns data on the dimension numbers used for a dot operation.
+ const DotDimensionNumbers& dot_dimension_numbers() const {
+ CHECK(dot_dimension_numbers_ != nullptr);
+ return *dot_dimension_numbers_;
+ }
+
+ // Returns the dump string of the dot dimension numbers.
+ string DotDimensionNumbersToString() const;
+
// Returns the random distribution for this rng node.
//
// Precondition: opcode() == HloOpcode::kRng
@@ -1173,6 +1188,9 @@ class HloInstruction {
// Describes the dimension numbers used for a convolution.
std::unique_ptr<ConvolutionDimensionNumbers> convolution_dimension_numbers_;
+ // Describes the dimension numbers used for a dot.
+ std::unique_ptr<DotDimensionNumbers> dot_dimension_numbers_;
+
// Describes the [begin, end) index range for a slice.
std::vector<int64> slice_starts_;
std::vector<int64> slice_limits_;
diff --git a/tensorflow/compiler/xla/service/hlo_instruction_test.cc b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
index 76b12fc8d3..11420cae63 100644
--- a/tensorflow/compiler/xla/service/hlo_instruction_test.cc
+++ b/tensorflow/compiler/xla/service/hlo_instruction_test.cc
@@ -1068,8 +1068,11 @@ TEST_F(HloInstructionTest, CloneOfFusionPreservesShape) {
builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
HloInstruction* reshape =
builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
HloModule module(TestName());
auto* computation = module.AddEntryComputation(builder.Build());
@@ -1182,12 +1185,15 @@ TEST_F(HloInstructionTest, Stringification) {
builder.AddInstruction(HloInstruction::CreateParameter(1, s2, "y"));
HloInstruction* reshape =
builder.AddInstruction(HloInstruction::CreateTranspose(s2t, y, {1, 0}));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(sout, HloOpcode::kDot, x, reshape));
+ HloInstruction::CreateDot(sout, x, reshape, dot_dnums));
EXPECT_EQ(dot->ToString(false, false),
"%dot = f32[5,20]{1,0} dot(f32[5,10]{1,0} %x, f32[10,20]{1,0} "
- "%transpose)");
+ "%transpose), lhs_contracting_dims=1,rhs_contracting_dims=0");
HloModule module(TestName());
auto* computation = module.AddEntryComputation(builder.Build());
diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc
index faaf73ea1c..6fe2134466 100644
--- a/tensorflow/compiler/xla/service/hlo_module.cc
+++ b/tensorflow/compiler/xla/service/hlo_module.cc
@@ -35,14 +35,15 @@ namespace xla {
HloModule::HloModule(const string& name,
const VersionedComputationHandle& entry_computation_handle,
const HloModuleConfig& config)
- : name_(name),
+ : name_(NameUniquer::GetSanitizedName(name)),
config_(config),
has_entry_computation_handle_(true),
entry_computation_handle_(entry_computation_handle) {}
-HloModule::HloModule(const string& name) : name_(name) {}
+HloModule::HloModule(const string& name)
+ : name_(NameUniquer::GetSanitizedName(name)) {}
HloModule::HloModule(const string& name, const HloModuleConfig& config)
- : name_(name), config_(config) {}
+ : name_(NameUniquer::GetSanitizedName(name)), config_(config) {}
HloComputation* HloModule::AddComputationInternal(
std::unique_ptr<HloComputation> computation, bool is_entry,
diff --git a/tensorflow/compiler/xla/service/hlo_profile_printer.h b/tensorflow/compiler/xla/service/hlo_profile_printer.h
index 316753a82a..2f056490ae 100644
--- a/tensorflow/compiler/xla/service/hlo_profile_printer.h
+++ b/tensorflow/compiler/xla/service/hlo_profile_printer.h
@@ -65,9 +65,11 @@ class HloProfilePrinter {
HloProfilePrinter(
HloComputationInfo* computation_infos, int64 computation_infos_size,
+ int64 profile_counters_size,
std::function<void(HloComputationInfo*, int64)> deleter = nullptr)
: computation_infos_(computation_infos),
computation_infos_size_(computation_infos_size),
+ profile_counters_size_(profile_counters_size),
deleter_(std::move(deleter)) {}
HloProfilePrinter(HloProfilePrinter&& other) {
@@ -79,10 +81,13 @@ class HloProfilePrinter {
HloProfilePrinter(const HloProfilePrinter&) = delete;
HloProfilePrinter& operator=(const HloProfilePrinter&) = delete;
- // Convert the profile counter sequence `counters` to a human readable string
+ // Converts the profile counter sequence `counters` to a human readable string
// representation.
string ToString(const int64* counters, double clock_rate_ghz) const;
+ // Returns the size of the profile buffer expected by this printer.
+ int64 profile_counters_size() const { return profile_counters_size_; }
+
~HloProfilePrinter();
private:
@@ -90,6 +95,7 @@ class HloProfilePrinter {
// is manifested as the deleter_ function.
HloComputationInfo* computation_infos_ = nullptr;
int64 computation_infos_size_ = 0;
+ int64 profile_counters_size_ = 0;
std::function<void(HloComputationInfo*, int64)> deleter_;
};
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
index 017f996bc4..d09de7b528 100644
--- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc
+++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc
@@ -566,7 +566,9 @@ Status MemoryUsageTracker::BeginInstruction(Item* item) {
VLOG(3) << " memory usage = " << memory_usage_;
VLOG(10) << ToString();
- DCHECK(Check());
+ if (VLOG_IS_ON(1)) {
+ DCHECK(Check());
+ }
return Status::OK();
}
@@ -603,8 +605,9 @@ Status MemoryUsageTracker::EndInstruction() {
VLOG(3) << " memory usage = " << memory_usage_;
VLOG(10) << ToString();
- DCHECK(Check());
-
+ if (VLOG_IS_ON(1)) {
+ DCHECK(Check());
+ }
return Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/hlo_sharding.cc b/tensorflow/compiler/xla/service/hlo_sharding.cc
index d1adec31c2..447c244666 100644
--- a/tensorflow/compiler/xla/service/hlo_sharding.cc
+++ b/tensorflow/compiler/xla/service/hlo_sharding.cc
@@ -246,7 +246,8 @@ Status HloSharding::ValidateNonTuple(const Shape& shape,
// The tile rank must be the same as the input rank.
if (ShapeUtil::Rank(shape) != ShapeUtil::Rank(tile_shape_)) {
return tensorflow::errors::InvalidArgument(
- "Tile rank is different to the input rank");
+ "Tile rank is different to the input rank. sharding=", ToString(),
+ ", input_shape=", ShapeUtil::HumanString(shape));
}
// The tile shape must not be the same as the input shape without maximal_
diff --git a/tensorflow/compiler/xla/service/hlo_verifier.cc b/tensorflow/compiler/xla/service/hlo_verifier.cc
index 15188c4057..ea7775b18a 100644
--- a/tensorflow/compiler/xla/service/hlo_verifier.cc
+++ b/tensorflow/compiler/xla/service/hlo_verifier.cc
@@ -75,7 +75,11 @@ class ShapeVerifier : public DfsHloVisitor {
}
Status HandleDot(HloInstruction* dot) override {
- return CheckBinaryShape(dot);
+ TF_ASSIGN_OR_RETURN(const Shape expected,
+ ShapeInference::InferDotOpShape(
+ dot->operand(0)->shape(), dot->operand(1)->shape(),
+ dot->dot_dimension_numbers()));
+ return CheckShape(dot, expected);
}
Status HandleConvolution(HloInstruction* convolution) override {
@@ -143,9 +147,13 @@ class ShapeVerifier : public DfsHloVisitor {
}
Status HandleBitcast(HloInstruction* bitcast) override {
- // Bitcasts can be any shape, as long as the size matches the operand size.
- TF_RET_CHECK(shape_size_fn_(bitcast->shape()) ==
- shape_size_fn_(bitcast->operand(0)->shape()));
+ // Bitcasts that are not the root of a computation can be any shape.
+ // Bitcasts that are the root of a computation must have the same shape
+ // byte size as their operand.
+ if (bitcast->parent()->root_instruction() == bitcast) {
+ TF_RET_CHECK(shape_size_fn_(bitcast->shape()) ==
+ shape_size_fn_(bitcast->operand(0)->shape()));
+ }
return tensorflow::Status::OK();
}
diff --git a/tensorflow/compiler/xla/service/liveness_util_test.cc b/tensorflow/compiler/xla/service/liveness_util_test.cc
index 476e86fa72..2c2a02f637 100644
--- a/tensorflow/compiler/xla/service/liveness_util_test.cc
+++ b/tensorflow/compiler/xla/service/liveness_util_test.cc
@@ -277,8 +277,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedDotAdd) {
auto b = builder.AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR2<float>({{2.0, 2.0}, {2.0, 2.0}})));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
auto dot = builder.AddInstruction(
- HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b));
+ HloInstruction::CreateDot(data_shape, a, b, dot_dnums));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
@@ -312,8 +315,11 @@ TEST_F(CanShareOperandBufferWithUserTest, FusedTransposeDotAdd) {
auto b_t = builder.AddInstruction(
HloInstruction::CreateTranspose(data_shape, b, {1, 0}));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
auto dot = builder.AddInstruction(
- HloInstruction::CreateBinary(data_shape, HloOpcode::kDot, a, b_t));
+ HloInstruction::CreateDot(data_shape, a, b_t, dot_dnums));
auto one = builder.AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<float>(1.0)));
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
index 29cc0f81bd..d951a37d5d 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.cc
@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
+#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
namespace xla {
void KernelSupportLibrary::For(
@@ -62,4 +63,47 @@ void KernelSupportLibrary::If(
false_block_generator();
llvm_ir::SetToLastInsertPoint(if_data.after_block, ir_builder_);
}
+
+void KernelSupportLibrary::EmitAndCallOutlinedKernel(
+ llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name,
+ KernelSupportLibrary::ArgumentVector arguments,
+ const std::function<void(KernelSupportLibrary::ArgumentVector)>&
+ kernel_body_generator) {
+ llvm::Module* module = ir_builder->GetInsertBlock()->getModule();
+ llvm::Function* function =
+ module->getFunction(llvm_ir::AsStringRef(kernel_name));
+ if (!function) {
+ VLOG(2) << "Generating kernel for " << kernel_name;
+ std::vector<llvm::Type*> arg_types;
+ std::transform(arguments.begin(), arguments.end(),
+ std::back_inserter(arg_types),
+ [](llvm::Value* arg) { return arg->getType(); });
+
+ auto* function_type = llvm::FunctionType::get(
+ ir_builder->getVoidTy(), arg_types, /*isVarArg=*/false);
+
+ function = llvm::Function::Create(
+ function_type, llvm::GlobalValue::InternalLinkage,
+ llvm_ir::AsStringRef(kernel_name), module);
+
+ llvm::IRBuilder<>::InsertPointGuard guard(*ir_builder);
+
+ auto* entry_bb =
+ llvm::BasicBlock::Create(ir_builder->getContext(), "entry", function);
+ auto* return_inst = llvm::ReturnInst::Create(ir_builder->getContext(),
+ /*retVal=*/nullptr, entry_bb);
+ // Set the insert point to before return_inst.
+ ir_builder->SetInsertPoint(return_inst);
+
+ std::vector<llvm::Value*> arg_values;
+ std::transform(function->arg_begin(), function->arg_end(),
+ std::back_inserter(arg_values), std::addressof<llvm::Value>);
+ kernel_body_generator(arg_values);
+ } else {
+ VLOG(3) << "Re-using kernel for " << kernel_name;
+ }
+
+ ir_builder->CreateCall(function, llvm_ir::AsArrayRef(arguments));
+}
+
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
index 9bafb7b577..997b84bb27 100644
--- a/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
+++ b/tensorflow/compiler/xla/service/llvm_ir/kernel_support_library.h
@@ -118,6 +118,38 @@ class KernelSupportLibrary {
const std::function<void()>& true_block_generator,
const std::function<void()>& false_block_generator = []() {});
+ using ArgumentVector = tensorflow::gtl::ArraySlice<llvm::Value*>;
+
+ // Generates the following control flow structure:
+ //
+ // define @`kernel_name`(arg0, arg1, ... arg`arguments.size()`) {
+ // kernel_body_generator({arg0, arg1, ... arg`arguments.size()`});
+ // }
+ //
+ // ...
+ // call @`kernel_name`(arguments[0], arguments[1] ...)
+ // ...
+ //
+ // If a function called `kernel_name` is already present in the module then
+ // that function is re-used. In that sense we're using the llvm::Module as a
+ // cache of outlined kernels, keyed by function name.
+ static void EmitAndCallOutlinedKernel(
+ llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name,
+ ArgumentVector arguments,
+ const std::function<void(ArgumentVector)>& kernel_body_generator);
+
+ // Thin wrapper around the more general EmitAndCallOutlinedKernel above.
+ static void EmitAndCallOutlinedKernel(
+ llvm::IRBuilder<>* ir_builder, tensorflow::StringPiece kernel_name,
+ llvm::Value* arg0, llvm::Value* arg1, llvm::Value* arg2,
+ const std::function<void(llvm::Value*, llvm::Value*, llvm::Value*)>&
+ kernel_body_generator) {
+ EmitAndCallOutlinedKernel(
+ ir_builder, kernel_name, {arg0, arg1, arg2}, [&](ArgumentVector args) {
+ kernel_body_generator(args[0], args[1], args[2]);
+ });
+ }
+
private:
llvm::IRBuilder<>* ir_builder_;
bool prevent_unrolling_;
diff --git a/tensorflow/compiler/xla/service/name_uniquer.cc b/tensorflow/compiler/xla/service/name_uniquer.cc
index a0d08c288d..7d8c05fffa 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer.cc
@@ -17,12 +17,44 @@ limitations under the License.
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/lib/strings/strcat.h"
+#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
+namespace {
+
+bool IsAllowed(char character) {
+ auto c = static_cast<unsigned char>(character);
+ return (isalnum(c) != 0) || c == '_' || c == '.' || c == '-';
+}
+
+} // namespace
+
+NameUniquer::NameUniquer(const string& separator) {
+ CHECK(std::all_of(separator.begin(), separator.end(), IsAllowed))
+ << "separator should comprises allowed characters only";
+ separator_ = separator;
+}
+
+/*static*/ string NameUniquer::GetSanitizedName(const string& name) {
+ string result = name;
+ CHECK(!result.empty()) << "name should not be empty";
+ char c = static_cast<unsigned char>(result[0]);
+ if (!isalpha(c) && c != '_') {
+ result[0] = '_';
+ }
+ for (int i = 1; i < result.length(); i++) {
+ if (!IsAllowed(result[i])) {
+ result[i] = '_';
+ }
+ }
+ return result;
+}
+
string NameUniquer::GetUniqueName(tensorflow::StringPiece prefix) {
string root = prefix.empty() ? "name" : prefix.ToString();
+ root = GetSanitizedName(root);
// Strip away numeric suffix (if any). Only recognize separator if it is in
// the middle of the name.
diff --git a/tensorflow/compiler/xla/service/name_uniquer.h b/tensorflow/compiler/xla/service/name_uniquer.h
index ed379b5225..4139c2700b 100644
--- a/tensorflow/compiler/xla/service/name_uniquer.h
+++ b/tensorflow/compiler/xla/service/name_uniquer.h
@@ -28,14 +28,21 @@ namespace xla {
// Simple stateful class that helps generate "unique" names. To use it, simply
// call GetUniqueName as many times as needed. The names returned by
// GetUniqueName are guaranteed to be distinct for this instance of the class.
+// Note that the names will be sanitized to match regexp
+// "[a-zA-Z_][a-zA-Z0-9_.-]*".
class NameUniquer {
public:
- explicit NameUniquer(const string& separator = "__")
- : separator_(separator) {}
+ // The separator must contain allowed characters only: "[a-zA-Z0-9_.-]".
+ explicit NameUniquer(const string& separator = "__");
- // Get a unique name in a string, with an optional prefix for convenience.
+ // Get a sanitized unique name in a string, with an optional prefix for
+ // convenience.
string GetUniqueName(tensorflow::StringPiece prefix = "");
+ // Sanitizes and returns the name. Unallowed characters will be replaced with
+ // '_'. The result will match the regexp "[a-zA-Z_][a-zA-Z0-9_.-]*".
+ static string GetSanitizedName(const string& name);
+
private:
// The string to use to separate the prefix of the name from the uniquing
// integer value.
diff --git a/tensorflow/compiler/xla/service/name_uniquer_test.cc b/tensorflow/compiler/xla/service/name_uniquer_test.cc
index 9f0747a6e2..4258cf1687 100644
--- a/tensorflow/compiler/xla/service/name_uniquer_test.cc
+++ b/tensorflow/compiler/xla/service/name_uniquer_test.cc
@@ -60,12 +60,30 @@ TEST_F(NameUniquerTest, NumericSuffixes) {
EXPECT_EQ("bar", uniquer.GetUniqueName("bar.-1000"));
EXPECT_EQ("bar.1", uniquer.GetUniqueName("bar.-2000"));
EXPECT_EQ("bar.2", uniquer.GetUniqueName("bar.1"));
+}
+
+TEST_F(NameUniquerTest, Sanitize) {
+ NameUniquer uniquer("_");
+
+ EXPECT_EQ("foo", uniquer.GetUniqueName("foo"));
+ EXPECT_EQ("foo_1", uniquer.GetUniqueName("foo"));
+ EXPECT_EQ("foo.54", uniquer.GetUniqueName("foo.54"));
+ EXPECT_EQ("foo_54", uniquer.GetUniqueName("foo_54"));
+ EXPECT_EQ("foo_54.1", uniquer.GetUniqueName("foo_54.1"));
+ EXPECT_EQ("foo_55", uniquer.GetUniqueName("foo"));
+
+ // Invalid characters will be replaced with '_'.
+ EXPECT_EQ("bar", uniquer.GetUniqueName("bar<-1000"));
+ EXPECT_EQ("bar_1", uniquer.GetUniqueName("bar<-2000"));
+ EXPECT_EQ("bar_2", uniquer.GetUniqueName("bar_1"));
// Separator is only recognized in the middle of the prefix.
- EXPECT_EQ(".10", uniquer.GetUniqueName(".10"));
- EXPECT_EQ(".10.1", uniquer.GetUniqueName(".10"));
- EXPECT_EQ("foobar.", uniquer.GetUniqueName("foobar."));
- EXPECT_EQ("foobar..1", uniquer.GetUniqueName("foobar."));
+ EXPECT_EQ("_10", uniquer.GetUniqueName(
+ ".10")); // the leading '.' is replaced with '_'.
+ EXPECT_EQ("_10_1", uniquer.GetUniqueName(".10"));
+ EXPECT_EQ("_10_2", uniquer.GetUniqueName("_10"));
+ EXPECT_EQ("foobar_", uniquer.GetUniqueName("foobar_"));
+ EXPECT_EQ("foobar__1", uniquer.GetUniqueName("foobar_"));
}
} // namespace
diff --git a/tensorflow/compiler/xla/service/service.cc b/tensorflow/compiler/xla/service/service.cc
index d997cab83f..fa62080be4 100644
--- a/tensorflow/compiler/xla/service/service.cc
+++ b/tensorflow/compiler/xla/service/service.cc
@@ -1381,6 +1381,9 @@ tensorflow::Status Service::Op(const OpRequest* arg, OpResponse* result) {
handle_status =
computation->AddCustomCallInstruction(arg->custom_call_request());
break;
+ case OpRequest::kDotRequest:
+ handle_status = computation->AddDotInstruction(arg->dot_request());
+ break;
case OpRequest::kDynamicSliceRequest:
handle_status =
computation->AddDynamicSliceInstruction(arg->dynamic_slice_request());
diff --git a/tensorflow/compiler/xla/service/shape_inference.cc b/tensorflow/compiler/xla/service/shape_inference.cc
index 3df1911d07..7178eb40dd 100644
--- a/tensorflow/compiler/xla/service/shape_inference.cc
+++ b/tensorflow/compiler/xla/service/shape_inference.cc
@@ -29,6 +29,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/stringpiece.h"
+#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
@@ -90,8 +91,6 @@ BinaryOperation OpcodeToBinaryOperation(HloOpcode opcode) {
return BINOP_ATAN2;
case HloOpcode::kComplex:
return BINOP_COMPLEX;
- case HloOpcode::kDot:
- return BINOP_DOT;
case HloOpcode::kMultiply:
return BINOP_MUL;
case HloOpcode::kAdd:
@@ -549,8 +548,98 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return ShapeUtil::MakeShape(operand_shape.element_type(), dimensions);
}
-/* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(const Shape& lhs,
- const Shape& rhs) {
+// Current DotDimensionNumbers Requirements:
+//
+// Contracting Dimensions:
+// *) Exactly one contracting dimension on both lhs and rhs.
+// *) Contracting dimension size must be the same on both lhs and rhs.
+// *) Contracting dimension numbers do not need to be the same (i.e. transposes
+// are passed on to emitter implementations).
+//
+// Batch Dimensions:
+// *) Same number of batch dimensions on both lhs and rhs.
+// *) Same batch dimension numbers (and sizes) on both lhs and rhs.
+//
+// Non-Contracting-Non-Batch Dimensions:
+// *) Can be 0 (matrix-vector) or 1 (matrix-matrix).
+//
+
+namespace {
+
+Status ValidateDotDimensionNumbers(
+ const Shape& lhs, const Shape& rhs,
+ const DotDimensionNumbers& dimension_numbers) {
+ // Check that dimension numbers are in range.
+ auto dims_in_range =
+ [](const int64 rank, tensorflow::gtl::ArraySlice<int64> contracting_dims,
+ tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
+ auto in_range = [&rank](int64 i) -> bool { return 0 <= i && i < rank; };
+ return std::all_of(contracting_dims.begin(), contracting_dims.end(),
+ in_range) &&
+ std::all_of(batch_dims.begin(), batch_dims.end(), in_range);
+ };
+
+ tensorflow::gtl::ArraySlice<int64> lhs_contracting_dimensions =
+ AsInt64Slice(dimension_numbers.lhs_contracting_dimensions());
+ tensorflow::gtl::ArraySlice<int64> rhs_contracting_dimensions =
+ AsInt64Slice(dimension_numbers.rhs_contracting_dimensions());
+ tensorflow::gtl::ArraySlice<int64> lhs_batch_dimensions =
+ AsInt64Slice(dimension_numbers.lhs_batch_dimensions());
+ tensorflow::gtl::ArraySlice<int64> rhs_batch_dimensions =
+ AsInt64Slice(dimension_numbers.rhs_batch_dimensions());
+
+ if (!dims_in_range(ShapeUtil::Rank(lhs), lhs_contracting_dimensions,
+ lhs_batch_dimensions) ||
+ !dims_in_range(ShapeUtil::Rank(rhs), rhs_contracting_dimensions,
+ rhs_batch_dimensions)) {
+ return InvalidArgument("A dimension number is out of range in dot: %s",
+ dimension_numbers.DebugString().c_str());
+ }
+
+ // Check that dimension numbers are unique.
+ auto dims_unique = [](tensorflow::gtl::ArraySlice<int64> contracting_dims,
+ tensorflow::gtl::ArraySlice<int64> batch_dims) -> bool {
+ tensorflow::gtl::FlatSet<int64> dim_set;
+ auto is_unique = [&dim_set](int64 i) -> bool {
+ return dim_set.insert(i).second;
+ };
+ return std::all_of(contracting_dims.begin(), contracting_dims.end(),
+ is_unique) &&
+ std::all_of(batch_dims.begin(), batch_dims.end(), is_unique);
+ };
+
+ if (!dims_unique(lhs_contracting_dimensions, lhs_batch_dimensions) ||
+ !dims_unique(rhs_contracting_dimensions, rhs_batch_dimensions)) {
+ return InvalidArgument("A dimension number is not unique in dot: %s",
+ dimension_numbers.DebugString().c_str());
+ }
+
+ // Check that the count of non-contracting-non-batch dimensions is in {0, 1}.
+ const int64 lhs_non_contracting_non_batch_dims =
+ ShapeUtil::Rank(lhs) -
+ dimension_numbers.lhs_contracting_dimensions_size() -
+ dimension_numbers.lhs_batch_dimensions_size();
+ const int64 rhs_non_contracting_non_batch_dims =
+ ShapeUtil::Rank(rhs) -
+ dimension_numbers.rhs_contracting_dimensions_size() -
+ dimension_numbers.rhs_batch_dimensions_size();
+ if (lhs_non_contracting_non_batch_dims < 0 ||
+ lhs_non_contracting_non_batch_dims > 1 ||
+ rhs_non_contracting_non_batch_dims < 0 ||
+ rhs_non_contracting_non_batch_dims > 1) {
+ return InvalidArgument(
+ "batch and contracting dimension number mismatch "
+ "with rank ");
+ }
+
+ return Status::OK();
+}
+
+} // namespace
+
+/* static */ StatusOr<Shape> ShapeInference::InferDotOpShape(
+ const Shape& lhs, const Shape& rhs,
+ const DotDimensionNumbers& dimension_numbers) {
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(lhs, "lhs of dot"));
TF_RETURN_IF_ERROR(ExpectNotTupleOrOpaque(rhs, "rhs of dot"));
@@ -570,37 +659,62 @@ StatusOr<Shape> InferWindowOutputShape(const Shape& base_shape,
return fail("element types do not match");
}
- if (ShapeUtil::Rank(lhs) < 1 || ShapeUtil::Rank(lhs) > 2 ||
- ShapeUtil::Rank(rhs) < 1 || ShapeUtil::Rank(rhs) > 2) {
- return fail("dot only supports rank 1 or 2");
+ if ((ShapeUtil::Rank(lhs) < 1) || (ShapeUtil::Rank(rhs) < 1)) {
+ return fail("dot only supports rank 1 or above.");
}
- // Determine the index of the contracted dimensions for input tensors.
- // dimensions -1 of lhs and dimension 0 of rhs are contracted.
- int64 lhs_contracted_dimension = ShapeUtil::GetDimensionNumber(lhs, -1);
- int64 rhs_contracted_dimension = 0;
+ // Validate basic properties of dot dimension numbers.
+ TF_RETURN_IF_ERROR(ValidateDotDimensionNumbers(lhs, rhs, dimension_numbers));
+
+ // Check that there is only one contracting dimension for both lhs and rhs.
+ if (dimension_numbers.lhs_contracting_dimensions_size() !=
+ dimension_numbers.rhs_contracting_dimensions_size() ||
+ dimension_numbers.lhs_contracting_dimensions_size() != 1) {
+ return fail("must specify one contracting dimension for both lhs and rhs.");
+ }
- // Check if the contracted dimension sizes are the same.
- if ((lhs_contracted_dimension < ShapeUtil::Rank(lhs) &&
- rhs_contracted_dimension < ShapeUtil::Rank(rhs)) &&
- lhs.dimensions(lhs_contracted_dimension) !=
- rhs.dimensions(rhs_contracted_dimension)) {
- return fail("contracted dimensions mismatch");
+ // Check that contracting dimension sizes match.
+ const int64 lhs_contracting_dimension =
+ dimension_numbers.lhs_contracting_dimensions(0);
+ const int64 rhs_contracting_dimension =
+ dimension_numbers.rhs_contracting_dimensions(0);
+ if (lhs.dimensions(lhs_contracting_dimension) !=
+ rhs.dimensions(rhs_contracting_dimension)) {
+ return fail("contracting dimension sizes do not match.");
+ }
+
+ // Check that number of batch dimensions match.
+ if (dimension_numbers.lhs_batch_dimensions_size() !=
+ dimension_numbers.rhs_batch_dimensions_size()) {
+ return fail("must the same number of batch dimensions for lhs and rhs.");
+ }
+
+ // Check that batch dimension numbers and sizes match.
+ for (int64 i = 0; i < dimension_numbers.lhs_batch_dimensions_size(); ++i) {
+ if (dimension_numbers.lhs_batch_dimensions(i) !=
+ dimension_numbers.rhs_batch_dimensions(i) ||
+ lhs.dimensions(dimension_numbers.lhs_batch_dimensions(i)) !=
+ rhs.dimensions(dimension_numbers.rhs_batch_dimensions(i))) {
+ return fail("batch dimension numbers and sizes must match for lhs/rhs.");
+ }
}
// The ranks of lhs and rhs are decremented by 1 respectively due to the
// contraction, and added for the rank of the result. When an input tensor is
// a scalar, its contribution to the rank of the result is 0.
// Generate the result dimensions in order, rhs dimensions followed by lhs
- // dimensions except the contracted dimensions.
+ // dimensions except the contracted and batch dimensions.
std::vector<int64> dimensions;
+ std::unordered_set<int64> rhs_batch_dims(
+ dimension_numbers.rhs_batch_dimensions().begin(),
+ dimension_numbers.rhs_batch_dimensions().end());
for (int64 i = 0; i < ShapeUtil::Rank(lhs); i++) {
- if (i != lhs_contracted_dimension) {
+ if (i != lhs_contracting_dimension) {
dimensions.push_back(lhs.dimensions(i));
}
}
for (int64 i = 0; i < ShapeUtil::Rank(rhs); i++) {
- if (i != rhs_contracted_dimension) {
+ if (i != rhs_contracting_dimension && rhs_batch_dims.count(i) == 0) {
dimensions.push_back(rhs.dimensions(i));
}
}
@@ -816,8 +930,6 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
rhs, tensorflow::strings::StrCat("rhs of binary operation ",
BinaryOperation_Name(operation))));
switch (operation) {
- case BINOP_DOT:
- return InferDotOpShape(lhs, rhs);
case BINOP_MAX:
case BINOP_MIN:
case BINOP_SUB:
diff --git a/tensorflow/compiler/xla/service/shape_inference.h b/tensorflow/compiler/xla/service/shape_inference.h
index 0aadb98a40..382c4f8abc 100644
--- a/tensorflow/compiler/xla/service/shape_inference.h
+++ b/tensorflow/compiler/xla/service/shape_inference.h
@@ -229,11 +229,13 @@ class ShapeInference {
tensorflow::gtl::ArraySlice<const Shape*> arg_shapes,
const ProgramShape& to_apply);
- private:
// Helper that infers the shape produced by performing a dot operation with
// the given LHS and RHS shapes.
- static StatusOr<Shape> InferDotOpShape(const Shape& lhs, const Shape& rhs);
+ static StatusOr<Shape> InferDotOpShape(
+ const Shape& lhs, const Shape& rhs,
+ const DotDimensionNumbers& dimension_numbers);
+ private:
// Helper that infers the shape produced by performing an element-wise binary
// operation with the given LHS and RHS shapes.
// Note: By "element-wise" we mean operations that look at a single element in
diff --git a/tensorflow/compiler/xla/service/shape_inference_test.cc b/tensorflow/compiler/xla/service/shape_inference_test.cc
index be93c879c0..6e53d2d609 100644
--- a/tensorflow/compiler/xla/service/shape_inference_test.cc
+++ b/tensorflow/compiler/xla/service/shape_inference_test.cc
@@ -898,8 +898,11 @@ TEST_F(ShapeInferenceTest, BroadcastScalar) {
// scalar <dot> vector: error
TEST_F(ShapeInferenceTest, ScalarDotVector) {
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_DOT, f32_, vector_32_, {});
+ ShapeInference::InferDotOpShape(f32_, vector_32_, dot_dnums);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
HasSubstr("dot only supports rank"));
@@ -907,61 +910,199 @@ TEST_F(ShapeInferenceTest, ScalarDotVector) {
// 3D <dot> 2D: error
TEST_F(ShapeInferenceTest, DotWithRankHigherThanTwo) {
- auto inferred_status = ShapeInference::InferBinaryOpShape(
- BINOP_DOT, ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, {});
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto inferred_status = ShapeInference::InferDotOpShape(
+ ShapeUtil::MakeShape(F32, {32, 32, 32}), matrix_32_64_, dot_dnums);
ASSERT_FALSE(inferred_status.ok());
ASSERT_THAT(inferred_status.status().error_message(),
- HasSubstr("dot only supports rank"));
+ HasSubstr("batch and contracting dimension number mismatch"));
}
// vector <dot> vector -> scalar
TEST_F(ShapeInferenceTest, VectorDotVector) {
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
auto inferred_status =
- ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_64_, {});
+ ShapeInference::InferDotOpShape(vector_64_, vector_64_, dot_dnums);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(f32_, inferred_status.ValueOrDie()));
auto inferred_status_mismatch =
- ShapeInference::InferBinaryOpShape(BINOP_DOT, vector_64_, vector_32_, {});
+ ShapeInference::InferDotOpShape(vector_64_, vector_32_, dot_dnums);
ASSERT_FALSE(inferred_status_mismatch.ok());
}
// matrix <dot> vector -> vector
TEST_F(ShapeInferenceTest, MatrixDotVector) {
- auto inferred_status = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_DOT, matrix_32_64_, vector_64_, {});
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(matrix_32_64_, vector_64_, dot_dnums);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_32_));
- auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_DOT, matrix_32_64_, vector_32_, {});
+ auto inferred_status_mismatch =
+ ShapeInference::InferDotOpShape(matrix_32_64_, vector_32_, dot_dnums);
ASSERT_FALSE(inferred_status_mismatch.ok());
}
// vector <dot> matrix -> vector
TEST_F(ShapeInferenceTest, VectorDotMatrix) {
- auto inferred_status = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_DOT, vector_32_, matrix_32_64_, {});
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(vector_32_, matrix_32_64_, dot_dnums);
ASSERT_IS_OK(inferred_status.status());
ASSERT_TRUE(ShapeUtil::Equal(inferred_status.ValueOrDie(), vector_64_));
- auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_DOT, vector_64_, matrix_32_64_, {});
+ auto inferred_status_mismatch =
+ ShapeInference::InferDotOpShape(vector_64_, matrix_32_64_, dot_dnums);
ASSERT_FALSE(inferred_status_mismatch.ok());
}
// matrix <dot> matrix -> matrix
TEST_F(ShapeInferenceTest, MatrixDotMatrix) {
- auto inferred_status_match = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_64_48_, {});
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ auto inferred_status_match =
+ ShapeInference::InferDotOpShape(matrix_32_64_, matrix_64_48_, dot_dnums);
ASSERT_IS_OK(inferred_status_match.status());
ASSERT_TRUE(
ShapeUtil::Equal(inferred_status_match.ValueOrDie(), matrix_32_48_))
<< "inferred: "
<< ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
<< " expected: " << ShapeUtil::HumanString(matrix_64_48_);
- auto inferred_status_mismatch = ShapeInference::InferBinaryOpShape(
- BinaryOperation::BINOP_DOT, matrix_32_64_, matrix_32_64_, {});
+ auto inferred_status_mismatch =
+ ShapeInference::InferDotOpShape(matrix_32_64_, matrix_32_64_, dot_dnums);
ASSERT_FALSE(inferred_status_mismatch.ok());
}
+// BatchMatMul with two batch dimensions and one contracting dimension.
+TEST_F(ShapeInferenceTest, DotGeneral) {
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 3});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {5, 2, 3, 14});
+ Shape output_shape = ShapeUtil::MakeShape(F32, {5, 2, 11, 14});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(3);
+ dot_dnums.add_lhs_batch_dimensions(0);
+ dot_dnums.add_lhs_batch_dimensions(1);
+
+ dot_dnums.add_rhs_contracting_dimensions(2);
+ dot_dnums.add_rhs_batch_dimensions(0);
+ dot_dnums.add_rhs_batch_dimensions(1);
+
+ auto inferred_status_match =
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+ ASSERT_IS_OK(inferred_status_match.status());
+ ASSERT_TRUE(
+ ShapeUtil::Equal(inferred_status_match.ValueOrDie(), output_shape))
+ << "inferred: "
+ << ShapeUtil::HumanString(inferred_status_match.ValueOrDie())
+ << " expected: " << ShapeUtil::HumanString(output_shape);
+}
+
+// BatchMatMul with two contracting dimensions fails.
+TEST_F(ShapeInferenceTest, DotWithTwoContractingDimsFails) {
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3, 2});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
+ Shape output_shape = ShapeUtil::MakeShape(F32, {2, 11, 14});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(2);
+ dot_dnums.add_lhs_contracting_dimensions(3);
+ dot_dnums.add_lhs_batch_dimensions(0);
+
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_batch_dimensions(0);
+
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_THAT(inferred_status.status().error_message(),
+ HasSubstr("must specify one contracting dimension for both "
+ "lhs and rhs"));
+}
+
+// BatchMatMul with different batch dimension sizes fails.
+TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimSizesFails) {
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 3, 14});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(2);
+ dot_dnums.add_lhs_batch_dimensions(0);
+
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_batch_dimensions(0);
+
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_THAT(inferred_status.status().error_message(),
+ HasSubstr("batch dimension numbers and sizes must match"));
+}
+
+// BatchMatMul with different batch dimension numbers fails.
+TEST_F(ShapeInferenceTest, DotWithMisatchedBatchDimNumbersFails) {
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {3, 2, 14});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(2);
+ dot_dnums.add_lhs_batch_dimensions(0);
+
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_batch_dimensions(1);
+
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_THAT(inferred_status.status().error_message(),
+ HasSubstr("batch dimension numbers and sizes must match"));
+}
+
+// BatchMatMul with out-of-range dimension numbers fails.
+TEST_F(ShapeInferenceTest, DotWithContractingDimNumberOutOfRange) {
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(3);
+ dot_dnums.add_lhs_batch_dimensions(0);
+
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_batch_dimensions(1);
+
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_THAT(inferred_status.status().error_message(),
+ HasSubstr("A dimension number is out of range"));
+}
+
+// BatchMatMul with non-unique dimension numbers fails.
+TEST_F(ShapeInferenceTest, DotWithContractingNonUniqueDimNumber) {
+ Shape lhs_shape = ShapeUtil::MakeShape(F32, {2, 11, 3});
+ Shape rhs_shape = ShapeUtil::MakeShape(F32, {2, 3, 14});
+
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_lhs_batch_dimensions(0);
+
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_batch_dimensions(1);
+
+ auto inferred_status =
+ ShapeInference::InferDotOpShape(lhs_shape, rhs_shape, dot_dnums);
+ ASSERT_FALSE(inferred_status.ok());
+ ASSERT_THAT(inferred_status.status().error_message(),
+ HasSubstr("A dimension number is not unique"));
+}
+
TEST_F(ShapeInferenceTest, BinOpBroadcastMatrixVector) {
// Test variations of broadcasting a vector for a binary add with a
// matrix.
diff --git a/tensorflow/compiler/xla/service/transpose_folding.cc b/tensorflow/compiler/xla/service/transpose_folding.cc
index fb55d4e543..42b616f4c3 100644
--- a/tensorflow/compiler/xla/service/transpose_folding.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding.cc
@@ -102,6 +102,10 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
auto& convolution = *pair.first;
auto& operand_indices = pair.second;
+ if (operand_indices.empty()) {
+ return false;
+ }
+
const ConvolutionDimensionNumbers& dnums =
convolution.convolution_dimension_numbers();
ConvolutionDimensionNumbers new_dnums = dnums;
@@ -121,8 +125,9 @@ bool FoldTransposeIntoConvolution(InstructionOperandsPair pair) {
transpose_dimensions[dnums.input_batch_dimension()]);
new_dnums.set_input_feature_dimension(
transpose_dimensions[dnums.input_feature_dimension()]);
- for (const auto& spatial_dimension : dnums.input_spatial_dimensions()) {
- CHECK_EQ(spatial_dimension, transpose_dimensions[spatial_dimension]);
+ for (auto& input_spatial_dimension :
+ *new_dnums.mutable_input_spatial_dimensions()) {
+ input_spatial_dimension = transpose_dimensions[input_spatial_dimension];
}
new_lhs = &transpose_operand;
} else {
diff --git a/tensorflow/compiler/xla/service/transpose_folding_test.cc b/tensorflow/compiler/xla/service/transpose_folding_test.cc
index 6ac32e88f1..caa1a111ad 100644
--- a/tensorflow/compiler/xla/service/transpose_folding_test.cc
+++ b/tensorflow/compiler/xla/service/transpose_folding_test.cc
@@ -64,9 +64,12 @@ TEST_F(TransposeFoldingTest, FoldDotTranspose) {
HloInstruction* transpose_y =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0}));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot,
- /*lhs=*/x, /*rhs=*/transpose_y));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
+ /*rhs=*/transpose_y, dot_dnums));
HloModule module("test_module");
HloComputation* entry_computation =
@@ -104,9 +107,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeConstant) {
HloInstruction* transpose1 =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {2, 3}), const1, {1, 0}));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {1, 3}), /*opcode=*/HloOpcode::kDot,
- /*lhs=*/transpose0, /*rhs=*/transpose1));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ ShapeUtil::MakeShape(F32, {1, 3}),
+ /*lhs=*/transpose0, /*rhs=*/transpose1, dot_dnums));
HloModule module("test_module");
HloComputation* entry_computation =
@@ -169,9 +175,12 @@ TEST_F(TransposeFoldingTest, FoldDotTransposeInWhile) {
HloInstruction* transpose_y =
builder.AddInstruction(HloInstruction::CreateTranspose(
ShapeUtil::MakeShape(F32, {3, 2}), y, {1, 0}));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {2, 2}), /*opcode=*/HloOpcode::kDot,
- /*lhs=*/x, /*rhs=*/transpose_y));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ HloInstruction* dot = builder.AddInstruction(
+ HloInstruction::CreateDot(ShapeUtil::MakeShape(F32, {2, 2}), /*lhs=*/x,
+ /*rhs=*/transpose_y, dot_dnums));
HloModule module("test_module");
HloComputation* entry_computation =
@@ -376,5 +385,69 @@ TEST_F(TransposeFoldingTest, FoldConvTransposeLhs) {
new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
}
+// Test that a transpose of every dimension in the activations gets folded into
+// convolution.
+TEST_F(TransposeFoldingTest, FoldConvComplexTransposeLhs) {
+ auto builder = HloComputation::Builder("entry_computation");
+ HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/0, ShapeUtil::MakeShape(F32, {3, 2, 1, 1}),
+ /*name=*/"x"));
+ HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
+ /*parameter_number=*/1, ShapeUtil::MakeShape(F32, {2, 3, 1, 1}),
+ /*name=*/"y"));
+ HloInstruction* transpose_x =
+ builder.AddInstruction(HloInstruction::CreateTranspose(
+ ShapeUtil::MakeShape(F32, {2, 3, 1, 1}), x, {1, 0, 3, 2}));
+ auto dnums = ComputationBuilder::CreateDefaultConvDimensionNumbers();
+ Window window;
+ for (int i = 0; i < 2; ++i) {
+ WindowDimension* dim = window.add_dimensions();
+ dim->set_padding_low(0);
+ dim->set_padding_high(0);
+ dim->set_base_dilation(1);
+ dim->set_window_dilation(1);
+ dim->set_stride(1);
+ dim->set_size(y->shape().dimensions(dnums.kernel_spatial_dimensions(i)));
+ }
+ StatusOr<Shape> conv_shape = ShapeInference::InferConvolveShape(
+ transpose_x->shape(), y->shape(), window, dnums);
+ EXPECT_IS_OK(conv_shape);
+ HloInstruction* conv = builder.AddInstruction(HloInstruction::CreateConvolve(
+ conv_shape.ValueOrDie(), transpose_x, y, window, dnums));
+
+ HloModule module("test_module");
+ HloComputation* entry_computation =
+ module.AddEntryComputation(builder.Build(conv));
+ FoldTranspose(&module);
+
+ // Instructions after folding: x, y, and the convolution.
+ std::unordered_set<HloInstruction*> instruction_set(
+ entry_computation->instructions().begin(),
+ entry_computation->instructions().end());
+ EXPECT_EQ(1, instruction_set.erase(x)) << "x is not in entry_computation.";
+ EXPECT_EQ(1, instruction_set.erase(y)) << "y is not in entry_computation.";
+ EXPECT_EQ(1, instruction_set.size())
+ << "entry_computation should contain exactly 3 instructions.";
+ HloInstruction* new_conv = *instruction_set.begin();
+ EXPECT_EQ(HloOpcode::kConvolution, new_conv->opcode());
+ EXPECT_EQ(dnums.input_feature_dimension(),
+ new_conv->convolution_dimension_numbers().input_batch_dimension());
+ EXPECT_EQ(
+ dnums.input_batch_dimension(),
+ new_conv->convolution_dimension_numbers().input_feature_dimension());
+ EXPECT_EQ(
+ dnums.input_spatial_dimensions(0),
+ new_conv->convolution_dimension_numbers().input_spatial_dimensions(1));
+ EXPECT_EQ(
+ dnums.input_spatial_dimensions(1),
+ new_conv->convolution_dimension_numbers().input_spatial_dimensions(0));
+ EXPECT_EQ(
+ dnums.output_spatial_dimensions(0),
+ new_conv->convolution_dimension_numbers().output_spatial_dimensions(0));
+ EXPECT_EQ(
+ dnums.output_spatial_dimensions(1),
+ new_conv->convolution_dimension_numbers().output_spatial_dimensions(1));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/user_computation.cc b/tensorflow/compiler/xla/service/user_computation.cc
index 4e90491b55..6d0d367981 100644
--- a/tensorflow/compiler/xla/service/user_computation.cc
+++ b/tensorflow/compiler/xla/service/user_computation.cc
@@ -88,8 +88,6 @@ HloOpcode BinaryOperationToHloOpcode(BinaryOperation binop) {
return HloOpcode::kAtan2;
case BINOP_COMPLEX:
return HloOpcode::kComplex;
- case BINOP_DOT:
- return HloOpcode::kDot;
case BINOP_MUL:
return HloOpcode::kMultiply;
case BINOP_ADD:
@@ -1207,6 +1205,33 @@ StatusOr<ComputationDataHandle> UserComputation::AddCustomCallInstruction(
return handle;
}
+StatusOr<ComputationDataHandle> UserComputation::AddDotInstruction(
+ const DotRequest& dot_request) {
+ tensorflow::mutex_lock lock(mutex_);
+
+ TF_ASSIGN_OR_RETURN(const OperationRequest* lhs,
+ LookUpRequest(dot_request.lhs()));
+ TF_ASSIGN_OR_RETURN(const OperationRequest* rhs,
+ LookUpRequest(dot_request.rhs()));
+
+ TF_ASSIGN_OR_RETURN(Shape shape, ShapeInference::InferDotOpShape(
+ lhs->output_shape(), rhs->output_shape(),
+ dot_request.dimension_numbers()));
+
+ const ComputationDataHandle handle = CreateComputationDataHandle();
+
+ OperationRequest& request =
+ (*session_computation_.mutable_requests())[handle.handle()];
+ *request.mutable_output_handle() = handle;
+ *request.mutable_output_shape() = shape;
+ *request.mutable_request()->mutable_dot_request() = dot_request;
+
+ VLOG(1) << "AddDotInstruction (" << GetVersionedHandleInternal()
+ << "), data handle " << handle.handle() << ": "
+ << dot_request.ShortDebugString();
+ return handle;
+}
+
StatusOr<ComputationDataHandle> UserComputation::AddUnaryInstruction(
const UnaryOpRequest& unary_request) {
tensorflow::mutex_lock lock(mutex_);
@@ -1629,6 +1654,15 @@ void PureFunctionalVisitor(const SessionComputation& session_computation,
break;
}
+ case OpRequest::kDotRequest: {
+ const DotRequest& dot_request = request.request().dot_request();
+ PureFunctionalVisitor(session_computation, dot_request.lhs(),
+ num_parameters, visited, is_functional);
+ PureFunctionalVisitor(session_computation, dot_request.rhs(),
+ num_parameters, visited, is_functional);
+ break;
+ }
+
case OpRequest::kSendRequest: {
*is_functional = false;
break;
@@ -2453,6 +2487,13 @@ static void ForEachOperand(
break;
}
+ case OpRequest::kDotRequest: {
+ const DotRequest& dot_request = request.request().dot_request();
+ apply(dot_request.rhs());
+ apply(dot_request.lhs());
+ break;
+ }
+
case OpRequest::kUnaryOpRequest: {
const UnaryOpRequest& unary_op_request =
request.request().unary_op_request();
@@ -2732,6 +2773,15 @@ void ComputationLowerer::Visit(
break;
}
+ case OpRequest::kDotRequest: {
+ const DotRequest& dot_request = request.request().dot_request();
+ HloInstruction* lhs = lookup_instruction(dot_request.lhs());
+ HloInstruction* rhs = lookup_instruction(dot_request.rhs());
+ hlo_instruction = add_instruction(HloInstruction::CreateDot(
+ request.output_shape(), lhs, rhs, dot_request.dimension_numbers()));
+ break;
+ }
+
case OpRequest::kCrossReplicaSumRequest: {
const CrossReplicaSumRequest& cross_replica_sum_request =
request.request().cross_replica_sum_request();
@@ -3151,8 +3201,7 @@ void ComputationLowerer::Visit(
lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs;
rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs;
}
- if (debug_options_.xla_eliminate_hlo_implicit_broadcast() &&
- binary_op_request.binop() != BINOP_DOT) {
+ if (debug_options_.xla_eliminate_hlo_implicit_broadcast()) {
if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) {
// lhs side is being implicitly broadcast. Change to explicit.
lhs =
diff --git a/tensorflow/compiler/xla/service/user_computation.h b/tensorflow/compiler/xla/service/user_computation.h
index 317c631dca..b6686c3f1a 100644
--- a/tensorflow/compiler/xla/service/user_computation.h
+++ b/tensorflow/compiler/xla/service/user_computation.h
@@ -153,6 +153,10 @@ class UserComputation {
StatusOr<ComputationDataHandle> AddCustomCallInstruction(
const CustomCallRequest& custom_call_request);
+ // Enqueues a dot instruction onto this user computation.
+ StatusOr<ComputationDataHandle> AddDotInstruction(
+ const DotRequest& dot_request);
+
// Enqueues a broadcast instruction onto this user computation.
StatusOr<ComputationDataHandle> AddBroadcastInstruction(
const BroadcastRequest& broadcast_request);
diff --git a/tensorflow/compiler/xla/service/user_computation_test.cc b/tensorflow/compiler/xla/service/user_computation_test.cc
index 5afaf226ae..e45673300b 100644
--- a/tensorflow/compiler/xla/service/user_computation_test.cc
+++ b/tensorflow/compiler/xla/service/user_computation_test.cc
@@ -334,50 +334,5 @@ TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {
operands[1]->opcode() == HloOpcode::kBroadcast);
}
-TEST_F(UserComputationTest, SkipDotInEliminatingImplicitBroadcast) {
- auto debug_options = DebugOptions();
- debug_options.set_xla_eliminate_hlo_implicit_broadcast(true);
-
- // %a = Param({1, 3});
- // %b = Param({3, 1});
- // %dot = Dot(%a, %b);
- ComputationHandle handle;
- handle.set_handle(123);
- UserComputation computation("TheComputation", handle);
-
- ParameterRequest a_request;
- *a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {1, 3});
- a_request.set_name("a");
- a_request.set_parameter(0);
- TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle a_handle,
- computation.AddParameterInstruction(a_request));
-
- ParameterRequest b_request;
- *b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {3, 1});
- b_request.set_name("b");
- b_request.set_parameter(1);
- TF_ASSERT_OK_AND_ASSIGN(ComputationDataHandle b_handle,
- computation.AddParameterInstruction(b_request));
-
- BinaryOpRequest dot;
- dot.set_binop(BINOP_DOT);
- *dot.mutable_lhs() = a_handle;
- *dot.mutable_rhs() = b_handle;
- TF_ASSERT_OK(computation.AddBinaryInstruction(dot).status());
-
- auto hlo_resolver = [](const VersionedComputationHandle& handle) {
- return nullptr;
- };
- VersionedComputationHandle latest_version = computation.GetVersionedHandle();
-
- // Build the HLO computation.
- TF_ASSERT_OK_AND_ASSIGN(
- std::unique_ptr<HloComputation> hlo_computation,
- computation.BuildHloComputation(latest_version.version, hlo_resolver,
- debug_options));
-
- EXPECT_EQ(3, hlo_computation->instruction_count());
-}
-
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/service/while_loop_simplifier.cc b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
index b38ee907d7..b2fd64a4d9 100644
--- a/tensorflow/compiler/xla/service/while_loop_simplifier.cc
+++ b/tensorflow/compiler/xla/service/while_loop_simplifier.cc
@@ -289,7 +289,7 @@ static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) {
// Don't try this transformation if the while loop isn't removable, since if
// it succeeds ultimately we're going to have to replace the old while loop
// with a new one.
- if (!while_op->parent()->IsRemovable(while_op)) {
+ if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) {
VLOG(2) << "Can't remove dead parameters from non-removable while op.";
return false;
}
@@ -558,7 +558,7 @@ static StatusOr<bool> TryRemoveWhileLoop(HloInstruction* while_op) {
// the loop aren't removed, just cloned and added back to the loop.
// Nevertheless our infrastructure sees loop simplification as removal of
// these nodes and currently doesn't allow it.
- if (!while_op->parent()->IsRemovable(while_op)) {
+ if (!while_op->parent()->IsRemovable(while_op) || while_op->HasSideEffect()) {
VLOG(2) << "Not attempting to remove while loop it is not removable: "
<< while_op->ToShortString();
return false;
diff --git a/tensorflow/compiler/xla/shape_layout.cc b/tensorflow/compiler/xla/shape_layout.cc
index 5bf9842a6c..789eba5780 100644
--- a/tensorflow/compiler/xla/shape_layout.cc
+++ b/tensorflow/compiler/xla/shape_layout.cc
@@ -32,13 +32,13 @@ tensorflow::Status ShapeLayout::CopyLayoutFromShape(const Shape& other_shape) {
return tensorflow::Status::OK();
}
-tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* other_shape) const {
- if (!ShapeUtil::Compatible(*other_shape, shape_)) {
+tensorflow::Status ShapeLayout::AssignLayoutToShape(Shape* to_shape) const {
+ if (!ShapeUtil::Compatible(*to_shape, shape_)) {
return InvalidArgument("Shape %s is not compatible with shape %s",
- ShapeUtil::HumanString(*other_shape).c_str(),
+ ShapeUtil::HumanString(*to_shape).c_str(),
ShapeUtil::HumanString(shape()).c_str());
}
- *other_shape = shape_;
+ *to_shape = shape_;
return tensorflow::Status::OK();
}
diff --git a/tensorflow/compiler/xla/shape_layout.h b/tensorflow/compiler/xla/shape_layout.h
index 92564660f2..4c83750f3e 100644
--- a/tensorflow/compiler/xla/shape_layout.h
+++ b/tensorflow/compiler/xla/shape_layout.h
@@ -38,18 +38,19 @@ class ShapeLayout {
explicit ShapeLayout(const Shape& shape) : shape_(shape) {}
// Assigns the layouts in this ShapeLayout to the Layout fields of the given
- // shape. 'shape' and the shape of the ShapeLayout object must be compatible.
- tensorflow::Status AssignLayoutToShape(Shape* shape) const;
+ // shape. 'to_shape' and the shape of the ShapeLayout object must be
+ // compatible.
+ tensorflow::Status AssignLayoutToShape(Shape* to_shape) const;
// Returns true if the Layouts in this ShapeLayout match the layouts in the
// given shape. Returns false otherwise. If the given shape is not compatible
// with the ShapeLayout's shape, then false is returned.
bool MatchesLayoutInShape(const Shape& shape) const;
- // Copies the layout from the given shape into this ShapeLayout. 'shape' must
- // be compatible with the ShapeLayout's shape, and 'shape' must have a layout
- // (LayoutUtil::HasLayout).
- tensorflow::Status CopyLayoutFromShape(const Shape& shape);
+ // Copies the layout from the given shape into this ShapeLayout. 'other_shape'
+ // must be compatible with the ShapeLayout's shape, and 'other_shape' must
+ // have a layout (LayoutUtil::HasLayout).
+ tensorflow::Status CopyLayoutFromShape(const Shape& other_shape);
// Clears (Layout::Clear) all the Layouts stored in this object.
void Clear();
diff --git a/tensorflow/compiler/xla/shape_util.cc b/tensorflow/compiler/xla/shape_util.cc
index 74fa0b2f2e..9e3f06e527 100644
--- a/tensorflow/compiler/xla/shape_util.cc
+++ b/tensorflow/compiler/xla/shape_util.cc
@@ -694,9 +694,9 @@ StatusOr<Shape> ParseShapeStringInternal(tensorflow::StringPiece* s) {
return LayoutUtil::ValidateLayoutInShape(shape);
}
-/* static */ Shape ShapeUtil::ChangeElementType(const Shape& shape,
+/* static */ Shape ShapeUtil::ChangeElementType(const Shape& original,
PrimitiveType type) {
- Shape new_shape = shape;
+ Shape new_shape = original;
new_shape.set_element_type(type);
return new_shape;
}
diff --git a/tensorflow/compiler/xla/shape_util.h b/tensorflow/compiler/xla/shape_util.h
index 2ea1bd95cb..df5b450438 100644
--- a/tensorflow/compiler/xla/shape_util.h
+++ b/tensorflow/compiler/xla/shape_util.h
@@ -170,7 +170,7 @@ class ShapeUtil {
// As above, but for program shapes, returns a string for the form:
//
// (param_name: f32[42x12], ...) -> f32[24x42]
- static string HumanString(const ProgramShape& shape);
+ static string HumanString(const ProgramShape& program_shape);
// Parses a ShapeUtil::HumanString-format shape string back into a shape
// object.
diff --git a/tensorflow/compiler/xla/statusor_test.cc b/tensorflow/compiler/xla/statusor_test.cc
index 5fa2211ac6..f9d25945bc 100644
--- a/tensorflow/compiler/xla/statusor_test.cc
+++ b/tensorflow/compiler/xla/statusor_test.cc
@@ -32,26 +32,26 @@ namespace {
class Base1 {
public:
virtual ~Base1() {}
- int pad;
+ int pad_;
};
class Base2 {
public:
virtual ~Base2() {}
- int yetotherpad;
+ int yetotherpad_;
};
class Derived : public Base1, public Base2 {
public:
~Derived() override {}
- int evenmorepad;
+ int evenmorepad_;
};
class CopyNoAssign {
public:
- explicit CopyNoAssign(int value) : foo(value) {}
- CopyNoAssign(const CopyNoAssign& other) : foo(other.foo) {}
- int foo;
+ explicit CopyNoAssign(int value) : foo_(value) {}
+ CopyNoAssign(const CopyNoAssign& other) : foo_(other.foo_) {}
+ int foo_;
private:
const CopyNoAssign& operator=(const CopyNoAssign&);
@@ -253,7 +253,7 @@ TEST(StatusOr, TestCopyCtorNonAssignable) {
StatusOr<CopyNoAssign> original(value);
StatusOr<CopyNoAssign> copy(original);
EXPECT_EQ(copy.status(), original.status());
- EXPECT_EQ(original.ValueOrDie().foo, copy.ValueOrDie().foo);
+ EXPECT_EQ(original.ValueOrDie().foo_, copy.ValueOrDie().foo_);
}
TEST(StatusOr, TestCopyCtorStatusOKConverting) {
diff --git a/tensorflow/compiler/xla/tests/bfloat16_test.cc b/tensorflow/compiler/xla/tests/bfloat16_test.cc
index a1c53ef2aa..ac3f3f4c9d 100644
--- a/tensorflow/compiler/xla/tests/bfloat16_test.cc
+++ b/tensorflow/compiler/xla/tests/bfloat16_test.cc
@@ -61,6 +61,15 @@ XLA_TEST_F(Bfloat16Test, ScalarOperation) {
error_spec_);
}
+XLA_TEST_F(Bfloat16Test, LogOperation) {
+ ComputationBuilder builder(client_, TestName());
+ auto x = builder.ConstantR0<bfloat16>(static_cast<bfloat16>(4.0f));
+ builder.Log(x);
+
+ ComputeAndCompareR0<bfloat16>(&builder, static_cast<bfloat16>(1.387f), {},
+ error_spec_);
+}
+
XLA_TEST_F(Bfloat16Test, NegateScalarF16) {
ComputationBuilder builder(client_, TestName());
builder.Neg(builder.ConstantR0<bfloat16>(static_cast<bfloat16>(2.1f)));
diff --git a/tensorflow/compiler/xla/tests/client_library_test_base.h b/tensorflow/compiler/xla/tests/client_library_test_base.h
index 1d27880fb1..d8fe12a72d 100644
--- a/tensorflow/compiler/xla/tests/client_library_test_base.h
+++ b/tensorflow/compiler/xla/tests/client_library_test_base.h
@@ -194,7 +194,7 @@ class ClientLibraryTestBase : public ::testing::Test {
tensorflow::gtl::ArraySlice<GlobalData*> arguments);
void ComputeAndCompareTuple(
ComputationBuilder* builder, const Literal& expected,
- tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec abs_error);
+ tensorflow::gtl::ArraySlice<GlobalData*> arguments, ErrorSpec error);
// Convenience method for running a built computation and comparing the result
// with the HloEvaluator.
diff --git a/tensorflow/compiler/xla/tests/dot_operation_test.cc b/tensorflow/compiler/xla/tests/dot_operation_test.cc
index bfb04fd9f9..680d790b57 100644
--- a/tensorflow/compiler/xla/tests/dot_operation_test.cc
+++ b/tensorflow/compiler/xla/tests/dot_operation_test.cc
@@ -561,5 +561,25 @@ TEST_F(DotOperationTest, TransposeFolding) {
}
}
+XLA_TEST_F(DotOperationTest, DotGeneralUnimplemented) {
+ ComputationBuilder builder(client_, TestName());
+ auto lhs = builder.ConstantR3FromArray3D<float>(
+ {{{1.0, 2.0}, {3.0, 4.0}}, {{5.0, 6.0}, {7.0, 8.0}}});
+ auto rhs = builder.ConstantR3FromArray3D<float>(
+ {{{1.0, 0.0}, {0.0, 1.0}}, {{0.0, 1.0}, {1.0, 0.0}}});
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(2);
+ dot_dnums.add_rhs_contracting_dimensions(1);
+ dot_dnums.add_lhs_batch_dimensions(0);
+ dot_dnums.add_rhs_batch_dimensions(0);
+ builder.DotGeneral(lhs, rhs, dot_dnums);
+
+ auto status = Execute(&builder, {}).status();
+ EXPECT_FALSE(status.ok());
+ EXPECT_THAT(
+ status.error_message(),
+ ::testing::HasSubstr("Dot with batch dimensions not implemented."));
+}
+
} // namespace
} // namespace xla
diff --git a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
index 22d2b917a1..89fa6ed9f7 100644
--- a/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
+++ b/tensorflow/compiler/xla/tests/multioutput_fusion_test.cc
@@ -76,8 +76,11 @@ class MultiOutputFusionTest : public HloTestBase {
elem_shape2, HloOpcode::kAdd, broadcast, param1));
HloInstruction* sub = builder.AddInstruction(HloInstruction::CreateBinary(
elem_shape2, HloOpcode::kSubtract, param1, broadcast));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(1);
+ dot_dnums.add_rhs_contracting_dimensions(0);
HloInstruction* dot = builder.AddInstruction(
- HloInstruction::CreateBinary(elem_shape2, HloOpcode::kDot, sub, add2));
+ HloInstruction::CreateDot(elem_shape2, sub, add2, dot_dnums));
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
@@ -133,8 +136,11 @@ class MultiOutputFusionTest : public HloTestBase {
HloInstruction* reshape =
builder.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(F32, {size, 1}), add));
- HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateBinary(
- ShapeUtil::MakeShape(F32, {1}), HloOpcode::kDot, sub, reshape));
+ DotDimensionNumbers dot_dnums;
+ dot_dnums.add_lhs_contracting_dimensions(0);
+ dot_dnums.add_rhs_contracting_dimensions(0);
+ HloInstruction* dot = builder.AddInstruction(HloInstruction::CreateDot(
+ ShapeUtil::MakeShape(F32, {1}), sub, reshape, dot_dnums));
auto computation = hlo_module->AddEntryComputation(builder.Build(dot));
if (manual_fusion) {
diff --git a/tensorflow/compiler/xla/tests/reduce_window_test.cc b/tensorflow/compiler/xla/tests/reduce_window_test.cc
index 0601a1466b..aa035f0ba5 100644
--- a/tensorflow/compiler/xla/tests/reduce_window_test.cc
+++ b/tensorflow/compiler/xla/tests/reduce_window_test.cc
@@ -962,68 +962,114 @@ struct R1ReduceWindowTestData {
int64 base_bounds[1];
int64 window_bounds[1];
int64 strides[1];
- Padding padding;
+ int64 pad_low[1];
+ int64 pad_high[1];
Reducer reducer;
} kR1TestCases[] = {
{/*base_bounds=*/{1}, /*window_bounds=*/{1},
/*strides=*/{1},
- /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ /*pad_low=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].first},
+ /*pad_high=*/{xla::MakePadding({1}, {1}, {1}, Padding::kValid)[0].second},
+ /*reducer=*/Reducer::kAdd},
{/*base_bounds=*/{3}, /*window_bounds=*/{3},
/*strides=*/{1},
- /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ /*pad_low=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].first},
+ /*pad_high=*/{xla::MakePadding({3}, {3}, {1}, Padding::kValid)[0].second},
+ /*reducer=*/Reducer::kAdd},
{/*base_bounds=*/{3}, /*window_bounds=*/{2},
/*strides=*/{1},
- /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].first},
+ /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kValid)[0].second},
+ /*reducer=*/Reducer::kAdd},
{/*base_bounds=*/{5}, /*window_bounds=*/{1},
/*strides=*/{1},
- /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
+ /*pad_low=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].first},
+ /*pad_high=*/{xla::MakePadding({5}, {1}, {1}, Padding::kValid)[0].second},
+ /*reducer=*/Reducer::kMax},
{/*base_bounds=*/{16}, /*window_bounds=*/{4},
/*strides=*/{4},
- /*padding=*/Padding::kValid, /*reducer=*/Reducer::kMax},
+ /*pad_low=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].first},
+ /*pad_high=*/{xla::MakePadding({16}, {4}, {4}, Padding::kValid)[0].second},
+ /*reducer=*/Reducer::kMax},
{/*base_bounds=*/{16}, /*window_bounds=*/{4},
/*strides=*/{3},
- /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].first},
+ /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kValid)[0].second},
+ /*reducer=*/Reducer::kAdd},
- {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{30},
+ {/*base_bounds=*/{128 * 2},
+ /*window_bounds=*/{30},
/*strides=*/{27},
- /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
-
- {/*base_bounds=*/{128 * 17}, /*window_bounds=*/{7},
+ /*pad_low=*/
+ {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].first},
+ /*pad_high=*/
+ {xla::MakePadding({128 * 2}, {30}, {27}, Padding::kValid)[0].second},
+ /*reducer=*/Reducer::kAdd},
+
+ {/*base_bounds=*/{128 * 17},
+ /*window_bounds=*/{7},
/*strides=*/{64},
- /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
-
- {/*base_bounds=*/{128 * 2}, /*window_bounds=*/{32},
+ /*pad_low=*/
+ {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].first},
+ /*pad_high=*/
+ {xla::MakePadding({128 * 17}, {7}, {64}, Padding::kValid)[0].second},
+ /*reducer=*/Reducer::kAdd},
+
+ {/*base_bounds=*/{128 * 2},
+ /*window_bounds=*/{32},
/*strides=*/{56},
- /*padding=*/Padding::kValid, /*reducer=*/Reducer::kAdd},
+ /*pad_low=*/
+ {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].first},
+ /*pad_high=*/
+ {xla::MakePadding({128 * 2}, {32}, {56}, Padding::kValid)[0].second},
+ /*reducer=*/Reducer::kAdd},
{/*base_bounds=*/{3}, /*window_bounds=*/{2},
/*strides=*/{1},
- /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
+ /*pad_low=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].first},
+ /*pad_high=*/{xla::MakePadding({3}, {2}, {1}, Padding::kSame)[0].second},
+ /*reducer=*/Reducer::kAdd},
{/*base_bounds=*/{5}, /*window_bounds=*/{3},
/*strides=*/{2},
- /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
+ /*pad_low=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].first},
+ /*pad_high=*/{xla::MakePadding({5}, {3}, {2}, Padding::kSame)[0].second},
+ /*reducer=*/Reducer::kAdd},
{/*base_bounds=*/{16}, /*window_bounds=*/{4},
/*strides=*/{3},
- /*padding=*/Padding::kSame, /*reducer=*/Reducer::kAdd},
+ /*pad_low=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].first},
+ /*pad_high=*/{xla::MakePadding({16}, {4}, {3}, Padding::kSame)[0].second},
+ /*reducer=*/Reducer::kAdd},
+
+ {/*base_bounds=*/{5}, /*window_bounds=*/{5},
+ /*strides=*/{1},
+ /*pad_low=*/{0},
+ /*pad_high=*/{5},
+ /*reducer=*/Reducer::kAdd},
+
+ {/*base_bounds=*/{5}, /*window_bounds=*/{5},
+ /*strides=*/{1},
+ /*pad_low=*/{5},
+ /*pad_high=*/{0},
+ /*reducer=*/Reducer::kAdd},
};
string R1ReduceWindowTestDataToString(
const ::testing::TestParamInfo<R1ReduceWindowTestData>& data) {
string str = tensorflow::strings::StrCat(
- "base_bounds_",
- tensorflow::str_util::Join(data.param.base_bounds, "x"), //
+ "base_bounds_", tensorflow::str_util::Join(data.param.base_bounds, "x"),
"__window_bounds_",
- tensorflow::str_util::Join(data.param.window_bounds, "x"), //
- "__strides_", tensorflow::str_util::Join(data.param.strides, "x"), //
- "__padding_", data.param.padding == Padding::kSame ? "same" : "valid", //
- "__reducer_", data.param.reducer == kAdd ? "add" : "max");
+ tensorflow::str_util::Join(data.param.window_bounds, "x"), "__strides_",
+ tensorflow::str_util::Join(data.param.strides, "x"), "__pad_low_",
+ tensorflow::str_util::Join(data.param.pad_low, "x"), "__pad_high_",
+ tensorflow::str_util::Join(data.param.pad_high, "x"), "__reducer_",
+ data.param.reducer == kAdd ? "add" : "max");
return str;
}
@@ -1044,15 +1090,18 @@ TEST_P(R1ReduceWindowTest, DoIt) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<GlobalData> input_arg,
client_->TransferToServer(*input_literal));
+ std::vector<std::pair<int64, int64>> padding(1);
+ padding[0] = {param.pad_low[0], param.pad_high[0]};
+
auto computation = param.reducer == kAdd
? CreateScalarAddComputation(F32, &b)
: CreateScalarMaxComputation(F32, &b);
- b.ReduceWindow(/*operand=*/
- b.Parameter(0, input_literal->shape(), "p0"),
- /*init_value=*/b.ConstantR0<float>(kInitValue),
- /*computation=*/computation,
- /*window_dimensions=*/param.window_bounds,
- /*window_strides=*/param.strides, /*padding=*/param.padding);
+ b.ReduceWindowWithGeneralPadding(
+ /*operand=*/b.Parameter(0, input_literal->shape(), "p0"),
+ /*init_value=*/b.ConstantR0<float>(kInitValue),
+ /*computation=*/computation,
+ /*window_dimensions=*/param.window_bounds,
+ /*window_strides=*/param.strides, /*padding=*/padding);
auto reduce_func = param.reducer == kAdd
? +[](float a, float b) { return a + b; }
@@ -1062,7 +1111,8 @@ TEST_P(R1ReduceWindowTest, DoIt) {
/*init=*/kInitValue,
/*reduce_func=*/reduce_func,
/*window=*/param.window_bounds,
- /*stride=*/param.strides, /*padding=*/param.padding);
+ /*stride=*/param.strides,
+ /*padding=*/padding);
ComputeAndCompareR1<float>(&b, tensorflow::gtl::ArraySlice<float>(*expected),
{input_arg.get()}, ErrorSpec(1e-3, 1e-3));
diff --git a/tensorflow/compiler/xla/tests/slice_test.cc b/tensorflow/compiler/xla/tests/slice_test.cc
index c21124750a..4db566f784 100644
--- a/tensorflow/compiler/xla/tests/slice_test.cc
+++ b/tensorflow/compiler/xla/tests/slice_test.cc
@@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/tests/literal_test_util.h"
#include "tensorflow/compiler/xla/tests/test_macros.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
+#include "tensorflow/core/lib/strings/stringprintf.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/types.h"
@@ -211,6 +212,13 @@ class SliceR1Test : public ClientLibraryTestBase,
}
};
+string SliceR1TestDataToString(const ::testing::TestParamInfo<R1Spec>& data) {
+ const R1Spec& spec = data.param;
+ return ::tensorflow::strings::Printf("%lld_%lld_%lld_%lld", spec.input_dim0,
+ spec.slice_start, spec.slice_limit,
+ spec.slice_stride);
+}
+
XLA_TEST_P(SliceR1Test, DoIt_F32) { Run<float>(GetParam()); }
XLA_TEST_P(SliceR1Test, DoIt_F64) { Run<double>(GetParam()); }
@@ -223,30 +231,66 @@ XLA_TEST_P(SliceR1Test, DoIt_U64) { Run<uint64>(GetParam()); }
XLA_TEST_P(SliceR1Test, DoIt_S64) { Run<int64>(GetParam()); }
-INSTANTIATE_TEST_CASE_P( //
- SliceR1TestInstantiation, //
- SliceR1Test, //
- ::testing::Values( //
- R1Spec{10, 0, 0, 1}, //
- R1Spec{10, 7, 7, 1}, //
- R1Spec{10, 2, 4, 1}, //
- R1Spec{10, 2, 4, 2}, //
- R1Spec{10, 0, 10, 1}, //
- R1Spec{1024, 1024 - 4, 1024, 1}, //
- R1Spec{4096, 7, 7 + 1024, 1}, //
- R1Spec{10, 0, 10, 2}, //
- R1Spec{10, 0, 10, 3}, //
- R1Spec{10, 0, 10, 4}, //
- R1Spec{10, 0, 10, 5}, //
- R1Spec{10, 0, 10, 10}, //
- R1Spec{500, 200, 400, 7}, //
- R1Spec{4096, 1, 4095, 3}, //
- R1Spec{2047, 1024 - 24, 1024 + 160, 31}, //
- R1Spec{2047, 1, 2046, 3 * 128}, //
- R1Spec{4096, 1024 + 3, 4095, 500}, //
- R1Spec{8192, 0, 8192, 1024 * 3 + 400} //
- ) //
+// Tests for R1 slice ops.
+// The format for each testcase is {input size, start, limit, stride}.
+// clang-format off
+INSTANTIATE_TEST_CASE_P(
+ SliceR1TestInstantiation,
+ SliceR1Test,
+ ::testing::Values(
+ R1Spec{10, 0, 0, 1},
+ R1Spec{10, 7, 7, 1},
+ R1Spec{10, 0, 5, 1},
+ R1Spec{10, 3, 5, 1},
+ R1Spec{10, 0, 10, 1},
+ R1Spec{1024, 0, 5, 1},
+ R1Spec{1024, 3, 5, 1},
+ R1Spec{1024 + 17, 0, 5, 1},
+ R1Spec{1024 + 17, 3, 5, 1},
+ R1Spec{1024 + 17, 1024, 1024 + 6, 1},
+ R1Spec{1024 + 17, 1024 + 1, 1024 + 6, 1},
+ R1Spec{1024, 1024 - 4, 1024, 1},
+ R1Spec{4 * 1024, 7, 7 + 1024, 1},
+ R1Spec{4 * 1024, 0, 4 * 1024, 1},
+ R1Spec{4 * 1024, 1, 4 * 1024 - 1, 1},
+ R1Spec{4 * 1024, 1024, 3 * 1024, 1},
+ R1Spec{4 * 1024, 1024 + 1, 3 * 1024 - 1, 1},
+ R1Spec{16 * 1024, 0, 5, 1},
+ R1Spec{16 * 1024, 3, 5, 1},
+ R1Spec{16 * 1024 + 17, 0, 5, 1},
+ R1Spec{16 * 1024 + 17, 3, 5, 1},
+ R1Spec{16 * 1024 + 17, 16 * 1024, 16 * 1024 + 6, 1},
+ R1Spec{16 * 1024 + 17, 16 * 1024 + 1, 16 * 1024 + 6, 1},
+ R1Spec{16 * 1024, 4 * 1024 - 17, 8 * 1024 - 18, 1},
+ R1Spec{64 * 1024, 0, 64 * 1024, 1},
+ R1Spec{64 * 1024, 1, 64 * 1024 - 1, 1},
+ R1Spec{64 * 1024, 1024, 63 * 1024, 1},
+ R1Spec{64 * 1024, 1024 + 1, 63 * 1024 - 1, 1},
+ R1Spec{64 * 1024, 32 * 1024, 33 * 1024, 1},
+ R1Spec{64 * 1024, 32 * 1024 + 1, 33 * 1024 - 1, 1},
+ R1Spec{64 * 1024, 32 * 1024 - 17, 36 * 1024 - 18, 1},
+// TODO(b/69425338): This uses too much memory on GPU.
+#ifndef XLA_TEST_BACKEND_GPU
+ R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024, 12 * 1024 * 1024, 1},
+ R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 + 1, 12 * 1024 * 1024 - 1, 1},
+ R1Spec{16 * 1024 * 1024, 4 * 1024 * 1024 - 1, 12 * 1024 * 1024 + 1, 1},
+#endif
+ R1Spec{10, 2, 4, 2},
+ R1Spec{10, 0, 10, 2},
+ R1Spec{10, 0, 10, 3},
+ R1Spec{10, 0, 10, 4},
+ R1Spec{10, 0, 10, 5},
+ R1Spec{10, 0, 10, 10},
+ R1Spec{500, 200, 400, 7},
+ R1Spec{4096, 1, 4095, 3},
+ R1Spec{2047, 1024 - 24, 1024 + 160, 31},
+ R1Spec{2047, 1, 2046, 3 * 128},
+ R1Spec{4096, 1024 + 3, 4095, 500},
+ R1Spec{8192, 0, 8192, 1024 * 3 + 400}
+ ),
+ SliceR1TestDataToString
);
+// clang-format on
struct R2Spec {
int64 input_dim0;
diff --git a/tensorflow/compiler/xla/tests/while_test.cc b/tensorflow/compiler/xla/tests/while_test.cc
index 49f673f5f0..f3f10517e3 100644
--- a/tensorflow/compiler/xla/tests/while_test.cc
+++ b/tensorflow/compiler/xla/tests/while_test.cc
@@ -357,8 +357,7 @@ TEST_F(WhileTest, WhileWithVectorResultIntoTuple) {
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
-// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result.
-TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) {
+TEST_F(WhileTest, WhileWithPermutationAndTupleResult) {
std::vector<Shape> shape_elements = {
ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
@@ -411,8 +410,7 @@ TEST_F(WhileTest, DISABLED_WhileWithPermutationAndTupleResult) {
ComputeAndCompareTuple(&builder, *expected, {}, ErrorSpec(0.0001));
}
-// TODO(b/63003356): 11-06-2017: fails on all back-ends with incorrect result.
-TEST_F(WhileTest, DISABLED_WhileWithPermutationAndVectorResult) {
+TEST_F(WhileTest, WhileWithPermutationAndVectorResult) {
std::vector<Shape> shape_elements = {
ShapeUtil::MakeShape(S32, {}), ShapeUtil::MakeShape(F32, {3}),
ShapeUtil::MakeShape(F32, {3}), ShapeUtil::MakeShape(F32, {3})};
diff --git a/tensorflow/compiler/xla/util.cc b/tensorflow/compiler/xla/util.cc
index e595df3052..fe5d29a6b6 100644
--- a/tensorflow/compiler/xla/util.cc
+++ b/tensorflow/compiler/xla/util.cc
@@ -191,9 +191,9 @@ std::vector<int64> ComposePermutations(tensorflow::gtl::ArraySlice<int64> p1,
return output;
}
-bool IsIdentityPermutation(tensorflow::gtl::ArraySlice<int64> p) {
- for (int64 i = 0; i < p.size(); ++i) {
- if (p[i] != i) {
+bool IsIdentityPermutation(tensorflow::gtl::ArraySlice<int64> permutation) {
+ for (int64 i = 0; i < permutation.size(); ++i) {
+ if (permutation[i] != i) {
return false;
}
}
diff --git a/tensorflow/compiler/xla/xla_data.proto b/tensorflow/compiler/xla/xla_data.proto
index 2ba1a2d904..6800c3d7fa 100644
--- a/tensorflow/compiler/xla/xla_data.proto
+++ b/tensorflow/compiler/xla/xla_data.proto
@@ -498,6 +498,23 @@ message CustomCallRequest {
Shape shape = 4;
}
+message DotDimensionNumbers {
+ // The dimension numbers that represent the 'lhs' contracting dimensions.
+ repeated int64 lhs_contracting_dimensions = 1;
+ // The dimension numbers that represent the 'rhs' contracting dimensions.
+ repeated int64 rhs_contracting_dimensions = 2;
+ // The dimension numbers that represent the 'lhs' batch dimensions.
+ repeated int64 lhs_batch_dimensions = 3;
+ // The dimension numbers that represent the 'rhs' batch dimensions.
+ repeated int64 rhs_batch_dimensions = 4;
+};
+
+message DotRequest {
+ ComputationDataHandle lhs = 2;
+ ComputationDataHandle rhs = 3;
+ DotDimensionNumbers dimension_numbers = 4;
+}
+
message MapRequest {
repeated ComputationDataHandle operands = 2;
ComputationHandle to_apply = 3;
@@ -732,9 +749,6 @@ enum BinaryOperation {
BINOP_LT = 9;
BINOP_NE = 10;
- // Dot product, matrix multiply.
- BINOP_DOT = 12;
-
// Element-wise maximum.
BINOP_MAX = 14;
@@ -885,6 +899,7 @@ message OpRequest {
ConvolveRequest convolve_request = 8;
CrossReplicaSumRequest cross_replica_sum_request = 9;
CustomCallRequest custom_call_request = 10;
+ DotRequest dot_request = 43;
DynamicSliceRequest dynamic_slice_request = 11;
DynamicUpdateSliceRequest dynamic_update_slice_request = 12;
GetTupleElementRequest get_tuple_element_request = 13;
@@ -914,7 +929,7 @@ message OpRequest {
BatchNormInferenceRequest batch_norm_inference_request = 38;
FftRequest fft_request = 41;
ConvertRequest bitcast_convert_request = 42;
- // Next: 43
+ // Next: 44
}
}
diff --git a/tensorflow/contrib/android/README.md b/tensorflow/contrib/android/README.md
index f49e5857fe..c7c128bf14 100644
--- a/tensorflow/contrib/android/README.md
+++ b/tensorflow/contrib/android/README.md
@@ -15,9 +15,9 @@ For prebuilt libraries, see the
page for a recent build.
The TensorFlow Inference Interface is also available as a
-[JCenter package](https://bintray.com/google/tensorflow/tensorflow-android) and
-can be included quite simply in your android project with a couple of lines in
-the project's `build.gradle` file:
+[JCenter package](https://bintray.com/google/tensorflow/tensorflow)
+(see the tensorflow-android directory) and can be included quite simply in your
+android project with a couple of lines in the project's `build.gradle` file:
```
allprojects {
diff --git a/tensorflow/contrib/android/cmake/CMakeLists.txt b/tensorflow/contrib/android/cmake/CMakeLists.txt
index aba356d616..a115d1610e 100644
--- a/tensorflow/contrib/android/cmake/CMakeLists.txt
+++ b/tensorflow/contrib/android/cmake/CMakeLists.txt
@@ -34,6 +34,8 @@ add_library(lib_tf STATIC IMPORTED )
set_target_properties(lib_tf PROPERTIES IMPORTED_LOCATION
${PREBUILT_DIR}/lib/libtensorflow-core.a)
# Change to compile flags should be replicated into bazel build file
+# TODO: Consider options other than -O2 for binary size.
+# e.g. -Os for gcc, and -Oz for clang.
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DIS_SLIM_BUILD \
-std=c++11 -fno-rtti -fno-exceptions \
-O2 -Wno-narrowing -fomit-frame-pointer \
diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
index 6ed177e001..9e32bee505 100644
--- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
+++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler.h
@@ -208,6 +208,8 @@ class ASBSQueue : public BatchScheduler<TaskType> {
// place any more tasks in this batch.
void ReleaseBatch(const ASBSBatch<TaskType>* batch);
+ size_t max_task_size() const override { return options_.max_batch_size; }
+
private:
std::shared_ptr<AdaptiveSharedBatchScheduler<TaskType>> scheduler_;
const QueueOptions options_;
diff --git a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc
index a07cd6d834..e2aac54eeb 100644
--- a/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc
+++ b/tensorflow/contrib/batching/adaptive_shared_batch_scheduler_test.cc
@@ -186,6 +186,7 @@ TEST(AdaptiveSharedBatchSchedulerTest, ObeysQueueOptions) {
queue_options.max_enqueued_batches = 2;
TF_ASSERT_OK(
scheduler->AddQueue(queue_options, queue_0_callback, &queue_0));
+ EXPECT_EQ(10, queue_0->max_task_size());
queue_options.max_batch_size = 0;
// Queue must have max_batch_size > 0.
EXPECT_FALSE(
diff --git a/tensorflow/contrib/batching/basic_batch_scheduler.h b/tensorflow/contrib/batching/basic_batch_scheduler.h
index 9d3805fbaf..91065db249 100644
--- a/tensorflow/contrib/batching/basic_batch_scheduler.h
+++ b/tensorflow/contrib/batching/basic_batch_scheduler.h
@@ -192,6 +192,10 @@ class BasicBatchScheduler : public BatchScheduler<TaskType> {
size_t NumEnqueuedTasks() const override;
size_t SchedulingCapacity() const override;
+ size_t max_task_size() const override {
+ return shared_scheduler_queue_->max_task_size();
+ }
+
private:
explicit BasicBatchScheduler(
std::unique_ptr<BatchScheduler<TaskType>> shared_scheduler_queue);
diff --git a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc
index e020301795..187823151c 100644
--- a/tensorflow/contrib/batching/basic_batch_scheduler_test.cc
+++ b/tensorflow/contrib/batching/basic_batch_scheduler_test.cc
@@ -73,6 +73,7 @@ TEST(BasicBatchSchedulerTest, Basic) {
std::unique_ptr<BasicBatchScheduler<FakeTask>> scheduler;
TF_ASSERT_OK(
BasicBatchScheduler<FakeTask>::Create(options, callback, &scheduler));
+ EXPECT_EQ(10, scheduler->max_task_size());
EXPECT_EQ(0, scheduler->NumEnqueuedTasks());
EXPECT_EQ(3 * 10, scheduler->SchedulingCapacity());
TF_ASSERT_OK(ScheduleTask(3, scheduler.get()));
diff --git a/tensorflow/contrib/batching/batch_scheduler.h b/tensorflow/contrib/batching/batch_scheduler.h
index a5072f439a..e18cf6c350 100644
--- a/tensorflow/contrib/batching/batch_scheduler.h
+++ b/tensorflow/contrib/batching/batch_scheduler.h
@@ -178,6 +178,10 @@ class BatchScheduler {
// This method is useful for monitoring, or for guaranteeing a future slot in
// the schedule (but being mindful about the caveats listed above).
virtual size_t SchedulingCapacity() const = 0;
+
+ // Returns the maximum allowed size of tasks submitted to the scheduler. (This
+ // is typically equal to a configured maximum batch size.)
+ virtual size_t max_task_size() const = 0;
};
//////////
diff --git a/tensorflow/contrib/batching/shared_batch_scheduler.h b/tensorflow/contrib/batching/shared_batch_scheduler.h
index 41a3f99137..1d2158062e 100644
--- a/tensorflow/contrib/batching/shared_batch_scheduler.h
+++ b/tensorflow/contrib/batching/shared_batch_scheduler.h
@@ -248,6 +248,9 @@ class Queue {
// BatchScheduler::SchedulingCapacity().
size_t SchedulingCapacity() const;
+ // Returns the maximum allowed size of tasks submitted to the queue.
+ size_t max_task_size() const { return options_.max_batch_size; }
+
// Called by a thread that is ready to process a batch, to request one from
// this queue. Either returns a batch that is ready to be processed, or
// nullptr if the queue declines to schedule a batch at this time. If it
@@ -338,6 +341,8 @@ class QueueHandle : public BatchScheduler<TaskType> {
size_t NumEnqueuedTasks() const override;
size_t SchedulingCapacity() const override;
+ size_t max_task_size() const override { return queue_->max_task_size(); }
+
private:
// The scheduler that owns 'queue_'.
std::shared_ptr<SharedBatchScheduler<TaskType>> scheduler_;
diff --git a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc
index 3e924ae5f1..3ac79a8fdc 100644
--- a/tensorflow/contrib/batching/shared_batch_scheduler_test.cc
+++ b/tensorflow/contrib/batching/shared_batch_scheduler_test.cc
@@ -429,6 +429,7 @@ TEST(SharedBatchSchedulerTest, ConstMethods) {
queue_options.max_enqueued_batches = max_enqueued_batches;
std::unique_ptr<BatchScheduler<FakeTask>> queue;
TF_ASSERT_OK(scheduler->AddQueue(queue_options, callback, &queue));
+ EXPECT_EQ(2, queue->max_task_size());
EXPECT_EQ(0, queue->NumEnqueuedTasks());
EXPECT_EQ(max_enqueued_batches * 2, queue->SchedulingCapacity());
diff --git a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc
index 0d46565a19..ccee9530b6 100644
--- a/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc
+++ b/tensorflow/contrib/boosted_trees/lib/utils/sparse_column_iterable.cc
@@ -97,7 +97,7 @@ class IndicesRowIterator
}
bool operator<(const IndicesRowIterator& other) const {
- return (row_idx_ < other.row_idx_);
+ return (row_idx_ < other.row_idx_);
}
bool operator==(const IndicesRowIterator& other) const {
diff --git a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py
index 7e8e15e7d8..294e04002a 100644
--- a/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py
+++ b/tensorflow/contrib/boosted_trees/python/ops/quantile_ops.py
@@ -45,6 +45,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject):
init_stamp_token,
epsilon,
num_quantiles,
+ max_elements=None,
name=None,
container=None):
"""Creates a QuantileAccumulator object.
@@ -53,6 +54,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject):
init_stamp_token: The initial value for the stamp token.
epsilon: Error bound on the quantile computation.
num_quantiles: Number of quantiles to produce from the final summary.
+ max_elements: Maximum number of elements added to the accumulator.
name: the name to save the accumulator under.
container: An optional `string`. Defaults to `""`
"""
@@ -67,6 +69,7 @@ class QuantileAccumulator(saver.BaseSaverBuilder.SaveableObject):
self._quantile_accumulator_handle,
init_stamp_token,
epsilon=epsilon,
+ max_elements=max_elements,
num_quantiles=num_quantiles)
is_initialized_op = gen_quantile_ops.quantile_accumulator_is_initialized(
self._quantile_accumulator_handle)
diff --git a/tensorflow/contrib/cmake/external/nsync.cmake b/tensorflow/contrib/cmake/external/nsync.cmake
index 155c91cb97..0508006047 100644
--- a/tensorflow/contrib/cmake/external/nsync.cmake
+++ b/tensorflow/contrib/cmake/external/nsync.cmake
@@ -16,7 +16,7 @@ include (ExternalProject)
set(nsync_INCLUDE_DIR ${CMAKE_CURRENT_BINARY_DIR}/external/nsync/public)
set(nsync_URL https://github.com/google/nsync)
-set(nsync_TAG 93815892dddafe9146a5f7e7042281d59d0f4323)
+set(nsync_TAG 8502189abfa44c249c01c2cad64e6ed660a9a668)
set(nsync_BUILD ${CMAKE_CURRENT_BINARY_DIR}/nsync/src/nsync)
set(nsync_INSTALL ${CMAKE_CURRENT_BINARY_DIR}/nsync/install)
diff --git a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt
index 594c2492d4..aaae18a313 100644
--- a/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt
+++ b/tensorflow/contrib/cmake/patches/nsync/CMakeLists.txt
@@ -158,12 +158,21 @@ if (NOT "${NSYNC_LANGUAGE}X" STREQUAL "c++11X")
elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "NetBSDX")
include_directories ("${PROJECT_SOURCE_DIR}/platform/netbsd")
set (NSYNC_POSIX ON)
+ set (NSYNC_OS_EXTRA_SRC
+ "platform/posix/src/nsync_semaphore_mutex.c"
+ )
elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "FreeBSDX")
include_directories ("${PROJECT_SOURCE_DIR}/platform/freebsd")
set (NSYNC_POSIX ON)
+ set (NSYNC_OS_EXTRA_SRC
+ "platform/posix/src/nsync_semaphore_mutex.c"
+ )
elseif ("${CMAKE_SYSTEM_NAME}X" STREQUAL "OpenBSDX")
include_directories ("${PROJECT_SOURCE_DIR}/platform/openbsd")
set (NSYNC_POSIX ON)
+ set (NSYNC_OS_EXTRA_SRC
+ "platform/posix/src/nsync_semaphore_mutex.c"
+ )
endif ()
endif ()
diff --git a/tensorflow/contrib/cmake/tf_core_cpu.cmake b/tensorflow/contrib/cmake/tf_core_cpu.cmake
index 5c01ca382f..e4213ea2a4 100644
--- a/tensorflow/contrib/cmake/tf_core_cpu.cmake
+++ b/tensorflow/contrib/cmake/tf_core_cpu.cmake
@@ -63,7 +63,7 @@ if (tensorflow_ENABLE_GPU)
file(GLOB_RECURSE tf_core_gpu_srcs
"${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu/*.cc"
"${tensorflow_source_dir}/tensorflow/core/platform/default/gpu/cupti_wrapper.cc"
- "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc"
+ "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc"
"${tensorflow_source_dir}/tensorflow/core/common_runtime/gpu_device_factory.cc"
"${tensorflow_source_dir}/tensorflow/core/grappler/devices.h"
"${tensorflow_source_dir}/tensorflow/core/grappler/devices.cc"
diff --git a/tensorflow/contrib/cmake/tf_core_framework.cmake b/tensorflow/contrib/cmake/tf_core_framework.cmake
index c607546f4a..5ec1a8d04f 100644
--- a/tensorflow/contrib/cmake/tf_core_framework.cmake
+++ b/tensorflow/contrib/cmake/tf_core_framework.cmake
@@ -211,7 +211,7 @@ if (NOT tensorflow_ENABLE_GPU)
list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_gpu_srcs})
else()
file(GLOB tf_core_platform_srcs_exclude
- "${tensorflow_source_dir}/tensorflow/core/platform/default/gpu_tracer.cc")
+ "${tensorflow_source_dir}/tensorflow/core/platform/default/device_tracer.cc")
list(REMOVE_ITEM tf_core_platform_srcs ${tf_core_platform_srcs_exclude})
endif()
diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake
index 0128946e45..819b6213ea 100755
--- a/tensorflow/contrib/cmake/tf_python.cmake
+++ b/tensorflow/contrib/cmake/tf_python.cmake
@@ -899,6 +899,8 @@ set (pywrap_tensorflow_internal_src
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_func.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/py_seq_tensor.cc"
+ "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.h"
+ "${tensorflow_source_dir}/tensorflow/python/lib/core/py_util.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.h"
"${tensorflow_source_dir}/tensorflow/python/lib/core/safe_ptr.cc"
"${tensorflow_source_dir}/tensorflow/python/lib/io/py_record_reader.h"
diff --git a/tensorflow/contrib/cmake/tf_tests.cmake b/tensorflow/contrib/cmake/tf_tests.cmake
index 18b71d1f9a..2e3ee2c96b 100644
--- a/tensorflow/contrib/cmake/tf_tests.cmake
+++ b/tensorflow/contrib/cmake/tf_tests.cmake
@@ -225,6 +225,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/wals_test.py"
"${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py"
+ "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/backend_test.py"
+ "${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py"
# Float division by zero
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py"
# Flaky, for unknown reasons. Cannot reproduce in terminal. Revisit once we can get stack traces.
diff --git a/tensorflow/contrib/copy_graph/python/util/copy_elements.py b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
index d060eda0a7..bae66ffd42 100644
--- a/tensorflow/contrib/copy_graph/python/util/copy_elements.py
+++ b/tensorflow/contrib/copy_graph/python/util/copy_elements.py
@@ -225,6 +225,7 @@ def copy_op_to_graph(org_instance, to_graph, variables,
new_original_op,
op_def)
#Use Graph's hidden methods to add the op
+ to_graph._add_op(new_op) # pylint: disable=protected-access
to_graph._record_op_seen_by_control_dependencies(new_op)
for device_function in reversed(to_graph._device_function_stack):
new_op._set_device(device_function(new_op))
diff --git a/tensorflow/contrib/data/BUILD b/tensorflow/contrib/data/BUILD
index f7d8a084d9..3b1c33063f 100644
--- a/tensorflow/contrib/data/BUILD
+++ b/tensorflow/contrib/data/BUILD
@@ -18,6 +18,7 @@ py_library(
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:iterator_ops",
"//tensorflow/contrib/data/python/ops:readers",
+ "//tensorflow/contrib/data/python/ops:shuffle_ops",
"//tensorflow/contrib/data/python/ops:transformation_ops",
"//tensorflow/python:util",
"//tensorflow/python/data/ops:iterator_ops",
diff --git a/tensorflow/contrib/data/__init__.py b/tensorflow/contrib/data/__init__.py
index 7c6244f22b..c9ad091bd4 100644
--- a/tensorflow/contrib/data/__init__.py
+++ b/tensorflow/contrib/data/__init__.py
@@ -66,6 +66,7 @@ from tensorflow.contrib.data.python.ops.readers import TextLineDataset
from tensorflow.contrib.data.python.ops.readers import TFRecordDataset
from tensorflow.contrib.data.python.ops.resampling import rejection_resample
from tensorflow.contrib.data.python.ops.scan_ops import scan
+from tensorflow.contrib.data.python.ops.shuffle_ops import shuffle_and_repeat
from tensorflow.python.data.ops.iterator_ops import Iterator
# pylint: enable=unused-import
diff --git a/tensorflow/contrib/data/python/kernel_tests/BUILD b/tensorflow/contrib/data/python/kernel_tests/BUILD
index 1d4817fa26..4112de31c1 100644
--- a/tensorflow/contrib/data/python/kernel_tests/BUILD
+++ b/tensorflow/contrib/data/python/kernel_tests/BUILD
@@ -277,7 +277,7 @@ py_test(
py_test(
name = "map_dataset_op_test",
- size = "small",
+ size = "medium",
srcs = ["map_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
@@ -419,12 +419,14 @@ py_test(
py_test(
name = "shuffle_dataset_op_test",
- size = "small",
+ size = "medium",
srcs = ["shuffle_dataset_op_test.py"],
srcs_version = "PY2AND3",
deps = [
+ ":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:dataset_ops",
"//tensorflow/contrib/data/python/ops:iterator_ops",
+ "//tensorflow/contrib/data/python/ops:shuffle_ops",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
diff --git a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
index 6b5b53cc0f..ba1be0690f 100644
--- a/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
+++ b/tensorflow/contrib/data/python/kernel_tests/shuffle_dataset_op_test.py
@@ -22,8 +22,10 @@ import os
import numpy as np
+from tensorflow.contrib.data.python.kernel_tests import dataset_serialization_test_base
from tensorflow.contrib.data.python.ops import dataset_ops as contrib_dataset_ops
from tensorflow.contrib.data.python.ops import iterator_ops as contrib_iterator_ops
+from tensorflow.contrib.data.python.ops import shuffle_ops
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.framework import constant_op
@@ -156,6 +158,13 @@ class ShuffleDatasetTest(test.TestCase):
for i in range(5):
self.assertEqual(10, counts[i])
+ def testSeedNoneSeed2NonNone(self):
+ with self.assertRaises(ValueError):
+ dataset_ops.ShuffleDataset(dataset_ops.Dataset.range(5),
+ buffer_size=1,
+ seed=None,
+ seed2=10)
+
class ShuffleDatasetSerializationTest(test.TestCase):
@@ -474,5 +483,76 @@ class ShuffleDatasetSerializationTest(test.TestCase):
self.assertEqual(expected_outputs_sorted, sorted(actual))
+class ShuffleAndRepeatTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_ds(self, seed, count=5):
+ return dataset_ops.Dataset.range(20).apply(
+ shuffle_ops.shuffle_and_repeat(buffer_size=5, count=count, seed=seed))
+
+ def testCorrectOutput(self):
+ output = self.gen_outputs(lambda: self._build_ds(10), [], 100)
+ self.assertSequenceEqual(
+ sorted(output), sorted(
+ np.array([range(20) for _ in range(5)]).flatten()))
+ for i in range(5):
+ self.assertSequenceEqual(sorted(output[i * 20:(i + 1) * 20]), range(20))
+
+ def testReshuffling(self):
+ # Check that the output orders of different epochs are indeed different.
+ output = self.gen_outputs(lambda: self._build_ds(10), [], 100)
+ for i in range(4):
+ epoch1 = output[i * 20:(i + 1) * 20]
+ epoch2 = output[(i + 1) * 20:(i + 2) * 20]
+ self.assertNotEqual(epoch1, epoch2)
+
+ def testSameOrderForSameSeeds(self):
+ output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100)
+ output2 = self.gen_outputs(lambda: self._build_ds(10), [], 100)
+ self.assertEqual(output1, output2)
+
+ def testDifferentOrderForDifferentSeeds(self):
+ output1 = self.gen_outputs(lambda: self._build_ds(10), [], 100)
+ output2 = self.gen_outputs(lambda: self._build_ds(20), [], 100)
+ self.assertNotEqual(output1, output2)
+ self.assertEqual(sorted(output1), sorted(output2))
+
+ def testCountNone(self):
+ output1 = self.gen_outputs(
+ lambda: self._build_ds(10, count=None), [], 100, verify_exhausted=False)
+ output2 = self.gen_outputs(
+ lambda: self._build_ds(20, count=None), [], 100, verify_exhausted=False)
+ self.assertNotEqual(output1, output2)
+ self.assertEqual(sorted(output1), sorted(output2))
+
+ def testCountMinusOne(self):
+ output1 = self.gen_outputs(
+ lambda: self._build_ds(10, count=-1), [], 100, verify_exhausted=False)
+ output2 = self.gen_outputs(
+ lambda: self._build_ds(20, count=-1), [], 100, verify_exhausted=False)
+ self.assertNotEqual(output1, output2)
+ self.assertEqual(sorted(output1), sorted(output2))
+
+ def testInfiniteOutputs(self):
+ # Asserting that the iterator is exhausted after producing 100 items should
+ # fail.
+ with self.assertRaises(AssertionError):
+ self.gen_outputs(lambda: self._build_ds(10, count=None), [], 100)
+ with self.assertRaises(AssertionError):
+ self.gen_outputs(lambda: self._build_ds(10, count=-1), [], 100)
+
+
+class ShuffleAndRepeatSerializationTest(
+ dataset_serialization_test_base.DatasetSerializationTestBase):
+
+ def _build_ds(self, seed):
+ return dataset_ops.Dataset.range(20).apply(
+ shuffle_ops.shuffle_and_repeat(buffer_size=5, count=5, seed=seed))
+
+ def testCore(self):
+ self.run_core_tests(lambda: self._build_ds(10), lambda: self._build_ds(20),
+ 100)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/data/python/ops/BUILD b/tensorflow/contrib/data/python/ops/BUILD
index 25ed58cdf5..1f35ee056b 100644
--- a/tensorflow/contrib/data/python/ops/BUILD
+++ b/tensorflow/contrib/data/python/ops/BUILD
@@ -41,6 +41,25 @@ py_library(
)
py_library(
+ name = "random_ops",
+ srcs = [
+ "random_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ "//tensorflow/python:constant_op",
+ "//tensorflow/python:dataset_ops_gen",
+ "//tensorflow/python:dtypes",
+ "//tensorflow/python:framework_ops",
+ "//tensorflow/python:random_seed",
+ "//tensorflow/python:tensor_shape",
+ "//tensorflow/python/data/ops:dataset_ops",
+ "//tensorflow/python/data/util:nest",
+ "//tensorflow/python/data/util:sparse",
+ ],
+)
+
+py_library(
name = "readers",
srcs = [
"readers.py",
@@ -63,6 +82,19 @@ py_library(
)
py_library(
+ name = "shuffle_ops",
+ srcs = [
+ "shuffle_ops.py",
+ ],
+ srcs_version = "PY2AND3",
+ deps = [
+ ":random_ops",
+ ":transformation_ops",
+ "//tensorflow/python/data/ops:dataset_ops",
+ ],
+)
+
+py_library(
name = "transformation_ops",
srcs = [
"batching.py",
diff --git a/tensorflow/contrib/data/python/ops/random_ops.py b/tensorflow/contrib/data/python/ops/random_ops.py
new file mode 100644
index 0000000000..7d727165fe
--- /dev/null
+++ b/tensorflow/contrib/data/python/ops/random_ops.py
@@ -0,0 +1,67 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Datasets for random number generators."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.python.data.ops import dataset_ops
+from tensorflow.python.data.util import nest
+from tensorflow.python.data.util import sparse
+from tensorflow.python.framework import constant_op
+from tensorflow.python.framework import dtypes
+from tensorflow.python.framework import ops
+from tensorflow.python.framework import random_seed
+from tensorflow.python.framework import tensor_shape
+from tensorflow.python.ops import gen_dataset_ops
+
+
+class RandomDataset(dataset_ops.Dataset):
+ """A `Dataset` of pseudorandom values."""
+
+ def __init__(self, seed=None):
+ """A `Dataset` of pseudorandom values."""
+ super(RandomDataset, self).__init__()
+ seed, seed2 = random_seed.get_seed(seed)
+ if seed is None:
+ self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed")
+ else:
+ self._seed = ops.convert_to_tensor(seed, dtype=dtypes.int64, name="seed")
+ if seed2 is None:
+ self._seed2 = constant_op.constant(0, dtype=dtypes.int64, name="seed2")
+ else:
+ self._seed2 = ops.convert_to_tensor(
+ seed2, dtype=dtypes.int64, name="seed2")
+
+ def _as_variant_tensor(self):
+ return gen_dataset_ops.random_dataset(
+ seed=self._seed,
+ seed2=self._seed2,
+ output_shapes=nest.flatten(
+ sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
+ output_types=nest.flatten(
+ sparse.as_dense_types(self.output_types, self.output_classes)))
+
+ @property
+ def output_classes(self):
+ return ops.Tensor
+
+ @property
+ def output_shapes(self):
+ return tensor_shape.scalar()
+
+ @property
+ def output_types(self):
+ return dtypes.int64
diff --git a/tensorflow/contrib/data/python/ops/shuffle_ops.py b/tensorflow/contrib/data/python/ops/shuffle_ops.py
new file mode 100644
index 0000000000..460732d65e
--- /dev/null
+++ b/tensorflow/contrib/data/python/ops/shuffle_ops.py
@@ -0,0 +1,69 @@
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Experimental shuffle ops."""
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from tensorflow.contrib.data.python.ops import batching
+from tensorflow.contrib.data.python.ops import random_ops
+from tensorflow.python.data.ops import dataset_ops
+
+
+def shuffle_and_repeat(buffer_size, count=None, seed=None):
+ """Shuffles and repeats a Dataset returning a new permutation for each epoch.
+
+ `dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size, count))`
+
+ is equivalent to
+
+ `dataset.shuffle(buffer_size, reshuffle_each_iteration=True).repeat(count)`
+
+ The difference is that the latter dataset is not serializable. So,
+ if you need to checkpoint an input pipeline with reshuffling you must use
+ this implementation.
+
+ Args:
+ buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
+ maximum number elements that will be buffered when prefetching.
+ count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ number of times the dataset should be repeated. The default behavior
+ (if `count` is `None` or `-1`) is for the dataset be repeated
+ indefinitely.
+ seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ random seed that will be used to create the distribution. See
+ @{tf.set_random_seed} for behavior.
+
+ Returns:
+ A `Dataset` transformation function, which can be passed to
+ @{tf.contrib.data.Dataset.apply}.
+ """
+ def _apply_fn(dataset): # pylint: disable=missing-docstring
+ random_ds = random_ops.RandomDataset(seed).apply(
+ batching.batch_and_drop_remainder(2))
+ if count is not None and count is not -1:
+ random_ds = random_ds.take(count)
+
+ def map_fn(seeds):
+ return dataset_ops.ShuffleDataset(
+ input_dataset=dataset,
+ buffer_size=buffer_size,
+ seed=seeds[0],
+ reshuffle_each_iteration=False,
+ seed2=seeds[1])
+
+ return random_ds.flat_map(map_fn)
+
+ return _apply_fn
diff --git a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py
index 38b3a23c2d..49451446b5 100644
--- a/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py
+++ b/tensorflow/contrib/distributions/python/kernel_tests/bijectors/reshape_test.py
@@ -28,8 +28,19 @@ from tensorflow.python.ops.distributions.bijector_test_util import assert_biject
from tensorflow.python.platform import test
-class ReshapeBijectorTest(test.TestCase):
- """Tests correctness of the reshape transformation."""
+class _ReshapeBijectorTest(object):
+ """Base class for testing the reshape transformation.
+
+ Methods defined in this class call a method self.build_shapes() that
+ is implemented by subclasses defined below, returning respectively
+ ReshapeBijectorTestStatic: static shapes,
+ ReshapeBijectorTestDynamic: shape placeholders of known ndims, and
+ ReshapeBijectorTestDynamicNdims: shape placeholders of unspecified ndims,
+ so that each test in this base class is automatically run over all
+ three cases. The subclasses also implement assertRaisesError to test
+ for either Python exceptions (in the case of static shapes) or
+ TensorFlow op errors (dynamic shapes).
+ """
def setUp(self):
self._rng = np.random.RandomState(42)
@@ -40,9 +51,10 @@ class ReshapeBijectorTest(test.TestCase):
expected_y = np.reshape(expected_x, [4, 6])
with self.test_session() as sess:
+ shape_in, shape_out, feed_dict = self.build_shapes([3, 2], [6,])
bijector = Reshape(
- event_shape_out=[6,],
- event_shape_in=[3, 2],
+ event_shape_out=shape_out,
+ event_shape_in=shape_in,
validate_args=True)
(x_,
y_,
@@ -52,66 +64,23 @@ class ReshapeBijectorTest(test.TestCase):
bijector.forward(expected_x),
bijector.forward_log_det_jacobian(expected_x),
bijector.inverse_log_det_jacobian(expected_y),
- ))
+ ), feed_dict=feed_dict)
self.assertEqual("reshape", bijector.name)
self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
self.assertAllClose(0., fldj_, rtol=1e-6, atol=0)
self.assertAllClose(0., ildj_, rtol=1e-6, atol=0)
- def testEventShapeDynamicNdims(self):
- """Check forward/inverse shape methods with dynamic ndims."""
-
- shape_in = tensor_shape.TensorShape([6,])
- shape_in_ph = array_ops.placeholder(dtype=dtypes.int32)
-
- shape_out = tensor_shape.TensorShape([2, 3])
- shape_out_ph = array_ops.placeholder(dtype=dtypes.int32)
-
- bijector = Reshape(
- event_shape_out=shape_out_ph,
- event_shape_in=shape_in_ph, validate_args=True)
-
- # using the _tensor methods, we should always get a fully-specified
- # result since these are evaluated at graph runtime.
- with self.test_session() as sess:
- (shape_out_,
- shape_in_) = sess.run((
- bijector.forward_event_shape_tensor(shape_in),
- bijector.inverse_event_shape_tensor(shape_out),
- ), feed_dict={
- shape_in_ph: shape_in,
- shape_out_ph: shape_out,
- })
- self.assertAllEqual(shape_out, shape_out_)
- self.assertAllEqual(shape_in, shape_in_)
-
- def testEventShapeDynamic(self):
- """Check shape methods with static ndims but dynamic shape."""
-
- shape_in = tensor_shape.TensorShape([6,])
- shape_in_partial = tensor_shape.TensorShape([None,])
- shape_in_ph = array_ops.placeholder(
- shape=[1,], dtype=dtypes.int32)
-
- shape_out = tensor_shape.TensorShape([2, 3])
- shape_out_partial = tensor_shape.TensorShape([None, None])
- shape_out_ph = array_ops.placeholder(
- shape=[2,], dtype=dtypes.int32)
+ def testEventShapeTensor(self):
+ """Test event_shape_tensor methods when even ndims may be dynamic."""
+ shape_in_static = [2, 3]
+ shape_out_static = [6,]
+ shape_in, shape_out, feed_dict = self.build_shapes(shape_in_static,
+ shape_out_static)
bijector = Reshape(
- event_shape_out=shape_out_ph,
- event_shape_in=shape_in_ph,
- validate_args=True)
-
- # if event shapes are not statically available, should
- # return partially-specified TensorShapes.
- self.assertAllEqual(
- bijector.forward_event_shape(shape_in).as_list(),
- shape_out_partial.as_list())
- self.assertAllEqual(
- bijector.inverse_event_shape(shape_out).as_list(),
- shape_in_partial.as_list())
+ event_shape_out=shape_out,
+ event_shape_in=shape_in, validate_args=True)
# using the _tensor methods, we should always get a fully-specified
# result since these are evaluated at graph runtime.
@@ -120,42 +89,9 @@ class ReshapeBijectorTest(test.TestCase):
shape_in_) = sess.run((
bijector.forward_event_shape_tensor(shape_in),
bijector.inverse_event_shape_tensor(shape_out),
- ), feed_dict={
- shape_in_ph: shape_in,
- shape_out_ph: shape_out,
- })
- self.assertAllEqual(shape_out, shape_out_)
- self.assertAllEqual(shape_in, shape_in_)
-
- def testEventShapeStatic(self):
- """Check shape methods when shape is statically known."""
-
- shape_in = tensor_shape.TensorShape([6,])
- shape_out = tensor_shape.TensorShape([2, 3])
-
- bijector_static = Reshape(
- event_shape_out=shape_out,
- event_shape_in=shape_in,
- validate_args=True)
-
- # test that forward_ and inverse_event_shape do sensible things
- # when shapes are statically known.
- self.assertEqual(
- bijector_static.forward_event_shape(shape_in),
- shape_out)
- self.assertEqual(
- bijector_static.inverse_event_shape(shape_out),
- shape_in)
-
- with self.test_session() as sess:
- (shape_out_static_,
- shape_in_static_,
- ) = sess.run((
- bijector_static.forward_event_shape_tensor(shape_in),
- bijector_static.inverse_event_shape_tensor(shape_out),
- ))
- self.assertAllEqual(shape_out, shape_out_static_)
- self.assertAllEqual(shape_in, shape_in_static_)
+ ), feed_dict=feed_dict)
+ self.assertAllEqual(shape_out_static, shape_out_)
+ self.assertAllEqual(shape_in_static, shape_in_)
def testScalarReshape(self):
"""Test reshaping to and from a scalar shape ()."""
@@ -166,11 +102,11 @@ class ReshapeBijectorTest(test.TestCase):
expected_x_scalar = np.random.randn(1,)
expected_y_scalar = expected_x_scalar[0]
+ shape_in, shape_out, feed_dict = self.build_shapes([], [1,])
with self.test_session() as sess:
bijector = Reshape(
- event_shape_out=[],
- event_shape_in=[1,], validate_args=True)
-
+ event_shape_out=shape_in,
+ event_shape_in=shape_out, validate_args=True)
(x_,
y_,
x_scalar_,
@@ -180,53 +116,178 @@ class ReshapeBijectorTest(test.TestCase):
bijector.forward(expected_x),
bijector.inverse(expected_y_scalar),
bijector.forward(expected_x_scalar),
- ))
+ ), feed_dict=feed_dict)
self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
self.assertAllClose(expected_y_scalar, y_scalar_, rtol=1e-6, atol=0)
self.assertAllClose(expected_x_scalar, x_scalar_, rtol=1e-6, atol=0)
- def testRaisesOpError(self):
- x1 = np.random.randn(4, 2, 3)
- x2 = np.random.randn(4, 3, 2)
- x3 = np.random.randn(4, 5, 1, 1)
+ def testMultipleUnspecifiedDimensionsOpError(self):
with self.test_session() as sess:
- shape_in_ph = array_ops.placeholder(shape=[2,], dtype=dtypes.int32)
- shape_out_ph = array_ops.placeholder(shape=[3,], dtype=dtypes.int32)
+ shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [4, -1, -1,])
bijector = Reshape(
- event_shape_out=shape_out_ph,
- event_shape_in=shape_in_ph,
+ event_shape_out=shape_out,
+ event_shape_in=shape_in,
validate_args=True)
- with self.assertRaisesOpError(
+ with self.assertRaisesError(
+ "elements must have at most one `-1`."):
+ sess.run(bijector.forward_event_shape_tensor(shape_in),
+ feed_dict=feed_dict)
+
+ def testInvalidDimensionsOpError(self):
+
+ with self.test_session() as sess:
+
+ shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 2, -2,])
+ bijector = Reshape(
+ event_shape_out=shape_out,
+ event_shape_in=shape_in,
+ validate_args=True)
+
+ with self.assertRaisesError(
+ "elements must be either positive integers or `-1`."):
+ sess.run(bijector.forward_event_shape_tensor(shape_in),
+ feed_dict=feed_dict)
+
+ def testValidButNonMatchingInputOpError(self):
+ x = np.random.randn(4, 3, 2)
+
+ with self.test_session() as sess:
+ shape_in, shape_out, feed_dict = self.build_shapes([2, 3], [1, 6, 1,])
+ bijector = Reshape(
+ event_shape_out=shape_out,
+ event_shape_in=shape_in,
+ validate_args=True)
+
+ # Here we pass in a tensor (x) whose shape is compatible with
+ # the output shape, so tf.reshape will throw no error, but
+ # doesn't match the expected input shape.
+ with self.assertRaisesError(
"Input `event_shape` does not match `event_shape_in`."):
- sess.run(bijector.forward(x2),
- feed_dict={shape_out_ph: [1, 6, 1],
- shape_in_ph: [2, 3]})
+ sess.run(bijector.forward(x),
+ feed_dict=feed_dict)
- with self.assertRaisesOpError(
- "event_shape_out entries must be positive."):
- sess.run(bijector.forward(x1),
- feed_dict={shape_out_ph: [-1, -1, 6],
- shape_in_ph: [2, 3]})
+ def testValidButNonMatchingInputPartiallySpecifiedOpError(self):
+ x = np.random.randn(4, 3, 2)
+
+ with self.test_session() as sess:
+ shape_in, shape_out, feed_dict = self.build_shapes([2, -1], [1, 6, 1,])
+ bijector = Reshape(
+ event_shape_out=shape_out,
+ event_shape_in=shape_in,
+ validate_args=True)
+
+ with self.assertRaisesError(
+ "Input `event_shape` does not match `event_shape_in`."):
+ sess.run(bijector.forward(x),
+ feed_dict=feed_dict)
+
+ def testInputOutputMismatchOpError(self):
+ x1 = np.random.randn(4, 2, 3)
+ x2 = np.random.randn(4, 1, 1, 5)
+
+ with self.test_session() as sess:
+ shape_in, shape_out, fd_mismatched = self.build_shapes([2, 3],
+ [1, 1, 5])
+ bijector = Reshape(
+ event_shape_out=shape_out,
+ event_shape_in=shape_in,
+ validate_args=True)
# test that *all* methods check basic assertions
- fd_mismatched = {shape_out_ph: [1, 1, 5], shape_in_ph: [2, 3]}
- with self.assertRaisesOpError(
- "Input/output `event_size`s do not match."):
+ with self.assertRaisesError(
+ "Input to reshape is a tensor with"):
sess.run(bijector.forward(x1), feed_dict=fd_mismatched)
- with self.assertRaisesOpError(
- "Input/output `event_size`s do not match."):
- sess.run(bijector.inverse(x3), feed_dict=fd_mismatched)
- with self.assertRaisesOpError(
- "Input/output `event_size`s do not match."):
- sess.run(bijector.inverse_log_det_jacobian(x3),
- feed_dict=fd_mismatched)
- with self.assertRaisesOpError(
- "Input/output `event_size`s do not match."):
- sess.run(bijector.forward_log_det_jacobian(x1),
- feed_dict=fd_mismatched)
+ with self.assertRaisesError(
+ "Input to reshape is a tensor with"):
+ sess.run(bijector.inverse(x2), feed_dict=fd_mismatched)
+
+ def testOneShapePartiallySpecified(self):
+ expected_x = np.random.randn(4, 6)
+ expected_y = np.reshape(expected_x, [4, 2, 3])
+
+ with self.test_session() as sess:
+ # one of input/output shapes is partially specified
+ shape_in, shape_out, feed_dict = self.build_shapes([-1,], [2, 3])
+ bijector = Reshape(
+ event_shape_out=shape_out,
+ event_shape_in=shape_in,
+ validate_args=True)
+ (x_,
+ y_,
+ ) = sess.run((
+ bijector.inverse(expected_y),
+ bijector.forward(expected_x),
+ ), feed_dict=feed_dict)
+ self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
+ self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
+
+ def testBothShapesPartiallySpecified(self):
+ expected_x = np.random.randn(4, 2, 3)
+ expected_y = np.reshape(expected_x, [4, 3, 2])
+ with self.test_session() as sess:
+ shape_in, shape_out, feed_dict = self.build_shapes([-1, 3], [-1, 2])
+ bijector = Reshape(
+ event_shape_out=shape_out,
+ event_shape_in=shape_in,
+ validate_args=True)
+ (x_,
+ y_,
+ ) = sess.run((
+ bijector.inverse(expected_y),
+ bijector.forward(expected_x),
+ ), feed_dict=feed_dict)
+ self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
+ self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
+
+ def testDefaultVectorShape(self):
+ expected_x = np.random.randn(4, 4)
+ expected_y = np.reshape(expected_x, [4, 2, 2])
+ with self.test_session() as sess:
+ _, shape_out, feed_dict = self.build_shapes([-1,], [-1, 2])
+ bijector = Reshape(shape_out,
+ validate_args=True)
+ (x_,
+ y_,
+ ) = sess.run((
+ bijector.inverse(expected_y),
+ bijector.forward(expected_x),
+ ), feed_dict=feed_dict)
+ self.assertAllClose(expected_y, y_, rtol=1e-6, atol=0)
+ self.assertAllClose(expected_x, x_, rtol=1e-6, atol=0)
+
+ def build_shapes(self, *args, **kwargs):
+ raise NotImplementedError("Subclass failed to implement `build_shapes`.")
+
+
+class ReshapeBijectorTestStatic(test.TestCase, _ReshapeBijectorTest):
+
+ def build_shapes(self, shape_in, shape_out):
+ shape_in_static = shape_in
+ shape_out_static = shape_out
+ feed_dict = {}
+ return shape_in_static, shape_out_static, feed_dict
+
+ def assertRaisesError(self, msg):
+ return self.assertRaisesRegexp(Exception, msg)
+
+ def testEventShape(self):
+ shape_in_static = tensor_shape.TensorShape([2, 3])
+ shape_out_static = tensor_shape.TensorShape([6,])
+ bijector = Reshape(
+ event_shape_out=shape_out_static,
+ event_shape_in=shape_in_static, validate_args=True)
+
+ # test that forward_ and inverse_event_shape do sensible things
+ # when shapes are statically known.
+ self.assertEqual(
+ bijector.forward_event_shape(shape_in_static),
+ shape_out_static)
+ self.assertEqual(
+ bijector.inverse_event_shape(shape_out_static),
+ shape_in_static)
def testBijectiveAndFinite(self):
x = np.random.randn(4, 2, 3)
@@ -238,5 +299,32 @@ class ReshapeBijectorTest(test.TestCase):
validate_args=True)
assert_bijective_and_finite(bijector, x, y, rtol=1e-6, atol=0)
+
+class ReshapeBijectorTestDynamic(test.TestCase, _ReshapeBijectorTest):
+
+ def build_shapes(self, shape_in, shape_out):
+ shape_in_ph = array_ops.placeholder(shape=(len(shape_in),),
+ dtype=dtypes.int32)
+ shape_out_ph = array_ops.placeholder(shape=(len(shape_out),),
+ dtype=dtypes.int32)
+ feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out}
+ return shape_in_ph, shape_out_ph, feed_dict
+
+ def assertRaisesError(self, msg):
+ return self.assertRaisesOpError(msg)
+
+
+class ReshapeBijectorTestDynamicNdims(test.TestCase, _ReshapeBijectorTest):
+
+ def build_shapes(self, shape_in, shape_out):
+ shape_in_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32)
+ shape_out_ph = array_ops.placeholder(shape=None, dtype=dtypes.int32)
+ feed_dict = {shape_in_ph: shape_in, shape_out_ph: shape_out}
+ return shape_in_ph, shape_out_ph, feed_dict
+
+ def assertRaisesError(self, msg):
+ return self.assertRaisesOpError(msg)
+
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py
index b84502003a..0fe9f6aa78 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/absolute_value_impl.py
@@ -48,7 +48,9 @@ class AbsoluteValue(bijector.Bijector):
```python
- abs = ds.bijectors.AbsoluteValue()
+ tfd = tf.contrib.distributions
+
+ abs = tfd.bijectors.AbsoluteValue()
abs.forward([-1., 0., 1.])
==> [1., 0., 1.]
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py
index ae14288393..f51c48d2dd 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/masked_autoregressive_impl.py
@@ -124,17 +124,17 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector):
#### Example Use
```python
- ds = tf.contrib.distributions
- bs = tf.contrib.distributions.bijectors
+ tfd = tf.contrib.distributions
+ tfb = tfd.bijectors
dims = 5
# A common choice for a normalizing flow is to use a Gaussian for the base
# distribution. (However, any continuous distribution would work.) E.g.,
- maf = ds.TransformedDistribution(
- distribution=ds.Normal(loc=0., scale=1.),
- bijector=bs.MaskedAutoregressiveFlow(
- shift_and_log_scale_fn=bs.masked_autoregressive_default_template(
+ maf = tfd.TransformedDistribution(
+ distribution=tfd.Normal(loc=0., scale=1.),
+ bijector=tfb.MaskedAutoregressiveFlow(
+ shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(
hidden_layers=[512, 512])),
event_shape=[dims])
@@ -143,10 +143,10 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector):
maf.log_prob(0.) # Cheap; no `tf.while_loop` despite no Bijector caching.
# [1] also describes an "Inverse Autoregressive Flow", e.g.,
- iaf = ds.TransformedDistribution(
- distribution=ds.Normal(loc=0., scale=1.),
- bijector=bs.Invert(bs.MaskedAutoregressiveFlow(
- shift_and_log_scale_fn=bs.masked_autoregressive_default_template(
+ iaf = tfd.TransformedDistribution(
+ distribution=tfd.Normal(loc=0., scale=1.),
+ bijector=tfb.Invert(tfb.MaskedAutoregressiveFlow(
+ shift_and_log_scale_fn=tfb.masked_autoregressive_default_template(
hidden_layers=[512, 512]))),
event_shape=[dims])
@@ -158,10 +158,10 @@ class MaskedAutoregressiveFlow(bijector_lib.Bijector):
# poor choice. Here's an example of using a "shift only" version and with a
# different number/depth of hidden layers.
shift_only = True
- maf_no_scale_hidden2 = ds.TransformedDistribution(
- distribution=ds.Normal(loc=0., scale=1.),
- bijector=bs.MaskedAutoregressiveFlow(
- bs.masked_autoregressive_default_template(
+ maf_no_scale_hidden2 = tfd.TransformedDistribution(
+ distribution=tfd.Normal(loc=0., scale=1.),
+ bijector=tfb.MaskedAutoregressiveFlow(
+ tfb.masked_autoregressive_default_template(
hidden_layers=[32],
shift_only=shift_only),
is_constant_jacobian=shift_only),
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py
index b1d8f2f41b..8654cc39d0 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/permute_impl.py
@@ -40,9 +40,9 @@ class Permute(bijector_lib.Bijector):
"""Permutes the rightmost dimension of a `Tensor`.
```python
- bs = tf.contrib.distributions.bijectors
+ tfd = tf.contrib.distributions
- reverse = bs.Permute(permutation=[2, 1, 0])
+ reverse = tfd.bijectors.Permute(permutation=[2, 1, 0])
reverse.forward([-1., 0., 1.])
# ==> [1., 0., -1]
diff --git a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py
index 93682639aa..55eca06312 100644
--- a/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py
+++ b/tensorflow/contrib/distributions/python/ops/bijectors/reshape_impl.py
@@ -36,70 +36,77 @@ __all__ = [
]
+def _static_ndims_from_shape(shape):
+ return shape.shape.with_rank_at_least(1)[0].value
+
+
+def _ndims_from_shape(shape):
+ return array_ops.shape(shape)[0]
+
+
class Reshape(bijector_lib.Bijector):
"""Reshapes the `event_shape` of a `Tensor`.
The semantics generally follow that of `tf.reshape()`, with
a few differences:
- * The user must provide both the input and output shape, so that
- the transformation can be inverted.
- * The `Reshape` bijector automatically broadcasts over the leftmost
- dimensions of its input (`sample_shape` and `batch_shape`); only
- the rightmost `event_ndims_in` dimensions are reshaped. The
- number of dimensions to reshape is inferred from the provided
- `event_shape_in` (`event_ndims_in = len(event_shape_in)`).
- * The `Reshape` bijector does not currently support
- partially-specified shapes, i.e., those with a dimension
- implicitly specified by `-1`.
+
+ * The user must provide both the input and output shape, so that
+ the transformation can be inverted. If an input shape is not
+ specified, the default assumes a vector-shaped input, i.e.,
+ event_shape_in = (-1,).
+ * The `Reshape` bijector automatically broadcasts over the leftmost
+ dimensions of its input (`sample_shape` and `batch_shape`); only
+ the rightmost `event_ndims_in` dimensions are reshaped. The
+ number of dimensions to reshape is inferred from the provided
+ `event_shape_in` (`event_ndims_in = len(event_shape_in)`).
Example usage:
```python
- bs = tf.contrib.distributions.bijectors
+ tfd = tf.contrib.distributions
- reverse = bs.Reshape(event_shape_out=[1,2],
- event_shape_in=[2,])
+ r = tfd.bijectors.Reshape(event_shape_out=[1, -1])
- reverse.forward([1., 2.]) # shape [2,]
- # ==> [[1., 2.]] # shape [1,2]
+ r.forward([3., 4.]) # shape [2]
+ # ==> [[3., 4.]] # shape [1, 2]
- reverse.forward([[1., 2.], [3., 4.]]) # shape [2, 2]
- # ==> [[[1., 2.]], [[3., 4.]]] # shape [2, 1, 2]
+ r.forward([[1., 2.], [3., 4.]]) # shape [2, 2]
+ # ==> [[[1., 2.]],
+ # [[3., 4.]]] # shape [2, 1, 2]
- reverse.inverse([[1., 2.]]) # shape [1,2]
- # ==> [1., 2.] # shape [2,]
+ r.inverse([[3., 4.]]) # shape [1,2]
+ # ==> [3., 4.] # shape [2]
- reverse.forward_log_det_jacobian(any_value)
+ r.forward_log_det_jacobian(any_value)
# ==> 0.
- reverse.inverse_log_det_jacobian(any_value)
+ r.inverse_log_det_jacobian(any_value)
# ==> 0.
```
"""
- def __init__(self, event_shape_out, event_shape_in,
+ def __init__(self, event_shape_out, event_shape_in=(-1,),
validate_args=False, name=None):
"""Creates a `Reshape` bijector.
Args:
event_shape_out: An `int`-like vector-shaped `Tensor`
- representing the fully specified (no -1's) event shape of the
- transformed output.
- event_shape_in: An `int`-like vector-shaped `Tensor`
- representing the fully specified (no -1's) event shape of the
- input.
+ representing the event shape of the transformed output.
+ event_shape_in: An optional `int`-like vector-shape `Tensor`
+ representing the event shape of the input. This is required in
+ order to define inverse operations; the default of (-1,)
+ assumes a vector-shaped input.
validate_args: Python `bool` indicating whether arguments should
be checked for correctness.
name: Python `str`, name given to ops managed by this object.
Raises:
TypeError: if either `event_shape_in` or `event_shape_out` has
- non-vector shape (`rank > 1`), or non-integer `dtype`.
- ValueError: if either `event_shape_in` or `event_shape_out`
- contains non-positive entries, or if their sizes do not match
- (`prod(event_shape_in)` != `prod(event_shape_out)`), or if
- their dimensionality(s) cannot be statically inferred.
+ non-integer `dtype`.
+ ValueError: if either of `event_shape_in` or `event_shape_out`
+ has non-vector shape (`rank > 1`), or if their sizes do not
+ match.
"""
with ops.name_scope(name, "reshape",
values=[event_shape_out, event_shape_in]):
@@ -111,105 +118,74 @@ class Reshape(bijector_lib.Bijector):
name="event_shape_in",
preferred_dtype=dtypes.int32)
- # check that input shapes are positive integers
assertions = []
- assertions += self._maybe_check_valid_shape(
- event_shape_out, "event_shape_out",
- validate_args=validate_args)
- assertions += self._maybe_check_valid_shape(
- event_shape_in, "event_shape_in", validate_args=validate_args)
-
- # check that prod(event_shape_in) = prod(event_shape_out)
- assertions += self._maybe_check_matching_sizes(
- event_shape_in, event_shape_out, validate_args=validate_args)
+ assertions.extend(self._maybe_check_valid_shape(
+ event_shape_out, validate_args))
+ assertions.extend(self._maybe_check_valid_shape(
+ event_shape_in, validate_args))
self._assertions = assertions
self._event_shape_in = event_shape_in
self._event_shape_out = event_shape_out
- self._event_shape_in_static = tensor_util.constant_value_as_shape(
- event_shape_in)
- self._event_shape_out_static = tensor_util.constant_value_as_shape(
- event_shape_out)
super(Reshape, self).__init__(is_constant_jacobian=True,
validate_args=validate_args,
name=name or "reshape")
- def _maybe_check_valid_shape(self, shape_tensor, label,
- validate_args=False):
- """Check that a shape Tensor is int-type and positive."""
-
- assertions = []
-
- if not shape_tensor.dtype.is_integer:
+ def _maybe_check_valid_shape(self, shape, validate_args):
+ """Check that a shape Tensor is int-type and otherwise sane."""
+ if not shape.dtype.is_integer:
raise TypeError("{} dtype ({}) should be `int`-like.".format(
- label, shape_tensor.dtype.name))
+ shape.op.name, shape.dtype.name))
- shape_rank = tensor_util.constant_value(array_ops.rank(shape_tensor))
- if shape_rank is not None and shape_rank > 1:
- raise ValueError("{} rank should be <= 1.".format(label))
+ assertions = []
- s = tensor_util.constant_value(shape_tensor)
- if s is not None:
- if (s <= 0).any():
- raise ValueError("{} entries must be positive, but found {}".format(
- label, s))
+ ndims = array_ops.rank(shape)
+ ndims_ = tensor_util.constant_value(ndims)
+ if ndims_ is not None and ndims_ > 1:
+ raise ValueError("`{}` rank ({}) should be <= 1.".format(
+ shape.op.name, ndims_))
elif validate_args:
- assertions.append(check_ops.assert_positive(
- shape_tensor, message="{} entries must be positive".format(label)))
-
- return assertions
-
- def _maybe_check_matching_sizes(self, event_shape_in, event_shape_out,
- validate_args=False):
- """Check that prod(event_shape_in)==prod(event_shape_out)."""
+ assertions.append(check_ops.assert_less_equal(
+ ndims, 1, message="`{}` rank should be <= 1.".format(shape.op.name)))
- def _get_size_from_shape(shape):
- """Computes size from a shape `Tensor`, statically if possible."""
- s = tensor_util.constant_value(shape)
- if s is not None:
- return [np.int32(np.prod(s))]*2
- return None, math_ops.reduce_prod(shape, name="size")
-
- # Ensure `event_shape_in` is compatible with `event_shape_out`.
- event_size_in_, event_size_in = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking
- event_shape_in)
- event_size_out_, event_size_out = _get_size_from_shape( # pylint: disable=unbalanced-tuple-unpacking
- event_shape_out)
-
- assertions = []
- if event_size_in_ is not None and event_size_out_ is not None:
- if event_size_in_ != event_size_out_:
+ shape_ = tensor_util.constant_value_as_shape(shape)
+ if shape_.is_fully_defined():
+ es = np.int32(shape_.as_list())
+ if sum(es == -1) > 1:
+ raise ValueError(
+ "`{}` must have at most one `-1` (given {})"
+ .format(shape.op.name, es))
+ if np.any(es < -1):
raise ValueError(
- "Input `event_size` ({}) does not match output `event_size` ({}).".
- format(event_size_in, event_size_out_))
+ "`{}` elements must be either positive integers or `-1`"
+ "(given {})."
+ .format(shape.op.name, es))
elif validate_args:
- assertions.append(check_ops.assert_equal(
- event_size_in, event_size_out,
- message="Input/output `event_size`s do not match."))
-
+ assertions.extend([
+ check_ops.assert_less_equal(
+ math_ops.reduce_sum(
+ math_ops.cast(math_ops.equal(shape, -1), dtypes.int32)),
+ 1,
+ message="`{}` elements must have at most one `-1`."
+ .format(shape.op.name)),
+ check_ops.assert_greater_equal(
+ shape, -1,
+ message="`{}` elements must be either positive integers or `-1`."
+ .format(shape.op.name)),
+ ])
return assertions
def _reshape_helper(self, x, event_shape_in, event_shape_out):
"""Reshape only the event_shape of an input `Tensor`."""
- def _get_rank_from_shape(shape):
- """Computes rank from a shape `Tensor`, statically if possible."""
- # Uses fact that rank is "shape of shape".
- ndims = shape.shape.with_rank_at_least(1)[0].value
- if ndims is not None:
- return ndims, ndims
- return None, array_ops.shape(shape)[0]
-
- event_ndims_in_, event_ndims_in = _get_rank_from_shape(event_shape_in)
+ event_ndims_in_ = _static_ndims_from_shape(event_shape_in)
+ event_ndims_in = _ndims_from_shape(event_shape_in)
+ x_ndims_, x_ndims = x.shape.ndims, array_ops.rank(x)
assertions = []
- # Ensure x.event_shape is compatible with event_shape_in.
- if x.shape.ndims is not None:
- x_ndims_, x_ndims = [x.shape.ndims]*2
- else:
- x_ndims_, x_ndims = None, array_ops.rank(x)
+ # Ensure x.event_shape is compatible with event_shape_in.
if (event_ndims_in_ is not None
and x_ndims_ is not None
and x.shape.with_rank_at_least(event_ndims_in_)[
@@ -223,13 +199,35 @@ class Reshape(bijector_lib.Bijector):
event_shape_in_ = tensor_util.constant_value(event_shape_in)
if x_event_shape_ is not None and event_shape_in_ is not None:
- if not np.equal(x_event_shape_, event_shape_in_).all():
+ # Compare the shape dimensions that are fully specified in the
+ # input (i.e., for which event_shape_in is not -1). If x_event_shape
+ # matches along all of these dimensions, it is compatible with
+ # the desired input shape and any further mismatches (i.e.,
+ # imcompatibility with the desired *output* shape) will be
+ # caught inside of array_ops.reshape() below.
+ x_event_shape_specified_ = x_event_shape_[event_shape_in_ >= 0]
+ event_shape_in_specified_ = event_shape_in_[event_shape_in_ >= 0]
+ if not np.equal(x_event_shape_specified_,
+ event_shape_in_specified_).all():
raise ValueError(
- "Input `event_shape` ({}) does not match `event_shape_in` ({}).".
+ "Input `event_shape` does not match `event_shape_in` ({} vs {}).".
format(x_event_shape_, event_shape_in_))
elif self.validate_args:
+ # Similarly to the static case, we compare the shape dimensions
+ # that are fully specified in the input. We extract these
+ # dimensions using boolean_mask(), which requires that the mask
+ # have known ndims. We can assume that shape Tensors always have
+ # ndims==1 (this assumption is verified inside of
+ # _maybe_check_valid_shape), so the reshape operation is just a
+ # no-op that formally encodes this fact to make boolean_mask()
+ # happy.
+ event_shape_mask = array_ops.reshape(event_shape_in >= 0, [-1])
+ x_event_shape_specified = array_ops.boolean_mask(x_event_shape,
+ event_shape_mask)
+ event_shape_in_specified = array_ops.boolean_mask(event_shape_in,
+ event_shape_mask)
assertions.append(check_ops.assert_equal(
- x_event_shape, event_shape_in,
+ x_event_shape_specified, event_shape_in_specified,
message="Input `event_shape` does not match `event_shape_in`."))
if assertions:
@@ -243,8 +241,19 @@ class Reshape(bijector_lib.Bijector):
sample_and_batch_shape = sample_and_batch_shape[
:(ndims - math_ops.abs(event_ndims_in))]
- new_shape = array_ops.concat(
- [sample_and_batch_shape, event_shape_out], axis=0)
+ if (event_ndims_in_ is not None
+ and x_ndims_ is not None
+ and event_ndims_in_ == x_ndims_):
+ # Hack to allow forward/inverse_event_shape to do shape
+ # inference by calling this helper method with a dummy Tensor of
+ # shape event_shape_in. In this special case,
+ # sample_and_batch_shape will be empty so we can preserve static
+ # shape information by avoiding the concat operation below
+ # (which would be a no-op).
+ new_shape = event_shape_out
+ else:
+ new_shape = array_ops.concat(
+ [sample_and_batch_shape, event_shape_out], axis=0)
return array_ops.reshape(x, new_shape)
@@ -269,29 +278,37 @@ class Reshape(bijector_lib.Bijector):
return constant_op.constant(0., dtype=x.dtype)
def _forward_event_shape(self, input_shape):
- self._event_shape_in_static.assert_is_compatible_with(input_shape)
- return self._event_shape_out_static
+ # NOTE: this method and the other *_event_shape* methods
+ # compute shape by explicit transformation of a dummy
+ # variable. This approach is not generally recommended because it
+ # bloats the graph and could in general trigger side effects.
+ #
+ # In this particular case of the Reshape bijector, the
+ # forward and inverse transforms have no side effects, and we
+ # believe the reduction in code complexity from delegating the
+ # heavy lifting to tf.reshape() is worth the added graph ops.
+ # However, you should think hard before implementing this approach
+ # in other Bijectors; it is strongly preferred to compute
+ # shapes explicitly whenever it's feasible to do so.
+ with ops.control_dependencies(self._assertions):
+ dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape)
+ dummy_reshaped = self.forward(dummy)
+ return dummy_reshaped.shape
def _inverse_event_shape(self, output_shape):
- self._event_shape_out_static.assert_is_compatible_with(output_shape)
- return self._event_shape_in_static
+ with ops.control_dependencies(self._assertions):
+ dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape)
+ dummy_reshaped = self.inverse(dummy)
+ return dummy_reshaped.shape
def _forward_event_shape_tensor(self, input_shape):
- input_assertions = self._maybe_check_valid_shape(
- input_shape, "input event shape", validate_args=self.validate_args)
- input_assertions += self._maybe_check_matching_sizes(
- input_shape, self._event_shape_out,
- validate_args=self.validate_args)
-
- return control_flow_ops.with_dependencies(
- input_assertions + self._assertions, self._event_shape_out)
+ with ops.control_dependencies(self._assertions):
+ dummy = array_ops.zeros(dtype=dtypes.float32, shape=input_shape)
+ dummy_reshaped = self.forward(dummy)
+ return array_ops.shape(dummy_reshaped)
def _inverse_event_shape_tensor(self, output_shape):
-
- output_assertions = self._maybe_check_valid_shape(
- output_shape, "output event shape", validate_args=self.validate_args)
- output_assertions += self._maybe_check_matching_sizes(
- output_shape, self._event_shape_in, validate_args=self.validate_args)
-
- return control_flow_ops.with_dependencies(
- output_assertions + self._assertions, self._event_shape_in)
+ with ops.control_dependencies(self._assertions):
+ dummy = array_ops.zeros(dtype=dtypes.float32, shape=output_shape)
+ dummy_reshaped = self.inverse(dummy)
+ return array_ops.shape(dummy_reshaped)
diff --git a/tensorflow/contrib/distributions/python/ops/cauchy.py b/tensorflow/contrib/distributions/python/ops/cauchy.py
index 8d59c1abfb..6f5d724a2a 100644
--- a/tensorflow/contrib/distributions/python/ops/cauchy.py
+++ b/tensorflow/contrib/distributions/python/ops/cauchy.py
@@ -43,16 +43,17 @@ class Cauchy(distribution.Distribution):
The probability density function (pdf) is,
```none
- pdf(x; loc, scale) = 1 / (pi * scale * (1 + ((x - loc) / scale)**2))
+ pdf(x; loc, scale) = 1 / (pi scale (1 + z**2))
+ z = (x - loc) / scale
```
where `loc` is the location, and `scale` is the scale.
The Cauchy distribution is a member of the [location-scale family](
https://en.wikipedia.org/wiki/Location-scale_family), i.e.
+ `Y ~ Cauchy(loc, scale)` is equivalent to,
```none
X ~ Cauchy(loc=0, scale=1)
- Y ~ Cauchy(loc=loc, scale=scale)
Y = loc + scale * X
```
@@ -61,14 +62,16 @@ class Cauchy(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
+ tfd = tf.contrib.distributions
+
# Define a single scalar Cauchy distribution.
- dist = Cauchy(loc=0., scale=3.)
+ dist = tfd.Cauchy(loc=0., scale=3.)
# Evaluate the cdf at 1, returning a scalar.
dist.cdf(1.)
# Define a batch of two scalar valued Cauchy distributions.
- dist = Cauchy(loc=[1, 2.], scale=[11, 22.])
+ dist = tfd.Cauchy(loc=[1, 2.], scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@@ -76,18 +79,17 @@ class Cauchy(distribution.Distribution):
# Get 3 samples, returning a 3 x 2 tensor.
dist.sample([3])
- ```
-
- Arguments are broadcast when possible.
- ```python
+ # Arguments are broadcast when possible.
# Define a batch of two scalar valued Cauchy distributions.
# Both have median 1, but different scales.
- dist = tf.contrib.distributions.Cauchy(loc=1., scale=[11, 22.])
+ dist = tfd.Cauchy(loc=1., scale=[11, 22.])
+
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
- dist.prob(3.0)
+ dist.prob(3.)
```
+
"""
def __init__(self,
diff --git a/tensorflow/contrib/distributions/python/ops/deterministic.py b/tensorflow/contrib/distributions/python/ops/deterministic.py
index 850d08d1bd..8049522e9f 100644
--- a/tensorflow/contrib/distributions/python/ops/deterministic.py
+++ b/tensorflow/contrib/distributions/python/ops/deterministic.py
@@ -290,8 +290,10 @@ class VectorDeterministic(_BaseDeterministic):
#### Examples
```python
+ tfd = tf.contrib.distributions
+
# Initialize a single VectorDeterministic supported at [0., 2.] in R^2.
- constant = tf.contrib.distributions.Deterministic([0., 2.])
+ constant = tfd.Deterministic([0., 2.])
constant.prob([0., 2.])
==> 1.
constant.prob([0., 3.])
@@ -299,7 +301,7 @@ class VectorDeterministic(_BaseDeterministic):
# Initialize a [3] batch of constants on R^2.
loc = [[0., 1.], [2., 3.], [4., 5.]]
- constant = constant_lib.VectorDeterministic(loc)
+ constant = tfd.VectorDeterministic(loc)
constant.prob([[0., 1.], [1.9, 3.], [3.99, 5.]])
==> [1., 0., 0.]
```
diff --git a/tensorflow/contrib/distributions/python/ops/gumbel.py b/tensorflow/contrib/distributions/python/ops/gumbel.py
index ba8d3c639b..d0efaefb8e 100644
--- a/tensorflow/contrib/distributions/python/ops/gumbel.py
+++ b/tensorflow/contrib/distributions/python/ops/gumbel.py
@@ -62,15 +62,17 @@ class _Gumbel(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
+ tfd = tf.contrib.distributions
+
# Define a single scalar Gumbel distribution.
- dist = tf.contrib.distributions.Gumbel(loc=0., scale=3.)
+ dist = tfd.Gumbel(loc=0., scale=3.)
# Evaluate the cdf at 1, returning a scalar.
dist.cdf(1.)
# Define a batch of two scalar valued Gumbels.
# The first has mean 1 and scale 11, the second 2 and 22.
- dist = tf.contrib.distributions.Gumbel(loc=[1, 2.], scale=[11, 22.])
+ dist = tfd.Gumbel(loc=[1, 2.], scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@@ -85,7 +87,7 @@ class _Gumbel(distribution.Distribution):
```python
# Define a batch of two scalar valued Logistics.
# Both have mean 1, but different scales.
- dist = tf.contrib.distributions.Gumbel(loc=1., scale=[11, 22.])
+ dist = tfd.Gumbel(loc=1., scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
diff --git a/tensorflow/contrib/distributions/python/ops/independent.py b/tensorflow/contrib/distributions/python/ops/independent.py
index 6a74ca9a0a..cbce005013 100644
--- a/tensorflow/contrib/distributions/python/ops/independent.py
+++ b/tensorflow/contrib/distributions/python/ops/independent.py
@@ -68,11 +68,11 @@ class Independent(distribution_lib.Distribution):
#### Examples
```python
- ds = tf.contrib.distributions
+ tfd = tf.contrib.distributions
# Make independent distribution from a 2-batch Normal.
- ind = ds.Independent(
- distribution=ds.Normal(loc=[-1., 1], scale=[0.1, 0.5]),
+ ind = tfd.Independent(
+ distribution=tfd.Normal(loc=[-1., 1], scale=[0.1, 0.5]),
reinterpreted_batch_ndims=1)
# All batch dims have been "absorbed" into event dims.
@@ -80,8 +80,8 @@ class Independent(distribution_lib.Distribution):
ind.event_shape # ==> [2]
# Make independent distribution from a 2-batch bivariate Normal.
- ind = ds.Independent(
- distribution=ds.MultivariateNormalDiag(
+ ind = tfd.Independent(
+ distribution=tfd.MultivariateNormalDiag(
loc=[[-1., 1], [1, -1]],
scale_identity_multiplier=[1., 0.5]),
reinterpreted_batch_ndims=1)
diff --git a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
index 956dee38a3..ee4d86867d 100644
--- a/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
+++ b/tensorflow/contrib/distributions/python/ops/inverse_gamma.py
@@ -88,8 +88,9 @@ class InverseGamma(distribution.Distribution):
#### Examples
```python
- dist = InverseGamma(concentration=3.0, rate=2.0)
- dist2 = InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
+ tfd = tf.contrib.distributions
+ dist = tfd.InverseGamma(concentration=3.0, rate=2.0)
+ dist2 = tfd.InverseGamma(concentration=[3.0, 4.0], rate=[2.0, 3.0])
```
"""
diff --git a/tensorflow/contrib/distributions/python/ops/logistic.py b/tensorflow/contrib/distributions/python/ops/logistic.py
index 48794a4882..473677f8d9 100644
--- a/tensorflow/contrib/distributions/python/ops/logistic.py
+++ b/tensorflow/contrib/distributions/python/ops/logistic.py
@@ -60,15 +60,17 @@ class Logistic(distribution.Distribution):
Examples of initialization of one or a batch of distributions.
```python
+ tfd = tf.contrib.distributions
+
# Define a single scalar Logistic distribution.
- dist = tf.contrib.distributions.Logistic(loc=0., scale=3.)
+ dist = tfd.Logistic(loc=0., scale=3.)
# Evaluate the cdf at 1, returning a scalar.
dist.cdf(1.)
# Define a batch of two scalar valued Logistics.
# The first has mean 1 and scale 11, the second 2 and 22.
- dist = tf.contrib.distributions.Logistic(loc=[1, 2.], scale=[11, 22.])
+ dist = tfd.Logistic(loc=[1, 2.], scale=[11, 22.])
# Evaluate the pdf of the first distribution on 0, and the second on 1.5,
# returning a length two tensor.
@@ -76,14 +78,11 @@ class Logistic(distribution.Distribution):
# Get 3 samples, returning a 3 x 2 tensor.
dist.sample([3])
- ```
- Arguments are broadcast when possible.
-
- ```python
+ # Arguments are broadcast when possible.
# Define a batch of two scalar valued Logistics.
# Both have mean 1, but different scales.
- dist = tf.contrib.distributions.Logistic(loc=1., scale=[11, 22.])
+ dist = tfd.Logistic(loc=1., scale=[11, 22.])
# Evaluate the pdf of both distributions on the same point, 3.0,
# returning a length 2 tensor.
diff --git a/tensorflow/contrib/distributions/python/ops/mixture.py b/tensorflow/contrib/distributions/python/ops/mixture.py
index e676931d91..f2d492f548 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture.py
@@ -49,13 +49,13 @@ class Mixture(distribution.Distribution):
```python
# Create a mixture of two Gaussians:
- ds = tf.contrib.distributions
+ tfd = tf.contrib.distributions
mix = 0.3
- bimix_gauss = ds.Mixture(
- cat=ds.Categorical(probs=[mix, 1.-mix]),
+ bimix_gauss = tfd.Mixture(
+ cat=tfd.Categorical(probs=[mix, 1.-mix]),
components=[
- ds.Normal(loc=-1., scale=0.1),
- ds.Normal(loc=+1., scale=0.5),
+ tfd.Normal(loc=-1., scale=0.1),
+ tfd.Normal(loc=+1., scale=0.5),
])
# Plot the PDF.
diff --git a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
index 5558ef0f25..5448918a50 100644
--- a/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
+++ b/tensorflow/contrib/distributions/python/ops/mixture_same_family.py
@@ -43,15 +43,14 @@ class MixtureSameFamily(distribution.Distribution):
#### Examples
```python
- import matplotlib.pyplot as plt
- ds = tf.contrib.distributions
+ tfd = tf.contrib.distributions
### Create a mixture of two scalar Gaussians:
- gm = ds.MixtureSameFamily(
- mixture_distribution=ds.Categorical(
+ gm = tfd.MixtureSameFamily(
+ mixture_distribution=tfd.Categorical(
probs=[0.3, 0.7]),
- components_distribution=ds.Normal(
+ components_distribution=tfd.Normal(
loc=[-1., 1], # One for each component.
scale=[0.1, 0.5])) # And same here.
@@ -63,14 +62,15 @@ class MixtureSameFamily(distribution.Distribution):
# Plot PDF.
x = np.linspace(-2., 3., int(1e4), dtype=np.float32)
+ import matplotlib.pyplot as plt
plt.plot(x, gm.prob(x).eval());
### Create a mixture of two Bivariate Gaussians:
- gm = ds.MixtureSameFamily(
- mixture_distribution=ds.Categorical(
+ gm = tfd.MixtureSameFamily(
+ mixture_distribution=tfd.Categorical(
probs=[0.3, 0.7]),
- components_distribution=ds.MultivariateNormalDiag(
+ components_distribution=tfd.MultivariateNormalDiag(
loc=[[-1., 1], # component 1
[1, -1]], # component 2
scale_identity_multiplier=[.3, .6]))
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag.py b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
index 163cf75d99..e862552880 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag.py
@@ -84,10 +84,10 @@ class MultivariateNormalDiag(
#### Examples
```python
- ds = tf.contrib.distributions
+ tfd = tf.contrib.distributions
# Initialize a single 2-variate Gaussian.
- mvn = ds.MultivariateNormalDiag(
+ mvn = tfd.MultivariateNormalDiag(
loc=[1., -1],
scale_diag=[1, 2.])
@@ -101,7 +101,7 @@ class MultivariateNormalDiag(
mvn.prob([-1., 0]).eval() # shape: []
# Initialize a 3-batch, 2-variate scaled-identity Gaussian.
- mvn = ds.MultivariateNormalDiag(
+ mvn = tfd.MultivariateNormalDiag(
loc=[1., -1],
scale_identity_multiplier=[1, 2., 3])
@@ -119,7 +119,7 @@ class MultivariateNormalDiag(
mvn.prob([-1., 0]).eval() # shape: [3]
# Initialize a 2-batch of 3-variate Gaussians.
- mvn = ds.MultivariateNormalDiag(
+ mvn = tfd.MultivariateNormalDiag(
loc=[[1., 2, 3],
[11, 22, 33]] # shape: [2, 3]
scale_diag=[[1., 2, 3],
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
index 040bc23072..413e88f03a 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_diag_plus_low_rank.py
@@ -86,7 +86,7 @@ class MultivariateNormalDiagPlusLowRank(
#### Examples
```python
- ds = tf.contrib.distributions
+ tfd = tf.contrib.distributions
# Initialize a single 3-variate Gaussian with covariance `cov = S @ S.T`,
# `S = diag(d) + U @ diag(m) @ U.T`. The perturbation, `U @ diag(m) @ U.T`, is
@@ -97,7 +97,7 @@ class MultivariateNormalDiagPlusLowRank(
[-1, 1],
[2, -0.5]] # shape: [3, 2]
m = [4., 5] # shape: [2]
- mvn = ds.MultivariateNormalDiagPlusLowRank(
+ mvn = tfd.MultivariateNormalDiagPlusLowRank(
loc=mu
scale_diag=d
scale_perturb_factor=U,
@@ -118,7 +118,7 @@ class MultivariateNormalDiagPlusLowRank(
m = [[0.1, 0.2],
[0.4, 0.5]] # shape: [b, r] = [2, 2]
- mvn = ds.MultivariateNormalDiagPlusLowRank(
+ mvn = tfd.MultivariateNormalDiagPlusLowRank(
loc=mu,
scale_perturb_factor=U,
scale_perturb_diag=m)
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
index f9952b2069..8e69dadfb4 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_full_covariance.py
@@ -73,14 +73,14 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL):
#### Examples
```python
- ds = tf.contrib.distributions
+ tfd = tf.contrib.distributions
# Initialize a single 3-variate Gaussian.
mu = [1., 2, 3]
cov = [[ 0.36, 0.12, 0.06],
[ 0.12, 0.29, -0.13],
[ 0.06, -0.13, 0.26]]
- mvn = ds.MultivariateNormalFullCovariance(
+ mvn = tfd.MultivariateNormalFullCovariance(
loc=mu,
covariance_matrix=cov)
@@ -100,7 +100,7 @@ class MultivariateNormalFullCovariance(mvn_tril.MultivariateNormalTriL):
mu = [[1., 2, 3],
[11, 22, 33]] # shape: [2, 3]
covariance_matrix = ... # shape: [2, 3, 3], symmetric, positive definite.
- mvn = ds.MultivariateNormalFullCovariance(
+ mvn = tfd.MultivariateNormalFullCovariance(
loc=mu,
covariance=covariance_matrix)
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
index 300bdd5f60..a739979289 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_linear_operator.py
@@ -90,8 +90,7 @@ class MultivariateNormalLinearOperator(
#### Examples
```python
- ds = tf.contrib.distributions
- la = tf.linalg
+ tfd = tf.contrib.distributions
# Initialize a single 3-variate Gaussian.
mu = [1., 2, 3]
@@ -103,9 +102,9 @@ class MultivariateNormalLinearOperator(
# [ 0.2, 0.5, 0. ],
# [ 0.1, -0.3, 0.4]])
- mvn = ds.MultivariateNormalLinearOperator(
+ mvn = tfd.MultivariateNormalLinearOperator(
loc=mu,
- scale=la.LinearOperatorLowerTriangular(scale))
+ scale=tf.linalg.LinearOperatorLowerTriangular(scale))
# Covariance agrees with cholesky(cov) parameterization.
mvn.covariance().eval()
@@ -122,9 +121,9 @@ class MultivariateNormalLinearOperator(
scale_diag = [[1., 2, 3],
[0.5, 1, 1.5]] # shape: [2, 3]
- mvn = ds.MultivariateNormalLinearOperator(
+ mvn = tfd.MultivariateNormalLinearOperator(
loc=mu,
- scale=la.LinearOperatorDiag(scale_diag))
+ scale=tf.linalg.LinearOperatorDiag(scale_diag))
# Compute the pdf of two `R^3` observations; return a length-2 vector.
x = [[-0.9, 0, 0.1],
diff --git a/tensorflow/contrib/distributions/python/ops/mvn_tril.py b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
index 260dcc18f5..6c7dc4ca7a 100644
--- a/tensorflow/contrib/distributions/python/ops/mvn_tril.py
+++ b/tensorflow/contrib/distributions/python/ops/mvn_tril.py
@@ -76,12 +76,13 @@ class MultivariateNormalTriL(
```
Trainable (batch) lower-triangular matrices can be created with
- `ds.matrix_diag_transform()` and/or `ds.fill_triangular()`
+ `tf.contrib.distributions.matrix_diag_transform()` and/or
+ `tf.contrib.distributions.fill_triangular()`
#### Examples
```python
- ds = tf.contrib.distributions
+ tfd = tf.contrib.distributions
# Initialize a single 3-variate Gaussian.
mu = [1., 2, 3]
@@ -92,7 +93,7 @@ class MultivariateNormalTriL(
# ==> [[ 0.6, 0. , 0. ],
# [ 0.2, 0.5, 0. ],
# [ 0.1, -0.3, 0.4]])
- mvn = ds.MultivariateNormalTriL(
+ mvn = tfd.MultivariateNormalTriL(
loc=mu,
scale_tril=scale)
@@ -112,7 +113,7 @@ class MultivariateNormalTriL(
mu = [[1., 2, 3],
[11, 22, 33]] # shape: [2, 3]
tril = ... # shape: [2, 3, 3], lower triangular, non-zero diagonal.
- mvn = ds.MultivariateNormalTriL(
+ mvn = tfd.MultivariateNormalTriL(
loc=mu,
scale_tril=tril)
@@ -124,9 +125,9 @@ class MultivariateNormalTriL(
# Instantiate a "learnable" MVN.
dims = 4
with tf.variable_scope("model"):
- mvn = ds.MultivariateNormalTriL(
+ mvn = tfd.MultivariateNormalTriL(
loc=tf.get_variable(shape=[dims], dtype=tf.float32, name="mu"),
- scale_tril=ds.fill_triangular(
+ scale_tril=tfd.fill_triangular(
tf.get_variable(shape=[dims * (dims + 1) / 2],
dtype=tf.float32, name="chol_Sigma")))
```
diff --git a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
index e1118ed431..2701c36fb5 100644
--- a/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
+++ b/tensorflow/contrib/distributions/python/ops/poisson_lognormal.py
@@ -107,10 +107,11 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
#### Examples
```python
- ds = tf.contrib.distributions
+ tfd = tf.contrib.distributions
+
# Create two batches of PoissonLogNormalQuadratureCompounds, one with
# prior `loc = 0.` and another with `loc = 1.` In both cases `scale = 1.`
- pln = ds.PoissonLogNormalQuadratureCompound(
+ pln = tfd.PoissonLogNormalQuadratureCompound(
loc=[0., -0.5],
scale=1.,
quadrature_grid_and_probs=(
diff --git a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
index b05f15771a..c4b8f055b7 100644
--- a/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
+++ b/tensorflow/contrib/distributions/python/ops/sinh_arcsinh.py
@@ -115,7 +115,7 @@ class SinhArcsinh(transformed_distribution.TransformedDistribution):
tailweight: Tailweight parameter. Default is `1.0` (unchanged tailweight)
distribution: `tf.Distribution`-like instance. Distribution that is
transformed to produce this distribution.
- Default is `ds.Normal(0., 1.)`.
+ Default is `tf.distributions.Normal(0., 1.)`.
Must be a scalar-batch, scalar-event distribution. Typically
`distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
a function of non-trainable parameters. WARNING: If you backprop through
diff --git a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
index 92043d6a08..904724af42 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_diffeomixture.py
@@ -188,8 +188,7 @@ class VectorDiffeomixture(distribution_lib.Distribution):
#### Examples
```python
- ds = tf.contrib.distributions
- la = tf.linalg
+ tfd = tf.contrib.distributions
# Create two batches of VectorDiffeomixtures, one with mix_loc=[0.] and
# another with mix_loc=[1]. In both cases, `K=2` and the affine
@@ -197,20 +196,20 @@ class VectorDiffeomixture(distribution_lib.Distribution):
# k=0: loc=zeros(dims) scale=LinearOperatorScaledIdentity
# k=1: loc=[2.]*dims scale=LinOpDiag
dims = 5
- vdm = ds.VectorDiffeomixture(
+ vdm = tfd.VectorDiffeomixture(
mix_loc=[[0.], [1]],
mix_scale=[1.],
- distribution=ds.Normal(loc=0., scale=1.),
+ distribution=tfd.Normal(loc=0., scale=1.),
loc=[
None, # Equivalent to `np.zeros(dims, dtype=np.float32)`.
np.float32([2.]*dims),
],
scale=[
- la.LinearOperatorScaledIdentity(
+ tf.linalg.LinearOperatorScaledIdentity(
num_rows=dims,
multiplier=np.float32(1.1),
is_positive_definite=True),
- la.LinearOperatorDiag(
+ tf.linalg.LinearOperatorDiag(
diag=np.linspace(2.5, 3.5, dims, dtype=np.float32),
is_positive_definite=True),
],
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
index 356d78b67a..526fe2d39a 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_diag.py
@@ -89,14 +89,13 @@ class VectorExponentialDiag(
#### Examples
```python
- ds = tf.contrib.distributions
- la = tf.linalg
+ tfd = tf.contrib.distributions
# Initialize a single 2-variate VectorExponential, supported on
# {(x, y) in R^2 : x > 0, y > 0}.
# The first component has pdf exp{-x}, the second 0.5 exp{-x / 2}
- vex = ds.VectorExponentialDiag(scale_diag=[1., 2.])
+ vex = tfd.VectorExponentialDiag(scale_diag=[1., 2.])
# Compute the pdf of an`R^2` observation; return a scalar.
vex.prob([3., 4.]).eval() # shape: []
@@ -107,7 +106,7 @@ class VectorExponentialDiag(
scale_diag = [[1., 2, 3],
[0.5, 1, 1.5]] # shape: [2, 3]
- vex = ds.VectorExponentialDiag(loc, scale_diag)
+ vex = tfd.VectorExponentialDiag(loc, scale_diag)
# Compute the pdf of two `R^3` observations; return a length-2 vector.
x = [[1.9, 2.2, 3.1],
diff --git a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
index b313a851b3..9d5fd9ac41 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_exponential_linear_operator.py
@@ -107,16 +107,15 @@ class VectorExponentialLinearOperator(
#### Examples
```python
- ds = tf.contrib.distributions
- la = tf.linalg
+ tfd = tf.contrib.distributions
# Initialize a single 2-variate VectorExponential, supported on
# {(x, y) in R^2 : x > 0, y > 0}.
mat = [[1.0, 0.1],
[0.1, 1.0]]
- vex = ds.VectorExponentialLinearOperator(
- scale=la.LinearOperatorFullMatrix(mat))
+ vex = tfd.VectorExponentialLinearOperator(
+ scale=tf.linalg.LinearOperatorFullMatrix(mat))
# Compute the pdf of an`R^2` observation; return a scalar.
vex.prob([1., 2.]).eval() # shape: []
@@ -127,9 +126,9 @@ class VectorExponentialLinearOperator(
scale_diag = [[1., 2, 3],
[0.5, 1, 1.5]] # shape: [2, 3]
- vex = ds.VectorExponentialLinearOperator(
+ vex = tfd.VectorExponentialLinearOperator(
loc=mu,
- scale=la.LinearOperatorDiag(scale_diag))
+ scale=tf.linalg.LinearOperatorDiag(scale_diag))
# Compute the pdf of two `R^3` observations; return a length-2 vector.
x = [[1.9, 2.2, 3.1],
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
index 0e3867809a..8dd983b750 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_diag.py
@@ -101,10 +101,10 @@ class VectorLaplaceDiag(
#### Examples
```python
- ds = tf.contrib.distributions
+ tfd = tf.contrib.distributions
# Initialize a single 2-variate VectorLaplace.
- vla = ds.VectorLaplaceDiag(
+ vla = tfd.VectorLaplaceDiag(
loc=[1., -1],
scale_diag=[1, 2.])
@@ -118,7 +118,7 @@ class VectorLaplaceDiag(
vla.prob([-1., 0]).eval() # shape: []
# Initialize a 3-batch, 2-variate scaled-identity VectorLaplace.
- vla = ds.VectorLaplaceDiag(
+ vla = tfd.VectorLaplaceDiag(
loc=[1., -1],
scale_identity_multiplier=[1, 2., 3])
@@ -136,7 +136,7 @@ class VectorLaplaceDiag(
vla.prob([-1., 0]).eval() # shape: [3]
# Initialize a 2-batch of 3-variate VectorLaplace's.
- vla = ds.VectorLaplaceDiag(
+ vla = tfd.VectorLaplaceDiag(
loc=[[1., 2, 3],
[11, 22, 33]] # shape: [2, 3]
scale_diag=[[1., 2, 3],
diff --git a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
index c7abdbb4ca..ec485c95c1 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_laplace_linear_operator.py
@@ -109,8 +109,7 @@ class VectorLaplaceLinearOperator(
#### Examples
```python
- ds = tf.contrib.distributions
- la = tf.linalg
+ tfd = tf.contrib.distributions
# Initialize a single 3-variate VectorLaplace with some desired covariance.
mu = [1., 2, 3]
@@ -124,9 +123,9 @@ class VectorLaplaceLinearOperator(
# [ 0.1, -0.3, 0.4]])
# Divide scale by sqrt(2) so that the final covariance will be what we want.
- vla = ds.VectorLaplaceLinearOperator(
+ vla = tfd.VectorLaplaceLinearOperator(
loc=mu,
- scale=la.LinearOperatorLowerTriangular(scale / tf.sqrt(2)))
+ scale=tf.linalg.LinearOperatorLowerTriangular(scale / tf.sqrt(2.)))
# Covariance agrees with cholesky(cov) parameterization.
vla.covariance().eval()
@@ -143,9 +142,9 @@ class VectorLaplaceLinearOperator(
scale_diag = [[1., 2, 3],
[0.5, 1, 1.5]] # shape: [2, 3]
- vla = ds.VectorLaplaceLinearOperator(
+ vla = tfd.VectorLaplaceLinearOperator(
loc=mu,
- scale=la.LinearOperatorDiag(scale_diag))
+ scale=tf.linalg.LinearOperatorDiag(scale_diag))
# Compute the pdf of two `R^3` observations; return a length-2 vector.
x = [[-0.9, 0, 0.1],
diff --git a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
index 544a871070..e1ccf11645 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_sinh_arcsinh_diag.py
@@ -143,7 +143,7 @@ class VectorSinhArcsinhDiag(transformed_distribution.TransformedDistribution):
broadcastable with `event_shape`.
distribution: `tf.Distribution`-like instance. Distribution from which `k`
iid samples are used as input to transformation `F`. Default is
- `ds.Normal(0., 1.)`.
+ `tf.distributions.Normal(loc=0., scale=1.)`.
Must be a scalar-batch, scalar-event distribution. Typically
`distribution.reparameterization_type = FULLY_REPARAMETERIZED` or it is
a function of non-trainable parameters. WARNING: If you backprop through
diff --git a/tensorflow/contrib/distributions/python/ops/vector_student_t.py b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
index 29d41ab81c..8c67647a61 100644
--- a/tensorflow/contrib/distributions/python/ops/vector_student_t.py
+++ b/tensorflow/contrib/distributions/python/ops/vector_student_t.py
@@ -91,14 +91,14 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
Extra leading dimensions, if provided, allow for batches.
```python
- ds = tf.contrib.distributions
+ tfd = tf.contrib.distributions
# Initialize a single 3-variate vector Student's t-distribution.
mu = [1., 2, 3]
chol = [[1., 0, 0.],
[1, 3, 0],
[1, 2, 3]]
- vt = ds.VectorStudentT(df=2, loc=mu, scale_tril=chol)
+ vt = tfd.VectorStudentT(df=2, loc=mu, scale_tril=chol)
# Evaluate this on an observation in R^3, returning a scalar.
vt.prob([-1., 0, 1])
@@ -107,7 +107,7 @@ class _VectorStudentT(transformed_distribution.TransformedDistribution):
mu = [[1., 2, 3],
[11, 22, 33]]
chol = ... # shape 2 x 3 x 3, lower triangular, positive diagonal.
- vt = ds.VectorStudentT(loc=mu, scale_tril=chol)
+ vt = tfd.VectorStudentT(loc=mu, scale_tril=chol)
# Evaluate this on a two observations, each in R^3, returning a length two
# tensor.
diff --git a/tensorflow/contrib/eager/python/BUILD b/tensorflow/contrib/eager/python/BUILD
index bf2e883bc5..55d768044b 100644
--- a/tensorflow/contrib/eager/python/BUILD
+++ b/tensorflow/contrib/eager/python/BUILD
@@ -232,6 +232,7 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":network",
+ "//tensorflow/contrib/layers:layers_py",
"//tensorflow/python:constant_op",
"//tensorflow/python:errors",
"//tensorflow/python:framework_test_lib",
diff --git a/tensorflow/contrib/eager/python/network.py b/tensorflow/contrib/eager/python/network.py
index 0388aaa849..e3c13cbd2e 100644
--- a/tensorflow/contrib/eager/python/network.py
+++ b/tensorflow/contrib/eager/python/network.py
@@ -451,8 +451,30 @@ class Network(base.Layer):
"at https://github.com/tensorflow/tensorflow/issues/new if this is "
"important to you")
+ def add_loss(self, losses, inputs=None):
+ raise RuntimeError(
+ "add_loss is not supported in Network class yet. Please file an issue "
+ "at https://github.com/tensorflow/tensorflow/issues/new if this is "
+ "important to you")
+
+ @property
+ def losses(self):
+ """Gather losses from `Layer`s in the `Network`.
+
+ Note that when executing eagerly, `Layer.losses` evaluates
+ regularizers. When using graph execution, variable regularization ops have
+ already been created and are simply returned here.
+
+ Returns:
+ A list of tensors.
+ """
+ layer_losses = []
+ for layer in self.layers:
+ layer_losses.extend(layer.losses)
+ return layer_losses
+
# TODO(allenl): Support other Layer methods needed for graph mode, such as for
- # losses and updates
+ # updates
class Sequential(Network):
diff --git a/tensorflow/contrib/eager/python/network_test.py b/tensorflow/contrib/eager/python/network_test.py
index e7835a63e6..3eb4f5f8b3 100644
--- a/tensorflow/contrib/eager/python/network_test.py
+++ b/tensorflow/contrib/eager/python/network_test.py
@@ -19,6 +19,7 @@ from __future__ import print_function
import gc
from tensorflow.contrib.eager.python import network
+from tensorflow.contrib.layers.python.layers import regularizers
from tensorflow.python.eager import context
from tensorflow.python.eager import function
from tensorflow.python.eager import test
@@ -45,6 +46,22 @@ class MyNetwork(network.Network):
return self.l1(x)
+class RegularizedNetwork(network.Network):
+
+ def __init__(self):
+ super(RegularizedNetwork, self).__init__()
+ self.l1 = self.track_layer(core.Dense(
+ 1,
+ bias_regularizer=regularizers.l1_regularizer(2.0),
+ kernel_regularizer=regularizers.l1_regularizer(2.0)))
+ self.l2 = self.track_layer(core.Dense(
+ 1,
+ bias_regularizer=regularizers.l1_regularizer(2.0)))
+
+ def call(self, values):
+ return self.l2(self.l1(values))
+
+
class NetworkTest(test.TestCase):
def _save_modify_load_network_built(self, net, global_step=None):
@@ -485,6 +502,18 @@ class NetworkTest(test.TestCase):
checked_ops=checked_ops)
@test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
+ def testVariableRegularizers(self):
+ net = RegularizedNetwork()
+ net(constant_op.constant([[1.]]))
+ self.evaluate(net.variables[0].assign([[2.]]))
+ self.evaluate(net.variables[1].assign([3.]))
+ self.evaluate(net.variables[2].assign([[-2.]]))
+ self.evaluate(net.variables[3].assign([4.]))
+ self.assertAllEqual([4., 6., 8.], self.evaluate(net.losses))
+ self.evaluate(net.variables[3].assign([5.]))
+ self.assertAllEqual([4., 6., 10.], self.evaluate(net.losses))
+
+ @test_util.run_in_graph_and_eager_modes(assert_no_eager_garbage=True)
def testDuplicateNameError(self):
one = constant_op.constant([[1.]])
net = MyNetwork(name="foo")
diff --git a/tensorflow/contrib/estimator/BUILD b/tensorflow/contrib/estimator/BUILD
index 8395e2db5e..706a174efb 100644
--- a/tensorflow/contrib/estimator/BUILD
+++ b/tensorflow/contrib/estimator/BUILD
@@ -93,6 +93,7 @@ py_test(
srcs_version = "PY2AND3",
tags = [
"no_pip",
+ "notap", # b/62204861
"notsan",
],
deps = [
@@ -346,7 +347,7 @@ py_library(
cuda_py_test(
name = "replicate_model_fn_test",
- size = "small",
+ size = "medium",
srcs = ["python/estimator/replicate_model_fn_test.py"],
additional_deps = [
"//tensorflow/python/estimator",
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
index d9c83aa865..f5154231da 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn.py
@@ -42,10 +42,49 @@ from tensorflow.python.ops import sparse_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import tf_logging
+from tensorflow.python.training import device_setter as device_setter_lib
from tensorflow.python.training import training_util
-def replicate_model_fn(model_fn, optimizer_fn, devices=None):
+class Mode(object):
+ """Modes for variables replication used for forcing a particular mode.
+
+ Forcing a mode is meant for performance experimentation purposes rather than
+ for general use cases.
+ """
+
+ AUTO = 0
+ """Use internal heuristics for choosing the best Mode value.
+
+ This mode is supposed to be the most appropriate in most cases given what
+ is known about the system.
+ """
+ # TODO(isaprykin): Query system configuration to choose modes other than
+ # `SHARED_LOCAL_PARAMETER_SERVER`, even though it is often appropriate.
+
+ SHARED_LOCAL_PARAMETER_SERVER = 2
+ """Variables are placed on a single device and shared across all devices.
+
+ Two ways to achieve this replication over available GPUs are supported:
+ 1) If exactly 1 GPU is detected, then variables and operations are placed
+ onto GPU.
+ 2) If more than 1 GPU is detected, then variables are going to be placed on
+ the CPU. Replicas of operations are placed on each individual GPU.
+ """
+
+ SHARED_ROUND_ROBIN = 3
+ """Variables are placed on all devices in a round-robin fashion.
+
+ Every subsequent variable is placed on the next device. There is only one
+ copy of each variable that is shared across all devices.
+ """
+
+ # TODO(isaprykin): Implement `REPLICATED_ALL_REDUCE`.
+ REPLICATED_ALL_REDUCE = 3
+ """Variables are mirrored on all devices."""
+
+
+def replicate_model_fn(model_fn, optimizer_fn, devices=None, mode=Mode.AUTO):
"""Replicate `Estimator.model_fn` over GPUs within a single host.
The given `model_fn` specifies a single forward pass of a model. To replicate
@@ -58,14 +97,11 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None):
optimizer.
If `devices` are `None`, then all available GPUs are going to be used for
- replication. If no GPUs are available, then the model is going to be
- placed on the CPU.
+ replication: `devices=[<all available GPUs>]`. If no GPUs are available,
+ then the model is going to be placed on the CPU: `devices=['/device:CPU:0']`.
- Two modes of local replication over available GPUs are supported:
- 1) If exactly 1 GPU is detected, then variables and operations are placed
- onto GPU.
- 2) If more than 1 GPU is detected, then variables are going to be placed on
- the CPU. Replicas of operations are placed on each individual GPU.
+ Varibles are placed on to `devices` according to the given `mode`. Operations
+ are going for each tower are going to be copied on each device.
Here is an example of how one might use their `model_fn` to run over GPUs:
```python
@@ -127,6 +163,8 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None):
argument can be used to replice only on the subset of available GPUs.
If `None`, then all available GPUs are going to be used for replication.
If no GPUs are available, then the model is going to be placed on the CPU.
+ mode: An optional argument that specifies the replication method used for
+ distributing variables across devices.
Returns:
A replicated version of the supplied `model_fn`. Returned function that
@@ -137,16 +175,21 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None):
devices = _get_local_devices('GPU') or _get_local_devices('CPU')
is_a_single_gpu_case = len(devices) == 1 and 'GPU' in devices[0]
- local_ps_device = '/{}:0'.format('GPU' if is_a_single_gpu_case else 'CPU')
+ consolidation_device = '/{}:0'.format('GPU'
+ if is_a_single_gpu_case else 'CPU')
+
+ ps_devices = [consolidation_device]
+ if mode == Mode.SHARED_ROUND_ROBIN:
+ ps_devices = devices
- tf_logging.info('Replicating the `model_fn` across {}. Local parameter '
- 'server device is going to be {}.'.format(
- devices, local_ps_device))
+ tf_logging.info('Replicating the `model_fn` across {}. Variables are going '
+ 'to be placed on {}. Consolidation device is going to be {}.'
+ .format(devices, ps_devices, consolidation_device))
def replicated_model_fn(features, labels, mode, params=None, config=None):
"""Replicated version of `model_fn` to be used instead."""
feature_shards, label_shards = _split_batch(
- features, labels, len(devices), device=local_ps_device)
+ features, labels, len(devices), device=consolidation_device)
tower_specs = _get_loss_towers(
model_fn=model_fn,
mode=mode,
@@ -155,17 +198,17 @@ def replicate_model_fn(model_fn, optimizer_fn, devices=None):
params=params,
config=config,
devices=devices,
- local_ps_device=local_ps_device)
+ local_ps_devices=ps_devices)
if mode == model_fn_lib.ModeKeys.TRAIN:
train_op = _minimize_towers(tower_specs,
_call_optimizer_fn(optimizer_fn, params))
return _train_spec(
- tower_specs, train_op, aggregation_device=local_ps_device)
+ tower_specs, train_op, aggregation_device=consolidation_device)
elif mode == model_fn_lib.ModeKeys.EVAL:
- return _eval_spec(tower_specs, aggregation_device=local_ps_device)
+ return _eval_spec(tower_specs, aggregation_device=consolidation_device)
elif mode == model_fn_lib.ModeKeys.PREDICT:
- return _predict_spec(tower_specs, aggregation_device=local_ps_device)
+ return _predict_spec(tower_specs, aggregation_device=consolidation_device)
return replicated_model_fn
@@ -222,7 +265,7 @@ def _get_loss_towers(model_fn,
params,
config,
devices,
- local_ps_device,
+ local_ps_devices,
name_scope_pattern=_DEFAULT_NAME_SCOPE_PATTERN):
"""Replicate the loss computation across devices."""
tower_specs = []
@@ -234,15 +277,22 @@ def _get_loss_towers(model_fn,
if 'config' in model_fn_args:
optional_params['config'] = copy.deepcopy(config)
+ # pylint: disable=protected-access
+ round_robin_strategy = device_setter_lib._RoundRobinStrategy(
+ num_tasks=len(local_ps_devices))
+ # pylint: enable=protected-access
+
for i, device in enumerate(devices):
is_the_first_tower = (i == 0)
device_setter = _local_device_setter(
- worker_device=device, ps_device=local_ps_device)
+ worker_device=device,
+ ps_devices=local_ps_devices,
+ ps_strategy=round_robin_strategy)
- # We would like to preserve the names of the variables and ops that a user
- # might be relying on. Names with prefix are going to resolve to variables
- # and ops of the first tower.
+ # We would like to preserve the names of the variables and ops that the user
+ # might be relying on. Names without a prefix are going to resolve to
+ # variables and ops of the first tower.
name_scope = name_scope_pattern
if is_the_first_tower:
name_scope = ''
@@ -263,7 +313,7 @@ def _get_loss_towers(model_fn,
return tower_specs
-def _local_device_setter(ps_device, worker_device):
+def _local_device_setter(worker_device, ps_devices, ps_strategy):
"""A device setter that puts distributes Var/Ops to PS/workers."""
ps_ops = ['Variable', 'VariableV2', 'VarHandleOp']
@@ -273,7 +323,7 @@ def _local_device_setter(ps_device, worker_device):
node_def = op if isinstance(op, node_def_pb2.NodeDef) else op.node_def
if node_def.op in ps_ops:
ps_device_spec = framework_device.DeviceSpec.from_string(
- '{}'.format(ps_device))
+ '{}'.format(ps_devices[ps_strategy(op)]))
ps_device_spec.merge_from(current_device)
return ps_device_spec.to_string()
diff --git a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
index ffe69f89b4..662021853d 100644
--- a/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
+++ b/tensorflow/contrib/estimator/python/estimator/replicate_model_fn_test.py
@@ -49,15 +49,29 @@ from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.summary.writer import writer_cache
+from tensorflow.python.training import device_setter
from tensorflow.python.training import gradient_descent
+# TODO(isaprykin): Parametrize all the tests on replicate_model_fn.Mode when
+# it's supported.
class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
def setUp(self):
self._model_dir = tempfile.mkdtemp()
- def test_complete_flow(self):
+ def test_complete_flow_with_mode_auto(self):
+ return self._complete_flow_with_mode(replicate_model_fn.Mode.AUTO)
+
+ def test_complete_flow_with_mode_local_ps_server(self):
+ return self._complete_flow_with_mode(
+ replicate_model_fn.Mode.SHARED_LOCAL_PARAMETER_SERVER)
+
+ def test_complete_flow_with_mode_round_robin(self):
+ return self._complete_flow_with_mode(
+ replicate_model_fn.Mode.SHARED_ROUND_ROBIN)
+
+ def _complete_flow_with_mode(self, mode):
n_classes = 3
input_dimension = 2
batch_size = 12
@@ -109,7 +123,8 @@ class DNNClassifierIntegrationTest(test_util.TensorFlowTestCase):
model_fn=replicate_model_fn.replicate_model_fn(
estimator.model_fn,
optimizer_fn,
- devices=['/gpu:0', '/gpu:1', '/gpu:2']),
+ devices=['/gpu:0', '/gpu:1', '/gpu:2'],
+ mode=mode),
model_dir=estimator.model_dir,
config=estimator.config,
params=estimator.params)
@@ -359,7 +374,7 @@ class GetLossTowersTest(test_util.TensorFlowTestCase):
params=None,
config=None,
devices=['/gpu:0', '/gpu:1'],
- local_ps_device='/gpu:0',
+ local_ps_devices=['/gpu:0'],
name_scope_pattern='test_tower_{}')
session.run(variables.global_variables_initializer())
@@ -382,6 +397,54 @@ class GetLossTowersTest(test_util.TensorFlowTestCase):
c = variable_scope.get_variable('c', dtype=dtypes.float64)
self.assertEqual(0.25, session.run(c))
+ def test_variables_are_round_robined_correctly(self):
+ """Test that creates multiple variables and tests round-robin placement."""
+
+ def model_fn(mode, features, labels, params):
+ del params
+ for variable_name in ['a', 'b', 'c', 'd']:
+ c = variable_scope.get_variable(
+ variable_name,
+ initializer=constant_op.constant(0.25, dtype=dtypes.float64),
+ dtype=dtypes.float64)
+
+ predictions = math_ops.add(np.array([0.1, 0.2, 0.3, features[0]]), c)
+ labels = np.array([0.1, 0.2, 0.3, labels[0]])
+ loss = losses.absolute_difference(
+ labels=labels,
+ predictions=predictions,
+ reduction=losses.Reduction.SUM)
+ return model_fn_lib.EstimatorSpec(
+ mode=mode, loss=math_ops.reduce_sum(loss))
+
+ with self.test_session() as session:
+ tower_specs = replicate_model_fn._get_loss_towers(
+ model_fn,
+ mode=None,
+ features=[[0.6], [1.6], [2.6]],
+ labels=[[0.6], [0.6], [2.6]],
+ params=None,
+ config=None,
+ devices=['/gpu:0', '/gpu:1', '/gpu:3'],
+ local_ps_devices=['/gpu:0', '/gpu:1', '/gpu:3'],
+ name_scope_pattern='test_tower_{}')
+ session.run(variables.global_variables_initializer())
+
+ self.assertEqual(len(tower_specs), 3)
+ self.assertEqual('/device:GPU:0', tower_specs[0].loss.device)
+ self.assertEqual('/device:GPU:1', tower_specs[1].loss.device)
+ self.assertEqual('/device:GPU:3', tower_specs[2].loss.device)
+
+ with variable_scope.variable_scope('', reuse=True):
+ a = variable_scope.get_variable('a', dtype=dtypes.float64)
+ self.assertEqual('/device:GPU:0', a.device)
+ b = variable_scope.get_variable('b', dtype=dtypes.float64)
+ self.assertEqual('/device:GPU:1', b.device)
+ c = variable_scope.get_variable('c', dtype=dtypes.float64)
+ self.assertEqual('/device:GPU:3', c.device)
+ d = variable_scope.get_variable('d', dtype=dtypes.float64)
+ self.assertEqual('/device:GPU:0', d.device)
+
class SplitBatchTest(test_util.TensorFlowTestCase):
@@ -604,7 +667,7 @@ class PredictSpecTest(test_util.TensorFlowTestCase):
params=None,
config=None,
devices=['/gpu:0', '/gpu:1'],
- local_ps_device='/gpu:0',
+ local_ps_devices=['/gpu:0'],
)
session.run(variables.global_variables_initializer())
@@ -850,25 +913,66 @@ class GetLocalDevicesTest(test_util.TensorFlowTestCase):
class LocalDeviceSetterTest(test_util.TensorFlowTestCase):
def test_vars_are_on_ps_but_ops_are_on_workers(self):
+ ps_devices = ['/device:GPU:3']
+ round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices))
+
+ local_device_setter = replicate_model_fn._local_device_setter(
+ ps_devices=ps_devices,
+ ps_strategy=round_robin,
+ worker_device='/device:GPU:2')
+
+ with ops_lib.device(local_device_setter):
+ a = variables.Variable(0.01)
+ self.assertEqual('/device:GPU:3', a.device)
+
+ b = variables.Variable(0.02)
+ self.assertEqual('/device:GPU:3', b.device)
+
+ c = variables.Variable(0.03)
+ self.assertEqual('/device:GPU:3', c.device)
+
+ a_op = array_ops.concat(a, axis=0)
+ self.assertEqual('/device:GPU:2', a_op.device)
+
+ b_op = array_ops.concat(b, axis=0)
+ self.assertEqual('/device:GPU:2', b_op.device)
+
+ def test_round_robin_placement(self):
+ ps_devices = [
+ '/device:GPU:0', '/device:GPU:1', '/device:GPU:3', '/device:GPU:4'
+ ]
+ round_robin = device_setter._RoundRobinStrategy(num_tasks=len(ps_devices))
+
local_device_setter = replicate_model_fn._local_device_setter(
- ps_device='/device:GPU:3', worker_device='/device:GPU:2')
+ ps_devices=ps_devices,
+ ps_strategy=round_robin,
+ worker_device='/device:GPU:2')
with ops_lib.device(local_device_setter):
- c = variables.Variable(0.01)
+ a = variables.Variable(0.01)
+ self.assertEqual('/device:GPU:0', a.device)
+
+ b = variables.Variable(0.02)
+ self.assertEqual('/device:GPU:1', b.device)
+
+ c = variables.Variable(0.03)
self.assertEqual('/device:GPU:3', c.device)
- cc = variables.Variable(0.02)
- self.assertEqual('/device:GPU:3', cc.device)
+ a_op = array_ops.concat(a, axis=0)
+ self.assertEqual('/device:GPU:2', a_op.device)
+
+ b_op = array_ops.concat(b, axis=0)
+ self.assertEqual('/device:GPU:2', b_op.device)
- ccc = variables.Variable(0.03)
- self.assertEqual('/device:GPU:3', ccc.device)
+ c = variables.Variable(0.03)
+ self.assertEqual('/device:GPU:4', c.device)
+
+ d = variables.Variable(0.03)
+ self.assertEqual('/device:GPU:0', d.device)
c_op = array_ops.concat(c, axis=0)
self.assertEqual('/device:GPU:2', c_op.device)
- cc_op = array_ops.concat(cc, axis=0)
- self.assertEqual('/device:GPU:2', cc_op.device)
-
class ComputeSumWithDevicePlacementTest(test_util.TensorFlowTestCase):
diff --git a/tensorflow/contrib/ffmpeg/BUILD b/tensorflow/contrib/ffmpeg/BUILD
index dc5a04a0b1..eccce99071 100644
--- a/tensorflow/contrib/ffmpeg/BUILD
+++ b/tensorflow/contrib/ffmpeg/BUILD
@@ -155,7 +155,10 @@ tf_py_test(
data = [
":test_data",
],
- tags = ["manual"],
+ tags = [
+ "manual",
+ "notap",
+ ],
)
py_library(
diff --git a/tensorflow/contrib/ffmpeg/__init__.py b/tensorflow/contrib/ffmpeg/__init__.py
index 871dff7bbe..daba965a98 100644
--- a/tensorflow/contrib/ffmpeg/__init__.py
+++ b/tensorflow/contrib/ffmpeg/__init__.py
@@ -26,6 +26,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio
+from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video
from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio
from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video
diff --git a/tensorflow/contrib/ffmpeg/decode_video_op_test.py b/tensorflow/contrib/ffmpeg/decode_video_op_test.py
index 4d1fac4ef8..b43b6b8919 100644
--- a/tensorflow/contrib/ffmpeg/decode_video_op_test.py
+++ b/tensorflow/contrib/ffmpeg/decode_video_op_test.py
@@ -20,11 +20,9 @@ from __future__ import print_function
import os.path
-import six
+import six # pylint: disable=unused-import
from tensorflow.contrib import ffmpeg
-from tensorflow.python.framework import dtypes
-from tensorflow.python.ops import array_ops
from tensorflow.python.ops import image_ops
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
@@ -32,7 +30,8 @@ from tensorflow.python.platform import test
class DecodeVideoOpTest(test.TestCase):
- def _loadFileAndTest(self, filename, width, height, frames, bmp_filename, index):
+ def _loadFileAndTest(self, filename, width, height, frames, bmp_filename,
+ index):
"""Loads an video file and validates the output tensor.
Args:
@@ -40,6 +39,8 @@ class DecodeVideoOpTest(test.TestCase):
width: The width of the video.
height: The height of the video.
frames: The frames of the video.
+ bmp_filename: The filename for the bmp file.
+ index: Index location inside the video.
"""
with self.test_session():
path = os.path.join(resource_loader.get_data_files_path(), 'testdata',
@@ -48,7 +49,7 @@ class DecodeVideoOpTest(test.TestCase):
contents = f.read()
bmp_path = os.path.join(resource_loader.get_data_files_path(), 'testdata',
- bmp_filename)
+ bmp_filename)
with open(bmp_path, 'rb') as f:
bmp_contents = f.read()
@@ -58,7 +59,7 @@ class DecodeVideoOpTest(test.TestCase):
video_op = ffmpeg.decode_video(contents)
video = video_op.eval()
self.assertEqual(video.shape, (frames, height, width, 3))
- self.assertAllEqual(video[index,:,:,:], image)
+ self.assertAllEqual(video[index, :, :, :], image)
def testMp4(self):
self._loadFileAndTest('small.mp4', 560, 320, 166, 'small_100.bmp', 99)
diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
index 201774e1d0..1245f515fe 100644
--- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
+++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib.cc
@@ -220,7 +220,8 @@ string BuildWavFile(int32 samples_per_second, int32 channel_count,
Status ReadInfoFile(const string& filename, uint32* width, uint32* height,
uint32* frames) {
string data;
- ReadFileToString(Env::Default(), filename, &data);
+ TF_QCHECK_OK(ReadFileToString(Env::Default(), filename, &data))
+ << "Could not read FFmpeg file: " << filename;
bool in_output = false;
bool in_mapping = false;
uint32 frames_value = 0;
@@ -377,7 +378,7 @@ Status ReadVideoFile(const string& filename, std::vector<uint8>* output_data,
open(stderr_filename.c_str(), O_RDWR | O_CREAT | O_APPEND, 0600);
if (fd < 0) {
const int error = errno;
- LOG(ERROR) << "FFmpeg stderr file coule not be created: "
+ LOG(ERROR) << "FFmpeg stderr file could not be created: "
<< strerror(error);
::_exit(error);
}
diff --git a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc
index 39e7e90ccc..36fc71794b 100644
--- a/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc
+++ b/tensorflow/contrib/ffmpeg/default/ffmpeg_lib_utility_test.cc
@@ -23,6 +23,7 @@
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/core/threadpool.h"
+#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/test.h"
diff --git a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
index 78ead471d2..08b5a6ea48 100644
--- a/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
+++ b/tensorflow/contrib/ffmpeg/ffmpeg_ops.py
@@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py
+from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py
from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py
from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py
from tensorflow.contrib.util import loader
diff --git a/tensorflow/contrib/framework/python/framework/graph_util.py b/tensorflow/contrib/framework/python/framework/graph_util.py
index 6d5cde5c9e..a18ff2320d 100644
--- a/tensorflow/contrib/framework/python/framework/graph_util.py
+++ b/tensorflow/contrib/framework/python/framework/graph_util.py
@@ -150,5 +150,5 @@ def get_placeholders(graph):
# The return value (a Tensor) of placeholder() is the
# first output of this operation in fact.
operations = graph.get_operations()
- result = [i.outputs[0] for i in operations if i.type == 'Placeholder']
+ result = [i.outputs[0] for i in operations if i.type == "Placeholder"]
return result
diff --git a/tensorflow/contrib/framework/python/framework/graph_util_test.py b/tensorflow/contrib/framework/python/framework/graph_util_test.py
index 0722fafc13..b8a6d109e1 100644
--- a/tensorflow/contrib/framework/python/framework/graph_util_test.py
+++ b/tensorflow/contrib/framework/python/framework/graph_util_test.py
@@ -90,8 +90,9 @@ class GetPlaceholdersTest(test.TestCase):
with ops.Graph().as_default() as g:
placeholders = [array_ops.placeholder(dtypes.float32) for _ in range(5)]
results = graph_util.get_placeholders(g)
- self.assertEqual(sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access
- sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access
+ self.assertEqual(
+ sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access
+ sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access
if __name__ == '__main__':
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
index 88306094ab..5fec69ea43 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
@@ -493,6 +493,8 @@ void LaunchFusedConv2DBiasActivationOp<GPUDevice, T, BiasType, ScaleType>::
{{conv_input_rows, conv_input_cols}},
output_depth,
{{filter_rows, filter_cols}},
+ // TODO(yangzihao): Add support for arbitrary dilations for fused conv.
+ {{1, 1}}, // dilation_rows, dilation_cols
{{row_stride, col_stride}},
{{padding_rows, padding_cols}},
conv_input->dtype(),
diff --git a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
index dc43af1158..fa7a3c03aa 100644
--- a/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
+++ b/tensorflow/contrib/fused_conv/kernels/fused_conv_ops_gpu.h
@@ -30,11 +30,12 @@ class FusedConvParameters : public ConvParameters {
public:
FusedConvParameters(int64 batch, int64 in_depths, const SpatialArray& in,
int64 out_depths, const SpatialArray& filter,
- const SpatialArray& stride, const SpatialArray& padding,
- DataType dtype, int device_id, bool has_side_input,
+ const SpatialArray& dilation, const SpatialArray& stride,
+ const SpatialArray& padding, DataType dtype,
+ int device_id, bool has_side_input,
ActivationMode activation_mode)
- : ConvParameters(batch, in_depths, in, out_depths, filter, stride,
- padding, dtype, device_id),
+ : ConvParameters(batch, in_depths, in, out_depths, filter, dilation,
+ stride, padding, dtype, device_id),
activation_mode_(activation_mode),
has_side_input_(has_side_input) {
hash_code_ = Hash64Combine(hash_code_, has_side_input);
diff --git a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
index 887ebc5a6c..6a56237f67 100644
--- a/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
+++ b/tensorflow/contrib/fused_conv/ops/fused_conv2d_bias_activation_op.cc
@@ -52,6 +52,7 @@ REGISTER_OP("FusedConv2DBiasActivation")
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
.Attr("filter_format: {'HWIO', 'OIHW', 'OIHW_VECT_I'} = 'HWIO'")
.Attr("activation_mode: {'Relu'} = 'Relu'")
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](shape_inference::InferenceContext* c) {
using shape_inference::ShapeHandle;
using shape_inference::DimensionHandle;
@@ -151,6 +152,11 @@ REGISTER_OP("FusedConv2DBiasActivation")
kernel_height, kernel_width, input_channels % 4 ]`
activation_mode: The activation applied to the output.
Currently must be "Relu".
+ dilations: 1-D tensor of length 4. The dilation factor for each dimension
+ of `input`. If set to k > 1, there will be k-1 skipped cells between
+ each filter element on that dimension. The dimension order is determined
+ by the value of `data_format`, see above for details. Dilations in the
+ batch and depth dimensions must be 1.
)doc");
} // namespace tensorflow
diff --git a/tensorflow/contrib/graph_editor/transform.py b/tensorflow/contrib/graph_editor/transform.py
index 2a97a79070..14ac529665 100644
--- a/tensorflow/contrib/graph_editor/transform.py
+++ b/tensorflow/contrib/graph_editor/transform.py
@@ -173,6 +173,9 @@ def copy_op_handler(info, op, copy_shape=True):
if op._original_op:
op_._original_op = op._original_op
+ # Add op to the graph
+ info.graph_._add_op(op_)
+
return op_, op_.outputs
diff --git a/tensorflow/contrib/kfac/python/ops/fisher_factors.py b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
index fbc192f1dc..6c1dd0ae40 100644
--- a/tensorflow/contrib/kfac/python/ops/fisher_factors.py
+++ b/tensorflow/contrib/kfac/python/ops/fisher_factors.py
@@ -580,6 +580,9 @@ class ConvDiagonalFactor(DiagonalFactor):
# the target entry of _outputs_grads changes with idx.)
with _maybe_colocate_with(inputs, self._colocate_cov_ops_with_inputs):
filter_height, filter_width, _, _ = self._filter_shape
+
+ # TODO(b/64144716): there is potential here for a big savings in terms of
+ # memory use.
patches = array_ops.extract_image_patches(
inputs,
ksizes=[1, filter_height, filter_width, 1],
@@ -739,6 +742,9 @@ class ConvInputKroneckerFactor(InverseProvidingFactor):
# TODO(jamesmartens): factor this patches stuff out into a utility function
with _maybe_colocate_with(self._inputs, self._colocate_cov_ops_with_inputs):
filter_height, filter_width, in_channels, _ = self._filter_shape
+
+ # TODO(b/64144716): there is potential here for a big savings in terms of
+ # memory use.
patches = array_ops.extract_image_patches(
self._inputs,
ksizes=[1, filter_height, filter_width, 1],
diff --git a/tensorflow/contrib/layers/python/layers/feature_column.py b/tensorflow/contrib/layers/python/layers/feature_column.py
index 226d933d85..092d418c3f 100644
--- a/tensorflow/contrib/layers/python/layers/feature_column.py
+++ b/tensorflow/contrib/layers/python/layers/feature_column.py
@@ -521,7 +521,7 @@ def sparse_column_with_integerized_feature(column_name,
Args:
column_name: A string defining sparse column name.
- bucket_size: An int that is > 1. The number of buckets. It should be bigger
+ bucket_size: An int that is >= 1. The number of buckets. It should be bigger
than maximum feature. In other words features in this column should be an
int64 in range [0, bucket_size)
combiner: A string specifying how to reduce if the sparse column is
@@ -539,7 +539,7 @@ def sparse_column_with_integerized_feature(column_name,
An integerized _SparseColumn definition.
Raises:
- ValueError: bucket_size is not greater than 1.
+ ValueError: bucket_size is less than 1.
ValueError: dtype is not integer.
"""
return _SparseColumnIntegerized(
diff --git a/tensorflow/contrib/layers/python/layers/layers.py b/tensorflow/contrib/layers/python/layers/layers.py
index 6cd586a5f0..6569b7ec9a 100644
--- a/tensorflow/contrib/layers/python/layers/layers.py
+++ b/tensorflow/contrib/layers/python/layers/layers.py
@@ -2561,7 +2561,10 @@ def separable_convolution2d(
regularizer=weights_regularizer,
trainable=trainable,
collections=weights_collections)
- strides = [1, 1, stride_h, stride_w] if data_format.startswith('NC') else [1, stride_h, stride_w, 1]
+ strides = [1, 1, stride_h,
+ stride_w] if data_format.startswith('NC') else [
+ 1, stride_h, stride_w, 1
+ ]
outputs = nn.depthwise_conv2d(inputs, depthwise_weights, strides, padding,
rate=utils.two_element_tuple(rate),
diff --git a/tensorflow/contrib/layers/python/layers/layers_test.py b/tensorflow/contrib/layers/python/layers/layers_test.py
index a05e464a26..ae64b75d93 100644
--- a/tensorflow/contrib/layers/python/layers/layers_test.py
+++ b/tensorflow/contrib/layers/python/layers/layers_test.py
@@ -3332,11 +3332,18 @@ class SeparableConv2dTest(test.TestCase):
batch, height, width = 4, 10, 12
kernel_dim, stride = 3, 2
images = random_ops.random_uniform((batch, 3, height, width), seed=1)
- output = layers_lib.separable_conv2d(images, num_outputs=num_filters, kernel_size=[kernel_dim, kernel_dim],
- depth_multiplier=2, stride=stride, padding='VALID', data_format='NCHW')
- self.assertListEqual(
- output.get_shape().as_list(), [batch, correct_output_filters,
- (height - kernel_dim + 1) // stride, (width - kernel_dim + 1) // stride])
+ output = layers_lib.separable_conv2d(
+ images,
+ num_outputs=num_filters,
+ kernel_size=[kernel_dim, kernel_dim],
+ depth_multiplier=2,
+ stride=stride,
+ padding='VALID',
+ data_format='NCHW')
+ self.assertListEqual(output.get_shape().as_list(), [
+ batch, correct_output_filters, (height - kernel_dim + 1) // stride,
+ (width - kernel_dim + 1) // stride
+ ])
class ScaleGradientTests(test.TestCase):
diff --git a/tensorflow/contrib/learn/BUILD b/tensorflow/contrib/learn/BUILD
index 94920db574..26bbcab307 100644
--- a/tensorflow/contrib/learn/BUILD
+++ b/tensorflow/contrib/learn/BUILD
@@ -461,6 +461,7 @@ py_test(
size = "medium",
srcs = ["python/learn/estimators/state_saving_rnn_estimator_test.py"],
srcs_version = "PY2AND3",
+ tags = ["noasan"],
deps = [
":learn",
"//tensorflow/contrib/layers:layers_py",
diff --git a/tensorflow/contrib/lite/build_ios_universal_lib.sh b/tensorflow/contrib/lite/build_ios_universal_lib.sh
index e0f2ef768b..cbc96e6edd 100755
--- a/tensorflow/contrib/lite/build_ios_universal_lib.sh
+++ b/tensorflow/contrib/lite/build_ios_universal_lib.sh
@@ -1,4 +1,19 @@
#!/bin/bash -x
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+
set -e
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8
diff --git a/tensorflow/contrib/lite/download_dependencies.sh b/tensorflow/contrib/lite/download_dependencies.sh
index 571d857be7..7fce1ba346 100755
--- a/tensorflow/contrib/lite/download_dependencies.sh
+++ b/tensorflow/contrib/lite/download_dependencies.sh
@@ -1,5 +1,5 @@
#!/bin/bash
-# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h
index 75b1f1da38..94046d9728 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h
+++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.h
@@ -14,8 +14,8 @@
#import <UIKit/UIKit.h>
-@interface AppDelegate : UIResponder <UIApplicationDelegate>
+@interface AppDelegate : UIResponder<UIApplicationDelegate>
-@property (strong, nonatomic) UIWindow *window;
+@property(strong, nonatomic) UIWindow *window;
@end
diff --git a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm
index 1e808eb976..d1215fa0bf 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm
+++ b/tensorflow/contrib/lite/examples/ios/simple/AppDelegate.mm
@@ -22,8 +22,7 @@
didFinishLaunchingWithOptions:(NSDictionary *)launchOptions {
UITabBarController *bar = [[UITabBarController alloc] init];
- [bar setViewControllers:
- @[[[RunModelViewController alloc] init]]];
+ [bar setViewControllers:@[ [[RunModelViewController alloc] init] ]];
bar.selectedIndex = 0;
self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]];
self.window.rootViewController = bar;
@@ -31,14 +30,19 @@
return YES;
}
-- (void)applicationWillResignActive:(UIApplication *)application {}
+- (void)applicationWillResignActive:(UIApplication *)application {
+}
-- (void)applicationDidEnterBackground:(UIApplication *)application {}
+- (void)applicationDidEnterBackground:(UIApplication *)application {
+}
-- (void)applicationWillEnterForeground:(UIApplication *)application {}
+- (void)applicationWillEnterForeground:(UIApplication *)application {
+}
-- (void)applicationDidBecomeActive:(UIApplication *)application {}
+- (void)applicationDidBecomeActive:(UIApplication *)application {
+}
-- (void)applicationWillTerminate:(UIApplication *)application {}
+- (void)applicationWillTerminate:(UIApplication *)application {
+}
@end
diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h
index 4e1a83ccf5..a4b358b4eb 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h
+++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.h
@@ -18,7 +18,7 @@
- (IBAction)getUrl:(id)sender;
-@property (weak, nonatomic) IBOutlet UITextView *urlContentTextView;
-@property (weak, nonatomic) IBOutlet UITextField *urlTextField;
+@property(weak, nonatomic) IBOutlet UITextView *urlContentTextView;
+@property(weak, nonatomic) IBOutlet UITextField *urlTextField;
@end
diff --git a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
index 965d830105..0dafb1f61e 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
+++ b/tensorflow/contrib/lite/examples/ios/simple/RunModelViewController.mm
@@ -14,10 +14,10 @@
#import "RunModelViewController.h"
-#include <fstream>
-#include <iostream>
#include <pthread.h>
#include <unistd.h>
+#include <fstream>
+#include <iostream>
#include <queue>
#include <sstream>
#include <string>
@@ -30,7 +30,11 @@
#include "ios_image_load.h"
#define LOG(x) std::cerr
-#define CHECK(x) if (!(x)) { LOG(ERROR) << #x << "failed"; exit(1); }
+#define CHECK(x) \
+ if (!(x)) { \
+ LOG(ERROR) << #x << "failed"; \
+ exit(1); \
+ }
NSString* RunInferenceOnImage();
@@ -49,15 +53,12 @@ NSString* RunInferenceOnImage();
// Returns the top N confidence values over threshold in the provided vector,
// sorted by confidence in descending order.
-static void GetTopN(
- const float* prediction,
- const int prediction_size,
- const int num_results, const float threshold,
- std::vector<std::pair<float, int> >* top_results) {
+static void GetTopN(const float* prediction, const int prediction_size, const int num_results,
+ const float threshold, std::vector<std::pair<float, int> >* top_results) {
// Will contain top N results in ascending order.
- std::priority_queue<std::pair<float, int>,
- std::vector<std::pair<float, int> >,
- std::greater<std::pair<float, int> > > top_result_pq;
+ std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int> >,
+ std::greater<std::pair<float, int> > >
+ top_result_pq;
const long count = prediction_size;
for (int i = 0; i < count; ++i) {
@@ -88,8 +89,8 @@ static void GetTopN(
NSString* FilePathForResourceName(NSString* name, NSString* extension) {
NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension];
if (file_path == NULL) {
- LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "."
- << [extension UTF8String] << "' in bundle.";
+ LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String]
+ << "' in bundle.";
}
return file_path;
}
@@ -102,7 +103,8 @@ NSString* RunInferenceOnImage() {
NSString* graph_path = FilePathForResourceName(@"mobilenet_v1_1.0_224", @"tflite");
- std::unique_ptr<tflite::FlatBufferModel> model(tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]));
+ std::unique_ptr<tflite::FlatBufferModel> model(
+ tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]));
if (!model) {
LOG(FATAL) << "Failed to mmap model " << graph;
}
@@ -143,7 +145,7 @@ NSString* RunInferenceOnImage() {
std::ifstream t;
t.open([labels_path UTF8String]);
std::string line;
- while(t){
+ while (t) {
std::getline(t, line);
label_strings.push_back(line);
}
@@ -154,7 +156,8 @@ NSString* RunInferenceOnImage() {
int image_width;
int image_height;
int image_channels;
- std::vector<uint8_t> image_data = LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels);
+ std::vector<uint8_t> image_data =
+ LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels);
const int wanted_width = 224;
const int wanted_height = 224;
const int wanted_channels = 3;
@@ -212,8 +215,7 @@ NSString* RunInferenceOnImage() {
std::string predictions = ss.str();
NSString* result = @"";
- result = [NSString stringWithFormat: @"%@ - %s", result,
- predictions.c_str()];
-
+ result = [NSString stringWithFormat:@"%@ - %s", result, predictions.c_str()];
+
return result;
}
diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h
index 7287d0d63d..98934ce41d 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h
+++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.h
@@ -17,9 +17,7 @@
#include <vector>
-std::vector<uint8_t> LoadImageFromFile(const char* file_name,
- int* out_width,
- int* out_height,
- int* out_channels);
+std::vector<uint8_t> LoadImageFromFile(const char* file_name, int* out_width,
+ int* out_height, int* out_channels);
#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
diff --git a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm
index 789522d2a9..cb0fe1a765 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm
+++ b/tensorflow/contrib/lite/examples/ios/simple/ios_image_load.mm
@@ -14,17 +14,16 @@
#include "ios_image_load.h"
-#include <stdlib.h>
-#include <string.h>
#include <assert.h>
#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
#import <CoreImage/CoreImage.h>
#import <ImageIO/ImageIO.h>
-std::vector<uint8_t> LoadImageFromFile(const char* file_name,
- int* out_width, int* out_height,
- int* out_channels) {
+std::vector<uint8_t> LoadImageFromFile(const char* file_name, int* out_width, int* out_height,
+ int* out_channels) {
FILE* file_handle = fopen(file_name, "rb");
fseek(file_handle, 0, SEEK_END);
const size_t bytes_in_file = ftell(file_handle);
@@ -32,11 +31,10 @@ std::vector<uint8_t> LoadImageFromFile(const char* file_name,
std::vector<uint8_t> file_data(bytes_in_file);
fread(file_data.data(), 1, bytes_in_file, file_handle);
fclose(file_handle);
- CFDataRef file_data_ref = CFDataCreateWithBytesNoCopy(NULL, file_data.data(),
- bytes_in_file,
- kCFAllocatorNull);
- CGDataProviderRef image_provider =
- CGDataProviderCreateWithCFData(file_data_ref);
+
+ CFDataRef file_data_ref =
+ CFDataCreateWithBytesNoCopy(NULL, file_data.data(), bytes_in_file, kCFAllocatorNull);
+ CGDataProviderRef image_provider = CGDataProviderCreateWithCFData(file_data_ref);
const char* suffix = strrchr(file_name, '.');
if (!suffix || suffix == file_name) {
@@ -44,12 +42,10 @@ std::vector<uint8_t> LoadImageFromFile(const char* file_name,
}
CGImageRef image;
if (strcasecmp(suffix, ".png") == 0) {
- image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true,
- kCGRenderingIntentDefault);
- } else if ((strcasecmp(suffix, ".jpg") == 0) ||
- (strcasecmp(suffix, ".jpeg") == 0)) {
- image = CGImageCreateWithJPEGDataProvider(image_provider, NULL, true,
- kCGRenderingIntentDefault);
+ image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault);
+ } else if ((strcasecmp(suffix, ".jpg") == 0) || (strcasecmp(suffix, ".jpeg") == 0)) {
+ image =
+ CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault);
} else {
CFRelease(image_provider);
CFRelease(file_data_ref);
@@ -68,9 +64,10 @@ std::vector<uint8_t> LoadImageFromFile(const char* file_name,
const int bytes_in_image = (bytes_per_row * height);
std::vector<uint8_t> result(bytes_in_image);
const int bits_per_component = 8;
- CGContextRef context = CGBitmapContextCreate(result.data(), width, height,
- bits_per_component, bytes_per_row, color_space,
- kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big);
+
+ CGContextRef context =
+ CGBitmapContextCreate(result.data(), width, height, bits_per_component, bytes_per_row,
+ color_space, kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big);
CGColorSpaceRelease(color_space);
CGContextDrawImage(context, CGRectMake(0, 0, width, height), image);
CGContextRelease(context);
diff --git a/tensorflow/contrib/lite/examples/ios/simple/main.mm b/tensorflow/contrib/lite/examples/ios/simple/main.mm
index d70550a730..05cb55ddd7 100644
--- a/tensorflow/contrib/lite/examples/ios/simple/main.mm
+++ b/tensorflow/contrib/lite/examples/ios/simple/main.mm
@@ -14,7 +14,7 @@
#import <UIKit/UIKit.h>
-int main(int argc, char * argv[]) {
+int main(int argc, char *argv[]) {
@autoreleasepool {
NSString *delegateClassName = @"AppDelegate";
return UIApplicationMain(argc, argv, nil, delegateClassName);
diff --git a/tensorflow/contrib/lite/ios_makefile.inc b/tensorflow/contrib/lite/ios_makefile.inc
index bcff7ed988..345ed26212 100644
--- a/tensorflow/contrib/lite/ios_makefile.inc
+++ b/tensorflow/contrib/lite/ios_makefile.inc
@@ -1,47 +1,31 @@
-# Settings for iOS.
-ifeq ($(TARGET), IOS)
- BUILD_FOR_IOS_SIMULATOR := false
- ifeq ($(IOS_ARCH), x86_64)
- BUILD_FOR_IOS_SIMULATOR := true
- endif
- ifeq ($(IOS_ARCH), i386)
- BUILD_FOR_IOS_SIMULATOR := true
- endif
- ifeq ($(BUILD_FOR_IOS_SIMULATOR), true)
- IPHONEOS_PLATFORM := $(shell xcrun --sdk iphonesimulator \
- --show-sdk-platform-path)
- IPHONEOS_SYSROOT := $(shell xcrun --sdk iphonesimulator \
- --show-sdk-path)
- else
- IPHONEOS_PLATFORM := $(shell xcrun --sdk iphoneos --show-sdk-platform-path)
- IPHONEOS_SYSROOT := $(shell xcrun --sdk iphoneos --show-sdk-path)
- endif
- IOS_SDK_VERSION := $(shell xcrun --sdk iphoneos --show-sdk-version)
- MIN_SDK_VERSION := 9.0
- # Override IOS_ARCH with armv7, armv7s, arm64, i386, or x86_64.
- IOS_ARCH := x86_64
- CXXFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \
- -DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK \
- -fembed-bitcode \
- -Wno-c++11-narrowing \
- -mno-thumb \
- -fno-exceptions \
- -isysroot \
- ${IPHONEOS_SYSROOT} \
- -arch $(IOS_ARCH) \
- -O3
- CCFLAGS += -miphoneos-version-min=$(MIN_SDK_VERSION) \
- -fembed-bitcode \
- -mno-thumb \
- -isysroot \
- ${IPHONEOS_SYSROOT} \
- -arch $(IOS_ARCH) \
- -O3
- LDFLAGS := -fembed-bitcode \
- -miphoneos-version-min=${MIN_SDK_VERSION} \
- -arch $(IOS_ARCH)
- OBJDIR := $(OBJDIR)ios_$(IOS_ARCH)/
- LIBDIR := $(LIBDIR)ios_$(IOS_ARCH)/
- BINDIR := $(BINDIR)ios_$(IOS_ARCH)/
- DEPDIR := $(DEPDIR)ios_$(IOS_ARCH)/
-endif
+#Settings for iOS.
+ifeq($(TARGET), IOS) BUILD_FOR_IOS_SIMULATOR
+ : = false ifeq($(IOS_ARCH), x86_64) BUILD_FOR_IOS_SIMULATOR
+ : = true endif ifeq($(IOS_ARCH), i386) BUILD_FOR_IOS_SIMULATOR
+ : = true endif ifeq($(BUILD_FOR_IOS_SIMULATOR), true) IPHONEOS_PLATFORM
+ : = $(shell xcrun-- sdk iphonesimulator-- show - sdk - platform -
+ path) IPHONEOS_SYSROOT
+ : = $(shell xcrun-- sdk iphonesimulator-- show - sdk -
+ path) else IPHONEOS_PLATFORM
+ : = $(shell xcrun-- sdk iphoneos-- show - sdk - platform -
+ path) IPHONEOS_SYSROOT
+ : = $(shell xcrun-- sdk iphoneos-- show - sdk - path) endif IOS_SDK_VERSION
+ : = $(shell xcrun-- sdk iphoneos-- show - sdk - version) MIN_SDK_VERSION
+ : = 9.0
+#Override IOS_ARCH with armv7, armv7s, arm64, i386, or x86_64.
+ IOS_ARCH
+ : = x86_64 CXXFLAGS
+ += -miphoneos - version
+ - min = $(MIN_SDK_VERSION) - DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
+ - fembed - bitcode - Wno - c++ 11 - narrowing - mno - thumb
+ - fno - exceptions
+ - isysroot ${IPHONEOS_SYSROOT} - arch $(IOS_ARCH) - O3 CCFLAGS
+ += -miphoneos - version
+ - min = $(MIN_SDK_VERSION) - fembed - bitcode - mno - thumb
+ - isysroot ${IPHONEOS_SYSROOT} - arch $(IOS_ARCH) -
+ O3 LDFLAGS
+ : = -fembed - bitcode - miphoneos - version
+ - min = ${MIN_SDK_VERSION} - arch $(IOS_ARCH) OBJDIR
+ : = $(OBJDIR) ios_$(IOS_ARCH) / LIBDIR
+ : = $(LIBDIR) ios_$(IOS_ARCH) / BINDIR
+ : = $(BINDIR) ios_$(IOS_ARCH) / DEPDIR : = $(DEPDIR) ios_$(IOS_ARCH) / endif
diff --git a/tensorflow/contrib/lite/java/demo/README.md b/tensorflow/contrib/lite/java/demo/README.md
index 71b633c577..5d13a798e2 100644
--- a/tensorflow/contrib/lite/java/demo/README.md
+++ b/tensorflow/contrib/lite/java/demo/README.md
@@ -8,7 +8,12 @@
It's easiest with Android Studio.
- You'll need at least SDK version 23.
+ - Make sure to install the latest version of Bazel. Some distributions
+ ship with Bazel 0.5.4, which is too old.
- Bazel requires Android Build Tools `26.0.1` or higher.
+ - **Bazel is incompatible with NDK revisions 15 and above,** with revision
+ 16 being a compile-breaking change. [Download an older version manually
+ instead of using the SDK Manager.](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android#install-bazel-and-android-prerequisites)
- You also need to install the Android Support Repository, available
through Android Studio under `Android SDK Manager -> SDK Tools ->
Android Support Repository`.
@@ -19,7 +24,8 @@
- Make sure the `api_level` in `WORKSPACE` is set to an SDK version that
you have installed.
- By default, Android Studio will install the SDK to `~/Android/Sdk` and
- the NDK to `~/Android/Sdk/ndk-bundle`.
+ the NDK to `~/Android/Sdk/ndk-bundle` (but the NDK should be a manual
+ download until Bazel supports NDK 16. See bullet points under (1)).
2. Build the app with Bazel. The demo needs C++11:
diff --git a/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc b/tensorflow/contrib/lite/models/speech_asr_am_model_test.cc
index 30d89a1354..bf95b313f3 100644
--- a/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc
+++ b/tensorflow/contrib/lite/models/speech_asr_am_model_test.cc
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
-// Unit test for speech TERSE AM model using TFLite Ops.
+// Unit test for speech ASR AM model using TFLite Ops.
#include <string.h>
@@ -45,10 +45,10 @@ constexpr int kLstmLayer5OutputStateTensor = 103;
constexpr int kLstmLayer5CellStateTensor = 104;
constexpr int kModelOutputTensor = 109;
-TEST(SpeechTerseAm, RandomIOTest) {
+TEST(SpeechAsrAm, RandomIOTest) {
// Read the model.
string tflite_file_path =
- file::JoinPath(TestDataPath(), "speech_terse_am_model.tflite");
+ file::JoinPath(TestDataPath(), "speech_asr_am_model.tflite");
auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
CHECK(model) << "Failed to mmap model " << tflite_file_path;
@@ -62,13 +62,13 @@ TEST(SpeechTerseAm, RandomIOTest) {
// Load the input frames.
Frames input_frames;
const string input_file_path =
- file::JoinPath(TestDataPath(), "speech_terse_am_model_in.csv");
+ file::JoinPath(TestDataPath(), "speech_asr_am_model_in.csv");
ReadFrames(input_file_path, &input_frames);
// Load the golden output results.
Frames output_frames;
const string output_file_path =
- file::JoinPath(TestDataPath(), "speech_terse_am_model_out.csv");
+ file::JoinPath(TestDataPath(), "speech_asr_am_model_out.csv");
ReadFrames(output_file_path, &output_frames);
const int speech_batch_size =
diff --git a/tensorflow/contrib/lite/models/speech_terse_lm_model_test.cc b/tensorflow/contrib/lite/models/speech_asr_lm_model_test.cc
index 04c54ffb22..53f2b66da4 100644
--- a/tensorflow/contrib/lite/models/speech_terse_lm_model_test.cc
+++ b/tensorflow/contrib/lite/models/speech_asr_lm_model_test.cc
@@ -59,10 +59,10 @@ static void ClearLstmStates(Interpreter* interpreter) {
interpreter->tensor(kLstmLayer3CellStateTensor)->bytes);
}
-TEST(SpeechTerseLm, EndToEndTest) {
+TEST(SpeechAsrLm, EndToEndTest) {
// Read the model.
string tflite_file_path =
- file::JoinPath(TestDataPath(), "speech_terse_lm_model.tflite");
+ file::JoinPath(TestDataPath(), "speech_asr_lm_model.tflite");
auto model = FlatBufferModel::BuildFromFile(tflite_file_path.c_str());
CHECK(model) << "Failed to mmap model " << tflite_file_path;
@@ -76,13 +76,13 @@ TEST(SpeechTerseLm, EndToEndTest) {
// Load the input frames.
Frames input_frames;
const string input_file_path =
- file::JoinPath(TestDataPath(), "speech_terse_lm_model_in.csv");
+ file::JoinPath(TestDataPath(), "speech_asr_lm_model_in.csv");
ReadFrames(input_file_path, &input_frames);
// Load the golden output results.
Frames output_frames;
const string output_file_path =
- file::JoinPath(TestDataPath(), "speech_terse_lm_model_out.csv");
+ file::JoinPath(TestDataPath(), "speech_asr_lm_model_out.csv");
ReadFrames(output_file_path, &output_frames);
CHECK_EQ(interpreter->tensor(kModelInput1Tensor)->dims->size, 1);
diff --git a/tensorflow/contrib/lite/models/testdata/g3doc/README.md b/tensorflow/contrib/lite/models/testdata/g3doc/README.md
index c9630c00db..46b24248f0 100644
--- a/tensorflow/contrib/lite/models/testdata/g3doc/README.md
+++ b/tensorflow/contrib/lite/models/testdata/g3doc/README.md
@@ -86,25 +86,34 @@ same input.
### Models:
-[Speech hotword model (Svdf rank=1)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank1_2017_11_14.tflite)
+[Speech hotword model (Svdf
+rank=1)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank1_2017_11_14.tflite)
-[Speech hotword model (Svdf rank=2)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank2_2017_11_14.tflite)
+[Speech hotword model (Svdf
+rank=2)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank2_2017_11_14.tflite)
-[Speaker-id model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_speakerid_model_2017_11_14.tflite)
+[Speaker-id
+model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_speakerid_model_2017_11_14.tflite)
-[TTS model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_tts_model_2017_11_14.tflite)
+[TTS
+model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_tts_model_2017_11_14.tflite)
-[ASR AM model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_terse_am_model_2017_11_14.tflite)
+[ASR AM
+model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_terse_am_model_2017_11_14.tflite)
### Test benches
-[Speech hotword model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_hotword_model_test.cc)
+[Speech hotword model
+test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_hotword_model_test.cc)
-[Speaker-id model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc)
+[Speaker-id model
+test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc)
-[TTS model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_tts_model_test.cc)
+[TTS model
+test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_tts_model_test.cc)
-[ASR AM model test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc)
+[ASR AM model
+test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc)
## Android Support
The models have been tested on Android phones, using the following tests:
diff --git a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
index b78e958e7f..bdb5e01538 100644
--- a/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
+++ b/tensorflow/contrib/lite/nnapi/NeuralNetworksShim.h
@@ -1454,9 +1454,9 @@ inline int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) {
* {@link ANeuralNetworksExecution_setOutputFromMemory} and
* {@link ANeuralNetworksExecution_setOperandValue}.
*
- * To build a model that can accommodate inputs of various sizes, as you may want
- * to do for a CNN, set the size of the dimensions that will vary at run time to
- * 0. If you do so, provide the full dimensions when calling
+ * To build a model that can accommodate inputs of various sizes, as you may
+ * want to do for a CNN, set the size of the dimensions that will vary at run
+ * time to 0. If you do so, provide the full dimensions when calling
* {@link ANeuralNetworksExecution_setInput} or {@link
* ANeuralNetworksExecution_setInputFromMemory}.
*
diff --git a/tensorflow/contrib/lite/python/lite.py b/tensorflow/contrib/lite/python/lite.py
index 0fd70f842b..982ea90f2b 100644
--- a/tensorflow/contrib/lite/python/lite.py
+++ b/tensorflow/contrib/lite/python/lite.py
@@ -50,7 +50,7 @@ GRAPHVIZ_DOT = _toco_flags_pb2.GRAPHVIZ_DOT
# to protect against crashes. However, it breaks some dependent targets because
# it forces us to depend on an external py_binary. The experimental API doesn't
# have that drawback.
-EXPERIMENTAL_USE_TOCO_API_DIRECTLY = True
+EXPERIMENTAL_USE_TOCO_API_DIRECTLY = False
# Find the toco_from_protos binary using the resource loader if using from
# bazel, otherwise we are in a pip where console_scripts already has
diff --git a/tensorflow/contrib/lite/tools/benchmark_model.cc b/tensorflow/contrib/lite/tools/benchmark_model.cc
index f80949b23e..6ae3ab5729 100644
--- a/tensorflow/contrib/lite/tools/benchmark_model.cc
+++ b/tensorflow/contrib/lite/tools/benchmark_model.cc
@@ -31,7 +31,12 @@ void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
#endif
#define LOG(x) std::cerr
-#define CHECK(x) if (!(x)) { LOG(ERROR) << #x << "failed"; exit(1); }
+
+#define CHECK(x) \
+ if (!(x)) { \
+ LOG(ERROR) << #x << "failed"; \
+ exit(1); \
+ }
namespace tensorflow {
namespace benchmark_tflite_model {
diff --git a/tensorflow/contrib/lite/tools/mutable_op_resolver.h b/tensorflow/contrib/lite/tools/mutable_op_resolver.h
index 8206a5481d..be60cf476d 100644
--- a/tensorflow/contrib/lite/tools/mutable_op_resolver.h
+++ b/tensorflow/contrib/lite/tools/mutable_op_resolver.h
@@ -20,15 +20,14 @@ limitations under the License.
#include "tensorflow/contrib/lite/model.h"
// Needed to resolve unordered_set hash on older compilers.
-namespace std
-{
-template<>
- struct hash<tflite::BuiltinOperator> {
- size_t operator()(const tflite::BuiltinOperator &op) const {
- return std::hash<int>()(op);
- }
- };
-}
+namespace std {
+template <>
+struct hash<tflite::BuiltinOperator> {
+ size_t operator()(const tflite::BuiltinOperator& op) const {
+ return std::hash<int>()(op);
+ }
+};
+} // namespace std
namespace tflite {
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager.cc b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
index 31a35b0d53..913935b382 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager.cc
@@ -258,9 +258,37 @@ NcclManager::Communicator* NcclManager::GetCommunicator(
devices[i] = collective->participants[i]->gpu_device_id;
}
+ int device_count = num_devices;
+#if NCCL_MAJOR >= 2
+ // NCCL2 prevents InitAll for more communicators than devices (but doesn't
+ // check that device ids are unique). Work around it by initializing each
+ // rank individually.
+ cudaGetDeviceCount(&device_count);
+#endif
std::vector<ncclComm_t> nccl_comms(num_devices);
- auto result = ncclCommInitAll(nccl_comms.data(), num_devices, devices.data());
- CHECK_EQ(result, ncclSuccess) << ncclGetErrorString(result);
+ if (num_devices <= device_count) {
+ auto result =
+ ncclCommInitAll(nccl_comms.data(), num_devices, devices.data());
+ CHECK_EQ(result, ncclSuccess) << ncclGetErrorString(result);
+ } else {
+ int savedDevice = 0;
+ CHECK_EQ(cudaGetDevice(&savedDevice), cudaSuccess);
+ ncclUniqueId commId;
+ ncclGetUniqueId(&commId);
+#if NCCL_MAJOR >= 2
+ CHECK_EQ(ncclGroupStart(), ncclSuccess);
+#endif
+ for (int rank = 0; rank < num_devices; ++rank) {
+ cudaSetDevice(devices[rank]);
+ auto result =
+ ncclCommInitRank(nccl_comms.data() + rank, num_devices, commId, rank);
+ CHECK_EQ(result, ncclSuccess) << ncclGetErrorString(result);
+ }
+#if NCCL_MAJOR >= 2
+ CHECK_EQ(ncclGroupEnd(), ncclSuccess);
+#endif
+ cudaSetDevice(savedDevice);
+ }
for (int rank = 0; rank < num_devices; ++rank) {
members[rank].nccl_comm = nccl_comms[rank];
}
diff --git a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
index 505c4b0d71..abafe4b407 100644
--- a/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
+++ b/tensorflow/contrib/nccl/kernels/nccl_manager_test.cc
@@ -30,6 +30,8 @@ namespace tensorflow {
static std::vector<BaseGPUDevice*> GetGPUDevices() {
std::vector<Device*> devices;
SessionOptions session_options;
+ session_options.config.mutable_gpu_options()
+ ->set_per_process_gpu_memory_fraction(0.1);
session_options.env = Env::Default();
Status s = DeviceFactory::GetFactory(DEVICE_GPU)
->AddDevices(session_options, "", &devices);
diff --git a/tensorflow/contrib/summary/summary.py b/tensorflow/contrib/summary/summary.py
index f783179f61..9e6af5232f 100644
--- a/tensorflow/contrib/summary/summary.py
+++ b/tensorflow/contrib/summary/summary.py
@@ -31,6 +31,7 @@ from tensorflow.contrib.summary.summary_ops import audio
from tensorflow.contrib.summary.summary_ops import create_summary_db_writer
from tensorflow.contrib.summary.summary_ops import create_summary_file_writer
from tensorflow.contrib.summary.summary_ops import eval_dir
+from tensorflow.contrib.summary.summary_ops import flush
from tensorflow.contrib.summary.summary_ops import generic
from tensorflow.contrib.summary.summary_ops import graph
from tensorflow.contrib.summary.summary_ops import histogram
diff --git a/tensorflow/contrib/summary/summary_ops.py b/tensorflow/contrib/summary/summary_ops.py
index 8e37987cb7..de6f2cd79f 100644
--- a/tensorflow/contrib/summary/summary_ops.py
+++ b/tensorflow/contrib/summary/summary_ops.py
@@ -516,6 +516,27 @@ def import_event(tensor, name=None):
context.context().summary_writer_resource, tensor, name=name)
+def flush(writer=None, name=None):
+ """Forces summary writer to send any buffered data to storage.
+
+ This operation blocks until that finishes.
+
+ Args:
+ writer: The @{tf.contrib.summary.SummaryWriter} resource to flush.
+ The thread default will be used if this parameter is None.
+ Otherwise a @{tf.no_op} is returned.
+ name: A name for the operation (optional).
+
+ Returns:
+ The created @{tf.Operation}.
+ """
+ if writer is None:
+ writer = context.context().summary_writer_resource
+ if writer is None:
+ return control_flow_ops.no_op()
+ return gen_summary_ops.flush_summary_writer(writer, name=name)
+
+
def eval_dir(model_dir, name=None):
"""Construct a logdir for an eval summary writer."""
return os.path.join(model_dir, "eval" if not name else "eval_" + name)
diff --git a/tensorflow/contrib/summary/summary_ops_test.py b/tensorflow/contrib/summary/summary_ops_test.py
index d20300c858..54433deb28 100644
--- a/tensorflow/contrib/summary/summary_ops_test.py
+++ b/tensorflow/contrib/summary/summary_ops_test.py
@@ -108,6 +108,33 @@ class TargetTest(test_util.TensorFlowTestCase):
self.assertEqual(len(events), 2)
self.assertEqual(events[1].summary.value[0].tag, 'scalar')
+ def testMaxQueue(self):
+ logs = tempfile.mkdtemp()
+ with summary_ops.create_summary_file_writer(
+ logs, max_queue=2, flush_millis=999999,
+ name='lol').as_default(), summary_ops.always_record_summaries():
+ get_total = lambda: len(summary_test_util.events_from_logdir(logs))
+ # Note: First tf.Event is always file_version.
+ self.assertEqual(1, get_total())
+ summary_ops.scalar('scalar', 2.0, step=1)
+ self.assertEqual(1, get_total())
+ summary_ops.scalar('scalar', 2.0, step=2)
+ self.assertEqual(3, get_total())
+
+ def testFlush(self):
+ logs = tempfile.mkdtemp()
+ with summary_ops.create_summary_file_writer(
+ logs, max_queue=999999, flush_millis=999999,
+ name='lol').as_default(), summary_ops.always_record_summaries():
+ get_total = lambda: len(summary_test_util.events_from_logdir(logs))
+ # Note: First tf.Event is always file_version.
+ self.assertEqual(1, get_total())
+ summary_ops.scalar('scalar', 2.0, step=1)
+ summary_ops.scalar('scalar', 2.0, step=2)
+ self.assertEqual(1, get_total())
+ summary_ops.flush()
+ self.assertEqual(3, get_total())
+
class DbTest(summary_test_util.SummaryDbTest):
diff --git a/tensorflow/contrib/summary/summary_test_util.py b/tensorflow/contrib/summary/summary_test_util.py
index 94767c8df2..915820e05b 100644
--- a/tensorflow/contrib/summary/summary_test_util.py
+++ b/tensorflow/contrib/summary/summary_test_util.py
@@ -83,7 +83,7 @@ def events_from_logdir(logdir):
"""
assert gfile.Exists(logdir)
files = gfile.ListDirectory(logdir)
- assert len(files) == 1, "Found not exactly one file in logdir: %s" % files
+ assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files
return events_from_file(os.path.join(logdir, files[0]))
diff --git a/tensorflow/contrib/tpu/BUILD b/tensorflow/contrib/tpu/BUILD
index f542d94139..a34c7f91f2 100644
--- a/tensorflow/contrib/tpu/BUILD
+++ b/tensorflow/contrib/tpu/BUILD
@@ -32,21 +32,6 @@ cc_library(
)
py_library(
- name = "tpu_test_util",
- srcs = ["python/tpu/test_util.py"],
- srcs_version = "PY2AND3",
- deps = [
- ":tpu_lib",
- ":tpu_py",
- "//tensorflow/python:errors",
- "//tensorflow/python:framework_ops",
- "//tensorflow/python:framework_test_lib",
- "//tensorflow/python:session",
- "//tensorflow/python:variables",
- ],
-)
-
-py_library(
name = "tpu_estimator",
srcs = [
"python/tpu/tpu_config.py",
diff --git a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
index cbbd19800e..d389050e67 100644
--- a/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
+++ b/tensorflow/contrib/tpu/ops/cross_replica_ops.cc
@@ -22,7 +22,7 @@ namespace tensorflow {
REGISTER_OP("CrossReplicaSum")
.Input("input: T")
.Output("output: T")
- .Attr("T: {float}")
+ .Attr("T: {bfloat16, float}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
An Op to sum inputs across replicated TPU instances. Each
diff --git a/tensorflow/contrib/tpu/python/tpu/test_util.py b/tensorflow/contrib/tpu/python/tpu/test_util.py
deleted file mode 100644
index a5d4ff9722..0000000000
--- a/tensorflow/contrib/tpu/python/tpu/test_util.py
+++ /dev/null
@@ -1,296 +0,0 @@
-# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-# ===================================================================
-"""Utilities to ease testing on TPU devices."""
-
-from __future__ import absolute_import
-from __future__ import division
-from __future__ import print_function
-
-import os.path
-import pickle
-import tempfile
-
-import numpy as np
-
-from tensorflow.contrib.tpu.python.tpu import tpu
-from tensorflow.contrib.tpu.python.tpu import tpu_config
-from tensorflow.contrib.tpu.python.tpu import tpu_estimator
-from tensorflow.core.protobuf import config_pb2
-from tensorflow.python.client import session as tf_session
-from tensorflow.python.estimator import model_fn as model_fn_lib
-from tensorflow.python.framework import errors
-from tensorflow.python.framework import ops
-from tensorflow.python.framework import random_seed
-from tensorflow.python.framework import test_util
-from tensorflow.python.ops import gen_array_ops
-from tensorflow.python.ops import variables
-from tensorflow.python.platform import gfile
-from tensorflow.python.platform import tf_logging as logging
-from tensorflow.python.training import saver as tf_saver
-
-
-def has_tpu():
- """Check if a TPU device is available.
-
- Device enumeration via `device_lib` currently fails for TPU systems.
- (http://b/68333779). To work around this, we determine the existence of a
- TPU by a successful call to `initialize_system`.
-
- Returns:
- boolean, True if a TPU device is available, otherwise False.
- """
-
- def _check():
- with tf_session.Session() as sess:
- sess.run(tpu.initialize_system())
- sess.run(tpu.shutdown_system())
-
- try:
- _check()
- return True
- except errors.OpError as _:
- return False
-
-
-def _available_devices():
- devices = ["cpu"]
- if not test_util.gpu_device_name():
- devices.append("gpu")
-
- if has_tpu():
- devices.append("tpu")
-
- return tuple(devices)
-
-
-def copy_dir(src, tgt):
- """Copy src to tgt."""
- gfile.MakeDirs(tgt)
- seen_dirs = set()
- for dirname, _, files in gfile.Walk(src):
- for f in files:
- src_f = os.path.join(dirname, f)
- tgt_f = src_f.replace(src, tgt)
- tgt_d = os.path.dirname(tgt_f)
- if tgt_d not in seen_dirs:
- gfile.MkDir(tgt_d)
- seen_dirs.add(tgt_d)
- gfile.Copy(src_f, tgt_f, overwrite=True)
-
-
-def compare_model(model_fn,
- input_fn,
- params,
- master="local",
- temp_dir=None,
- num_shards=2,
- tolerance=1e-4):
- """Compare the results of running `model_fn` on the TPU and CPU."""
- if not temp_dir:
- temp_dir = tempfile.mkdtemp()
-
- cpu_model_dir = "%s/cpu-model" % temp_dir
- tpu_model_dir = "%s/tpu-model" % temp_dir
- initial_model_dir = "%s/initial-model" % temp_dir
-
- logging.info("Checkpoints and weights will be written to %s", temp_dir)
-
- num_steps = 1
-
- def _model_adapter(features, labels, mode, params):
- """Run users model function with random seeds fixed to known values."""
- random_seed.set_random_seed(0)
- np.random.seed(0)
- return model_fn(features, labels, mode, params)
-
- def _input_adapter(params):
- random_seed.set_random_seed(0)
- np.random.seed(0)
- return input_fn(params)
-
- def _make_run_config(model_dir):
- return tpu_config.RunConfig(
- master=master,
- model_dir=model_dir,
- save_checkpoints_secs=10000,
- session_config=config_pb2.ConfigProto(
- allow_soft_placement=True, log_device_placement=False),
- tpu_config=tpu_config.TPUConfig(
- iterations_per_loop=num_steps,
- num_shards=num_shards,
- ),
- )
-
- def _make_estimator(use_tpu, model_dir):
- return tpu_estimator.TPUEstimator(
- model_fn=_model_adapter,
- use_tpu=use_tpu,
- config=_make_run_config(model_dir),
- train_batch_size=num_shards,
- params=dict(params, use_tpu=use_tpu),
- )
-
- def _extract_weights(checkpoint):
- """Extract model weights from the given checkpoint file."""
- weights = {}
- graph = ops.Graph()
- with graph.as_default():
- features, labels = _input_adapter(dict(params, batch_size=num_shards))
- model_fn(
- features, labels,
- params=dict(params, use_tpu=False),
- mode=model_fn_lib.ModeKeys.TRAIN)
- saver = tf_saver.Saver()
- with tf_session.Session(graph=graph) as sess:
- saver.restore(sess, checkpoint)
- all_vars = []
- all_vars.extend(graph.get_collection(ops.GraphKeys.GLOBAL_VARIABLES))
- all_vars.extend(graph.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES))
- all_vars.extend(graph.get_collection(ops.GraphKeys.MODEL_VARIABLES))
-
- for var in all_vars:
- weights[var.name] = sess.run(var)
- return weights
-
- def _run_step(use_tpu, model_dir):
- """Create an estimator and run a single step on the given device."""
- tf_session.Session.reset(target=master)
-
- logging.info("Running step. TPU=%d. model_dir=%s", use_tpu, model_dir)
- est = _make_estimator(use_tpu=use_tpu, model_dir=model_dir)
- est.train(input_fn=_input_adapter, steps=num_steps)
- weights = _extract_weights(est.latest_checkpoint())
- with gfile.Open(os.path.join(temp_dir, "tpu-%d.weights" % use_tpu),
- "wb") as f:
- f.write(pickle.dumps(weights))
- return weights
-
- # initialize models to the same weights by running a single step on the CPU
- _run_step(use_tpu=False, model_dir=initial_model_dir)
-
- copy_dir(initial_model_dir, cpu_model_dir)
- copy_dir(initial_model_dir, tpu_model_dir)
-
- cpu_weights = _run_step(use_tpu=False, model_dir=cpu_model_dir)
- tpu_weights = _run_step(use_tpu=True, model_dir=tpu_model_dir)
-
- bad_weights = False
- for k in cpu_weights:
- if k not in tpu_weights:
- raise KeyError("Missing weight %s from TPU checkpoint.", k)
-
- if not np.allclose(
- cpu_weights[k], tpu_weights[k], rtol=tolerance, atol=tolerance):
- bad_weights = True
- logging.error("Weights for layer %s have diverged.", k)
-
- if bad_weights:
- raise ValueError("Some weights have diverged. Output pickle files have "
- "been written to %s for inspection." % temp_dir)
-
-
-class TPUTestCase(test_util.TensorFlowTestCase):
- """Adds helpers for testing on TPU devices to `TensorFlowTestCase`.
-
- Example usage:
-
- ```
- def model_fn(features):
- return tf.reduce_sum(features * 2)
-
- class ModelTests(test_util.TPUTestCase):
- def test_sum(self):
- v = np.random.randn(10, 10).astype("float32")
- self.assert_device_output(model_fn, [v], (v*2).sum(),
- devices=("cpu", "tpu"))
- ```
- """
-
- def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
- super(TPUTestCase, self).__init__(methodName)
- self._available_devices = _available_devices()
-
- def run_on_device(self, model_fn, model_inputs, device):
- """Runs `model_fn` on the given device.
-
- Raises an exception if no such device is available. `model_fn` should
- return one or more tensors as a list or tuple.
-
- Args:
- model_fn: Function returning one or more tensors.
- model_inputs: An iterable of Numpy arrays or scalars.
- These will be passed as arguments to `model_fn`.
- device: Device to run on. One of ("tpu", "gpu", "cpu").
-
- Returns:
- Output from the model function.
- """
-
- def _make_placeholders():
- return dict([(gen_array_ops.placeholder_with_default(v, v.shape), v)
- for v in model_inputs])
-
- if device == "tpu":
- with self.test_session(graph=ops.Graph()) as sess:
- placeholders = _make_placeholders()
- tpu_computation = tpu.rewrite(model_fn, placeholders.keys())
- sess.run(tpu.initialize_system())
- sess.run(variables.global_variables_initializer())
- result = sess.run(tpu_computation, placeholders)
- sess.run(tpu.shutdown_system())
- # TODO(b/36891278): supports non-flat returns lists in tpu.rewrite().
- if len(result) == 1:
- return result[0]
- return result
- elif device == "gpu":
- with self.test_session(graph=ops.Graph(), use_gpu=True) as sess:
- placeholders = _make_placeholders()
- sess.run(variables.global_variables_initializer())
- return sess.run(model_fn(placeholders.keys()), placeholders)
- elif device == "cpu":
- # TODO(power) -- will this interact poorly with cached GPU sessions?
- with self.test_session(graph=ops.Graph(), use_gpu=False) as sess:
- placeholders = _make_placeholders()
- sess.run(variables.global_variables_initializer())
- return sess.run(model_fn(placeholders.keys()), placeholders)
-
- def _compare_values(self, actual_outputs, expected_outputs):
- if isinstance(expected_outputs, (list, tuple)):
- for a, b in zip(actual_outputs, expected_outputs):
- self.assertAllCloseAccordingToType(a, b)
- else:
- self.assertAllCloseAccordingToType(actual_outputs, expected_outputs)
-
- def assert_device_output(self,
- model_fn,
- model_inputs,
- expected_outputs,
- devices=("cpu", "gpu", "tpu")):
- """Run `model_fn` on the given devices.
-
- Results are compared via `assertAllCloseAccordingToType`.
-
- Args:
- model_fn: Function returning one or more tensors
- model_inputs: Numpy arrays or scalars passed as arguments to model_fn
- expected_outputs: Numpy arrays or scalars to compare against.
- devices: Set of devices to run on. If a device is not available, tests
- will be skipped for that device.
- """
- devices = set(devices).intersection(self._available_devices)
-
- for device in devices:
- device_out = self.run_on_device(model_fn, model_inputs, device=device)
- self._compare_values(device_out, expected_outputs)
diff --git a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
index fe17664d7f..84a4208be3 100644
--- a/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
+++ b/tensorflow/contrib/tpu/python/tpu/tpu_estimator.py
@@ -514,6 +514,7 @@ class _InfeedThreadController(_InfeedOutfeedThreadBaseController):
exc_info=1
)
time.sleep(120)
+ logging.error('Closing the failed session.')
session.close()
def join(self):
diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD
index bd7617fa96..5bcb87d2d1 100644
--- a/tensorflow/core/BUILD
+++ b/tensorflow/core/BUILD
@@ -1016,7 +1016,7 @@ filegroup(
cc_library(
name = "android_tensorflow_lib_lite",
srcs = if_android(["//tensorflow/core:android_srcs"]),
- copts = tf_copts() + if_not_android_mips_and_mips64(["-Os"]),
+ copts = tf_copts(android_optimization_level_override = None),
linkopts = ["-lz"],
tags = [
"manual",
@@ -1106,8 +1106,7 @@ cc_library(
cc_library(
name = "android_tensorflow_lib_selective_registration",
srcs = if_android(["//tensorflow/core:android_srcs"]),
- copts = tf_copts() + [
- "-Os",
+ copts = tf_copts(android_optimization_level_override = None) + [
"-DSUPPORT_SELECTIVE_REGISTRATION",
],
tags = [
@@ -1129,8 +1128,7 @@ cc_library(
cc_library(
name = "android_tensorflow_lib_selective_registration_nortti",
srcs = if_android(["//tensorflow/core:android_srcs"]),
- copts = tf_copts() + tf_opts_nortti_if_android() + [
- "-Os",
+ copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_android() + [
"-DSUPPORT_SELECTIVE_REGISTRATION",
],
tags = [
@@ -1210,7 +1208,7 @@ cc_library(
"framework/tensor_testutil.h",
"util/reporter.h",
],
- copts = tf_copts() + ["-Os"],
+ copts = tf_copts(android_optimization_level_override = None),
tags = [
"manual",
"notap",
diff --git a/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt
index 6522ce976f..070d6adb97 100644
--- a/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Conv2D.pbtxt
@@ -26,7 +26,7 @@ END
description: <<END
1-D tensor of length 4. The stride of the sliding window for each
dimension of `input`. The dimension order is determined by the value of
- `data_format`, see below for details.
+`data_format`, see below for details.
END
}
attr {
@@ -45,6 +45,16 @@ Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each
+filter element on that dimension. The dimension order is determined by the
+value of `data_format`, see above for details. Dilations in the batch and
+depth dimensions must be 1.
+END
+ }
summary: "Computes a 2-D convolution given 4-D `input` and `filter` tensors."
description: <<END
Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
diff --git a/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt
index 4ea3374dbb..ff2d9d71db 100644
--- a/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropFilter.pbtxt
@@ -53,5 +53,15 @@ Alternatively, the format could be "NCHW", the data storage order of:
[batch, in_channels, in_height, in_width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each filter
+element on that dimension. The dimension order is determined by the value of
+`data_format`, see above for details. Dilations in the batch and depth
+dimensions must be 1.
+END
+ }
summary: "Computes the gradients of convolution with respect to the filter."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt
index 4420073e38..2de38b4263 100644
--- a/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Conv2DBackpropInput.pbtxt
@@ -52,5 +52,15 @@ Alternatively, the format could be "NCHW", the data storage order of:
[batch, in_channels, in_height, in_width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each filter
+element on that dimension. The dimension order is determined by the value of
+`data_format`, see above for details. Dilations in the batch and depth
+dimensions must be 1.
+END
+ }
summary: "Computes the gradients of convolution with respect to the input."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt
index 8f3cd4493c..d26564097e 100644
--- a/tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Conv3D.pbtxt
@@ -36,6 +36,16 @@ Alternatively, the format could be "NCDHW", the data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 5. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each
+filter element on that dimension. The dimension order is determined by the
+value of `data_format`, see above for details. Dilations in the batch and
+depth dimensions must be 1.
+END
+ }
summary: "Computes a 3-D convolution given 5-D `input` and `filter` tensors."
description: <<END
In signal processing, cross-correlation is a measure of similarity of
diff --git a/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt
index 6f9b917237..937c9c8ead 100644
--- a/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropFilterV2.pbtxt
@@ -45,5 +45,15 @@ Alternatively, the format could be "NCDHW", the data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 5. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each
+filter element on that dimension. The dimension order is determined by the
+value of `data_format`, see above for details. Dilations in the batch and
+depth dimensions must be 1.
+END
+ }
summary: "Computes the gradients of 3-D convolution with respect to the filter."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt b/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt
index 19aba156d5..414e418dc5 100644
--- a/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_Conv3DBackpropInputV2.pbtxt
@@ -45,5 +45,15 @@ Alternatively, the format could be "NCDHW", the data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 5. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each
+filter element on that dimension. The dimension order is determined by the
+value of `data_format`, see above for details. Dilations in the batch and
+depth dimensions must be 1.
+END
+ }
summary: "Computes the gradients of 3-D convolution with respect to the input."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt
index cc10ebe923..3c313f7be6 100644
--- a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNative.pbtxt
@@ -23,6 +23,16 @@ Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each filter
+element on that dimension. The dimension order is determined by the value of
+`data_format`, see above for details. Dilations in the batch and depth
+dimensions must be 1.
+END
+ }
summary: "Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors."
description: <<END
Given an input tensor of shape `[batch, in_height, in_width, in_channels]`
diff --git a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt
index 9126be2afa..e66aa3b707 100644
--- a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropFilter.pbtxt
@@ -56,5 +56,15 @@ Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each filter
+element on that dimension. The dimension order is determined by the value of
+`data_format`, see above for details. Dilations in the batch and depth
+dimensions must be 1.
+END
+ }
summary: "Computes the gradients of depthwise convolution with respect to the filter."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt
index f1d16858db..f501ad21b3 100644
--- a/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DepthwiseConv2dNativeBackpropInput.pbtxt
@@ -56,5 +56,15 @@ Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each filter
+element on that dimension. The dimension order is determined by the value of
+`data_format`, see above for details. Dilations in the batch and depth
+dimensions must be 1.
+END
+ }
summary: "Computes the gradients of depthwise convolution with respect to the input."
}
diff --git a/tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt b/tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt
index 00e96c8a15..dfaa531cbc 100644
--- a/tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_DeserializeSparse.pbtxt
@@ -14,4 +14,47 @@ The `dtype` of the serialized `SparseTensor` objects.
END
}
summary: "Deserialize `SparseTensor` objects."
+ description: <<END
+The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where
+the last dimension stores serialized `SparseTensor` objects and the other N
+dimensions (N >= 0) correspond to a batch. The ranks of the original
+`SparseTensor` objects must all match. When the final `SparseTensor` is
+created, its rank is the rank of the incoming `SparseTensor` objects plus N;
+the sparse tensors have been concatenated along new dimensions, one for each
+batch.
+
+The output `SparseTensor` object's shape values for the original dimensions
+are the max across the input `SparseTensor` objects' shape values for the
+corresponding dimensions. The new dimensions match the size of the batch.
+
+The input `SparseTensor` objects' indices are assumed ordered in
+standard lexicographic order. If this is not the case, after this
+step run `SparseReorder` to restore index ordering.
+
+For example, if the serialized input is a `[2 x 3]` matrix representing two
+original `SparseTensor` objects:
+
+ index = [ 0]
+ [10]
+ [20]
+ values = [1, 2, 3]
+ shape = [50]
+
+and
+
+ index = [ 2]
+ [10]
+ values = [4, 5]
+ shape = [30]
+
+then the final deserialized `SparseTensor` will be:
+
+ index = [0 0]
+ [0 10]
+ [0 20]
+ [1 2]
+ [1 10]
+ values = [1, 2, 3, 4, 5]
+ shape = [2 50]
+END
}
diff --git a/tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt b/tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt
index b19bbeab12..d18bafdce9 100644
--- a/tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt
+++ b/tensorflow/core/api_def/base_api/api_def_QuantizedConv2D.pbtxt
@@ -55,6 +55,16 @@ END
The type of padding algorithm to use.
END
}
+ attr {
+ name: "dilations"
+ description: <<END
+1-D tensor of length 4. The dilation factor for each dimension of
+`input`. If set to k > 1, there will be k-1 skipped cells between each
+filter element on that dimension. The dimension order is determined by the
+value of `data_format`, see above for details. Dilations in the batch and
+depth dimensions must be 1.
+END
+ }
summary: "Computes a 2D convolution given quantized 4D input and filter tensors."
description: <<END
The inputs are quantized tensors where the lowest value represents the real
diff --git a/tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt b/tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt
new file mode 100644
index 0000000000..0466b40f85
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_RandomDataset.pbtxt
@@ -0,0 +1,18 @@
+op {
+ graph_op_name: "RandomDataset"
+ in_arg {
+ name: "seed"
+ description: <<END
+A scalar seed for the random number generator. If either seed or
+seed2 is set to be non-zero, the random number generator is seeded
+by the given seed. Otherwise, a random seed is used.
+END
+ }
+ in_arg {
+ name: "seed2"
+ description: <<END
+A second scalar seed to avoid seed collision.
+END
+ }
+ summary: "Creates a Dataset that returns pseudorandom numbers."
+}
diff --git a/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt
new file mode 100644
index 0000000000..b07ee9fda9
--- /dev/null
+++ b/tensorflow/core/api_def/base_api/api_def_ResourceScatterNdUpdate.pbtxt
@@ -0,0 +1,69 @@
+op {
+ graph_op_name: "ResourceScatterNdUpdate"
+ in_arg {
+ name: "ref"
+ description: <<END
+A resource handle. Must be from a VarHandleOp.
+END
+ }
+ in_arg {
+ name: "indices"
+ description: <<END
+A Tensor. Must be one of the following types: int32, int64.
+A tensor of indices into ref.
+END
+ }
+ in_arg {
+ name: "updates"
+ description: <<END
+A Tensor. Must have the same type as ref. A tensor of updated
+values to add to ref.
+END
+ }
+ attr {
+ name: "use_locking"
+ description: <<END
+An optional bool. Defaults to True. If True, the assignment will
+be protected by a lock; otherwise the behavior is undefined,
+but may exhibit less contention.
+END
+ }
+ summary: "Applies sparse `updates` to individual values or slices within a given"
+ description: <<END
+variable according to `indices`.
+
+`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+`indices` must be integer tensor, containing indices into `ref`.
+It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+The innermost dimension of `indices` (with length `K`) corresponds to
+indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+dimension of `ref`.
+
+`updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+```
+[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+```
+
+For example, say we want to update 4 scattered elements to a rank-1 tensor to
+8 elements. In Python, that update would look like this:
+
+```python
+ ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ update = tf.scatter_nd_update(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(update)
+```
+
+The resulting update to ref would look like this:
+
+ [1, 11, 3, 10, 9, 6, 7, 12]
+
+See @{tf.scatter_nd} for more details about how to make updates to
+slices.
+END
+}
diff --git a/tensorflow/core/framework/bfloat16_test.cc b/tensorflow/core/framework/bfloat16_test.cc
index 6e45338751..17e6209f8e 100644
--- a/tensorflow/core/framework/bfloat16_test.cc
+++ b/tensorflow/core/framework/bfloat16_test.cc
@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/bfloat16.h"
+#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/lib/core/casts.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
@@ -104,6 +105,17 @@ TEST(Bfloat16Test, Conversion) {
}
}
+TEST(Bfloat16Test, Epsilon) {
+ EXPECT_LT(1.0f, static_cast<float>(bfloat16::epsilon() + bfloat16(1.0f)));
+ EXPECT_EQ(1.0f, static_cast<float>((bfloat16::epsilon() / bfloat16(2.0f)) +
+ bfloat16(1.0f)));
+}
+
+TEST(Bfloat16Test, Negate) {
+ EXPECT_EQ(-3.0f, static_cast<float>(-bfloat16(3.0f)));
+ EXPECT_EQ(4.5f, static_cast<float>(-bfloat16(-4.5f)));
+}
+
static void BM_FloatToBFloat16(int iters) {
testing::StopTiming();
static const int N = 32 << 20;
diff --git a/tensorflow/core/framework/common_shape_fns.cc b/tensorflow/core/framework/common_shape_fns.cc
index ea66863bed..036e3473b1 100644
--- a/tensorflow/core/framework/common_shape_fns.cc
+++ b/tensorflow/core/framework/common_shape_fns.cc
@@ -397,6 +397,15 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
TF_RETURN_IF_ERROR(
CheckFormatConstraintsOnShape(data_format, filter_shape, "filter", c));
+ std::vector<int32> dilations;
+ TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations));
+
+ if (dilations.size() != 4) {
+ return errors::InvalidArgument(
+ "Conv2D requires the dilation attribute to contain 4 values, but got: ",
+ dilations.size());
+ }
+
std::vector<int32> strides;
TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
@@ -410,6 +419,8 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
const int32 stride_rows = GetTensorDim(strides, data_format, 'H');
const int32 stride_cols = GetTensorDim(strides, data_format, 'W');
+ const int32 dilation_rows = GetTensorDim(dilations, data_format, 'H');
+ const int32 dilation_cols = GetTensorDim(dilations, data_format, 'W');
DimensionHandle batch_size_dim;
DimensionHandle input_depth_dim;
@@ -447,12 +458,12 @@ Status Conv2DShape(shape_inference::InferenceContext* c) {
TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
DimensionHandle output_rows, output_cols;
- TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(c, input_spatial_dims[0],
- filter_rows_dim, stride_rows,
- padding, &output_rows));
- TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(c, input_spatial_dims[1],
- filter_cols_dim, stride_cols,
- padding, &output_cols));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
+ c, input_spatial_dims[0], filter_rows_dim, dilation_rows, stride_rows,
+ padding, &output_rows));
+ TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDimsV2(
+ c, input_spatial_dims[1], filter_cols_dim, dilation_cols, stride_cols,
+ padding, &output_cols));
ShapeHandle output_shape;
TF_RETURN_IF_ERROR(
@@ -1307,6 +1318,9 @@ Status ValidateSparseTensor(InferenceContext* c, ShapeHandle indices_shape,
Status ScatterNdUpdateShape(InferenceContext* c) {
ShapeHandle input_shape = c->input(0);
+ if (c->input_handle_shapes_and_types(0) != nullptr) {
+ input_shape = (*c->input_handle_shapes_and_types(0))[0].shape;
+ }
ShapeHandle indices_shape;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(1), 1, &indices_shape));
ShapeHandle updates_shape;
@@ -1361,7 +1375,9 @@ Status ScatterNdUpdateShape(InferenceContext* c) {
}
}
- c->set_output(0, input_shape);
+ if (c->input_handle_shapes_and_types(0) == nullptr) {
+ c->set_output(0, input_shape);
+ }
return Status::OK();
}
diff --git a/tensorflow/core/framework/common_shape_fns_test.cc b/tensorflow/core/framework/common_shape_fns_test.cc
index ec9746b2af..5f3e5ad457 100644
--- a/tensorflow/core/framework/common_shape_fns_test.cc
+++ b/tensorflow/core/framework/common_shape_fns_test.cc
@@ -423,6 +423,15 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
.Finalize(&op.node_def));
};
+ // Invalid rank for input
+ INFER_ERROR("must be rank 4", op, "[4,4];[2,1,1,1]");
+ // Invalid rank for filter
+ INFER_ERROR("must be rank 4", op, "[1,4,4,1];[2,1,1]");
+
+ // Invalid value for strides
+ set_op({{1, 1, 0, 1}}, "VALID", "NHWC", "HWIO");
+ INFER_ERROR("must be > 0", op, "[1,2,2,1];[1,1,1,1]");
+
// 1x1 filter
set_op({{1, 1, 1, 1}}, "VALID", "NHWC", "HWIO");
INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
@@ -443,11 +452,6 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
set_op({{1, 1, 2, 1}}, "VALID", "NHWC", "HWIO");
INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]");
- // Invalid rank for input
- INFER_ERROR("must be rank 4", op, "[4,4];[2,1,1,1]");
- // Invalid rank for filter
- INFER_ERROR("must be rank 4", op, "[1,4,4,1];[2,1,1]");
-
// Unknown dims in the critical fields lead to partial inference.
INFER_OK(op, "[1,4,4,1];[2,1,1,1]", "[d0_0,3,2,d1_3]");
INFER_OK(op, "[1,?,4,1];[2,1,1,1]", "[d0_0,?,2,d1_3]");
@@ -538,6 +542,98 @@ TEST(CommonShapeFnsTest, Conv2DShapeTest) {
INFER_OK(op, "[1,4,4,?];[?,?,?,?]", "[d0_0,2,2,d1_3]");
}
+TEST(CommonShapeFnsTest, Conv2DDilatedShapeTest) {
+ ShapeInferenceTestOp op("Conv2D");
+ auto set_op = [&op](const std::vector<int32>& dilations,
+ const std::vector<int32>& strides, const string& padding,
+ const string& data_format) {
+ TF_CHECK_OK(NodeDefBuilder("test", "Conv2D")
+ .Input("input", 0, DT_FLOAT)
+ .Input("filter", 0, DT_FLOAT)
+ .Attr("dilations", dilations)
+ .Attr("strides", strides)
+ .Attr("padding", padding)
+ .Attr("data_format", data_format)
+ .Finalize(&op.node_def));
+ };
+
+ // Invalid rank for dilation
+ set_op({{1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
+ INFER_ERROR("contain 4 values", op, "[1,2,2,1];[1,1,1,1]");
+
+ // Invalid value for dilation
+ set_op({{1, 0, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
+ INFER_ERROR("must be >= 1", op, "[1,2,2,1];[1,1,1,1]");
+
+ // Tests for NHWC
+ // 1x1 filter, 2x1 dilations, 1x1 strides
+ set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
+ INFER_OK(op, "[1,2,2,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
+
+ // 1x1 filter, 2x1 dilations, 2x1 strides
+ set_op({{1, 2, 1, 1}}, {{1, 2, 1, 1}}, "VALID", "NHWC");
+ INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,2,4,d1_3]");
+
+ // 1x1 filter, 2x1 dilations, 2x2 strides
+ set_op({{1, 2, 1, 1}}, {{1, 2, 2, 1}}, "VALID", "NHWC");
+ INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,2,2,d1_3]");
+
+ // 3x3 filter, 2x1 dilations, 1x1 strides
+ set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "VALID", "NHWC");
+ INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,1,3,d1_3]");
+
+ // 3x3 filter, 2x1 dilations, 2x1 strides
+ set_op({{1, 2, 1, 1}}, {{1, 2, 1, 1}}, "VALID", "NHWC");
+ INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,1,3,d1_3]");
+
+ // 3x3 filter, 1x2 dilations, 2x2 strides
+ set_op({{1, 1, 2, 1}}, {{1, 2, 2, 1}}, "VALID", "NHWC");
+ INFER_OK(op, "[1,5,5,1];[3,3,1,1]", "[d0_0,2,1,d1_3]");
+
+ // Tests for NCHW
+ // 1x1 filter, 2x1 dilations, 1x1 strides
+ set_op({{1, 1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NCHW");
+ INFER_OK(op, "[1,1,2,2];[1,1,1,1]", "[d0_0,d1_3,2,2]");
+
+ // 1x1 filter, 2x1 dilations, 2x1 strides
+ set_op({{1, 1, 2, 1}}, {{1, 1, 2, 1}}, "VALID", "NCHW");
+ INFER_OK(op, "[1,1,4,4];[1,1,1,1]", "[d0_0,d1_3,2,4]");
+
+ // 1x1 filter, 2x1 dilations, 2x2 strides
+ set_op({{1, 1, 2, 1}}, {{1, 1, 2, 2}}, "VALID", "NCHW");
+ INFER_OK(op, "[1,1,4,4];[1,1,1,1]", "[d0_0,d1_3,2,2]");
+
+ // 3x3 filter, 2x1 dilations, 1x1 strides
+ set_op({{1, 1, 2, 1}}, {{1, 1, 1, 1}}, "VALID", "NCHW");
+ INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,1,3]");
+
+ // 3x3 filter, 2x1 dilations, 2x1 strides
+ set_op({{1, 1, 2, 1}}, {{1, 1, 2, 1}}, "VALID", "NCHW");
+ INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,1,3]");
+
+ // 3x3 filter, 1x2 dilations, 2x2 strides
+ set_op({{1, 1, 1, 2}}, {{1, 1, 2, 2}}, "VALID", "NCHW");
+ INFER_OK(op, "[1,1,5,5];[3,3,1,1]", "[d0_0,d1_3,2,1]");
+
+ // Some tests for "SAME" padding
+
+ // 4x4 input, 1x1 filter, 2x1 dilations, 1x1 stride
+ set_op({{1, 2, 1, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC");
+ INFER_OK(op, "[1,4,4,1];[1,1,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
+
+ // 3x3 input, 2x2 filter, 2x2 dilations, 1x1 stride
+ set_op({{1, 2, 2, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC");
+ INFER_OK(op, "[1,3,3,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
+
+ // 4x4 input, 2x2 filter, 1x2 dilations, 2x2 stride
+ set_op({{1, 1, 2, 1}}, {{1, 2, 2, 1}}, "SAME", "NHWC");
+ INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,2,2,d1_3]");
+
+ // 4x4 input, 2x2 filter, 2x2 dilations, 1x1 stride
+ set_op({{1, 2, 2, 1}}, {{1, 1, 1, 1}}, "SAME", "NHWC");
+ INFER_OK(op, "[1,4,4,1];[2,2,1,1]", "[d0_0,d0_1,d0_2,d1_3]");
+}
+
TEST(CommonShapeFnsTest, Conv3DShapeTest) {
ShapeInferenceTestOp op("Conv3D");
auto set_op = [&op](const std::vector<int32>& strides,
diff --git a/tensorflow/core/framework/numeric_types.h b/tensorflow/core/framework/numeric_types.h
index 2b080e13fd..bdd5af064b 100644
--- a/tensorflow/core/framework/numeric_types.h
+++ b/tensorflow/core/framework/numeric_types.h
@@ -58,7 +58,7 @@ struct bfloat16 {
explicit EIGEN_DEVICE_FUNC bfloat16(const T& val)
: bfloat16(static_cast<float>(val)) {}
- EIGEN_DEVICE_FUNC EIGEN_EXPLICIT_CAST(float) const {
+ EIGEN_DEVICE_FUNC explicit operator float() const {
float result;
uint16_t* q = reinterpret_cast<uint16_t*>(&result);
@@ -89,6 +89,10 @@ struct bfloat16 {
return static_cast<int>(float(*this));
}
+ EIGEN_DEVICE_FUNC explicit operator long() const {
+ return static_cast<long>(float(*this));
+ }
+
EIGEN_DEVICE_FUNC explicit operator char() const {
return static_cast<char>(float(*this));
}
@@ -121,15 +125,48 @@ struct bfloat16 {
return static_cast<double>(float(*this));
}
+ static bfloat16 epsilon() {
+ bfloat16 x;
+ x.value = 0x3c00; // 0x1.0p-7
+ return x;
+ }
+
uint16_t value;
};
-inline bool operator==(const bfloat16 a, const bfloat16 b) {
- return a.value == b.value;
+inline bfloat16 operator+(bfloat16 a, bfloat16 b) {
+ return bfloat16(static_cast<float>(a) + static_cast<float>(b));
}
-
-inline bool operator!=(const bfloat16 a, const bfloat16 b) {
- return a.value != b.value;
+inline bfloat16 operator-(bfloat16 a, bfloat16 b) {
+ return bfloat16(static_cast<float>(a) - static_cast<float>(b));
+}
+inline bfloat16 operator*(bfloat16 a, bfloat16 b) {
+ return bfloat16(static_cast<float>(a) * static_cast<float>(b));
+}
+inline bfloat16 operator/(bfloat16 a, bfloat16 b) {
+ return bfloat16(static_cast<float>(a) / static_cast<float>(b));
+}
+inline bfloat16 operator-(bfloat16 a) {
+ a.value ^= 0x8000;
+ return a;
+}
+inline bool operator<(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) < static_cast<float>(b);
+}
+inline bool operator<=(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) <= static_cast<float>(b);
+}
+inline bool operator==(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) == static_cast<float>(b);
+}
+inline bool operator!=(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) != static_cast<float>(b);
+}
+inline bool operator>(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) > static_cast<float>(b);
+}
+inline bool operator>=(bfloat16 a, bfloat16 b) {
+ return static_cast<float>(a) >= static_cast<float>(b);
}
} // end namespace tensorflow
diff --git a/tensorflow/core/framework/op_def_builder_test.cc b/tensorflow/core/framework/op_def_builder_test.cc
index c1511ebe34..9b24e3aa00 100644
--- a/tensorflow/core/framework/op_def_builder_test.cc
+++ b/tensorflow/core/framework/op_def_builder_test.cc
@@ -124,22 +124,23 @@ TEST_F(OpDefBuilderTest, AttrWithRestrictions) {
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
- "DT_QINT32, DT_UINT32, DT_UINT64] } } }");
+ "DT_QINT32, DT_UINT32, DT_UINT64, DT_BFLOAT16] } } }");
ExpectSuccess(
b().Attr("a:{numbertype, variant}"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
"DT_UINT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8, "
- "DT_QINT32, DT_UINT32, DT_UINT64, DT_VARIANT] } } }");
+ "DT_QINT32, DT_UINT32, DT_UINT64, DT_BFLOAT16, DT_VARIANT] } } }");
ExpectSuccess(b().Attr("a:realnumbertype"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
- "DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64] } } }");
+ "DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64, "
+ "DT_BFLOAT16] } } }");
ExpectSuccess(b().Attr("a:{realnumbertype, variant , string, }"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_HALF, DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, "
"DT_INT16, DT_UINT16, DT_INT8, DT_UINT32, DT_UINT64, "
- "DT_VARIANT, DT_STRING] } } }");
+ "DT_BFLOAT16, DT_VARIANT, DT_STRING] } } }");
ExpectSuccess(b().Attr("a:quantizedtype"),
"attr: { name: 'a' type: 'type' allowed_values { list { type: "
"[DT_QINT8, DT_QUINT8, DT_QINT32, DT_QINT16, DT_QUINT16]} } }");
@@ -216,12 +217,14 @@ TEST_F(OpDefBuilderTest, AttrListOfRestricted) {
b().Attr("a:list(realnumbertype)"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
- "DT_UINT16, DT_INT8, DT_HALF, DT_UINT32, DT_UINT64] } } }");
+ "DT_UINT16, DT_INT8, DT_HALF, DT_BFLOAT16, DT_UINT32, DT_UINT64"
+ "] } } }");
ExpectSuccess(
b().Attr("a:list({realnumbertype, variant})"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
"[DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8, DT_INT16, "
- "DT_UINT16, DT_INT8, DT_HALF, DT_UINT32, DT_UINT64, DT_VARIANT] } } }");
+ "DT_UINT16, DT_INT8, DT_HALF, DT_BFLOAT16, DT_UINT32, DT_UINT64, "
+ "DT_VARIANT] } } }");
ExpectSuccess(
b().Attr("a:list(quantizedtype)"),
"attr: { name: 'a' type: 'list(type)' allowed_values { list { type: "
diff --git a/tensorflow/core/framework/types.cc b/tensorflow/core/framework/types.cc
index faae19585d..48849f9dda 100644
--- a/tensorflow/core/framework/types.cc
+++ b/tensorflow/core/framework/types.cc
@@ -206,18 +206,18 @@ string DataTypeSliceString(const DataTypeSlice types) {
}
DataTypeVector AllTypes() {
- return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16,
- DT_UINT16, DT_INT8, DT_STRING, DT_COMPLEX64, DT_COMPLEX128,
- DT_INT64, DT_BOOL, DT_QINT8, DT_QUINT8, DT_QINT16,
- DT_QUINT16, DT_QINT32, DT_HALF, DT_RESOURCE, DT_VARIANT,
- DT_UINT32, DT_UINT64};
+ return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16,
+ DT_UINT16, DT_INT8, DT_STRING, DT_COMPLEX64, DT_COMPLEX128,
+ DT_INT64, DT_BOOL, DT_QINT8, DT_QUINT8, DT_QINT16,
+ DT_QUINT16, DT_QINT32, DT_HALF, DT_RESOURCE, DT_VARIANT,
+ DT_UINT32, DT_UINT64, DT_BFLOAT16};
}
#if !defined(IS_MOBILE_PLATFORM) || defined(SUPPORT_SELECTIVE_REGISTRATION)
DataTypeVector RealNumberTypes() {
- return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16,
- DT_INT8, DT_UINT16, DT_HALF, DT_UINT32, DT_UINT64};
+ return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8, DT_INT16,
+ DT_INT8, DT_UINT16, DT_HALF, DT_UINT32, DT_UINT64, DT_BFLOAT16};
}
DataTypeVector QuantizedTypes() {
@@ -227,14 +227,14 @@ DataTypeVector QuantizedTypes() {
DataTypeVector RealAndQuantizedTypes() {
return {DT_FLOAT, DT_DOUBLE, DT_INT32, DT_INT64, DT_UINT8,
DT_UINT16, DT_UINT16, DT_INT8, DT_QINT8, DT_QUINT8,
- DT_QINT16, DT_QUINT16, DT_QINT32, DT_HALF};
+ DT_QINT16, DT_QUINT16, DT_QINT32, DT_HALF, DT_BFLOAT16};
}
DataTypeVector NumberTypes() {
- return {DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32,
- DT_UINT8, DT_UINT16, DT_INT16, DT_INT8,
- DT_COMPLEX64, DT_COMPLEX128, DT_QINT8, DT_QUINT8,
- DT_QINT32, DT_HALF, DT_UINT32, DT_UINT64};
+ return {DT_FLOAT, DT_DOUBLE, DT_INT64, DT_INT32, DT_UINT8,
+ DT_UINT16, DT_INT16, DT_INT8, DT_COMPLEX64, DT_COMPLEX128,
+ DT_QINT8, DT_QUINT8, DT_QINT32, DT_HALF, DT_UINT32,
+ DT_UINT64, DT_BFLOAT16};
}
#elif defined(__ANDROID_TYPES_FULL__)
diff --git a/tensorflow/core/grappler/costs/BUILD b/tensorflow/core/grappler/costs/BUILD
index f02cb51038..f1edbbb602 100644
--- a/tensorflow/core/grappler/costs/BUILD
+++ b/tensorflow/core/grappler/costs/BUILD
@@ -50,6 +50,7 @@ cc_library(
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
+ "//tensorflow/core/grappler:utils",
"//tensorflow/core/grappler/clusters:cluster",
],
)
diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc
index dd389de636..ec44d11bdd 100644
--- a/tensorflow/core/grappler/costs/graph_properties.cc
+++ b/tensorflow/core/grappler/costs/graph_properties.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/graph_constructor.h"
#include "tensorflow/core/grappler/costs/utils.h"
+#include "tensorflow/core/grappler/utils.h"
namespace tensorflow {
namespace grappler {
@@ -264,6 +265,79 @@ bool IsEnterWithQueue(const Node& node) {
return false;
}
+bool HasAnyUnknownDimensions(const TensorShapeProto& proto) {
+ if (proto.unknown_rank()) {
+ return true;
+ }
+ for (const auto& dim : proto.dim()) {
+ if (dim.size() < 0) {
+ return true;
+ }
+ }
+ return false;
+}
+
+void VerboseLogUnknownDimensionSources(
+ const Graph& graph,
+ const std::map<string, std::vector<OpInfo::TensorProperties>>&
+ input_properties_map,
+ const std::map<string, std::vector<OpInfo::TensorProperties>>&
+ output_properties_map) {
+ if (!VLOG_IS_ON(2)) {
+ return;
+ }
+
+ VLOG(2) << "Nodes with known inputs, but with unknown output dimensions:";
+
+ // Find all nodes in the graph for which we
+ // do not have any unknown dimensions in their inputs, but
+ // we have some unknown dimensions in their outputs.
+ for (const Node* const node : graph.nodes()) {
+ if (node->num_outputs() == 0) {
+ continue;
+ }
+
+ const auto& input_properties = input_properties_map.at(node->name());
+ const auto& output_properties = output_properties_map.at(node->name());
+
+ bool has_unknown_inputs = false;
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ if (HasAnyUnknownDimensions(input_properties[i].shape())) {
+ has_unknown_inputs = true;
+ break;
+ }
+ }
+
+ if (has_unknown_inputs) {
+ continue;
+ }
+
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ if (HasAnyUnknownDimensions(output_properties[i].shape())) {
+ string inputs = "input_shapes=[";
+ for (int i = 0; i < node->num_inputs(); ++i) {
+ inputs +=
+ PartialTensorShape::DebugString(input_properties[i].shape());
+ }
+ inputs += "]";
+
+ string outputs = "output_shapes=[";
+ for (int i = 0; i < node->num_outputs(); ++i) {
+ outputs +=
+ PartialTensorShape::DebugString(output_properties[i].shape());
+ }
+ outputs += "]";
+
+ VLOG(2) << "Node: " << node->name() << ", Op: " << node->def().op()
+ << ", " << inputs << ", " << outputs;
+
+ // don't log again for this node
+ break;
+ }
+ }
+ }
+}
+
} // namespace
// Queue of nodes to process. Nodes can be enqueued in any order, but will be
@@ -312,9 +386,15 @@ class SymbolicShapeRefiner {
Status UpdateNode(const Node* node, bool relax, bool* refined) {
return shape_refiner_->UpdateNode(node, relax, refined);
}
- Status SetShape(const Node* node, int output_port,
- shape_inference::ShapeHandle shape) {
- return shape_refiner_->SetShape(node, output_port, shape);
+ Status SetUnknownShape(const Node* node, int output_port) {
+ shape_inference::ShapeHandle shape =
+ GetUnknownOutputShape(node, output_port);
+ InferenceContext* ctx = GetContext(node);
+ if (ctx == nullptr) {
+ return errors::InvalidArgument("Missing context");
+ }
+ ctx->set_output(output_port, shape);
+ return Status::OK();
}
struct ShapeId {
@@ -646,6 +726,23 @@ Status GraphProperties::UpdateMergeNode(SymbolicShapeRefiner* shape_refiner,
return Status::OK();
}
+Status GraphProperties::OverwriteFedPorts(
+ SymbolicShapeRefiner* shape_refiner,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
+ const Node* node, TopoQueue* new_shapes) const {
+ auto it = fed_ports.find(node->name());
+ Status status;
+ if (it != fed_ports.end()) {
+ // It is possible to feed node output ports with tensors of any shape: as a
+ // result, the shape of a fed port is completely unknown.
+ for (const int output_port : it->second) {
+ status.Update(shape_refiner->SetUnknownShape(node, output_port));
+ }
+ new_shapes->push(node);
+ }
+ return status;
+}
+
// Manually propagate the input shape for Enter nodes and update any Merge node
// outputs.
Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
@@ -673,9 +770,10 @@ Status GraphProperties::UpdateEnter(SymbolicShapeRefiner* shape_refiner,
return Status::OK();
}
-Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner,
- bool relax, const Node* n,
- TopoQueue* new_shapes) {
+Status GraphProperties::UpdateShapes(
+ SymbolicShapeRefiner* shape_refiner, bool relax,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
+ const Node* n, TopoQueue* new_shapes) const {
if (n->IsEnter()) {
// The Enter shape function always forwards an UnknownShape, so do the right
// thing here.
@@ -695,7 +793,9 @@ Status GraphProperties::UpdateShapes(SymbolicShapeRefiner* shape_refiner,
}
}
}
- return Status::OK();
+ // Nodes can be fed with any shape. The TensorFlow shape inference code can't
+ // handle this properly, so overwrite its behavior here.
+ return OverwriteFedPorts(shape_refiner, fed_ports, n, new_shapes);
}
// Propagates the shapes in the transitive fan-out of <new_shapes>.
@@ -703,6 +803,7 @@ Status GraphProperties::PropagateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
resources,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
int num_loops) const {
// Limit the number of iterations to prevent infinite loops in the presence of
// incorrect shape functions. The algoritm should converge in at most
@@ -728,8 +829,8 @@ Status GraphProperties::PropagateShapes(
for (const Edge* e : n->out_edges()) {
if (!e->IsControlEdge()) {
const Node* fanout = e->dst();
- TF_RETURN_IF_ERROR(
- UpdateShapes(shape_refiner, relax, fanout, new_shapes));
+ TF_RETURN_IF_ERROR(UpdateShapes(shape_refiner, relax, fed_ports,
+ fanout, new_shapes));
}
}
}
@@ -803,7 +904,7 @@ Status GraphProperties::UpdateResource(
return Status::OK();
}
-Status GraphProperties::InferStatically() {
+Status GraphProperties::InferStatically(bool assume_valid_feeds) {
Graph graph(OpRegistry::Global());
FunctionLibraryDefinition function_library(graph.op_registry(),
item_.graph.library());
@@ -820,11 +921,21 @@ Status GraphProperties::InferStatically() {
Status s = ImportGraphDef(options, item_.graph, &graph, &shape_refiner);
TF_RETURN_IF_ERROR(s);
+ std::unordered_map<string, std::unordered_set<int>> fed_ports;
+ if (!assume_valid_feeds) {
+ for (const auto& feed : item_.feed) {
+ int port_index = 0;
+ string node_name = ParseNodeName(feed.first, &port_index);
+ fed_ports[node_name].insert(port_index);
+ }
+ }
+
// List the resources and the nodes using them. Also collect the Enter and
// Merge nodes.
std::unordered_map<const Node*, std::unordered_set<const Node*>> resources;
std::unordered_set<const Node*> enter_nodes;
std::unordered_set<const Node*> merge_nodes;
+ std::unordered_set<const Node*> fed_nodes;
int num_loops = 0;
for (const Node* const node : graph.nodes()) {
for (int i = 0; i < node->num_inputs(); ++i) {
@@ -841,6 +952,9 @@ Status GraphProperties::InferStatically() {
} else if (node->IsNextIteration()) {
++num_loops;
}
+ if (fed_ports.find(node->name()) != fed_ports.end()) {
+ fed_nodes.insert(node);
+ }
}
SymbolicShapeRefiner refiner(&shape_refiner);
@@ -855,15 +969,22 @@ Status GraphProperties::InferStatically() {
// Force the propagation of shapes of Enter nodes manually (the Enter shape
// function always forwards an UnknownShape).
for (const Node* node : enter_nodes) {
- TF_RETURN_IF_ERROR(UpdateShapes(&refiner, relax, node, &new_shapes));
+ TF_RETURN_IF_ERROR(
+ UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes));
}
// Seed the propagation of shapes through merge nodes.
for (const Node* node : merge_nodes) {
- TF_RETURN_IF_ERROR(UpdateShapes(&refiner, relax, node, &new_shapes));
+ TF_RETURN_IF_ERROR(
+ UpdateShapes(&refiner, relax, fed_ports, node, &new_shapes));
+ }
+ // Also seed the propagation of shapes in the fanout of fed nodes.
+ for (const Node* node : fed_nodes) {
+ TF_RETURN_IF_ERROR(
+ OverwriteFedPorts(&refiner, fed_ports, node, &new_shapes));
}
// Propagate shapes normally.
- TF_RETURN_IF_ERROR(
- PropagateShapes(&refiner, relax, &new_shapes, resources, num_loops));
+ TF_RETURN_IF_ERROR(PropagateShapes(&refiner, relax, &new_shapes, resources,
+ fed_ports, num_loops));
}
// Track shapes globally across the graph.
@@ -874,6 +995,10 @@ Status GraphProperties::InferStatically() {
if (!node_ctx) {
continue;
}
+ // Skip any information that comes from fed nodes.
+ if (fed_ports.find(node->name()) != fed_ports.end()) {
+ continue;
+ }
for (const auto& merged_shapes : node_ctx->MergedShapes()) {
if (!shape_manager.Merge(merged_shapes.first, merged_shapes.second)
.ok()) {
@@ -948,6 +1073,10 @@ Status GraphProperties::InferStatically() {
}
}
+ // Help trace the unknown dimensions to their origins.
+ VerboseLogUnknownDimensionSources(graph, input_properties_,
+ output_properties_);
+
return Status::OK();
}
diff --git a/tensorflow/core/grappler/costs/graph_properties.h b/tensorflow/core/grappler/costs/graph_properties.h
index 95bc5044d0..6fc53a7f2e 100644
--- a/tensorflow/core/grappler/costs/graph_properties.h
+++ b/tensorflow/core/grappler/costs/graph_properties.h
@@ -34,12 +34,19 @@ class TopoQueue;
// nodes, and potentially a set of nodes to feed.
class GraphProperties {
public:
- // Factory method for creating a GrapplerShapes from a MetaGraphDef.
- // Returns nullptr if the given meta_graph cannot be converted.
explicit GraphProperties(const GrapplerItem& item) : item_(item) {}
- Status InferStatically();
+ // Infer the shapes through abstract interpretation. Feed information can be
+ // incorrect so it should be discarded to ensure correctness of the analysis.
+ // However, it can help infer shapes in the fanout of fed nodes (even though
+ // the correctness of these shapes can't be guaranteed), so in some cases
+ // (such as simulation or scheduling) it makes sense of keep these shapes.
+ Status InferStatically(bool assume_valid_feeds);
+ // Infer the shape by running the graph on the specified cluster and recording
+ // the shapes of the processed tensors.
Status InferDynamically(Cluster* cluster);
+ // Extract the properties from a cost graph. For testing only since there is
+ // no way to ensure that the cost graph match the item.
Status InferFromCostGraph(const CostGraphDef& cost_graph);
// Stores `item_.graph` with the inferred output shapes to `output_graph_def`.
@@ -65,12 +72,6 @@ class GraphProperties {
OpInfo::TensorProperties*);
private:
- // Inputs
- GrapplerItem item_;
- std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_;
- std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_;
- const std::vector<OpInfo::TensorProperties> missing_properties_;
-
// Merges shapes <shapes_and_types>, determined from an EnqueueV2 node, into
// <*queue_shapes_and_types>.
static Status MergeEnqueueShapesAndTypes(
@@ -99,17 +100,31 @@ class GraphProperties {
static Status UpdateEnter(SymbolicShapeRefiner* shape_refiner,
const Node* node, bool relax,
TopoQueue* new_shapes);
+ // Process a node that is used to feed the model.
+ Status OverwriteFedPorts(
+ SymbolicShapeRefiner* shape_refiner,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
+ const Node* node, TopoQueue* new_shapes) const;
// Update the shapes for node 'n'. If output shapes for n have changed,
// enqueue its fanout in 'new_shapes'.
- static Status UpdateShapes(SymbolicShapeRefiner* shape_refiner, bool relax,
- const Node* n, TopoQueue* new_shapes);
+ Status UpdateShapes(
+ SymbolicShapeRefiner* shape_refiner, bool relax,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
+ const Node* n, TopoQueue* new_shapes) const;
// Propagate the shapes for the nodes enqueued in new_shapes and their
// transitive fanout until a fixed point is reached.
Status PropagateShapes(
SymbolicShapeRefiner* shape_refiner, bool relax, TopoQueue* new_shapes,
const std::unordered_map<const Node*, std::unordered_set<const Node*>>&
resources,
+ const std::unordered_map<string, std::unordered_set<int>>& fed_ports,
int num_loops) const;
+
+ // Data members
+ GrapplerItem item_;
+ std::map<string, std::vector<OpInfo::TensorProperties>> input_properties_;
+ std::map<string, std::vector<OpInfo::TensorProperties>> output_properties_;
+ const std::vector<OpInfo::TensorProperties> missing_properties_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc
index c11af5777a..cc40ff2cfc 100644
--- a/tensorflow/core/grappler/costs/graph_properties_test.cc
+++ b/tensorflow/core/grappler/costs/graph_properties_test.cc
@@ -73,7 +73,7 @@ TEST_F(GraphPropertiesTest, StaticProperties) {
CHECK(fake_input.NextItem(&item));
GraphProperties properties(item);
- Status s = properties.InferStatically();
+ Status s = properties.InferStatically(true);
TF_CHECK_OK(s);
for (const auto& node : item.graph.node()) {
@@ -179,7 +179,7 @@ TEST_F(GraphPropertiesTest, Variables) {
{
GraphProperties static_properties(item);
- TF_CHECK_OK(static_properties.InferStatically());
+ TF_CHECK_OK(static_properties.InferStatically(false));
const auto props = static_properties.GetOutputProperties("Var");
EXPECT_EQ(1, props.size());
@@ -219,7 +219,7 @@ TEST_F(GraphPropertiesTest, VarHandles) {
.Finalize(item.graph.add_node()));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto props = properties.GetOutputProperties("VarRead");
EXPECT_EQ(1, props.size());
@@ -286,7 +286,7 @@ TEST_F(GraphPropertiesTest, Queues) {
TF_CHECK_OK(root.ToGraphDef(&item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto props1 = properties.GetOutputProperties("Dequeue1");
ASSERT_EQ(1, props1.size());
@@ -335,7 +335,7 @@ TEST_F(GraphPropertiesTest, MergeWithoutLoops) {
"merge_without_loops.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> nodes{"cond/Merge", "cond/concat", "cond/concat_1"};
std::vector<string> expected_outputs{"float: [-1,-1,1]", "float: [2,1,1]",
@@ -377,7 +377,7 @@ TEST_F(GraphPropertiesTest, WhileLoop) {
"while_loop.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -435,7 +435,7 @@ TEST_F(GraphPropertiesTest, NestedLoop) {
"nested_loop.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -498,7 +498,7 @@ TEST_F(GraphPropertiesTest, LoopsAndQueues) {
"loops_and_queues.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -556,7 +556,7 @@ TEST_F(GraphPropertiesTest, LoopsAndResourceVars) {
"loops_and_resource_vars.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> outer_nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -608,7 +608,7 @@ TEST_F(GraphPropertiesTest, QueuesAndLoops) {
"queues_and_loops.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
std::vector<string> nodes{"while/Merge_1", "while/NextIteration_1",
"while/Exit_1"};
@@ -657,7 +657,7 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape) {
item.fetch.push_back("init_restore");
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto restore_props = properties.GetOutputProperties("restore");
const OpInfo::TensorProperties& restore_prop = restore_props[0];
@@ -704,7 +704,7 @@ TEST_F(GraphPropertiesTest, InferRestoreOpShape_WithTwoNodesShareSameOutput) {
item.fetch.push_back("init2");
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto props = properties.GetOutputProperties("restore");
const OpInfo::TensorProperties& prop = props[0];
@@ -732,7 +732,7 @@ TEST_F(GraphPropertiesTest, FunctionStaticShapeInference) {
"simple_function.pbtxt");
TF_CHECK_OK(ReadGraphDefFromFile(filename, &item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto props = properties.GetOutputProperties("MyAdd_55e046a8_1");
const OpInfo::TensorProperties& prop = props[0];
EXPECT_EQ(DT_FLOAT, prop.dtype());
@@ -766,7 +766,7 @@ TEST_F(GraphPropertiesTest, SymbolicShapes) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
const auto shape_c = properties.GetOutputProperties("c").at(0).shape();
EXPECT_EQ(2, shape_a.dim_size());
@@ -822,7 +822,7 @@ TEST_F(GraphPropertiesTest, DoNotValidateColocationConstraints) {
GraphProperties properties(item);
// This function should return OK, since it doesn't validate the colocation
// constraints internally.
- TF_EXPECT_OK(properties.InferStatically());
+ TF_EXPECT_OK(properties.InferStatically(false));
}
TEST_F(GraphPropertiesTest, ShapeTracking) {
@@ -842,7 +842,7 @@ TEST_F(GraphPropertiesTest, ShapeTracking) {
TF_CHECK_OK(s.ToGraphDef(&item.graph));
GraphProperties properties(item);
- TF_CHECK_OK(properties.InferStatically());
+ TF_CHECK_OK(properties.InferStatically(false));
const auto shape_a = properties.GetOutputProperties("a").at(0).shape();
const auto shape_b = properties.GetOutputProperties("b").at(0).shape();
const auto shape_o1 = properties.GetOutputProperties("o1").at(0).shape();
@@ -851,6 +851,65 @@ TEST_F(GraphPropertiesTest, ShapeTracking) {
EXPECT_EQ(shape_b.DebugString(), shape_o2.DebugString());
}
+TEST_F(GraphPropertiesTest, FedNodes) {
+ TrivialTestGraphInputYielder fake_input(4, 1, 10, false,
+ cluster_->GetDeviceNames());
+ GrapplerItem item;
+ CHECK(fake_input.NextItem(&item));
+
+ {
+ // Conservative shape analysis: the shape of fed ports should be unknown
+ GraphProperties properties(item);
+ Status s = properties.InferStatically(false);
+ TF_CHECK_OK(s);
+ for (const auto& node : item.graph.node()) {
+ if (node.op() == "Const") {
+ continue;
+ }
+ const auto in_props = properties.GetInputProperties(node.name());
+ EXPECT_EQ(1, in_props.size());
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ const auto out_props = properties.GetOutputProperties(node.name());
+ EXPECT_EQ(1, out_props.size());
+ const OpInfo::TensorProperties& out_prop = out_props[0];
+
+ if (node.name() == "x") {
+ // x is fed: its input should have a known shape, while its output
+ // doesn't
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(1, in_prop.shape().dim_size());
+ EXPECT_EQ(2, in_prop.shape().dim(0).size());
+ EXPECT_TRUE(out_prop.shape().unknown_rank());
+ } else if (node.op() == "Square" || node.op() == "AddN") {
+ // These nodes are in the fanout of x: their shapes should be unknown.
+ EXPECT_TRUE(in_prop.shape().unknown_rank());
+ EXPECT_TRUE(out_prop.shape().unknown_rank());
+ }
+ }
+ }
+ {
+ // Optimistic shape analysis: the shape of fed ports should be derived from
+ // the shape of the fanin.
+ GraphProperties properties(item);
+ Status s = properties.InferStatically(true);
+ TF_CHECK_OK(s);
+ for (const auto& node : item.graph.node()) {
+ if (node.op() == "Square" || node.op() == "AddN") {
+ const auto in_props = properties.GetInputProperties(node.name());
+ EXPECT_EQ(1, in_props.size());
+ const OpInfo::TensorProperties& in_prop = in_props[0];
+ EXPECT_EQ(DT_FLOAT, in_prop.dtype());
+ EXPECT_FALSE(in_prop.shape().unknown_rank());
+ EXPECT_EQ(2, in_prop.shape().dim_size());
+ const auto out_props = properties.GetOutputProperties(node.name());
+ EXPECT_EQ(1, out_props.size());
+ const OpInfo::TensorProperties& out_prop = out_props[0];
+ EXPECT_EQ(in_prop.DebugString(), out_prop.DebugString());
+ }
+ }
+ }
+}
+
} // namespace
} // namespace grappler
} // namespace tensorflow
diff --git a/tensorflow/core/grappler/costs/virtual_scheduler.cc b/tensorflow/core/grappler/costs/virtual_scheduler.cc
index e5e1ee3292..6640de668d 100644
--- a/tensorflow/core/grappler/costs/virtual_scheduler.cc
+++ b/tensorflow/core/grappler/costs/virtual_scheduler.cc
@@ -122,7 +122,7 @@ Status VirtualScheduler::Init() {
// Construct graph properties.
Status status;
if (use_static_shapes_) {
- status = graph_properties_.InferStatically();
+ status = graph_properties_.InferStatically(true);
} else {
status = graph_properties_.InferDynamically(cluster_);
}
diff --git a/tensorflow/core/grappler/grappler_item_builder.cc b/tensorflow/core/grappler/grappler_item_builder.cc
index 36c7f92c49..da99777bbc 100644
--- a/tensorflow/core/grappler/grappler_item_builder.cc
+++ b/tensorflow/core/grappler/grappler_item_builder.cc
@@ -173,7 +173,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
<< ", skipping this input.";
return nullptr;
}
- LOG(INFO) << "Will use feed node " << feed_name;
+ VLOG(1) << "Will use feed node " << feed_name;
new_item->feed.emplace_back(feed_name, Tensor());
}
@@ -188,7 +188,7 @@ std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
<< ", skipping this input";
return nullptr;
}
- LOG(INFO) << "Will use fetch node " << name;
+ VLOG(1) << "Will use fetch node " << name;
new_item->fetch.push_back(name);
}
}
diff --git a/tensorflow/core/grappler/optimizers/BUILD b/tensorflow/core/grappler/optimizers/BUILD
index 5d9eb8e0b1..7b4ed10e7e 100644
--- a/tensorflow/core/grappler/optimizers/BUILD
+++ b/tensorflow/core/grappler/optimizers/BUILD
@@ -96,6 +96,7 @@ cc_library(
":graph_optimizer",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/grappler:grappler_item",
"//tensorflow/core/grappler:op_types",
@@ -332,6 +333,11 @@ tf_cc_test(
deps = [
":layout_optimizer",
"//tensorflow/cc:cc_ops",
+ "//tensorflow/cc:cc_ops_internal",
+ "//tensorflow/core:all_kernels",
+ "//tensorflow/core:core_cpu",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
index 3cfc4f61e4..efe8ac05a3 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer.cc
@@ -253,6 +253,30 @@ bool IsNumberType(DataType dtype) {
const char kOutputShapesAttr[] = "_output_shapes";
+PartialTensorShape GetInputShape(const string& input, const NodeMap& node_map) {
+ int output_pos;
+ string node_name = ParseNodeName(input, &output_pos);
+ const NodeDef* input_node = node_map.GetNode(node_name);
+ return input_node->attr().at(kOutputShapesAttr).list().shape(output_pos);
+}
+
+bool ShapesEqual(const string& input_x, const string& input_y,
+ const NodeMap& node_map) {
+ PartialTensorShape x_shape = GetInputShape(input_x, node_map);
+ PartialTensorShape y_shape = GetInputShape(input_y, node_map);
+ if (x_shape.unknown_rank() || y_shape.unknown_rank() ||
+ x_shape.dims() != y_shape.dims()) {
+ return false;
+ }
+ for (int i = 0; i < x_shape.dims(); ++i) {
+ if (x_shape.dim_size(i) == -1 || y_shape.dim_size(i) == -1 ||
+ x_shape.dim_size(i) != y_shape.dim_size(i)) {
+ return false;
+ }
+ }
+ return true;
+}
+
// Returns whether `reshape` is an identity op. The tensor that `reshape`
// reshapes is the `output_pos`-th output of node `input`.
bool ReshapeIsIdentity(const NodeDef& reshape, const NodeDef& input,
@@ -868,8 +892,11 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
// multiplication over addition to hoist common factors out of aggregate nodes
// where all the inputs are Mul nodes. This pattern occurs frequently in
// regularization terms for the gradients during training.
- // TODO(rmlarsen): Check shapes and enable for AddN.
- if (IsAdd(*node) && NumNonControlInputs(*node) > 1 &&
+ // For example, we can rewrite an expression of the form:
+ // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
+ // to the following:
+ // Mul(x, AddN(y1, y2, y3, ... yn))
+ if (IsAggregate(*node) && NumNonControlInputs(*node) > 1 &&
!OptimizedNodeExists(StrCat(node->name(), "_hoist_add"))) {
// Determine the set of common factors if the input nodes are all Mul nodes.
std::set<string> common_factors;
@@ -899,24 +926,15 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
}
if (common_factors.size() == 1) {
const string& common_factor = *common_factors.begin();
- // In this case we have an expression of the form
- // AddN(Mul(x, y1), Mul(y2, x), Mul(x, y3), ... Mul(x, yn))
- // that can be rewritten as
- // Mul(x, AddN(y1, y2, y3, ... yn))
-
- // 1. Use a copy of the first Mul node for the outer multiplication.
- NodeDef* new_mul_node = AddNode(StrCat(node->name(), "_hoist_mul"),
- node_map_->GetNode(node->input(0)));
- NodeDef* new_add_node = AddNode(StrCat(node->name(), "_hoist_add"), node);
- new_mul_node->set_device(node->device());
- new_mul_node->set_input(0, common_factor);
- node_map_->AddOutput(common_factor, new_mul_node->name());
- new_mul_node->set_input(1, new_add_node->name());
- node_map_->AddOutput(new_add_node->name(), new_mul_node->name());
-
- // 2. Hoist non-shared factors up into the new AddN node.
- nodes_to_simplify->PushBack(new_add_node);
- for (int i = 0; i < node->input_size(); ++i) {
+
+ // Gather up the non-shared factors (the y's in the example).
+ // Unless the aggregation is Add, we have to make sure that all the y's
+ // have the same shape since the other aggregation ops do not support
+ // broadcasting.
+ std::vector<string> unique_factors;
+ unique_factors.reserve(node->input_size());
+ bool shapes_match = true;
+ for (int i = 0; i < node->input_size() && shapes_match; ++i) {
const string& input = node->input(i);
if (IsControlInput(input)) {
break;
@@ -924,15 +942,41 @@ string ArithmeticOptimizer::TrySimplifyAndReplaceUses(
const NodeDef* mul_node = node_map_->GetNode(input);
const int unique_factor_index =
mul_node->input(0) == common_factor ? 1 : 0;
- const string unique_factor = mul_node->input(unique_factor_index);
- new_add_node->set_input(i, unique_factor);
+ unique_factors.push_back(mul_node->input(unique_factor_index));
+ if (i > 0 && !IsAdd(*node)) {
+ shapes_match = ShapesEqual(unique_factors.front(),
+ unique_factors.back(), *node_map_);
+ }
}
- // 4. Add frame dependencies that the original node might have had.
- AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
- {new_add_node});
+ if (shapes_match) {
+ // 1. Use a copy of the first Mul node for the outer multiplication.
+ NodeDef* new_mul_node = AddNode(StrCat(node->name(), "_hoist_mul"),
+ node_map_->GetNode(node->input(0)));
+ NodeDef* new_add_node =
+ AddNode(StrCat(node->name(), "_hoist_add"), node);
+ new_mul_node->set_device(node->device());
+ new_mul_node->set_input(0, common_factor);
+ node_map_->AddOutput(common_factor, new_mul_node->name());
+ new_mul_node->set_input(1, new_add_node->name());
+ node_map_->AddOutput(new_add_node->name(), new_mul_node->name());
+
+ // 2. Hoist non-shared factors up into the new AddN node.
+ nodes_to_simplify->PushBack(new_add_node);
+ for (int i = 0; i < node->input_size(); ++i) {
+ const string& input = node->input(i);
+ if (IsControlInput(input)) {
+ break;
+ }
+ new_add_node->set_input(i, unique_factors[i]);
+ }
- return new_mul_node->name();
+ // 3. Add frame dependencies that the original node might have had.
+ AddFrameControlDeps(node, {new_add_node, new_mul_node}, common_factor,
+ {new_add_node});
+
+ return new_mul_node->name();
+ }
}
}
@@ -1064,13 +1108,10 @@ Status ArithmeticOptimizer::Optimize(Cluster* /*cluster*/,
int num_frames;
TF_RETURN_IF_ERROR(IdentifyFramesWithNodeMap(*optimized_graph_, *node_map_,
&frame_map_, &num_frames));
- if (opt_level_ == RewriterConfig::AGGRESSIVE) {
- graph_properties_.reset(new GraphProperties(item));
- // Shapes are only needed in aggressive mode.
- TF_RETURN_IF_ERROR(graph_properties_->InferStatically());
- TF_RETURN_IF_ERROR(
- graph_properties_->AnnotateOutputShapes(optimized_graph_));
- }
+ graph_properties_.reset(new GraphProperties(item));
+ // Shapes are only needed in aggressive mode.
+ TF_RETURN_IF_ERROR(graph_properties_->InferStatically(false));
+ TF_RETURN_IF_ERROR(graph_properties_->AnnotateOutputShapes(optimized_graph_));
// Perform the optimizations.
DedupComputations();
diff --git a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
index e8a18ff9d9..80f42694d9 100644
--- a/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/arithmetic_optimizer_test.cc
@@ -32,6 +32,21 @@ string OptimizedName(const string& name) {
return AddPrefixToNodeName(name, kArithmeticOptimizer);
}
+void VerifyGraphsMatch(const GraphDef& original_graph,
+ const GraphDef& optimized_graph, int line) {
+ EXPECT_EQ(original_graph.node_size(), optimized_graph.node_size()) << line;
+ for (int i = 0; i < original_graph.node_size(); ++i) {
+ const NodeDef& original = original_graph.node(i);
+ const NodeDef& optimized = optimized_graph.node(i);
+ EXPECT_EQ(original.name(), optimized.name()) << line;
+ EXPECT_EQ(original.op(), optimized.op()) << line;
+ EXPECT_EQ(original.input_size(), optimized.input_size()) << line;
+ for (int j = 0; j < original.input_size(); ++j) {
+ EXPECT_EQ(original.input(j), optimized.input(j)) << line;
+ }
+ }
+}
+
class ArithmeticOptimizerTest : public ::testing::Test {};
TEST_F(ArithmeticOptimizerTest, NoOp) {
@@ -44,18 +59,7 @@ TEST_F(ArithmeticOptimizerTest, NoOp) {
GraphDef output;
Status status = optimizer.Optimize(nullptr, item, &output);
TF_EXPECT_OK(status);
-
- EXPECT_EQ(item.graph.node_size(), output.node_size());
- for (int i = 0; i < item.graph.node_size(); ++i) {
- const NodeDef& original = item.graph.node(i);
- const NodeDef& optimized = output.node(i);
- EXPECT_EQ(original.name(), optimized.name());
- EXPECT_EQ(original.op(), optimized.op());
- EXPECT_EQ(original.input_size(), optimized.input_size());
- for (int j = 0; j < original.input_size(); ++j) {
- EXPECT_EQ(original.input(j), optimized.input(j));
- }
- }
+ VerifyGraphsMatch(item.graph, output, __LINE__);
}
TEST_F(ArithmeticOptimizerTest, OpDedupping) {
@@ -398,39 +402,51 @@ TEST_F(ArithmeticOptimizerTest, TrivialSumsRepeatedAdd) {
}
TEST_F(ArithmeticOptimizerTest, HoistFactor) {
- tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
- Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
- Output y2 = ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2});
- Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
- Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
- Output add = ops::Add(s.WithOpName("add"), mul1, mul2);
- Output id = ops::Identity(s.WithOpName("id"), add);
-
- GrapplerItem item;
- TF_CHECK_OK(s.ToGraphDef(&item.graph));
-
- ArithmeticOptimizer optimizer;
- GraphDef output;
- Status status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
- // Run the optimizer twice to make sure the rewrite is idempotent.
- item.graph.Swap(&output);
- status = optimizer.Optimize(nullptr, item, &output);
- TF_EXPECT_OK(status);
-
- EXPECT_EQ(9, output.node_size());
- const NodeDef& new_add = output.node(8);
- EXPECT_EQ(OptimizedName("add_hoist_add"), new_add.name());
- EXPECT_EQ("y1", new_add.input(0));
- EXPECT_EQ("y2", new_add.input(1));
- const NodeDef& new_mul = output.node(7);
- EXPECT_EQ(OptimizedName("add_hoist_mul"), new_mul.name());
- EXPECT_EQ("x", new_mul.input(0));
- EXPECT_EQ(OptimizedName("add_hoist_add"), new_mul.input(1));
- const NodeDef& new_id = output.node(6);
- EXPECT_EQ("id", new_id.name());
- EXPECT_EQ(OptimizedName("add_hoist_mul"), new_id.input(0));
+ for (bool matching_shapes : {true, false}) {
+ for (bool use_addn : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Const(s.WithOpName("x"), {1.0f, 2.0f}, {1, 2});
+ Output y1 = ops::Const(s.WithOpName("y1"), {3.0f, 4.0f}, {1, 2});
+ Output y2 = matching_shapes
+ ? ops::Const(s.WithOpName("y2"), {5.0f, 6.0f}, {1, 2})
+ : ops::Const(s.WithOpName("y2"), {5.0f}, {1, 1});
+ Output mul1 = ops::Mul(s.WithOpName("mul1"), x, y1);
+ Output mul2 = ops::Mul(s.WithOpName("mul2"), y2, x);
+ Output id =
+ use_addn ? ops::Identity(s.WithOpName("id"),
+ ops::AddN(s.WithOpName("add"), {mul1, mul2}))
+ : ops::Identity(s.WithOpName("id"),
+ ops::Add(s.WithOpName("add"), mul1, mul2));
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ ArithmeticOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+ // Run the optimizer twice to make sure the rewrite is idempotent.
+ item.graph.Swap(&output);
+ status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ if (use_addn && !matching_shapes) {
+ VerifyGraphsMatch(item.graph, output, __LINE__);
+ } else {
+ EXPECT_EQ(9, output.node_size());
+ const NodeDef& new_add = output.node(8);
+ EXPECT_EQ(OptimizedName("add_hoist_add"), new_add.name());
+ EXPECT_EQ("y1", new_add.input(0));
+ EXPECT_EQ("y2", new_add.input(1));
+ const NodeDef& new_mul = output.node(7);
+ EXPECT_EQ(OptimizedName("add_hoist_mul"), new_mul.name());
+ EXPECT_EQ("x", new_mul.input(0));
+ EXPECT_EQ(OptimizedName("add_hoist_add"), new_mul.input(1));
+ const NodeDef& new_id = output.node(6);
+ EXPECT_EQ("id", new_id.name());
+ EXPECT_EQ(OptimizedName("add_hoist_mul"), new_id.input(0));
+ }
+ }
+ }
}
TEST_F(ArithmeticOptimizerTest, FuseConjAndTranspose) {
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc
index c77b2badf4..e0f39c2931 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding.cc
@@ -30,13 +30,16 @@ limitations under the License.
#include "tensorflow/core/grappler/costs/graph_properties.h"
#include "tensorflow/core/grappler/grappler_item.h"
#include "tensorflow/core/grappler/op_types.h"
+#include "tensorflow/core/grappler/utils.h"
#include "tensorflow/core/lib/gtl/cleanup.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
+#include "tensorflow/core/platform/tensor_coding.h"
#include "tensorflow/core/public/version.h"
#include "tensorflow/core/util/bcast.h"
+#include "tensorflow/core/util/saved_tensor_slice_util.h"
namespace tensorflow {
namespace grappler {
@@ -95,7 +98,38 @@ class DeviceSimple : public DeviceBase {
std::unique_ptr<Eigen::ThreadPoolDevice> eigen_device_;
};
+template <typename T>
+bool AllValuesAre(const TensorProto& tensor, const T& value) {
+ // TensorProto represents the content of the tensor in either <type>_val or
+ // tensor_content.
+ typename checkpoint::SaveTypeTraits<T>::RepeatedField* tensor_values =
+ checkpoint::MutableTensorProtoData<T>(const_cast<TensorProto*>(&tensor));
+ if (!tensor_values->empty()) {
+ for (const T& tensor_value : *tensor_values) {
+ if (tensor_value != value) {
+ return false;
+ }
+ }
+ return true;
+ }
+ const auto tensor_content_size = tensor.tensor_content().size();
+ if (tensor_content_size > 0) {
+ CHECK_EQ(0, tensor_content_size % sizeof(T));
+ std::vector<T> raw_values(tensor_content_size / sizeof(T));
+ port::CopyToArray(tensor.tensor_content(),
+ reinterpret_cast<char*>(raw_values.data()));
+ for (int i = 0; i < tensor_content_size / sizeof(T); ++i) {
+ if (raw_values[i] != value) {
+ return false;
+ }
+ }
+ return true;
+ }
+ return false;
+}
+
} // namespace
+
ConstantFolding::ConstantFolding(RewriterConfig::Toggle opt_level,
DeviceBase* cpu_device)
: opt_level_(opt_level), cpu_device_(cpu_device) {
@@ -190,14 +224,21 @@ Status ConvertShapeToConstant(const string& op, const DataType& type,
return Status::OK();
}
-Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
- const GraphProperties& properties) {
+bool ConstantFolding::IsReallyConstant(const NodeDef& node) const {
+ if (!IsConstant(node)) {
+ return false;
+ }
+ // If the node is fed it's not constant anymore.
+ return feed_nodes_.find(node.name()) == feed_nodes_.end();
+}
+
+Status ConstantFolding::MaterializeShapes(const GraphProperties& properties) {
// We may add some nodes to the graph to encode control dependencies: there is
// no need to process these, so only iterate over the nodes of the input
// graph.
- const int node_count = graph_.node_size();
+ const int node_count = graph_->node_size();
for (int i = 0; i < node_count; ++i) {
- NodeDef& node = *graph_.mutable_node(i);
+ NodeDef& node = *graph_->mutable_node(i);
const string op = node.op();
if (op != "Shape" && op != "Size" && op != "Rank" && op != "ShapeN") {
continue;
@@ -241,7 +282,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
// cases where the shape/rank/size would have been run in
// the original graph. Additional inputs are extra control
string ctrl_dep =
- AddControlDependency(node.input(0), &graph_, node_map_.get());
+ AddControlDependency(node.input(0), graph_, node_map_.get());
node.set_input(0, ctrl_dep);
node_map_->AddOutput(NodeName(ctrl_dep), node.name());
} else {
@@ -256,7 +297,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
AddPrefixToNodeName(strings::StrCat(node.name(), "-", j),
kConstantFoldingConst);
if (node_map_->GetNode(const_name) == nullptr) {
- NodeDef* added_node = graph_.add_node();
+ NodeDef* added_node = graph_->add_node();
added_node->set_name(const_name);
added_node->set_op("Const");
added_node->set_device(node.device());
@@ -267,7 +308,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
// We add a control dependency to the original ShapeN node,
// so that the node will only be run if all inputs of the
// original ShapeN node are run.
- string ctrl_dep = AddControlDependency(node.name(), &graph_,
+ string ctrl_dep = AddControlDependency(node.name(), graph_,
node_map_.get());
*added_node->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), added_node->name());
@@ -285,6 +326,7 @@ Status ConstantFolding::MaterializeShapes(const GrapplerItem& item,
return Status::OK();
}
+namespace {
bool ShapesEqual(const TensorShapeProto& shape1,
const TensorShapeProto& shape2) {
if (shape1.unknown_rank() || shape2.unknown_rank()) {
@@ -297,11 +339,13 @@ bool ShapesEqual(const TensorShapeProto& shape1,
if (shape1.dim(i).size() != shape2.dim(i).size()) {
return false;
}
+ if (shape1.dim(i).size() == -1 || shape2.dim(i).size() == -1) {
+ return false;
+ }
}
return true;
}
-namespace {
bool ExtractShape(const NodeDef& shape_node, const GraphProperties& properties,
BCast::Vec* shape, int64* min_id) {
if (shape_node.op() == "Shape") {
@@ -344,9 +388,9 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
const NodeDef* shape_node1 = node_map_->GetNode(node.input(0));
const NodeDef* shape_node2 = node_map_->GetNode(node.input(1));
if (shape_node1 == nullptr ||
- (shape_node1->op() != "Shape" && shape_node1->op() != "Const") ||
+ (shape_node1->op() != "Shape" && !IsReallyConstant(*shape_node1)) ||
shape_node2 == nullptr ||
- (shape_node2->op() != "Shape" && shape_node2->op() != "Const")) {
+ (shape_node2->op() != "Shape" && !IsReallyConstant(*shape_node2))) {
return Status::OK();
}
int64 min_id = 0;
@@ -392,13 +436,13 @@ Status ConstantFolding::MaterializeBroadcastGradientArgs(
strings::StrCat(node.name(), "-", j), kConstantFoldingConst);
out[j] = node_map_->GetNode(const_name);
if (out[j] == nullptr) {
- out[j] = graph_.add_node();
+ out[j] = graph_->add_node();
Tensor value(type, TensorShape({0}));
*out[j] = CreateNodeDef(const_name, TensorValue(&value));
out[j]->set_device(node.device());
node_map_->AddNode(const_name, out[j]);
string ctrl_dep =
- AddControlDependency(node.name(), &graph_, node_map_.get());
+ AddControlDependency(node.name(), graph_, node_map_.get());
*out[j]->add_input() = ctrl_dep;
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
}
@@ -426,7 +470,7 @@ Status ConstantFolding::MaterializeReductionIndices(
return Status::OK();
}
const NodeDef* indices = node_map_->GetNode(node->input(1));
- if (!indices || IsConstant(*indices)) {
+ if (!indices || IsReallyConstant(*indices)) {
// The reduction indices are already constant, there's nothing to do.
return Status::OK();
}
@@ -479,7 +523,7 @@ Status ConstantFolding::MaterializeReductionIndices(
if (node_map_->GetNode(const_name)) {
return Status::OK();
}
- NodeDef* reduction_indices = graph_.add_node();
+ NodeDef* reduction_indices = graph_->add_node();
Tensor value(dtype, TensorShape({rank}));
for (int i = 0; i < rank; ++i) {
if (dtype == DT_INT32) {
@@ -491,7 +535,7 @@ Status ConstantFolding::MaterializeReductionIndices(
*reduction_indices = CreateNodeDef(const_name, TensorValue(&value));
reduction_indices->set_device(node->device());
string ctrl_dep =
- AddControlDependency(node->input(1), &graph_, node_map_.get());
+ AddControlDependency(node->input(1), graph_, node_map_.get());
*reduction_indices->add_input() = ctrl_dep;
node_map_->AddNode(const_name, reduction_indices);
node_map_->AddOutput(NodeName(ctrl_dep), const_name);
@@ -504,10 +548,10 @@ Status ConstantFolding::MaterializeReductionIndices(
}
Status ConstantFolding::MaterializeConstants(
- const GrapplerItem& item, const GraphProperties& properties) {
- const int node_count = graph_.node_size();
+ const GraphProperties& properties) {
+ const int node_count = graph_->node_size();
for (int i = 0; i < node_count; ++i) {
- NodeDef& node = *graph_.mutable_node(i);
+ NodeDef& node = *graph_->mutable_node(i);
const string& op = node.op();
if (op == "BroadcastGradientArgs") {
TF_RETURN_IF_ERROR(MaterializeBroadcastGradientArgs(node, properties));
@@ -523,24 +567,23 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
if (node.input().empty()) {
return false;
}
-
// Skips nodes that must be preserved except whitelisted nodes.
if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() &&
nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) {
return false;
}
-
- // Skips ops that don't benefit from folding.
- const string& op = node.op();
- // Skip constants, they're already folded
- if (op == "Const") {
+ // Skip control flow nodes, they can't be folded
+ if (ModifiesFrameInfo(node)) {
return false;
}
- // Skip constrol flow nodes, they can't be folded
- if (op == "Enter" || op == "RefEnter" || op == "Exit" || op == "RefExit" ||
- op == "NextIteration" || op == "RefNextIteration") {
+ // Skip constants, they're already folded
+ if (IsConstant(node)) {
return false;
}
+
+ // Skips ops that don't benefit from folding.
+ const string& op = node.op();
+
if (op.find("Placeholder") == 0) {
return false;
}
@@ -594,7 +637,7 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
if (!input_node) {
return false;
}
- bool is_const = IsConstant(*input_node);
+ bool is_const = IsReallyConstant(*input_node);
if (!is_const && !is_merge) {
return false;
}
@@ -612,6 +655,36 @@ bool ConstantFolding::IsFoldable(const NodeDef& node) const {
return true;
}
+namespace {
+
+#define SET_TENSOR_VAL_CASE(DTYPE, TYPE, NAME) \
+ case DTYPE: \
+ t->add_##NAME##_val(static_cast<TYPE>(value)); \
+ break;
+
+Status CreateConstantTensorAttrValue(DataType type, double value,
+ const TensorShapeProto& shape,
+ AttrValue* attr_tensor) {
+ TensorProto* t = attr_tensor->mutable_tensor();
+ *t->mutable_tensor_shape() = shape;
+ switch (type) {
+ SET_TENSOR_VAL_CASE(DT_FLOAT, float, float);
+ SET_TENSOR_VAL_CASE(DT_DOUBLE, double, double);
+ SET_TENSOR_VAL_CASE(DT_INT64, int64, int64);
+ SET_TENSOR_VAL_CASE(DT_INT32, int32, int);
+ SET_TENSOR_VAL_CASE(DT_INT16, int32, int);
+ SET_TENSOR_VAL_CASE(DT_INT8, int32, int);
+ SET_TENSOR_VAL_CASE(DT_UINT8, int32, int);
+ SET_TENSOR_VAL_CASE(DT_BOOL, bool, bool);
+ default:
+ return errors::InvalidArgument("Unsupported type: ", type);
+ }
+ return Status::OK();
+}
+
+#undef SET_TENSOR_CAL_CASE
+} // namespace
+
// static
NodeDef ConstantFolding::CreateNodeDef(const string& name,
const TensorValue& tensor) {
@@ -652,6 +725,14 @@ NodeDef ConstantFolding::CreateNodeDef(const string& name,
POPULATE_TENSOR_PROTO(tensor, t, int64, int64)
} else if (tensor->dtype() == DT_INT32) {
POPULATE_TENSOR_PROTO(tensor, t, int32, int)
+ } else if (tensor->dtype() == DT_INT16) {
+ POPULATE_TENSOR_PROTO(tensor, t, int16, int)
+ } else if (tensor->dtype() == DT_INT8) {
+ POPULATE_TENSOR_PROTO(tensor, t, int8, int)
+ } else if (tensor->dtype() == DT_UINT8) {
+ POPULATE_TENSOR_PROTO(tensor, t, uint8, int)
+ } else if (tensor->dtype() == DT_BOOL) {
+ POPULATE_TENSOR_PROTO(tensor, t, bool, bool)
}
}
if (optimized) {
@@ -720,7 +801,7 @@ Status ConstantFolding::EvaluateOneFoldable(const NodeDef& node,
break;
}
const NodeDef* input_node = node_map_->GetNode(input);
- if (!IsConstant(*input_node)) {
+ if (!IsReallyConstant(*input_node)) {
return Status(error::INVALID_ARGUMENT,
strings::StrCat("Can't fold ", node.name(), ", its ", input,
" isn't constant"));
@@ -774,7 +855,7 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
continue;
}
NodeDef* input_node = node_map_->GetNode(input);
- if (!IsConstant(*input_node)) {
+ if (!IsReallyConstant(*input_node)) {
continue;
}
bool valid_input = true;
@@ -955,8 +1036,8 @@ Status ConstantFolding::FoldNode(NodeDef* node, GraphDef* output_graph) {
Status ConstantFolding::FoldGraph(GraphDef* output) {
std::unordered_set<string> processed_nodes;
std::deque<NodeDef*> queue;
- for (int i = 0; i < graph_.node_size(); i++) {
- auto node = graph_.mutable_node(i);
+ for (int i = 0; i < graph_->node_size(); i++) {
+ auto node = graph_->mutable_node(i);
if (IsFoldable(*node)) {
queue.push_back(node);
}
@@ -995,7 +1076,7 @@ Status ConstantFolding::FoldGraph(GraphDef* output) {
output->mutable_node()->DeleteSubrange(last + 1,
output->node_size() - last - 1);
- for (const auto& node : graph_.node()) {
+ for (const auto& node : graph_->node()) {
// If no fetch nodes is provided, we conservatively
// keep all nodes in the original graph in case users need to fetch
// their values.
@@ -1016,7 +1097,7 @@ bool ConstantFolding::IsSimplifiableReduction(const NodeDef& node) const {
if (IsReduction(node)) {
CHECK_LE(2, node.input_size());
const NodeDef* reductions_indices = node_map_->GetNode(node.input(1));
- if (IsConstant(*reductions_indices)) {
+ if (IsReallyConstant(*reductions_indices)) {
TensorVector output;
Status s = EvaluateNode(*reductions_indices, TensorVector(), &output);
if (!s.ok()) {
@@ -1040,7 +1121,7 @@ bool ConstantFolding::IsSimplifiableReshape(
}
CHECK_LE(2, node.input_size());
const NodeDef* new_shape = node_map_->GetNode(node.input(1));
- if (!IsConstant(*new_shape)) {
+ if (!IsReallyConstant(*new_shape)) {
return false;
}
TensorVector outputs;
@@ -1090,8 +1171,107 @@ bool ConstantFolding::IsSimplifiableReshape(
return shape.IsCompatibleWith(new_dims);
}
+#define IS_VALUE_CASE(DTYPE, VALUE) \
+ case DTYPE: \
+ return AllValuesAre<EnumToDataType<DTYPE>::Type>( \
+ node.attr().at("value").tensor(), EnumToDataType<DTYPE>::Type(VALUE))
+
+#define IS_ONES_CASE(TYPE) IS_VALUE_CASE(TYPE, 1)
+#define IS_ZEROS_CASE(TYPE) IS_VALUE_CASE(TYPE, 0)
+
+bool ConstantFolding::IsOnes(const NodeDef& node) const {
+ if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
+ return false;
+ }
+ if (node.op() == "OnesLike") {
+ return true;
+ }
+ if (node.op() != "Const") {
+ return false;
+ }
+ const auto dtype = node.attr().at("dtype").type();
+ switch (dtype) {
+ // IS_ONES_CASE(DT_HALF);
+ IS_ONES_CASE(DT_FLOAT);
+ IS_ONES_CASE(DT_DOUBLE);
+ IS_ONES_CASE(DT_UINT8);
+ IS_ONES_CASE(DT_INT8);
+ IS_ONES_CASE(DT_UINT16);
+ IS_ONES_CASE(DT_INT16);
+ IS_ONES_CASE(DT_INT32);
+ IS_ONES_CASE(DT_INT64);
+ IS_ONES_CASE(DT_COMPLEX64);
+ IS_ONES_CASE(DT_COMPLEX128);
+ default:
+ LOG(ERROR) << "Unexpected type " << DataTypeString(dtype);
+ return false;
+ }
+ return false;
+}
+
+bool ConstantFolding::IsZeros(const NodeDef& node) const {
+ if (feed_nodes_.find(node.name()) != feed_nodes_.end()) {
+ return false;
+ }
+ if (node.op() == "ZerosLike") {
+ return true;
+ }
+ if (!IsConstant(node)) {
+ return false;
+ }
+ const auto dtype = node.attr().at("dtype").type();
+ switch (dtype) {
+ // IS_ZEROS_CASE(DT_HALF);
+ IS_ZEROS_CASE(DT_FLOAT);
+ IS_ZEROS_CASE(DT_DOUBLE);
+ IS_ZEROS_CASE(DT_UINT8);
+ IS_ZEROS_CASE(DT_INT8);
+ IS_ZEROS_CASE(DT_UINT16);
+ IS_ZEROS_CASE(DT_INT16);
+ IS_ZEROS_CASE(DT_INT32);
+ IS_ZEROS_CASE(DT_INT64);
+ IS_ZEROS_CASE(DT_COMPLEX64);
+ IS_ZEROS_CASE(DT_COMPLEX128);
+ default:
+ LOG(ERROR) << "Unexpected type " << DataTypeString(dtype);
+ return false;
+ }
+ return false;
+}
+
+void ConstantFolding::ReplaceAddOrMulWithIdentity(int input_to_forward,
+ NodeDef* node) {
+ node->set_op("Identity");
+ // Propagate the designated input through the identity.
+ node->mutable_input()->SwapElements(0, input_to_forward);
+ // Add all other inputs as control dependencies.
+ for (int i = 1; i < node->input_size(); ++i) {
+ node->set_input(i, AsControlDependency(node->input(i)));
+ }
+ graph_modified_ = true;
+}
+
+Status ConstantFolding::ReplaceAddOrMulWithConstant(
+ double value, const TensorShapeProto& shape, NodeDef* node) {
+ AttrValue tensor_attr;
+ TF_RETURN_IF_ERROR(CreateConstantTensorAttrValue(node->attr().at("T").type(),
+ value, shape, &tensor_attr));
+ node->mutable_attr()->insert({"value", tensor_attr});
+ node->set_op("Const");
+ // Convert all inputs to control dependencies.
+ for (int i = 0; i < node->input_size(); ++i) {
+ if (IsControlInput(node->input(i))) {
+ break;
+ }
+ node->set_input(i, AsControlDependency(node->input(i)));
+ }
+ graph_modified_ = true;
+ return Status::OK();
+}
+
Status ConstantFolding::SimplifyGraph(GraphDef* output,
- const GraphProperties& properties) {
+ const GraphProperties& properties,
+ bool use_shape_info) {
for (auto& node : *output->mutable_node()) {
if (IsSimplifiableReduction(node)) {
// Replace the reduction node with an identity node, that can be further
@@ -1116,10 +1296,10 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
*node.add_input() = input;
}
}
- // It's possible to feed a placeholder with a tensor that doesn't have the
- // proper shape, and reshape this tensor later on. Therefore only remove
- // reshapes in graphs that don't have placeholders.
- if (IsSimplifiableReshape(node, properties)) {
+ const bool safe_to_use_shapes =
+ use_shape_info &&
+ (feed_nodes_.empty() || opt_level_ == RewriterConfig::AGGRESSIVE);
+ if (safe_to_use_shapes && IsSimplifiableReshape(node, properties)) {
const NodeDef* new_shape = node_map_->GetNode(node.input(1));
DataType output_type = node.attr().at("T").type();
node.set_op("Identity");
@@ -1134,6 +1314,63 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
*node.add_input() = input;
}
}
+
+ // Simplify multiplication by ones or zeros, and addition of zeros.
+ bool is_mul = IsMul(node);
+ bool is_add = IsAdd(node);
+ if (opt_level_ == RewriterConfig::AGGRESSIVE && use_shape_info &&
+ (is_mul || is_add) && properties.HasInputProperties(node.name()) &&
+ properties.HasOutputProperties(node.name())) {
+ const NodeDef* x = node_map_->GetNode(node.input(0));
+ const NodeDef* y = node_map_->GetNode(node.input(1));
+ if (x == nullptr || y == nullptr) {
+ return errors::InvalidArgument("Invalid inputs to node: ",
+ node.DebugString());
+ }
+ const TensorShapeProto& output_shape =
+ properties.GetOutputProperties(node.name())[0].shape();
+ const TensorShapeProto& x_shape =
+ properties.GetInputProperties(node.name())[0].shape();
+
+ // Simplify multiplication by or addition of zeros.
+ const bool x_is_zero = IsZeros(*x);
+ const bool x_matches_output_shape = ShapesEqual(output_shape, x_shape);
+ if (x_is_zero && x_matches_output_shape) {
+ // 0 * y = 0 or 0 + y = y.
+ ReplaceAddOrMulWithIdentity(is_mul ? 0 : 1, &node);
+ continue;
+ }
+ const TensorShapeProto& y_shape =
+ properties.GetInputProperties(node.name())[1].shape();
+ const bool y_is_zero = IsZeros(*y);
+ const bool y_matches_output_shape = ShapesEqual(output_shape, y_shape);
+ if (y_is_zero && y_matches_output_shape) {
+ // x * 0 = 0 or x + 0 = x.
+ ReplaceAddOrMulWithIdentity(is_mul ? 1 : 0, &node);
+ continue;
+ }
+
+ if (is_mul) {
+ // Simplify multiplication by zeros where the output shape does not
+ // match the shape of the zero input.
+ if (x_is_zero || y_is_zero) {
+ TF_RETURN_IF_ERROR(
+ ReplaceAddOrMulWithConstant(0, output_shape, &node));
+ continue;
+ }
+
+ // Simplify multiplication by ones.
+ if (IsOnes(*x) && y_matches_output_shape) {
+ // 1 * y = y.
+ ReplaceAddOrMulWithIdentity(1, &node);
+ continue;
+ } else if (IsOnes(*y) && x_matches_output_shape) {
+ // x * 1 = x.
+ ReplaceAddOrMulWithIdentity(0, &node);
+ continue;
+ }
+ }
+ }
}
return Status::OK();
}
@@ -1141,7 +1378,7 @@ Status ConstantFolding::SimplifyGraph(GraphDef* output,
Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
const GrapplerItem& item,
GraphDef* output) {
- node_map_.reset(new NodeMap(&graph_));
+ node_map_.reset(new NodeMap(graph_));
nodes_whitelist_.clear();
// Fold fetch nodes iff it has a single fanout. Note that if a fetch node
// has a single fanout, it would be rewritten as a constant with the same
@@ -1158,36 +1395,34 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster,
}
GraphProperties properties(item);
- const bool has_feed = !item.feed.empty();
- bool needs_shapes = !has_feed || opt_level_ == RewriterConfig::AGGRESSIVE;
- Status s = errors::Unknown(
- "The graph properties are needed but were not initialized");
- if (needs_shapes) {
- s = properties.InferStatically();
- }
-
- if (!has_feed && s.ok()) {
- // Only use static shape information when there is no feed in the
- // graph. That's because it's possible to feed a placeholder with a tensor
- // of any shape, which could make the static information inconsistent with
- // the shapes actually fed.
- TF_RETURN_IF_ERROR(MaterializeShapes(item, properties));
- }
- if (opt_level_ == RewriterConfig::AGGRESSIVE && s.ok()) {
- TF_RETURN_IF_ERROR(MaterializeConstants(item, properties));
+ // It's possible to feed a placeholder with a tensor of any shape: make sure
+ // that the shape inference deals with this conservatively unless we're in
+ // aggressive mode.
+ const bool assume_valid_feeds = opt_level_ == RewriterConfig::AGGRESSIVE;
+ Status s = properties.InferStatically(assume_valid_feeds);
+ const bool can_use_shape_info = s.ok();
+
+ if (can_use_shape_info) {
+ TF_RETURN_IF_ERROR(MaterializeShapes(properties));
+
+ if (opt_level_ == RewriterConfig::AGGRESSIVE) {
+ TF_RETURN_IF_ERROR(MaterializeConstants(properties));
+ }
}
TF_RETURN_IF_ERROR(FoldGraph(output));
- if (!has_feed && s.ok()) {
- TF_RETURN_IF_ERROR(SimplifyGraph(output, properties));
- }
+ TF_RETURN_IF_ERROR(SimplifyGraph(output, properties, can_use_shape_info));
+
return Status::OK();
}
Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
GraphDef* output) {
nodes_to_preserve_ = item.NodesToPreserve();
+ for (const auto& feed : item.feed) {
+ feed_nodes_.insert(NodeName(feed.first));
+ }
if (cpu_device_ == nullptr) {
owned_device_.reset(new DeviceSimple());
@@ -1200,13 +1435,13 @@ Status ConstantFolding::Optimize(Cluster* cluster, const GrapplerItem& item,
*output = item.graph;
int64 node_count;
do {
- graph_.Swap(output);
- item_to_optimize.graph = graph_;
+ graph_modified_ = false;
+ item_to_optimize.graph.Swap(output);
+ graph_ = &item_to_optimize.graph;
*output = GraphDef();
- node_count = graph_.node_size();
+ node_count = graph_->node_size();
TF_RETURN_IF_ERROR(RunOptimizationPass(cluster, item_to_optimize, output));
- } while (output->node_size() != node_count);
-
+ } while (graph_modified_ || output->node_size() != node_count);
*output->mutable_library() = item.graph.library();
*output->mutable_versions() = item.graph.versions();
diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h
index f04f413c10..3bb9926338 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding.h
+++ b/tensorflow/core/grappler/optimizers/constant_folding.h
@@ -51,16 +51,16 @@ class ConstantFolding : public GraphOptimizer {
const GraphDef& optimize_output, double result) override;
private:
- Status MaterializeShapes(const GrapplerItem& item,
- const GraphProperties& properties);
+ bool IsReallyConstant(const NodeDef& node) const;
+
+ Status MaterializeShapes(const GraphProperties& properties);
Status MaterializeBroadcastGradientArgs(const NodeDef& node,
const GraphProperties& properties);
Status MaterializeReductionIndices(NodeDef* node,
const GraphProperties& properties);
- Status MaterializeConstants(const GrapplerItem& item,
- const GraphProperties& properties);
+ Status MaterializeConstants(const GraphProperties& properties);
bool IsFoldable(const NodeDef& node) const;
Status EvaluateNode(const NodeDef& node,
@@ -72,12 +72,19 @@ class ConstantFolding : public GraphOptimizer {
Status FoldNode(NodeDef* node, GraphDef* output_graph);
+ bool IsOnes(const NodeDef& node) const;
+ bool IsZeros(const NodeDef& node) const;
+ void ReplaceAddOrMulWithIdentity(int input_to_forward, NodeDef* node);
+ Status ReplaceAddOrMulWithConstant(double value,
+ const TensorShapeProto& shape,
+ NodeDef* node);
Status FoldGraph(GraphDef* output);
bool IsSimplifiableReduction(const NodeDef& node) const;
bool IsSimplifiableReshape(const NodeDef& node,
const GraphProperties& properties) const;
- Status SimplifyGraph(GraphDef* output, const GraphProperties& properties);
+ Status SimplifyGraph(GraphDef* output, const GraphProperties& properties,
+ bool use_shape_info);
Status RunOptimizationPass(Cluster* cluster, const GrapplerItem& item,
GraphDef* output);
@@ -88,11 +95,13 @@ class ConstantFolding : public GraphOptimizer {
std::unique_ptr<DeviceBase> owned_device_;
std::unique_ptr<ResourceMgr> resource_mgr_;
- GraphDef graph_;
+ GraphDef* graph_;
std::unique_ptr<NodeMap> node_map_;
std::unordered_set<string> nodes_to_preserve_;
std::unordered_set<string> nodes_whitelist_;
+ std::unordered_set<string> feed_nodes_;
bool has_fetch_;
+ bool graph_modified_;
};
} // end namespace grappler
diff --git a/tensorflow/core/grappler/optimizers/constant_folding_test.cc b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
index b2d9b02c68..32a691d3ee 100644
--- a/tensorflow/core/grappler/optimizers/constant_folding_test.cc
+++ b/tensorflow/core/grappler/optimizers/constant_folding_test.cc
@@ -77,11 +77,166 @@ TEST_F(ConstantFoldingTest, SimpleFolding) {
test::ExpectTensorEqual<float>(tensors_expected[0], tensors[0]);
}
+TEST_F(ConstantFoldingTest, NeutralElement) {
+ for (bool use_const : {true, false}) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ Output x = ops::Placeholder(s.WithOpName("x"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({1, 2})));
+ Output y = ops::Placeholder(s.WithOpName("y"), DT_FLOAT,
+ ops::Placeholder::Shape(TensorShape({1, 2})));
+ Output zeros =
+ !use_const ? ops::ZerosLike(s.WithOpName("zeros"), x)
+ : ops::Const(s.WithOpName("zeros"), {0.0f, 0.0f}, {1, 2});
+ Output zeros_broadcast =
+ ops::Const(s.WithOpName("zeros_broadcast"), {0.0f}, {1, 1});
+ Output ones = !use_const
+ ? ops::OnesLike(s.WithOpName("ones"), x)
+ : ops::Const(s.WithOpName("ones"), {1.0f, 1.0f}, {1, 2});
+ Output mul1 = ops::Mul(s.WithOpName("mul1"), x, zeros);
+ Output mul2 = ops::Mul(s.WithOpName("mul2"), zeros, y);
+ Output mul3 = ops::Mul(s.WithOpName("mul3"), x, ones);
+ Output mul4 = ops::Mul(s.WithOpName("mul4"), ones, y);
+ Output mul5 = ops::Mul(s.WithOpName("mul1"), x, zeros_broadcast);
+ Output mul6 = ops::Mul(s.WithOpName("mul2"), zeros_broadcast, y);
+ Output add1 = ops::Add(s.WithOpName("add1"), x, zeros);
+ Output add2 = ops::Add(s.WithOpName("add2"), zeros, y);
+ Output addn = ops::AddN(s, {mul1, mul2, mul3, mul4, add1, add2});
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+
+ ConstantFolding optimizer(RewriterConfig::AGGRESSIVE,
+ nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = optimizer.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(14, output.node_size());
+ for (int i = 0; i < output.node_size(); ++i) {
+ const NodeDef& node = output.node(i);
+ const string& name = node.name();
+ if (name == "mul1") {
+ if (use_const) {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^x", node.input(0));
+ } else {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("zeros", node.input(0));
+ EXPECT_EQ("^x", node.input(1));
+ }
+ } else if (name == "mul2") {
+ if (use_const) {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^y", node.input(0));
+ } else {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("zeros", node.input(0));
+ EXPECT_EQ("^y", node.input(1));
+ }
+ } else if (name == "mul3") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^ones", node.input(1));
+ } else if (name == "mul4") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("y", node.input(0));
+ EXPECT_EQ("^ones", node.input(1));
+ } else if (name == "mul5") {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^x", node.input(0));
+ EXPECT_EQ("^ones", node.input(1));
+ TensorProto t = node.attr().at("value").tensor();
+ EXPECT_EQ(1, t.float_val_size());
+ EXPECT_EQ(0, t.float_val(0));
+ EXPECT_EQ(2, t.tensor_shape().dim_size());
+ EXPECT_EQ(1, t.tensor_shape().dim(0).size());
+ EXPECT_EQ(2, t.tensor_shape().dim(1).size());
+ } else if (name == "mul6") {
+ EXPECT_EQ("Const", node.op());
+ EXPECT_EQ("^y", node.input(0));
+ EXPECT_EQ("^ones", node.input(1));
+ TensorProto t = node.attr().at("value").tensor();
+ EXPECT_EQ(1, t.float_val_size());
+ EXPECT_EQ(0, t.float_val(0));
+ EXPECT_EQ(2, t.tensor_shape().dim_size());
+ EXPECT_EQ(1, t.tensor_shape().dim(0).size());
+ EXPECT_EQ(2, t.tensor_shape().dim(1).size());
+ } else if (name == "add1") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("x", node.input(0));
+ EXPECT_EQ("^zeros", node.input(1));
+ } else if (name == "add2") {
+ EXPECT_EQ("Identity", node.op());
+ EXPECT_EQ("y", node.input(0));
+ EXPECT_EQ("^zeros", node.input(1));
+ }
+ }
+ }
+}
+
+TEST_F(ConstantFoldingTest, CreateConstNodes) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+
+#define MAKE_TEST_GRAPH(TYPE) \
+ Output TYPE##_const = \
+ ops::Const(s.WithOpName(#TYPE "_const"), static_cast<TYPE>(10), {5}); \
+ Output TYPE##_mul = \
+ ops::Mul(s.WithOpName(#TYPE "_mul"), TYPE##_const, TYPE##_const); \
+ Output TYPE##_id = ops::Identity(s.WithOpName(#TYPE "_id"), TYPE##_mul)
+
+ MAKE_TEST_GRAPH(float);
+ MAKE_TEST_GRAPH(double);
+ MAKE_TEST_GRAPH(int64);
+ MAKE_TEST_GRAPH(int32);
+ MAKE_TEST_GRAPH(int16);
+ MAKE_TEST_GRAPH(int8);
+ MAKE_TEST_GRAPH(uint8);
+#undef MAKE_TEST_GRAPH
+
+ Output bool_const = ops::Const(s.WithOpName("bool_const"), true, {5});
+ Output bool_and =
+ ops::LogicalAnd(s.WithOpName("bool_and"), bool_const, bool_const);
+ Output bool_id = ops::Identity(s.WithOpName("bool_id"), bool_and);
+
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ ConstantFolding fold(nullptr /* cpu_device */);
+ GraphDef output;
+ Status status = fold.Optimize(nullptr, item, &output);
+ TF_EXPECT_OK(status);
+
+ EXPECT_EQ(24, output.node_size());
+ for (const NodeDef& node : output.node()) {
+#define CHECK_RESULT(TYPE, FIELD) \
+ if (node.name() == #TYPE "_mul") { \
+ EXPECT_EQ(5, \
+ node.attr().at("value").tensor().tensor_shape().dim(0).size()); \
+ EXPECT_EQ(1, node.attr().at("value").tensor().FIELD##_val_size()); \
+ EXPECT_EQ(10 * 10, node.attr().at("value").tensor().FIELD##_val(0)); \
+ }
+
+ CHECK_RESULT(float, float);
+ CHECK_RESULT(double, double);
+ CHECK_RESULT(int64, int64);
+ CHECK_RESULT(int32, int);
+ CHECK_RESULT(int16, int);
+ CHECK_RESULT(int8, int);
+ CHECK_RESULT(uint8, int);
+#undef CHECK_RESULT
+
+ if (node.name() == "bool_and") {
+ EXPECT_EQ(5,
+ node.attr().at("value").tensor().tensor_shape().dim(0).size());
+ EXPECT_EQ(1, node.attr().at("value").tensor().bool_val_size());
+ EXPECT_EQ(true && true, node.attr().at("value").tensor().bool_val(0));
+ }
+ }
+}
+
TEST_F(ConstantFoldingTest, FoldingNodeWithTwoOutputs) {
// Build a simple graph with a few trivially prunable ops.
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
- Output a = ops::Const(s.WithOpName("a"), 10, {3});
+ Output a = ops::Const(s.WithOpName("a"), 10, {5});
auto b = ops::Unique(s.WithOpName("b"), {a});
Output c = ops::Identity(s.WithOpName("c"), {b.y});
Output d = ops::Identity(s.WithOpName("d"), {b.idx});
@@ -963,3 +1118,5 @@ TEST_F(ConstantFoldingTest, MaterializeReductionIndices) {
} // namespace
} // namespace grappler
} // namespace tensorflow
+
+// LocalWords: NewRootScope
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer.cc b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
index d5563e9d4c..e9436638f0 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer.cc
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
+#include <deque>
#include <unordered_set>
#include "tensorflow/core/framework/attr_value.pb.h"
@@ -69,6 +70,8 @@ std::set<string> GetOpsFormatSupported() {
return ops_format_supported;
}
+// TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled
+// because of the worse performance in some cases.
std::set<string> GetOpsFormatAgnostic() {
std::set<string> ops_format_agnostic = {"Add",
"AddN",
@@ -88,7 +91,7 @@ std::set<string> GetOpsFormatAgnostic() {
"Split",
"SquaredDifference",
"Squeeze",
- "Sub"};
+ /*"Sum",*/ "Sub"};
return ops_format_agnostic;
}
@@ -186,33 +189,6 @@ class GraphProcessor {
return node;
}
- NodeDef* AddNodeReductionConst(const string& name, const string& device) {
- NodeDef* node = graph_->add_node();
- node_map_->AddNode(name, node);
- node->set_name(name);
- node->set_op("Const");
- AttrValue attr_data_type;
- attr_data_type.set_type(DT_INT32);
- node->mutable_attr()->insert({"dtype", attr_data_type});
-
- AttrValue attr_tensor;
- Tensor tensor(DT_INT32, TensorShape({3}));
- std::vector<int> axis = {0, 2, 3};
- for (int i = 0; static_cast<size_t>(i) < axis.size(); i++) {
- tensor.flat<int>()(i) = axis[i];
- }
- tensor.AsProtoTensorContent(attr_tensor.mutable_tensor());
- node->mutable_attr()->insert({"value", attr_tensor});
- string device_name;
- if (device.empty()) {
- device_name = virtual_placer_.get_canonical_device_name(*node);
- } else {
- device_name = device;
- }
- node->set_device(device_name);
- return node;
- }
-
const VirtualPlacer& virtual_placer_;
const std::unordered_set<string>& nodes_to_preserve_;
GraphDef* graph_;
@@ -370,10 +346,20 @@ class NodeProcessor : public GraphProcessor {
LOG(ERROR) << "Failed to parse TensorProto.";
}
if (tensor.dims() == 1) {
- int c = tensor.flat<int>()(3);
- tensor.flat<int>()(3) = tensor.flat<int>()(2);
- tensor.flat<int>()(2) = tensor.flat<int>()(1);
- tensor.flat<int>()(1) = c;
+ if (tensor.flat<int>().size() == 4) {
+ int c = tensor.flat<int>()(3);
+ tensor.flat<int>()(3) = tensor.flat<int>()(2);
+ tensor.flat<int>()(2) = tensor.flat<int>()(1);
+ tensor.flat<int>()(1) = c;
+ } else if (tensor.flat<int>().size() == 3) {
+ tensor.flat<int>()(0) = 0;
+ tensor.flat<int>()(1) = 2;
+ tensor.flat<int>()(2) = 3;
+ } else {
+ return Status(error::INVALID_ARGUMENT,
+ strings::StrCat("Unsupported tensor size: ",
+ tensor.flat<int>().size()));
+ }
} else if (tensor.dims() == 2) {
for (int i = 0; i < 2; i++) {
int c = tensor.matrix<int>()(3, i);
@@ -394,7 +380,9 @@ class NodeProcessor : public GraphProcessor {
Status UpdateAttrValueOfInput(int input_index) {
auto input_node = node_map_->GetNode(node_->input(input_index));
// We created a copy of the node, so that we don't modify the original node,
- // which might be used elsewhere.
+ // which might be used elsewhere. Note that this copy also copies the
+ // control dependency input in the case this node is inside a loop,
+ // to ensure added_node is in the same frame with node_.
NodeDef* added_node = graph_->add_node();
*added_node = *input_node;
string base_name = strings::StrCat(node_->name(), "-", input_node->name());
@@ -411,6 +399,14 @@ class NodeProcessor : public GraphProcessor {
return input_pos;
}
+ virtual std::set<int> GetOutputPos() const {
+ // For most nodes, no need to process control nodes or nodes that use an
+ // output other than the first output: only the first output is of
+ // 4D NCHW/NHWC format and thus relevant here.
+ std::set<int> output_pos = {0};
+ return output_pos;
+ }
+
NodeDef* AddNodeTranspose(const string& node_name, const string& input_name,
const string& const_name, DataType data_type,
const TensorShapeProto& input_shape,
@@ -476,37 +472,28 @@ class NodeProcessor : public GraphProcessor {
auto outputs = node_map_->GetOutputs(node_->name());
string const_name = GetOrAddNodePermNCHWToNHWC();
for (const auto& output : outputs) {
- string base_name = strings::StrCat(node_->name(), "-", output->name());
- string node_name =
- AddPrefixToNodeName(base_name, kTransposeNCHWToNHWC, "-");
- // TODO(yaozhang): handle the rare case where node A is connected to more
- // than one input of node B.
- auto it = std::find_if(output->mutable_input()->begin(),
- output->mutable_input()->end(),
- [this](const string& input) {
- string node_name = NodeName(input);
- return node_name.compare(node_->name()) == 0;
- });
- if (it == output->mutable_input()->end()) {
- return Status(error::INVALID_ARGUMENT,
- strings::StrCat("Expect ", node_->name(),
- " to be an input of ", output->name()));
- }
- int output_pos = NodePosition(*it);
- // No need to process control nodes or nodes that use an output
- // other than the first output: only the first output is of 4D NCHW/NHWC
- // format and thus relevant here.
- if (output_pos != 0) {
- continue;
+ for (int i = 0; i < output->input_size(); i++) {
+ auto& input = *output->mutable_input(i);
+ int input_port;
+ string input_name = ParseNodeName(input, &input_port);
+ auto output_pos = GetOutputPos();
+ if (input_name == node_->name() &&
+ output_pos.find(input_port) != output_pos.end()) {
+ string base_name =
+ strings::StrCat(node_->name(), "-", output->name(), "-", i);
+ string node_name =
+ AddPrefixToNodeName(base_name, kTransposeNCHWToNHWC, "-");
+ TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
+ TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
+ AddNodeTranspose(
+ node_name, input, const_name, node_->attr().at("T").type(),
+ node_->attr().at("_output_shapes").list().shape(0), false);
+ input = node_name;
+ node_map_->AddOutput(node_->name(), node_name);
+ node_map_->AddOutput(node_name, output->name());
+ }
}
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "T"));
- TF_RETURN_IF_ERROR(HasAttribute(*node_, "_output_shapes"));
- AddNodeTranspose(
- node_name, node_->name(), const_name, node_->attr().at("T").type(),
- node_->attr().at("_output_shapes").list().shape(0), false);
- *it = node_name;
- node_map_->UpdateOutput(node_->name(), output->name(), node_name);
- node_map_->AddOutput(node_name, output->name());
+ node_map_->RemoveOutput(node_->name(), output->name());
}
return Status::OK();
}
@@ -775,24 +762,52 @@ class AgnosticNodeProcessor : public NodeProcessor {
bool IsNodeAfterNCHWToNHWC() const {
std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
- auto node = node_map_->GetNode(node_->name());
- while (node->input_size() > 0) {
- int data_input_pos = 0;
- if (IsConcatV1(*node) || IsSplit(*node)) {
- data_input_pos = 1;
- }
- node = node_map_->GetNode(node->input(data_input_pos));
- if (IsNodeNCHWToNHWC(node->name())) {
+ std::deque<NodeDef*> queue;
+ auto first_node_pos = DataInputPos(*node_);
+ for (const auto& pos : first_node_pos) {
+ auto input_node = node_map_->GetNode(node_->input(pos));
+ queue.push_back(input_node);
+ }
+ // The code will exit this while loop in one iteration in most cases, as the
+ // graph is already topologically sorted.
+ while (!queue.empty()) {
+ NodeDef* current_node = queue.front();
+ queue.pop_front();
+ if (IsNodeNCHWToNHWC(current_node->name())) {
return true;
}
- bool connected =
- ops_format_agnostic.find(node->op()) != ops_format_agnostic.end();
- if (!connected) {
- return false;
+ // We only continue searching if the path is connected through
+ // format-agnostic nodes.
+ if (ops_format_agnostic.find(current_node->op()) !=
+ ops_format_agnostic.end()) {
+ auto current_node_pos = DataInputPos(*current_node);
+ for (const auto& pos : current_node_pos) {
+ auto input_node = node_map_->GetNode(current_node->input(pos));
+ queue.push_back(input_node);
+ }
}
}
return false;
}
+
+ private:
+ std::vector<int> DataInputPos(const NodeDef& node) const {
+ std::vector<int> pos;
+ if (IsSplit(node)) {
+ return {1};
+ }
+ if (IsConcatV1(node)) {
+ return {1};
+ }
+ if (IsAdd(node) || IsMul(node) || IsRealDiv(node) ||
+ IsSquaredDifference(node) || IsSub(node)) {
+ return {0, 1};
+ }
+ if (node.input_size() > 0 && !IsControlInput(node.input(0))) {
+ return {0};
+ }
+ return {};
+ }
};
class AddNProcessor : public AgnosticNodeProcessor {
@@ -815,42 +830,49 @@ class BinaryOpProcessor : public AgnosticNodeProcessor {
public:
explicit BinaryOpProcessor(const OptimizeContext& opt_cxt)
: AgnosticNodeProcessor(opt_cxt) {
- is_4d_with_vector_ = Is4DOperateWithVector();
+ is_4d_with_vector_ = IsNDOperateWithMD(4, 1);
}
protected:
bool ShouldProcess() const override {
+ // TODO(yaozhang): Support IsNDOperateWithMD(1, 4): first input is a vector
+ // and the second input is a 4D tensor; and update CustomizedProcessing()
+ // accordingly.
return !MustPreserve() && IsDimsFour(*node_) && HasOutputs() &&
IsNodeAfterNCHWToNHWC() &&
- (Is4DOperateWithND(4) || Is4DOperateWithScalar() ||
- Is4DOperateWithVector()) &&
+ (IsNDOperateWithMD(4, 0) || IsNDOperateWithMD(4, 1) ||
+ IsNDOperateWithMD(4, 4) || IsNDOperateWithMD(0, 4)) &&
IsOnGPU();
}
std::vector<int> GetInputPos() const override {
- std::vector<int> input_pos = {0};
- if (Is4DOperateWithND(4)) {
+ std::vector<int> input_pos;
+ auto input0 = node_map_->GetNode(node_->input(0));
+ auto input1 = node_map_->GetNode(node_->input(1));
+ if (IsDimsFour(*input0)) {
+ input_pos.push_back(0);
+ }
+ if (IsDimsFour(*input1)) {
input_pos.push_back(1);
}
return input_pos;
}
- bool Is4DOperateWithND(int n) const {
+ bool IsDimsFour(const NodeDef& node) const {
+ return NodeProcessor::IsDimsFour(node) || IsNodeNCHWToNHWC(node.name());
+ }
+
+ bool IsNDOperateWithMD(int n, int m) const {
auto input0 = node_map_->GetNode(node_->input(0));
auto input1 = node_map_->GetNode(node_->input(1));
if (input0 && input1) {
- return (IsDimsFour(*input0) || IsNodeNCHWToNHWC(input0->name())) &&
- ((n == 4)
- ? (IsDimsFour(*input1) || IsNodeNCHWToNHWC(input1->name()))
- : IsDimsN(*input1, n));
+ bool input0_is_n = (n == 4) ? IsDimsFour(*input0) : IsDimsN(*input0, n);
+ bool input1_is_m = (m == 4) ? IsDimsFour(*input1) : IsDimsN(*input1, m);
+ return input0_is_n && input1_is_m;
}
return false;
}
- bool Is4DOperateWithScalar() const { return Is4DOperateWithND(0); }
-
- bool Is4DOperateWithVector() const { return Is4DOperateWithND(1); }
-
NodeDef* AddNodeShapeConst(const string& name, int num_channels) {
NodeDef* node = graph_->add_node();
node_map_->AddNode(name, node);
@@ -948,7 +970,7 @@ class ConcatProcessor : public AgnosticNodeProcessor {
}
Status CustomizedProcessing() override {
- string concat_const_name = GetOrAddNodeConcatConst();
+ string concat_const_name = AddNodeConcatConst()->name();
node_map_->AddOutput(concat_const_name, node_->name());
*node_->mutable_input(axis_node_pos_) = concat_const_name;
return Status::OK();
@@ -956,8 +978,14 @@ class ConcatProcessor : public AgnosticNodeProcessor {
bool IsAlongDimC() const {
auto axis_node = node_map_->GetNode(node_->input(axis_node_pos_));
+ if (!IsConstant(*axis_node)) {
+ return false;
+ }
if (axis_node->attr().find("value") != axis_node->attr().end()) {
- return axis_node->attr().at("value").tensor().int_val(0) == 3;
+ auto tensor = axis_node->attr().at({"value"}).tensor();
+ if (tensor.tensor_shape().dim_size() == 0 && tensor.int_val_size() == 1) {
+ return tensor.int_val(0) == 3;
+ }
}
return false;
}
@@ -965,28 +993,18 @@ class ConcatProcessor : public AgnosticNodeProcessor {
int axis_node_pos_;
private:
- NodeDef* AddNodeConcatConst(const string& suffix, const string& depended_node,
- const string& device) {
- auto const_node = AddNodeConstScalar(
- strings::StrCat(kConcatConst, "-", suffix), device, DT_INT32, 1);
- // This is to ensure the concat node and the const node are
- // in the same frame.
- *const_node->add_input() = AsControlDependency(depended_node);
- return const_node;
- }
-
- string GetOrAddNodeConcatConst() {
- string const_name;
- if (is_in_frame_) {
- int value_node_pos = (axis_node_pos_ == 0) ? 1 : 0;
- auto const_node = AddNodeConcatConst(
- node_->name(), NodeName(node_->input(value_node_pos)),
- node_->device());
- const_name = const_node->name();
- } else {
- const_name = kConcatConst;
- }
- return const_name;
+ NodeDef* AddNodeConcatConst() {
+ auto axis_node = node_map_->GetNode(node_->input(axis_node_pos_));
+ // We created a copy of the node, so that we don't modify the original node,
+ // which might be used elsewhere. Note that this copy also copies the
+ // control dependency input in the case this node is inside a loop,
+ // to ensure added_node is in the same frame with node_.
+ auto added_node = graph_->add_node();
+ *added_node = *axis_node;
+ added_node->set_name(strings::StrCat(kConcatConst, "-", node_->name()));
+ added_node->mutable_attr()->at({"value"}).mutable_tensor()->set_int_val(0,
+ 1);
+ return added_node;
}
};
@@ -1036,6 +1054,16 @@ class SplitProcessor : public AgnosticNodeProcessor {
return input_pos;
}
+ std::set<int> GetOutputPos() const override {
+ std::set<int> output_pos{0};
+ if (HasAttribute(*node_, "num_split").ok()) {
+ for (int i = 1; i < node_->attr().at("num_split").i(); i++) {
+ output_pos.insert(i);
+ }
+ }
+ return output_pos;
+ }
+
Status CustomizedProcessing() override {
string split_const_name = AddNodeSplitConst()->name();
node_map_->AddOutput(split_const_name, node_->name());
@@ -1073,7 +1101,7 @@ class SplitProcessor : public AgnosticNodeProcessor {
// We created a copy of the node, so that we don't modify the original node,
// which might be used elsewhere. Note that this copy also copies the
// control dependency input in the case this node is inside a loop,
- // to ensure added_node is in the same frame with the Split node.
+ // to ensure added_node is in the same frame with node_.
NodeDef* added_node = graph_->add_node();
*added_node = *dim_node;
added_node->set_name(strings::StrCat(kSplitConst, "-", node_->name()));
@@ -1329,20 +1357,21 @@ class SumProcessor : public AgnosticNodeProcessor {
Status AddLayoutTransposeToOutputs() override { return Status::OK(); }
- Status CustomizedProcessing() override {
- node_map_->AddOutput(kReductionConst, node_->name());
- *node_->mutable_input(1) = GetOrAddNodeReductionConst();
- return Status::OK();
- }
+ Status CustomizedProcessing() override { return UpdateAttrValueOfInput(1); }
private:
bool IsAlongDimNHW() const {
- NodeDef* node = node_map_->GetNode(node_->input(1));
+ NodeDef* reduction_indices = node_map_->GetNode(node_->input(1));
+ if (!IsConstant(*reduction_indices)) {
+ return false;
+ }
Tensor tensor;
- if (node->attr().find({"value"}) == node->attr().end()) {
+ if (reduction_indices->attr().find({"value"}) ==
+ reduction_indices->attr().end()) {
return false;
}
- auto success = tensor.FromProto(node->attr().at({"value"}).tensor());
+ auto success =
+ tensor.FromProto(reduction_indices->attr().at({"value"}).tensor());
if (!success) {
LOG(ERROR) << "Failed to parse TensorProto.";
return false;
@@ -1356,29 +1385,6 @@ class SumProcessor : public AgnosticNodeProcessor {
}
return false;
}
-
- NodeDef* AddNodeReductionConst(const string& suffix,
- const string& depended_node,
- const string& device) {
- auto const_node = GraphProcessor::AddNodeReductionConst(
- strings::StrCat(kReductionConst, "-", suffix), device);
- // This is to ensure the Sum node and the const node are in the
- // same frame.
- *const_node->add_input() = AsControlDependency(depended_node);
- return const_node;
- }
-
- string GetOrAddNodeReductionConst() {
- string const_name;
- if (is_in_frame_) {
- auto const_node = AddNodeReductionConst(
- node_->name(), NodeName(node_->input(0)), node_->device());
- const_name = const_node->name();
- } else {
- const_name = kReductionConst;
- }
- return const_name;
- }
};
class DataLayoutOptimizer : GraphProcessor {
@@ -1409,18 +1415,10 @@ class DataLayoutOptimizer : GraphProcessor {
return AddNodePermConst(kPermNCHWToNHWC, "", {0, 2, 3, 1});
}
- NodeDef* AddNodeConcatConst() {
- return AddNodeConstScalar(kConcatConst, "", DT_INT32, 1);
- }
-
NodeDef* AddNodeGatherAxisConst() {
return AddNodeConstScalar(kGatherAxisConst, "", DT_INT32, 0);
}
- NodeDef* AddNodeReductionConst() {
- return GraphProcessor::AddNodeReductionConst(kReductionConst, "");
- }
-
// Expand all nodes which is in NHWC, but supports NCHW or is layout agnostic.
Status Expand() {
int node_size_original = graph_->node_size();
@@ -1474,9 +1472,7 @@ class DataLayoutOptimizer : GraphProcessor {
if (graph_->node_size() > node_size_original) {
NodeDef* n = AddNodePermNHWCToNCHW();
n = AddNodePermNCHWToNHWC();
- n = AddNodeConcatConst();
n = AddNodeGatherAxisConst();
- n = AddNodeReductionConst();
std::set<string> ops_format_agnostic = GetOpsFormatAgnostic();
for (int i = 0; i < graph_->node_size(); i++) {
if (ops_format_agnostic.find(graph_->node(i).op()) !=
@@ -1620,27 +1616,20 @@ Status LayoutOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
virtual_placer_.reset(new VirtualPlacer(cluster));
nodes_to_preserve_ = item.NodesToPreserve();
GraphProperties graph_properties(item);
- auto status = graph_properties.InferStatically();
+ auto status = graph_properties.InferStatically(false);
if (!status.ok()) {
*output = item.graph;
return status;
}
TuningConfig config;
- config.no_gemm = false;
+ config.no_gemm = true;
+ // TODO(yaozhang): Enable tuning with various TuningConfig choices wtih
+ // the measurement-based estimator.
status = Tune(item, graph_properties, config, output);
- // This is based on an empirical observation that if the introduced Transpose
- // nodes is more than 30, not using GEMM implementation would result in better
- // performance.
- if (status.ok() && GetNumTranspose(*output) > 30) {
- config.no_gemm = true;
- status = Tune(item, graph_properties, config, output);
- }
-
if (!status.ok()) {
*output = item.graph;
}
-
return status;
}
diff --git a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
index 8c89f6744b..363b4c3fd8 100644
--- a/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
+++ b/tensorflow/core/grappler/optimizers/layout_optimizer_test.cc
@@ -298,6 +298,39 @@ TEST_F(LayoutOptimizerTest, Connectivity) {
EXPECT_EQ(node_i2_output->input(0), "i1");
}
+TEST_F(LayoutOptimizerTest, ConnectivityBinaryOpWithInputScalarAnd4D) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto i1 = ops::Identity(s.WithOpName("i1"), conv);
+ auto i2 = ops::Identity(s.WithOpName("i2"), i1);
+ auto scalar_sub = ops::Const(s.WithOpName("scalar_sub"), 3.0f, {});
+ auto sub = ops::Sub(s.WithOpName("sub"), scalar_sub, i2);
+ auto i3 = ops::Identity(s.WithOpName("i3"), sub);
+ auto i4 = ops::Identity(s.WithOpName("i4"), i3);
+ auto i5 = ops::Identity(s.WithOpName("i5"), i4);
+ auto scalar_mul = ops::Const(s.WithOpName("scalar_mul"), 3.0f, {});
+ auto mul = ops::Mul(s.WithOpName("mul"), scalar_mul, i5);
+ auto i6 = ops::Identity(s.WithOpName("i6"), mul);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ // Make the graph not in topological order to test the handling of multi-hop
+ // connectivity (here we say two nodes are connected if all nodes in the
+ // middle are layout agnostic). If the graph is already in topological order,
+ // the problem is easier, where layout optimizer only needs to check
+ // single-hop connectivity.
+ NodeMap node_map_original(&item.graph);
+ auto node_i1 = node_map_original.GetNode("i1");
+ auto node_mul = node_map_original.GetNode("mul");
+ node_mul->Swap(node_i1);
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map_output(&output);
+ auto mul_node = node_map_output.GetNode("mul");
+ EXPECT_EQ(mul_node->input(0), "scalar_mul");
+ EXPECT_EQ(mul_node->input(1), "i5");
+}
+
TEST_F(LayoutOptimizerTest, PreserveFetch) {
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
auto conv = SimpleConv2D(&s, 3, 2, "VALID");
@@ -495,7 +528,175 @@ TEST_F(LayoutOptimizerTest, SplitNonConstDim) {
auto split_node = node_map.GetNode("split");
EXPECT_EQ(split_node->input(0), "i1");
EXPECT_EQ(split_node->input(1),
- "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-split");
+ "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-split-1");
+}
+
+TEST_F(LayoutOptimizerTest, SplitSamePortToMultipleInputsOfSameNode) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto axis = ops::Const(s.WithOpName("axis"), 3);
+ auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
+ auto concat =
+ ops::Concat(s.WithOpName("concat"), {split[1], split[1], split[1]}, axis);
+ auto o = ops::Identity(s.WithOpName("o"), concat);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto concat_node = node_map.GetNode("concat");
+ EXPECT_EQ(concat_node->input(0), "split:1");
+ EXPECT_EQ(concat_node->input(1), "split:1");
+ EXPECT_EQ(concat_node->input(2), "split:1");
+ EXPECT_EQ(concat_node->input(3), "LayoutOptimizerConcatConst-concat");
+ auto concat_dim = node_map.GetNode("LayoutOptimizerConcatConst-concat");
+ EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
+}
+
+TEST_F(LayoutOptimizerTest, Concat) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto axis = ops::Const(s.WithOpName("axis"), 3);
+ auto split = ops::Split(s.WithOpName("split"), axis, conv, 2);
+ auto concat = ops::Concat(s.WithOpName("concat"), {split[0], split[1]}, axis);
+ auto o = ops::Identity(s.WithOpName("o"), concat);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto concat_node = node_map.GetNode("concat");
+ EXPECT_EQ(concat_node->input(0), "split");
+ EXPECT_EQ(concat_node->input(1), "split:1");
+ EXPECT_EQ(concat_node->input(2), "LayoutOptimizerConcatConst-concat");
+ auto concat_dim = node_map.GetNode("LayoutOptimizerConcatConst-concat");
+ EXPECT_EQ(concat_dim->attr().at({"value"}).tensor().int_val(0), 1);
+}
+
+TEST_F(LayoutOptimizerTest, Sum) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto reduction_indices =
+ ops::Const(s.WithOpName("reduction_indices"), {0, 1, 2}, {3});
+ auto sum = ops::Sum(s.WithOpName("sum"), conv, reduction_indices);
+ auto o = ops::Identity(s.WithOpName("o"), sum);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ // TODO(yaozhang): enable SumProcessor with auto-tuning. Currently disabled
+ // because of the worse performance in some cases.
+ /*
+ NodeMap node_map(&output);
+ auto sum_node = node_map.GetNode("sum");
+ EXPECT_EQ(sum_node->input(0), "Conv2D");
+ EXPECT_EQ(sum_node->input(1), "LayoutOptimizer-sum-reduction_indices");
+ auto sum_const = node_map.GetNode("LayoutOptimizer-sum-reduction_indices");
+ Tensor tensor;
+ EXPECT_TRUE(
+ tensor.FromProto(sum_const->mutable_attr()->at({"value"}).tensor()));
+ Tensor tensor_expected(DT_INT32, {3});
+ test::FillValues<int>(&tensor_expected, {0, 2, 3});
+ test::ExpectTensorEqual<int>(tensor_expected, tensor);
+ */
+}
+
+TEST_F(LayoutOptimizerTest, MulScalarAnd4D) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
+ auto mul = ops::Mul(s.WithOpName("mul"), scalar, conv);
+ auto o = ops::Identity(s.WithOpName("o"), mul);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto mul_node = node_map.GetNode("mul");
+ EXPECT_EQ(mul_node->input(0), "scalar");
+ EXPECT_EQ(mul_node->input(1), "Conv2D");
+}
+
+TEST_F(LayoutOptimizerTest, Mul4DAndScalar) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto scalar = ops::Const(s.WithOpName("scalar"), 3.0f, {});
+ auto mul = ops::Mul(s.WithOpName("mul"), conv, scalar);
+ auto o = ops::Identity(s.WithOpName("o"), mul);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto mul_node = node_map.GetNode("mul");
+ EXPECT_EQ(mul_node->input(0), "Conv2D");
+ EXPECT_EQ(mul_node->input(1), "scalar");
+}
+
+TEST_F(LayoutOptimizerTest, Mul4DAnd4D) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto i = ops::Identity(s.WithOpName("i"), conv);
+ auto mul = ops::Mul(s.WithOpName("mul"), conv, i);
+ auto o = ops::Identity(s.WithOpName("o"), mul);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto mul_node = node_map.GetNode("mul");
+ EXPECT_EQ(mul_node->input(0), "Conv2D");
+ EXPECT_EQ(mul_node->input(1), "i");
+}
+
+TEST_F(LayoutOptimizerTest, Mul4DAndVector) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
+ auto mul = ops::Mul(s.WithOpName("mul"), conv, vector);
+ auto o = ops::Identity(s.WithOpName("o"), mul);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto mul_node = node_map.GetNode("mul");
+ EXPECT_EQ(mul_node->input(0), "Conv2D");
+ EXPECT_EQ(mul_node->input(1), "LayoutOptimizerReshapeNHWCToNCHW-mul-vector");
+ auto mul_const = node_map.GetNode("LayoutOptimizerReshapeConst-mul-vector");
+ Tensor tensor;
+ EXPECT_TRUE(
+ tensor.FromProto(mul_const->mutable_attr()->at({"value"}).tensor()));
+ Tensor tensor_expected(DT_INT32, {4});
+ test::FillValues<int>(&tensor_expected, {1, 2, 1, 1});
+ test::ExpectTensorEqual<int>(tensor_expected, tensor);
+}
+
+TEST_F(LayoutOptimizerTest, MulVectorAnd4D) {
+ tensorflow::Scope s = tensorflow::Scope::NewRootScope();
+ auto conv = SimpleConv2D(&s, 3, 2, "VALID");
+ auto vector = ops::Const(s.WithOpName("vector"), {3.0f, 7.0f}, {2});
+ auto mul = ops::Mul(s.WithOpName("mul"), vector, conv);
+ auto o = ops::Identity(s.WithOpName("o"), mul);
+ GrapplerItem item;
+ TF_CHECK_OK(s.ToGraphDef(&item.graph));
+ LayoutOptimizer optimizer;
+ GraphDef output;
+ Status status = optimizer.Optimize(virtual_cluster_.get(), item, &output);
+ NodeMap node_map(&output);
+ auto mul_node = node_map.GetNode("mul");
+ // TODO(yaozhang): Support vector as the first input and 4d tensor as the
+ // second input for BinaryOpProcessor.
+ EXPECT_EQ(mul_node->input(0), "vector");
+ EXPECT_EQ(mul_node->input(1),
+ "LayoutOptimizerTransposeNCHWToNHWC-Conv2D-mul-1");
}
} // namespace
diff --git a/tensorflow/core/grappler/optimizers/memory_optimizer.cc b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
index 7c44ce15c6..a2a2680c4f 100644
--- a/tensorflow/core/grappler/optimizers/memory_optimizer.cc
+++ b/tensorflow/core/grappler/optimizers/memory_optimizer.cc
@@ -716,7 +716,7 @@ Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
{
// Estimate the size of the data to swap for each node.
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically());
+ TF_RETURN_IF_ERROR(properties.InferStatically(true));
for (auto& swap : nodes_to_swap) {
const NodeDef* node = swap.first;
std::vector<OpInfo::TensorProperties> props =
diff --git a/tensorflow/core/grappler/optimizers/static_schedule.cc b/tensorflow/core/grappler/optimizers/static_schedule.cc
index 6ce6deef2c..450e853407 100644
--- a/tensorflow/core/grappler/optimizers/static_schedule.cc
+++ b/tensorflow/core/grappler/optimizers/static_schedule.cc
@@ -86,7 +86,7 @@ Status EstimateEarliestExecutionTimes(
name_map.clear();
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically());
+ TF_RETURN_IF_ERROR(properties.InferStatically(true));
OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster);
@@ -154,7 +154,7 @@ Status EstimateRequiredTimes(
}
}
GraphProperties properties(item);
- TF_RETURN_IF_ERROR(properties.InferStatically());
+ TF_RETURN_IF_ERROR(properties.InferStatically(true));
OpLevelCostEstimator estimator;
VirtualPlacer placer(cluster);
diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD
index 21411097e8..dcffb28513 100644
--- a/tensorflow/core/kernels/BUILD
+++ b/tensorflow/core/kernels/BUILD
@@ -3923,7 +3923,11 @@ tf_kernel_library(
"scatter_nd_op.h",
"scatter_nd_op_gpu.cu.cc",
],
- deps = STATE_DEPS + [":dense_update_functor"],
+ deps = STATE_DEPS + [
+ ":dense_update_functor",
+ ":training_op_helpers",
+ ":variable_ops",
+ ],
)
tf_kernel_library(
@@ -5833,11 +5837,11 @@ cc_library(
srcs = ["dataset.cc"],
hdrs = ["dataset.h"],
deps = [
+ "//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
- "//tensorflow/core/util/tensor_bundle",
],
)
@@ -6125,6 +6129,18 @@ tf_kernel_library(
)
tf_kernel_library(
+ name = "random_dataset_op",
+ srcs = ["random_dataset_op.cc"],
+ deps = [
+ ":dataset",
+ "//tensorflow/core:dataset_ops_op_lib",
+ "//tensorflow/core:framework",
+ "//tensorflow/core:lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+tf_kernel_library(
name = "range_dataset_op",
srcs = ["range_dataset_op.cc"],
deps = [
@@ -6291,6 +6307,7 @@ tf_kernel_library(
":parallel_interleave_dataset_op",
":parallel_map_dataset_op",
":prefetch_dataset_op",
+ ":random_dataset_op",
":range_dataset_op",
":reader_dataset_ops",
":repeat_dataset_op",
diff --git a/tensorflow/core/kernels/conv_grad_filter_ops.cc b/tensorflow/core/kernels/conv_grad_filter_ops.cc
index 3d2bb57aff..1791c51096 100644
--- a/tensorflow/core/kernels/conv_grad_filter_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_filter_ops.cc
@@ -194,7 +194,23 @@ class Conv2DFastBackpropFilterOp : public OpKernel {
context, (strides_[0] == 1 && strides_[3] == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ // TODO(yangzihao): Add a CPU implementation for dilated convolution.
+ OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
+ errors::InvalidArgument(
+ "Current Eigen and libxsmm implementations do not "
+ "yet support dilation rates larger than 1."));
}
void Compute(OpKernelContext* context) override {
@@ -262,6 +278,7 @@ class Conv2DFastBackpropFilterOp : public OpKernel {
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
@@ -290,7 +307,23 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
context, (strides_[0] == 1 && strides_[3] == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ // TODO(yangzihao): Add a CPU implementation for dilated convolution.
+ OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
+ errors::InvalidArgument(
+ "Current libxsmm and customized CPU implementations do "
+ "not yet support dilation rates larger than 1."));
}
void Compute(OpKernelContext* context) override {
@@ -459,6 +492,7 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
@@ -510,10 +544,30 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
int stride_n = GetTensorDim(strides_, data_format_, 'N');
int stride_c = GetTensorDim(strides_, data_format_, 'C');
+ int stride_h = GetTensorDim(strides_, data_format_, 'H');
+ int stride_w = GetTensorDim(strides_, data_format_, 'W');
OP_REQUIRES(
context, (stride_n == 1 && stride_c == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
@@ -546,13 +600,16 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
// do not support striding on the batch or depth dimension).
const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+ const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
+ const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, input,
- stride_rows, stride_cols, padding_, filter_backprop,
- data_format_);
+ dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
+ filter_backprop, data_format_);
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
bool use_cudnn_;
@@ -566,38 +623,46 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
template <typename T>
void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& out_backprop, const Tensor& input, int row_stride,
- int col_stride, const Padding& padding, Tensor* filter_backprop,
- TensorFormat data_format) {
+ const Tensor& out_backprop, const Tensor& input, int row_dilation,
+ int col_dilation, int row_stride, int col_stride, const Padding& padding,
+ Tensor* filter_backprop, TensorFormat data_format) {
using perftools::gputools::dnn::AlgorithmConfig;
using perftools::gputools::dnn::AlgorithmDesc;
using perftools::gputools::dnn::ProfileResult;
+ std::vector<int32> dilations(4, 1);
+ dilations[GetTensorDimIndex(data_format, 'H')] = row_dilation;
+ dilations[GetTensorDimIndex(data_format, 'W')] = col_dilation;
+
std::vector<int32> strides(4, 1);
strides[GetTensorDimIndex(data_format, 'H')] = row_stride;
strides[GetTensorDimIndex(data_format, 'W')] = col_stride;
TensorShape filter_shape = filter_backprop->shape();
ConvBackpropDimensions dims;
- OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions(
+ OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensionsV2(
"Conv2DSlowBackpropFilter", /*num_spatial_dims=*/2,
input.shape(), filter_shape, out_backprop.shape(),
- strides, padding, data_format, &dims));
+ dilations, strides, padding, data_format, &dims));
+ // TODO(yangzihao): The padding computations should be done in
+ // GetWindowedOutputSize() functions.
const int padding_rows =
(padding == VALID)
? 0
: std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
dims.spatial_dims[0].stride +
- dims.spatial_dims[0].filter_size -
- dims.spatial_dims[0].input_size);
+ (dims.spatial_dims[0].filter_size - 1) *
+ dims.spatial_dims[0].dilation +
+ 1 - dims.spatial_dims[0].input_size);
const int padding_cols =
(padding == VALID)
? 0
: std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
dims.spatial_dims[1].stride +
- dims.spatial_dims[1].filter_size -
- dims.spatial_dims[1].input_size);
+ (dims.spatial_dims[1].filter_size - 1) *
+ dims.spatial_dims[1].dilation +
+ 1 - dims.spatial_dims[1].input_size);
// TODO(zhengxq): cuDNN only supports equal padding on both sides, so only
// calling it when that is true. Remove this check when (if?) cuDNN starts
@@ -730,7 +795,9 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
.set_input_feature_map_count(dims.in_depth)
.set_output_feature_map_count(dims.out_depth);
perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
- conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride)
+ conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
+ .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
+ .set_vertical_filter_stride(dims.spatial_dims[0].stride)
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
.set_zero_padding_height(padding_rows / 2)
.set_zero_padding_width(padding_cols / 2);
@@ -821,6 +888,8 @@ void LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T>::operator()(
dims.out_depth, // out_depths
{{dims.spatial_dims[0].filter_size, // filter_rows
dims.spatial_dims[1].filter_size}}, // filter_cols
+ {{dims.spatial_dims[0].dilation, // dilation_rows
+ dims.spatial_dims[1].dilation}}, // dilation_cols
{{dims.spatial_dims[0].stride, // stride_rows
dims.spatial_dims[1].stride}}, // stride_cols
{{padding_rows, // padding_rows
diff --git a/tensorflow/core/kernels/conv_grad_input_ops.cc b/tensorflow/core/kernels/conv_grad_input_ops.cc
index d28f6b4d10..736241a029 100644
--- a/tensorflow/core/kernels/conv_grad_input_ops.cc
+++ b/tensorflow/core/kernels/conv_grad_input_ops.cc
@@ -198,7 +198,23 @@ class Conv2DFastBackpropInputOp : public OpKernel {
context, (strides_[0] == 1 && strides_[3] == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, (dilations_[0] && dilations_[3]),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ // TODO(yangzihao): Add a CPU implementation for dilated convolution.
+ OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
+ errors::InvalidArgument(
+ "Current Eigen and libxsmm implementations do not "
+ "yet support dilation rates larger than 1."));
}
void Compute(OpKernelContext* context) override {
@@ -268,6 +284,7 @@ class Conv2DFastBackpropInputOp : public OpKernel {
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
@@ -296,7 +313,23 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
context, (strides_[0] == 1 && strides_[3] == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, strides_[1] > 0 && strides_[2] > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, (dilations_[0] == 1 && dilations_[3] == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ // TODO(yangzihao): Add a CPU implementation for dilated convolution.
+ OP_REQUIRES(context, (dilations_[1] == 1 && dilations_[2] == 1),
+ errors::InvalidArgument(
+ "Current libxsmm and customized CPU implementations do "
+ "not yet support dilation rates larger than 1."));
}
void Compute(OpKernelContext* context) override {
@@ -532,6 +565,7 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
@@ -586,10 +620,30 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
"specify 4 dimensions"));
int stride_n = GetTensorDim(strides_, data_format_, 'N');
int stride_c = GetTensorDim(strides_, data_format_, 'C');
+ int stride_h = GetTensorDim(strides_, data_format_, 'H');
+ int stride_w = GetTensorDim(strides_, data_format_, 'W');
OP_REQUIRES(
context, (stride_n == 1 && stride_c == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
+ int dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ int dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ int dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ int dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, (dilation_n == 1 && dilation_c == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
@@ -622,12 +676,16 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
// do not support striding on the batch or depth dimension).
const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+ const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
+ const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
launcher_(context, use_cudnn_, cudnn_use_autotune_, out_backprop, filter,
- stride_rows, stride_cols, padding_, in_backprop, data_format_);
+ dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
+ in_backprop, data_format_);
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
Padding padding_;
bool use_cudnn_;
@@ -641,39 +699,48 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
template <typename T>
void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& out_backprop, const Tensor& filter, int row_stride,
- int col_stride, const Padding& padding, Tensor* in_backprop,
- TensorFormat data_format) {
+ const Tensor& out_backprop, const Tensor& filter, int row_dilation,
+ int col_dilation, int row_stride, int col_stride, const Padding& padding,
+ Tensor* in_backprop, TensorFormat data_format) {
using perftools::gputools::dnn::AlgorithmConfig;
using perftools::gputools::dnn::AlgorithmDesc;
using perftools::gputools::dnn::ProfileResult;
std::vector<int32> strides(4, 1);
- strides[GetTensorDimIndex(data_format, 'H')] = row_stride;
- strides[GetTensorDimIndex(data_format, 'W')] = col_stride;
+ std::vector<int32> dilations(4, 1);
+ auto input_h = GetTensorDimIndex(data_format, 'H');
+ auto input_w = GetTensorDimIndex(data_format, 'W');
+ strides[input_h] = row_stride;
+ strides[input_w] = col_stride;
+ dilations[input_h] = row_dilation;
+ dilations[input_w] = col_dilation;
TensorShape input_shape = in_backprop->shape();
const TensorShape& filter_shape = filter.shape();
ConvBackpropDimensions dims;
- OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensions(
+ OP_REQUIRES_OK(ctx, ConvBackpropComputeDimensionsV2(
"Conv2DSlowBackpropInput", /*num_spatial_dims=*/2,
input_shape, filter_shape, out_backprop.shape(),
- strides, padding, data_format, &dims));
+ dilations, strides, padding, data_format, &dims));
+ // TODO(yangzihao): The padding computations should be done in
+ // GetWindowedOutputSize() functions.
const int padding_rows =
(padding == VALID)
? 0
: std::max<int>(0, (dims.spatial_dims[0].output_size - 1) *
dims.spatial_dims[0].stride +
- dims.spatial_dims[0].filter_size -
- dims.spatial_dims[0].input_size);
+ (dims.spatial_dims[0].filter_size - 1) *
+ dims.spatial_dims[0].dilation +
+ 1 - dims.spatial_dims[0].input_size);
const int padding_cols =
(padding == VALID)
? 0
: std::max<int>(0, (dims.spatial_dims[1].output_size - 1) *
dims.spatial_dims[1].stride +
- dims.spatial_dims[1].filter_size -
- dims.spatial_dims[1].input_size);
+ (dims.spatial_dims[1].filter_size - 1) *
+ dims.spatial_dims[1].dilation +
+ 1 - dims.spatial_dims[1].input_size);
// TODO(keveman): cuDNN only supports equal padding on both sides, so only
// calling it when that is true. Remove this check when (if?) cuDNN starts
@@ -789,7 +856,9 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
.set_input_feature_map_count(dims.in_depth)
.set_output_feature_map_count(dims.out_depth);
perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
- conv_desc.set_vertical_filter_stride(dims.spatial_dims[0].stride)
+ conv_desc.set_vertical_dilation_rate(dims.spatial_dims[0].dilation)
+ .set_horizontal_dilation_rate(dims.spatial_dims[1].dilation)
+ .set_vertical_filter_stride(dims.spatial_dims[0].stride)
.set_horizontal_filter_stride(dims.spatial_dims[1].stride)
.set_zero_padding_height(padding_rows / 2)
.set_zero_padding_width(padding_cols / 2);
@@ -875,6 +944,8 @@ void LaunchConv2DBackpropInputOp<GPUDevice, T>::operator()(
dims.out_depth, // out_depths
{{dims.spatial_dims[0].filter_size, // filter_rows
dims.spatial_dims[1].filter_size}}, // filter_cols
+ {{dims.spatial_dims[0].dilation, // dilation_rows
+ dims.spatial_dims[1].dilation}}, // dilation_cols
{{dims.spatial_dims[0].stride, // stride_rows
dims.spatial_dims[1].stride}}, // stride_cols
{{padding_rows, // padding_rows
diff --git a/tensorflow/core/kernels/conv_grad_ops.h b/tensorflow/core/kernels/conv_grad_ops.h
index e068fb8684..535586d53a 100644
--- a/tensorflow/core/kernels/conv_grad_ops.h
+++ b/tensorflow/core/kernels/conv_grad_ops.h
@@ -175,15 +175,17 @@ template <typename Device, typename T>
struct LaunchConv2DBackpropInputOp {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& filter,
- int row_stride, int col_stride, const Padding& padding,
- Tensor* in_backprop, TensorFormat data_format);
+ int row_dilation, int col_dilation, int row_stride,
+ int col_stride, const Padding& padding, Tensor* in_backprop,
+ TensorFormat data_format);
};
template <typename Device, typename T>
struct LaunchConv2DBackpropFilterOp {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& input,
- int row_stride, int col_stride, const Padding& padding,
+ int row_dilation, int col_dilation, int row_stride,
+ int col_stride, const Padding& padding,
Tensor* filter_backprop, TensorFormat data_format);
};
@@ -191,8 +193,9 @@ struct LaunchConv2DBackpropFilterOp {
template <typename T>
struct LaunchConv2DBackpropInputOp<Eigen::GpuDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter, int row_stride,
- int col_stride, const Padding& padding, Tensor* output,
+ const Tensor& input, const Tensor& filter, int row_dilation,
+ int col_dilation, int row_stride, int col_stride,
+ const Padding& padding, Tensor* output,
TensorFormat data_format);
};
@@ -200,7 +203,8 @@ template <typename T>
struct LaunchConv2DBackpropFilterOp<Eigen::GpuDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
const Tensor& out_backprop, const Tensor& input,
- int row_stride, int col_stride, const Padding& padding,
+ int row_dilation, int col_dilation, int row_stride,
+ int col_stride, const Padding& padding,
Tensor* filter_backprop, TensorFormat data_format);
};
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_grad_ops_3d.cc b/tensorflow/core/kernels/conv_grad_ops_3d.cc
index c2d24d1f12..4d0f1ab317 100644
--- a/tensorflow/core/kernels/conv_grad_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_grad_ops_3d.cc
@@ -645,6 +645,9 @@ class Conv3DBackpropInputOp<GPUDevice, T> : public OpKernel {
{{input_size[0], input_size[1], input_size[2]}},
out_depth,
{{filter_size[0], filter_size[1], filter_size[2]}},
+ // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
+ // conv is supported.
+ /*dilations=*/{{1, 1, 1}},
{{strides[0], strides[1], strides[2]}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
@@ -1011,6 +1014,7 @@ class Conv3DBackpropFilterOp<GPUDevice, T> : public OpKernel {
{{input_size[0], input_size[1], input_size[2]}},
out_depth,
{{filter_size[0], filter_size[1], filter_size[2]}},
+ {{1, 1, 1}},
{{strides[0], strides[1], strides[2]}},
{{padding_planes, padding_rows, padding_cols}},
dtype,
diff --git a/tensorflow/core/kernels/conv_ops.cc b/tensorflow/core/kernels/conv_ops.cc
index bb67113fb0..ba40c428e4 100644
--- a/tensorflow/core/kernels/conv_ops.cc
+++ b/tensorflow/core/kernels/conv_ops.cc
@@ -112,7 +112,8 @@ struct LaunchGeneric {
template <typename T>
struct LaunchConv2DOp<CPUDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter, int row_stride,
+ const Tensor& input, const Tensor& filter,
+ int /*row_dilation*/, int /*col_dilation*/, int row_stride,
int col_stride, const Padding& padding, Tensor* output,
TensorFormat data_format) {
if (data_format != FORMAT_NHWC) {
@@ -133,8 +134,10 @@ class LaunchDeepConvOp {
const Tensor& filter, int batch, int input_rows,
int input_cols, int in_depth, int filter_rows,
int filter_cols, int pad_rows, int pad_cols, int out_rows,
- int out_cols, int out_depth, int stride_rows, int stride_cols,
- Tensor* output, TensorFormat data_format) {
+ int /*out_cols*/, int /*out_depth*/, int /*dilation_rows*/,
+ int /*dilation_cols*/, int /*stride_rows*/,
+ int /*stride_cols*/, Tensor* /*output*/,
+ TensorFormat /*data_format*/) {
return false;
}
};
@@ -147,9 +150,11 @@ class LaunchDeepConvOp<CPUDevice, float> {
const Tensor& filter, int batch, int input_rows,
int input_cols, int in_depth, int filter_rows,
int filter_cols, int pad_rows, int pad_cols, int out_rows,
- int out_cols, int out_depth, int stride_rows, int stride_cols,
+ int out_cols, int out_depth, int dilation_rows,
+ int dilation_cols, int stride_rows, int stride_cols,
Tensor* output, TensorFormat data_format) {
- if (data_format != FORMAT_NHWC ||
+ if (data_format != FORMAT_NHWC || dilation_rows != 1 ||
+ dilation_cols != 1 ||
!CanUseDeepConv2D(stride_rows, stride_cols, filter_rows, filter_cols,
in_depth, out_depth, out_rows, out_cols)) {
return false;
@@ -187,7 +192,8 @@ class LaunchXsmmConvOp {
int input_cols, int in_depth, int filter_rows,
int filter_cols, int pad_rows, int pad_cols, int out_rows,
int out_cols, int out_depth, int stride_rows, int stride_cols,
- Tensor* output, TensorFormat data_format) {
+ int dilation_rows, int dilation_cols, Tensor* output,
+ TensorFormat data_format) {
return false;
}
};
@@ -199,7 +205,8 @@ class LaunchXsmmConvOp<CPUDevice, float> {
const Tensor& filter, int batch, int input_rows,
int input_cols, int in_depth, int filter_rows,
int filter_cols, int pad_rows, int pad_cols, int out_rows,
- int out_cols, int out_depth, int stride_rows, int stride_cols,
+ int out_cols, int out_depth, int dilation_rows,
+ int dilation_cols, int stride_rows, int stride_cols,
Tensor* output, TensorFormat data_format) {
auto num_threads =
ctx->device()->tensorflow_cpu_worker_threads()->num_threads;
@@ -228,11 +235,8 @@ class LaunchXsmmConvOp<CPUDevice, float> {
desc.options = LIBXSMM_DNN_CONV_OPTION_WU_EXT_FILTER_REDUCE_OVERWRITE;
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
- if (!CanUseXsmmConv2D(desc, data_format)) {
- return false;
- }
-
- if (!CanUseXsmmConv2D(desc, data_format)) {
+ if (dilation_rows != 1 || dilation_cols != 1 ||
+ !CanUseXsmmConv2D(desc, data_format)) {
return false;
}
@@ -251,6 +255,7 @@ template <typename Device, typename T>
class Conv2DOp : public BinaryOp<T> {
public:
explicit Conv2DOp(OpKernelConstruction* context) : BinaryOp<T>(context) {
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations_));
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
@@ -259,15 +264,35 @@ class Conv2DOp : public BinaryOp<T> {
OP_REQUIRES_OK(context, context->GetAttr("use_cudnn_on_gpu", &use_cudnn_));
use_cudnn_ &= CanUseCudnn();
cudnn_use_autotune_ = CudnnUseAutotune();
+ OP_REQUIRES(context, dilations_.size() == 4,
+ errors::InvalidArgument("Sliding window dilations field must "
+ "specify 4 dimensions"));
OP_REQUIRES(context, strides_.size() == 4,
errors::InvalidArgument("Sliding window strides field must "
"specify 4 dimensions"));
const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
+ const int64 stride_h = GetTensorDim(strides_, data_format_, 'H');
+ const int64 stride_w = GetTensorDim(strides_, data_format_, 'W');
OP_REQUIRES(
context, stride_n == 1 && stride_c == 1,
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ OP_REQUIRES(context, stride_h > 0 && stride_w > 0,
+ errors::InvalidArgument(
+ "Row and column strides should be larger than 0."));
+
+ const int64 dilation_n = GetTensorDim(dilations_, data_format_, 'N');
+ const int64 dilation_c = GetTensorDim(dilations_, data_format_, 'C');
+ const int64 dilation_h = GetTensorDim(dilations_, data_format_, 'H');
+ const int64 dilation_w = GetTensorDim(dilations_, data_format_, 'W');
+ OP_REQUIRES(context, dilation_n == 1 && dilation_c == 1,
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
+ OP_REQUIRES(
+ context, dilation_h > 0 && dilation_w > 0,
+ errors::InvalidArgument("Dilated rates should be larger than 0."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
@@ -334,18 +359,22 @@ class Conv2DOp : public BinaryOp<T> {
errors::InvalidArgument("batch is too large"));
const int batch = static_cast<int>(batch_raw);
- // For now we take the stride from the second and third dimensions only (we
- // do not support striding on the batch or depth dimension).
+ // For now we take the stride and dilation from the second and third
+ // dimensions only (we do not support striding or dilation on the batch or
+ // depth dimension).
const int stride_rows = GetTensorDim(strides_, data_format_, 'H');
const int stride_cols = GetTensorDim(strides_, data_format_, 'W');
+ const int dilation_rows = GetTensorDim(dilations_, data_format_, 'H');
+ const int dilation_cols = GetTensorDim(dilations_, data_format_, 'W');
+
int64 out_rows = 0, out_cols = 0, pad_rows = 0, pad_cols = 0;
- OP_REQUIRES_OK(context,
- GetWindowedOutputSize(input_rows, filter_rows, stride_rows,
- padding_, &out_rows, &pad_rows));
- OP_REQUIRES_OK(context,
- GetWindowedOutputSize(input_cols, filter_cols, stride_cols,
- padding_, &out_cols, &pad_cols));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
+ input_rows, filter_rows, dilation_rows,
+ stride_rows, padding_, &out_rows, &pad_rows));
+ OP_REQUIRES_OK(context, GetWindowedOutputSizeV2(
+ input_cols, filter_cols, dilation_cols,
+ stride_cols, padding_, &out_cols, &pad_cols));
TensorShape out_shape =
ShapeFromFormat(data_format_, batch, out_rows, out_cols, out_depth);
@@ -361,6 +390,8 @@ class Conv2DOp : public BinaryOp<T> {
<< ", filter_rows = " << filter_rows
<< ", stride_rows = " << stride_rows
<< ", stride_cols = " << stride_cols
+ << ", dilation_rows = " << dilation_rows
+ << ", dilation_cols = " << dilation_cols
<< ", out_depth = " << out_depth;
// If there is nothing to compute, return.
@@ -372,7 +403,8 @@ class Conv2DOp : public BinaryOp<T> {
if (LaunchXsmmConvOp<Device, T>::Run(
context, input, filter, batch, input_rows, input_cols, in_depth,
filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, stride_rows, stride_cols, output, data_format_)) {
+ out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
+ output, data_format_)) {
return;
}
#endif
@@ -380,15 +412,18 @@ class Conv2DOp : public BinaryOp<T> {
if (LaunchDeepConvOp<Device, T>::Run(
context, input, filter, batch, input_rows, input_cols, in_depth,
filter_rows, filter_cols, pad_rows, pad_cols, out_rows, out_cols,
- out_depth, stride_rows, stride_cols, output, data_format_)) {
+ out_depth, dilation_rows, dilation_cols, stride_rows, stride_cols,
+ output, data_format_)) {
return;
}
launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- stride_rows, stride_cols, padding_, output, data_format_);
+ dilation_rows, dilation_cols, stride_rows, stride_cols, padding_,
+ output, data_format_);
}
private:
+ std::vector<int32> dilations_;
std::vector<int32> strides_;
bool use_cudnn_;
Padding padding_;
@@ -443,9 +478,9 @@ typedef AutoTuneSingleton<ConvAutoTuneGroup, ConvParameters,
template <typename T>
void LaunchConv2DOp<GPUDevice, T>::operator()(
OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input_param, const Tensor& filter, int row_stride,
- int col_stride, const Padding& padding, Tensor* output,
- TensorFormat data_format) {
+ const Tensor& input_param, const Tensor& filter, int row_dilation,
+ int col_dilation, int row_stride, int col_stride, const Padding& padding,
+ Tensor* output, TensorFormat data_format) {
using perftools::gputools::dnn::AlgorithmConfig;
using perftools::gputools::dnn::AlgorithmDesc;
using perftools::gputools::dnn::ProfileResult;
@@ -461,8 +496,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
Tensor input = input_param;
- if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_stride == 1 &&
- col_stride == 1 && data_format == FORMAT_NHWC) {
+ if (filter.dim_size(0) == 1 && filter.dim_size(1) == 1 && row_dilation == 1 &&
+ col_dilation == 1 && row_stride == 1 && col_stride == 1 &&
+ data_format == FORMAT_NHWC) {
// 1x1 filter, so call cublas directly.
const uint64 m = input.dim_size(0) * input.dim_size(1) * input.dim_size(2);
const uint64 k = filter.dim_size(2);
@@ -487,7 +523,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
}
return;
} else if (filter.dim_size(0) == input.dim_size(1) &&
- filter.dim_size(1) == input.dim_size(2) && padding == VALID &&
+ filter.dim_size(1) == input.dim_size(2) && row_dilation == 1 &&
+ col_dilation == 1 && padding == VALID &&
data_format == FORMAT_NHWC) {
// The input data and filter have the same height/width, so call cublas
// directly.
@@ -530,17 +567,19 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
const int64 patch_cols = filter.dim_size(1);
if (padding == SAME) {
// Total padding on rows and cols is
- // Pr = (R' - 1) * S + Kr - R
- // Pc = (C' - 1) * S + Kc - C
+ // Pr = (R' - 1) * S + (Kr - 1) * Dr + 1 - R
+ // Pc = (C' - 1) * S + (Kc - 1) * Dc + 1 - C
// where (R', C') are output dimensions, (R, C) are input dimensions, S
- // is stride, (Kr, Kc) are filter dimensions.
+ // is stride, (Dr, Dc) are dilations, (Kr, Kc) are filter dimensions.
// We pad Pr/2 on the left and Pr - Pr/2 on the right, Pc/2 on the top
// and Pc - Pc/2 on the bottom. When Pr or Pc is odd, this means
// we pad more on the right and bottom than on the top and left.
padding_rows =
- std::max<int>(0, (out_rows - 1) * row_stride + patch_rows - in_rows);
+ std::max<int>(0, (out_rows - 1) * row_stride +
+ (patch_rows - 1) * row_dilation + 1 - in_rows);
padding_cols =
- std::max<int>(0, (out_cols - 1) * col_stride + patch_cols - in_cols);
+ std::max<int>(0, (out_cols - 1) * col_stride +
+ (patch_cols - 1) * col_dilation + 1 - in_cols);
const bool rows_odd = (padding_rows % 2 != 0);
const bool cols_odd = (padding_cols % 2 != 0);
if (rows_odd || cols_odd) {
@@ -605,7 +644,9 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
.set_input_feature_map_count(filter.dim_size(2))
.set_output_feature_map_count(filter.dim_size(3));
perftools::gputools::dnn::ConvolutionDescriptor conv_desc;
- conv_desc.set_vertical_filter_stride(row_stride)
+ conv_desc.set_vertical_dilation_rate(row_dilation)
+ .set_horizontal_dilation_rate(col_dilation)
+ .set_vertical_filter_stride(row_stride)
.set_horizontal_filter_stride(col_stride)
.set_zero_padding_height(padding_rows / 2)
.set_zero_padding_width(padding_cols / 2);
@@ -652,6 +693,8 @@ void LaunchConv2DOp<GPUDevice, T>::operator()(
out_depths, // out_depths
{{patch_rows, // filter_rows
patch_cols}}, // filter_cols
+ {{row_dilation, // dilation_rows
+ col_dilation}}, // dilation_cols
{{row_stride, // stride_rows
col_stride}}, // stride_cols
{{padding_rows, // padding_rows
diff --git a/tensorflow/core/kernels/conv_ops.h b/tensorflow/core/kernels/conv_ops.h
index e29271dff2..09a3b78776 100644
--- a/tensorflow/core/kernels/conv_ops.h
+++ b/tensorflow/core/kernels/conv_ops.h
@@ -34,8 +34,9 @@ class OpKernelContext;
template <typename Device, typename T>
struct LaunchConv2DOp {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter, int row_stride,
- int col_stride, const Padding& padding, Tensor* output,
+ const Tensor& input, const Tensor& filter, int row_dilation,
+ int col_dilation, int row_stride, int col_stride,
+ const Padding& padding, Tensor* output,
TensorFormat data_format);
};
@@ -43,8 +44,9 @@ struct LaunchConv2DOp {
template <typename T>
struct LaunchConv2DOp<Eigen::GpuDevice, T> {
void operator()(OpKernelContext* ctx, bool use_cudnn, bool cudnn_use_autotune,
- const Tensor& input, const Tensor& filter, int row_stride,
- int col_stride, const Padding& padding, Tensor* output,
+ const Tensor& input, const Tensor& filter, int row_dilation,
+ int col_dilation, int row_stride, int col_stride,
+ const Padding& padding, Tensor* output,
TensorFormat data_format);
};
#endif // GOOGLE_CUDA
diff --git a/tensorflow/core/kernels/conv_ops_3d.cc b/tensorflow/core/kernels/conv_ops_3d.cc
index 37cb67bc51..39202d7334 100644
--- a/tensorflow/core/kernels/conv_ops_3d.cc
+++ b/tensorflow/core/kernels/conv_ops_3d.cc
@@ -377,6 +377,9 @@ struct LaunchConvOp<GPUDevice, T> {
{{in_planes, in_rows, in_cols}},
out_depth,
{{filter_planes, filter_rows, filter_cols}},
+ // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
+ // conv is supported.
+ /*dilations=*/{{1, 1, 1}},
{{strides[0], strides[1], strides[2]}},
{{pad_planes, pad_rows, pad_cols}},
dtype,
diff --git a/tensorflow/core/kernels/conv_ops_gpu.h b/tensorflow/core/kernels/conv_ops_gpu.h
index c852dc9991..6f82698596 100644
--- a/tensorflow/core/kernels/conv_ops_gpu.h
+++ b/tensorflow/core/kernels/conv_ops_gpu.h
@@ -91,13 +91,14 @@ class ConvParameters {
using SpatialArray = gtl::InlinedVector<int64, 3>;
ConvParameters(int64 batch, int64 in_depths, const SpatialArray& in,
int64 out_depths, const SpatialArray& filter,
- const SpatialArray& stride, const SpatialArray& padding,
- DataType dtype, int device_id)
+ const SpatialArray& dilation, const SpatialArray& stride,
+ const SpatialArray& padding, DataType dtype, int device_id)
: batch_(batch),
in_depths_(in_depths),
out_depths_(out_depths),
in_(in),
filter_(filter),
+ dilation_(dilation),
stride_(stride),
padding_(padding),
dtype_(dtype),
@@ -107,6 +108,7 @@ class ConvParameters {
for (int64 val : in) hash_code_ = Hash64Combine(hash_code_, val);
hash_code_ = Hash64Combine(hash_code_, out_depths);
for (int64 val : filter) hash_code_ = Hash64Combine(hash_code_, val);
+ for (int64 val : dilation) hash_code_ = Hash64Combine(hash_code_, val);
for (int64 val : stride) hash_code_ = Hash64Combine(hash_code_, val);
for (int64 val : padding) hash_code_ = Hash64Combine(hash_code_, val);
hash_code_ = Hash64Combine(hash_code_, dtype);
@@ -128,6 +130,7 @@ class ConvParameters {
"(", str_util::Join(in_, ", "), "), ",
out_depths_, ", ",
"(", str_util::Join(filter_, ", "), "), ",
+ "(", str_util::Join(dilation_, ", "), "), ",
"(", str_util::Join(stride_, ", "), "), ",
"(", str_util::Join(padding_, ", "), "), ",
dtype_, ", ",
@@ -154,11 +157,11 @@ class ConvParameters {
protected:
using ParameterDataType =
std::tuple<int64, int64, SpatialArray, int64, SpatialArray, SpatialArray,
- SpatialArray, DataType, int>;
+ SpatialArray, SpatialArray, DataType, int>;
ParameterDataType get_data_as_tuple() const {
return std::make_tuple(batch_, in_depths_, in_, out_depths_, filter_,
- stride_, padding_, dtype_, device_id_);
+ dilation_, stride_, padding_, dtype_, device_id_);
}
uint64 hash_code_;
@@ -169,6 +172,7 @@ class ConvParameters {
int64 out_depths_;
SpatialArray in_;
SpatialArray filter_;
+ SpatialArray dilation_;
SpatialArray stride_;
SpatialArray padding_;
DataType dtype_;
diff --git a/tensorflow/core/kernels/conv_ops_test.cc b/tensorflow/core/kernels/conv_ops_test.cc
index ea54d6cf6c..666bca265c 100644
--- a/tensorflow/core/kernels/conv_ops_test.cc
+++ b/tensorflow/core/kernels/conv_ops_test.cc
@@ -43,6 +43,8 @@ TEST(ConvParameters, WinogradNonfusedAlgoSize) {
128, // out_depths
{{3, // filter_rows
3}}, // filter_cols
+ {{1, // dilation_rows
+ 1}}, // dilation_cols
{{1, // stride_rows
1}}, // stride_cols
{{0, // padding_rows
@@ -60,6 +62,8 @@ TEST(ConvParameters, WinogradNonfusedAlgoSize) {
768, // out_depths
{{3, // filter_rows
3}}, // filter_cols
+ {{1, // dilation_rows
+ 1}}, // dilation_cols
{{1, // stride_rows
1}}, // stride_cols
{{0, // padding_rows
diff --git a/tensorflow/core/kernels/cwise_op_asinh.cc b/tensorflow/core/kernels/cwise_op_asinh.cc
index e6e1b83b30..0aec6aac34 100644
--- a/tensorflow/core/kernels/cwise_op_asinh.cc
+++ b/tensorflow/core/kernels/cwise_op_asinh.cc
@@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
- http://www.apache.org/licenses/LICENSE-2.0
+ http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
diff --git a/tensorflow/core/kernels/dataset.cc b/tensorflow/core/kernels/dataset.cc
index fcfa2956f7..0972129787 100644
--- a/tensorflow/core/kernels/dataset.cc
+++ b/tensorflow/core/kernels/dataset.cc
@@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/core/kernels/dataset.h"
+#include "tensorflow/core/graph/graph_def_builder.h"
+#include "tensorflow/core/graph/node_builder.h"
+
namespace tensorflow {
namespace {
@@ -70,6 +73,143 @@ class DatasetVariantWrapper {
} // namespace
+Status GraphDefBuilderWrapper::AddDataset(
+ const GraphDatasetBase* dataset,
+ const std::vector<std::pair<size_t, Node*>>& inputs,
+ const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
+ const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
+ Node** output) {
+ const string& op_type_name = dataset->op_name();
+ std::unique_ptr<const GraphDefBuilder::Options> opts(
+ new GraphDefBuilder::Options(b_->opts()));
+ // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
+ // attributes defined. It will be nice to have a consistent pattern.
+ bool has_output_types_attr = HasAttr(op_type_name, "output_types");
+ bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes");
+ if (has_output_shapes_attr) {
+ opts.reset(new GraphDefBuilder::Options(
+ opts->WithAttr("output_shapes", dataset->output_shapes())));
+ }
+ if (has_output_types_attr) {
+ opts.reset(new GraphDefBuilder::Options(
+ opts->WithAttr("output_types", dataset->output_dtypes())));
+ }
+ for (auto attr : attrs) {
+ opts.reset(
+ new GraphDefBuilder::Options(opts->WithAttr(attr.first, attr.second)));
+ }
+ if (opts->HaveError()) {
+ return errors::Internal("AddDataset: Failed to build Options with error ",
+ opts->StatusToString());
+ }
+ NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name,
+ opts->op_registry());
+ {
+ size_t total_size = inputs.size() + list_inputs.size();
+ auto inputs_iter = inputs.begin();
+ auto list_inputs_iter = list_inputs.begin();
+ for (int i = 0; i < total_size; i++) {
+ if (inputs_iter != inputs.end() && inputs_iter->first == i) {
+ node_builder.Input(NodeBuilder::NodeOut(inputs_iter->second));
+ inputs_iter++;
+ } else if (list_inputs_iter != list_inputs.end() &&
+ list_inputs_iter->first == i) {
+ std::vector<NodeBuilder::NodeOut> nodeout_inputs;
+ nodeout_inputs.reserve(list_inputs_iter->second.size());
+ for (Node* n : list_inputs_iter->second) {
+ nodeout_inputs.emplace_back(n);
+ }
+ node_builder.Input(nodeout_inputs);
+ list_inputs_iter++;
+ } else {
+ return errors::InvalidArgument("No input found for index ", i);
+ }
+ }
+ }
+ *output = opts->FinalizeBuilder(&node_builder);
+ if (*output == nullptr) {
+ return errors::Internal("AddDataset: Failed to build ", op_type_name,
+ " op with error ", opts->StatusToString());
+ }
+ return Status::OK();
+}
+
+Status GraphDefBuilderWrapper::AddFunction(OpKernelContext* ctx,
+ const string& function_name) {
+ if (b_->HasFunction(function_name)) {
+ LOG(INFO) << "Function with name " << function_name << "already exists in"
+ << " the graph. It will not be added again.";
+ return Status::OK();
+ }
+ TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name));
+ const FunctionLibraryDefinition* flib_def =
+ ctx->function_library()->GetFunctionLibraryDefinition();
+ const FunctionDef* f_def = flib_def->Find(function_name);
+ if (f_def == nullptr) {
+ return errors::InvalidArgument("Unable to find FunctionDef for ",
+ function_name, " in the registry.");
+ }
+ FunctionDefLibrary def;
+ *def.add_function() = *f_def;
+ const string gradient_func = flib_def->FindGradient(function_name);
+ if (!gradient_func.empty()) {
+ GradientDef* g_def = def.add_gradient();
+ g_def->set_function_name(function_name);
+ g_def->set_gradient_func(gradient_func);
+ }
+ TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
+
+ // Recursively add functions in inputs of function_name.
+ for (const NodeDef& node_def : f_def->node_def()) {
+ const OpRegistrationData* op_reg_data = nullptr;
+ TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data));
+ if (op_reg_data->is_function_op) {
+ TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
+ }
+ // Recursively add functions in attrs of this NodeDef.
+ for (const auto& pair : node_def.attr()) {
+ TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, ctx));
+ }
+ }
+
+ // Recursively add functions in attrs of function_name.
+ for (auto iter = f_def->attr().begin(); iter != f_def->attr().end(); iter++) {
+ TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, ctx));
+ }
+ return Status::OK();
+}
+
+void GraphDefBuilderWrapper::AddTensorInternal(const Tensor& val,
+ Node** output) {
+ *output = ops::SourceOp(
+ "Const",
+ b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
+}
+
+bool GraphDefBuilderWrapper::HasAttr(const string& op_type_name,
+ const string& attr_name) const {
+ const OpDef* op_def = nullptr;
+ Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def);
+ if (!s.ok() || op_def == nullptr) {
+ return false;
+ }
+ return HasAttr(op_def, attr_name);
+}
+
+Status GraphDatasetBase::Serialize(OpKernelContext* ctx,
+ string* serialized_graph_def,
+ string* output_node) const {
+ GraphDefBuilder b;
+ DatasetGraphDefBuilder db(&b);
+ Node* node = nullptr;
+ TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
+ *output_node = node->name();
+ GraphDef graph_def;
+ TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
+ graph_def.SerializeToString(serialized_graph_def);
+ return Status::OK();
+}
+
Status GetDatasetFromVariantTensor(const Tensor& tensor,
DatasetBase** out_dataset) {
if (!(tensor.dtype() == DT_VARIANT ||
diff --git a/tensorflow/core/kernels/dataset.h b/tensorflow/core/kernels/dataset.h
index afbebb0692..504a88a309 100644
--- a/tensorflow/core/kernels/dataset.h
+++ b/tensorflow/core/kernels/dataset.h
@@ -19,12 +19,13 @@ limitations under the License.
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/attr_value_util.h"
+#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
+#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/variant_encode_decode.h"
#include "tensorflow/core/framework/variant_tensor_data.h"
-#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/tracing.h"
@@ -59,6 +60,12 @@ class IteratorStateWriter {
virtual ~IteratorStateWriter() {}
};
+// Forward declarations to avoid introducing a dependency on headers in
+// "tensorflow/core/graph/...".
+class GraphDefBuilder;
+class GraphDatasetBase;
+class Node;
+
// Wrapper around GraphDefBuilder. Used to serialize Dataset graph.
class GraphDefBuilderWrapper {
public:
@@ -110,10 +117,8 @@ class GraphDefBuilderWrapper {
return Status::OK();
}
- template <class DatasetType>
- Status AddDataset(const DatasetType* dataset,
- const std::vector<NodeBuilder::NodeOut>& inputs,
- Node** output) {
+ Status AddDataset(const GraphDatasetBase* dataset,
+ const std::vector<Node*>& inputs, Node** output) {
return AddDataset(dataset, inputs, {}, output);
}
@@ -125,77 +130,23 @@ class GraphDefBuilderWrapper {
// `*output` contains a pointer to the output `Node`. It is guaranteed to be
// non-null if the method returns with an OK status.
// The returned Node pointer is owned by the backing Graph of GraphDefBuilder.
- template <class DatasetType>
- Status AddDataset(const DatasetType* dataset,
- const std::vector<NodeBuilder::NodeOut>& inputs,
+ Status AddDataset(const GraphDatasetBase* dataset,
+ const std::vector<Node*>& inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
Node** output) {
- std::vector<std::pair<size_t, NodeBuilder::NodeOut>> enumerated_inputs(
- inputs.size());
+ std::vector<std::pair<size_t, Node*>> enumerated_inputs(inputs.size());
for (int i = 0; i < inputs.size(); i++) {
enumerated_inputs[i] = std::make_pair(i, inputs[i]);
}
return AddDataset(dataset, enumerated_inputs, {}, attrs, output);
}
- template <class DatasetType>
Status AddDataset(
- const DatasetType* dataset,
- const std::vector<std::pair<size_t, NodeBuilder::NodeOut>>& inputs,
- const std::vector<
- std::pair<size_t, gtl::ArraySlice<NodeBuilder::NodeOut>>>&
- list_inputs,
+ const GraphDatasetBase* dataset,
+ const std::vector<std::pair<size_t, Node*>>& inputs,
+ const std::vector<std::pair<size_t, gtl::ArraySlice<Node*>>>& list_inputs,
const std::vector<std::pair<StringPiece, AttrValue>>& attrs,
- Node** output) {
- const string& op_type_name = dataset->op_name();
- std::unique_ptr<const GraphDefBuilder::Options> opts(
- new GraphDefBuilder::Options(b_->opts()));
- // TODO(srbs|mrry): Not all datasets have output_types and output_shapes
- // attributes defined. It will be nice to have a consistent pattern.
- bool has_output_types_attr = HasAttr(op_type_name, "output_types");
- bool has_output_shapes_attr = HasAttr(op_type_name, "output_shapes");
- if (has_output_shapes_attr) {
- opts.reset(new GraphDefBuilder::Options(
- opts->WithAttr("output_shapes", dataset->output_shapes())));
- }
- if (has_output_types_attr) {
- opts.reset(new GraphDefBuilder::Options(
- opts->WithAttr("output_types", dataset->output_dtypes())));
- }
- for (auto attr : attrs) {
- opts.reset(new GraphDefBuilder::Options(
- opts->WithAttr(attr.first, attr.second)));
- }
- if (opts->HaveError()) {
- return errors::Internal("AddDataset: Failed to build Options with error ",
- opts->StatusToString());
- }
- NodeBuilder node_builder(opts->GetNameForOp(op_type_name), op_type_name,
- opts->op_registry());
- {
- size_t total_size = inputs.size() + list_inputs.size();
- auto inputs_iter = inputs.begin();
- auto list_inputs_iter = list_inputs.begin();
- for (int i = 0; i < total_size; i++) {
- if (inputs_iter != inputs.end() && inputs_iter->first == i) {
- node_builder.Input(inputs_iter->second);
- inputs_iter++;
- } else if (list_inputs_iter != list_inputs.end() &&
- list_inputs_iter->first == i) {
- node_builder.Input(list_inputs_iter->second);
- list_inputs_iter++;
- } else {
- return errors::InvalidArgument("No input found for index ", i);
- }
- }
- }
- *output = opts->FinalizeBuilder(&node_builder);
- if (*output == nullptr) {
- return errors::Internal("AddDataset: Failed to build ", op_type_name,
- " op with error ", opts->StatusToString());
- }
- return Status::OK();
- }
+ Node** output);
// Adds a user-defined function with name `function_name` to the graph and
// recursively adds all functions it references. If a function with a matching
@@ -203,50 +154,7 @@ class GraphDefBuilderWrapper {
// name `function_name` is not found in the FunctionLibraryDefinition, returns
// an InvalidArgumentError. If the function with name `function_name` or any
// of its dependent functions are stateful, returns an InvalidArgument error.
- Status AddFunction(OpKernelContext* ctx, const string& function_name) {
- if (b_->HasFunction(function_name)) {
- LOG(INFO) << "Function with name " << function_name << "already exists in"
- << " the graph. It will not be added again.";
- return Status::OK();
- }
- TF_RETURN_IF_ERROR(EnsureFunctionIsStateless(ctx, function_name));
- const FunctionLibraryDefinition* flib_def =
- ctx->function_library()->GetFunctionLibraryDefinition();
- const FunctionDef* f_def = flib_def->Find(function_name);
- if (f_def == nullptr) {
- return errors::InvalidArgument("Unable to find FunctionDef for ",
- function_name, " in the registry.");
- }
- FunctionDefLibrary def;
- *def.add_function() = *f_def;
- const string gradient_func = flib_def->FindGradient(function_name);
- if (!gradient_func.empty()) {
- GradientDef* g_def = def.add_gradient();
- g_def->set_function_name(function_name);
- g_def->set_gradient_func(gradient_func);
- }
- TF_RETURN_IF_ERROR(b_->AddFunctionLibrary(def));
-
- // Recursively add functions in inputs of function_name.
- for (const NodeDef& node_def : f_def->node_def()) {
- const OpRegistrationData* op_reg_data = nullptr;
- TF_RETURN_IF_ERROR(flib_def->LookUp(node_def.op(), &op_reg_data));
- if (op_reg_data->is_function_op) {
- TF_RETURN_IF_ERROR(AddFunction(ctx, op_reg_data->op_def.name()));
- }
- // Recursively add functions in attrs of this NodeDef.
- for (const auto& pair : node_def.attr()) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(pair.second, ctx));
- }
- }
-
- // Recursively add functions in attrs of function_name.
- for (auto iter = f_def->attr().begin(); iter != f_def->attr().end();
- iter++) {
- TF_RETURN_IF_ERROR(AddAttrFunctions(iter->second, ctx));
- }
- return Status::OK();
- }
+ Status AddFunction(OpKernelContext* ctx, const string& function_name);
template <typename T>
void BuildAttrValue(const T& value, AttrValue* attr) {
@@ -254,11 +162,7 @@ class GraphDefBuilderWrapper {
}
private:
- void AddTensorInternal(const Tensor& val, Node** output) {
- *output = ops::SourceOp(
- "Const",
- b_->opts().WithAttr("dtype", val.dtype()).WithAttr("value", val));
- }
+ void AddTensorInternal(const Tensor& val, Node** output);
Status EnsureFunctionIsStateless(OpKernelContext* ctx,
const string& function_name) const {
@@ -294,14 +198,7 @@ class GraphDefBuilderWrapper {
HasAttr(op_def, "output_shapes");
}
- bool HasAttr(const string& op_type_name, const string& attr_name) const {
- const OpDef* op_def = nullptr;
- Status s = b_->opts().op_registry()->LookUpOpDef(op_type_name, &op_def);
- if (!s.ok() || op_def == nullptr) {
- return false;
- }
- return HasAttr(op_def, attr_name);
- }
+ bool HasAttr(const string& op_type_name, const string& attr_name) const;
bool HasAttr(const OpDef* op_def, const string& attr_name) const {
for (auto attr : op_def->attr()) {
@@ -548,17 +445,7 @@ class GraphDatasetBase : public DatasetBase {
private:
Status Serialize(OpKernelContext* ctx, string* serialized_graph_def,
- string* output_node) const {
- GraphDefBuilder b;
- DatasetGraphDefBuilder db(&b);
- Node* node = nullptr;
- TF_RETURN_IF_ERROR(AsGraphDefInternal(ctx, &db, &node));
- *output_node = node->name();
- GraphDef graph_def;
- TF_RETURN_IF_ERROR(b.ToGraphDef(&graph_def));
- graph_def.SerializeToString(serialized_graph_def);
- return Status::OK();
- }
+ string* output_node) const;
const string op_name_;
};
diff --git a/tensorflow/core/kernels/dataset_utils.cc b/tensorflow/core/kernels/dataset_utils.cc
index cd58c80912..bd20e20cad 100644
--- a/tensorflow/core/kernels/dataset_utils.cc
+++ b/tensorflow/core/kernels/dataset_utils.cc
@@ -32,7 +32,7 @@ Status MakeIteratorFromInputElement(
// is always 0, so a negative random step ID should suffice.
opts.step_id = CapturedFunction::generate_step_id();
ScopedStepContainer step_container(
- opts.step_id, [captured_func, ctx](const string& name) {
+ opts.step_id, [captured_func](const string& name) {
captured_func->resource_manager()->Cleanup(name).IgnoreError();
});
opts.step_container = &step_container;
diff --git a/tensorflow/core/kernels/depthwise_conv_op.cc b/tensorflow/core/kernels/depthwise_conv_op.cc
index 2759ecb2f1..a5fd07fbe1 100644
--- a/tensorflow/core/kernels/depthwise_conv_op.cc
+++ b/tensorflow/core/kernels/depthwise_conv_op.cc
@@ -373,8 +373,11 @@ class DepthwiseConv2dNativeOp : public BinaryOp<T> {
// If in_depth==1, this operation is just a standard convolution, so
// invoke that op.
if (std::is_same<T, float>::value && in_depth == 1) {
+ // TODO(yangzihao): Send in arbitrary dilation rates after the dilated
+ // conv is supported.
launcher_(context, use_cudnn_, cudnn_use_autotune_, input, filter,
- stride_, stride_, padding_, output, data_format_);
+ /*row_dilation=*/1, /*col_dilation=*/1, stride_, stride_,
+ padding_, output, data_format_);
return;
}
diff --git a/tensorflow/core/kernels/filter_dataset_op.cc b/tensorflow/core/kernels/filter_dataset_op.cc
index e4d80e4ce3..67417d467d 100644
--- a/tensorflow/core/kernels/filter_dataset_op.cc
+++ b/tensorflow/core/kernels/filter_dataset_op.cc
@@ -95,7 +95,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
DataTypeVector other_arguments_types;
other_arguments_types.reserve(captured_func_->captured_inputs().size());
- std::vector<NodeBuilder::NodeOut> other_arguments;
+ std::vector<Node*> other_arguments;
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
@@ -149,7 +149,7 @@ class FilterDatasetOp : public UnaryDatasetOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = CapturedFunction::generate_step_id();
ScopedStepContainer step_container(
- opts.step_id, [this, ctx](const string& name) {
+ opts.step_id, [this](const string& name) {
dataset()
->captured_func_->resource_manager()
->Cleanup(name)
diff --git a/tensorflow/core/kernels/flat_map_dataset_op.cc b/tensorflow/core/kernels/flat_map_dataset_op.cc
index ac1689e5bf..8fe8489371 100644
--- a/tensorflow/core/kernels/flat_map_dataset_op.cc
+++ b/tensorflow/core/kernels/flat_map_dataset_op.cc
@@ -102,7 +102,7 @@ class FlatMapDatasetOp : public UnaryDatasetOpKernel {
DataTypeVector other_arguments_types;
other_arguments_types.reserve(captured_func_->captured_inputs().size());
- std::vector<NodeBuilder::NodeOut> other_arguments;
+ std::vector<Node*> other_arguments;
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
diff --git a/tensorflow/core/kernels/group_by_window_dataset_op.cc b/tensorflow/core/kernels/group_by_window_dataset_op.cc
index 8644bcf9b5..604555a560 100644
--- a/tensorflow/core/kernels/group_by_window_dataset_op.cc
+++ b/tensorflow/core/kernels/group_by_window_dataset_op.cc
@@ -169,7 +169,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
opts.step_id = CapturedFunction::generate_step_id();
opts.runner = ctx->runner();
ScopedStepContainer step_container(
- opts.step_id, [this, ctx](const string& name) {
+ opts.step_id, [this](const string& name) {
dataset()
->captured_key_func_->resource_manager()
->Cleanup(name)
@@ -198,7 +198,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
opts2.step_id = CapturedFunction::generate_step_id();
opts2.runner = ctx->runner();
ScopedStepContainer step_container2(
- opts2.step_id, [this, ctx](const string& name) {
+ opts2.step_id, [this](const string& name) {
dataset()
->captured_window_size_func_->resource_manager()
->Cleanup(name)
@@ -257,7 +257,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
opts.step_id = CapturedFunction::generate_step_id();
opts.runner = ctx->runner();
ScopedStepContainer step_container(
- opts.step_id, [this, ctx](const string& name) {
+ opts.step_id, [this](const string& name) {
dataset()
->captured_reduce_func_->resource_manager()
->Cleanup(name)
diff --git a/tensorflow/core/kernels/interleave_dataset_op.cc b/tensorflow/core/kernels/interleave_dataset_op.cc
index cbee68b2db..833e8cb9c5 100644
--- a/tensorflow/core/kernels/interleave_dataset_op.cc
+++ b/tensorflow/core/kernels/interleave_dataset_op.cc
@@ -126,7 +126,7 @@ class InterleaveDatasetOp : public UnaryDatasetOpKernel {
TF_RETURN_IF_ERROR(b->AddScalar(block_length_, &block_length_node));
DataTypeVector other_arguments_types;
other_arguments_types.reserve(captured_func_->captured_inputs().size());
- std::vector<NodeBuilder::NodeOut> other_arguments;
+ std::vector<Node*> other_arguments;
other_arguments.reserve(captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
diff --git a/tensorflow/core/kernels/map_and_batch_dataset_op.cc b/tensorflow/core/kernels/map_and_batch_dataset_op.cc
index ad1e356dbd..9bd66e681f 100644
--- a/tensorflow/core/kernels/map_and_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/map_and_batch_dataset_op.cc
@@ -239,8 +239,8 @@ class MapAndBatchDatasetOp : public UnaryDatasetOpKernel {
// to unblock a consumer.
FunctionLibraryRuntime::Options opts;
opts.step_id = CapturedFunction::generate_step_id();
- ScopedStepContainer* step_container = new ScopedStepContainer(
- opts.step_id, [this, ctx](const string& name) {
+ ScopedStepContainer* step_container =
+ new ScopedStepContainer(opts.step_id, [this](const string& name) {
dataset()
->captured_func_->resource_manager()
->Cleanup(name)
diff --git a/tensorflow/core/kernels/map_dataset_op.cc b/tensorflow/core/kernels/map_dataset_op.cc
index 4ba09bc335..29899a987e 100644
--- a/tensorflow/core/kernels/map_dataset_op.cc
+++ b/tensorflow/core/kernels/map_dataset_op.cc
@@ -100,7 +100,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
DataTypeVector other_arguments_types(
captured_func_->captured_inputs().size());
- std::vector<NodeBuilder::NodeOut> other_arguments(
+ std::vector<Node*> other_arguments(
captured_func_->captured_inputs().size());
for (const Tensor& t : captured_func_->captured_inputs()) {
Node* node;
@@ -146,7 +146,7 @@ class MapDatasetOp : public UnaryDatasetOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = CapturedFunction::generate_step_id();
ScopedStepContainer step_container(
- opts.step_id, [this, ctx](const string& name) {
+ opts.step_id, [this](const string& name) {
dataset()
->captured_func_->resource_manager()
->Cleanup(name)
diff --git a/tensorflow/core/kernels/mkl_batch_matmul_op.cc b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
index 138acdf298..9fee94f946 100644
--- a/tensorflow/core/kernels/mkl_batch_matmul_op.cc
+++ b/tensorflow/core/kernels/mkl_batch_matmul_op.cc
@@ -28,6 +28,7 @@ limitations under the License.
#if defined(INTEL_MKL)
#include <vector>
#include "mkl_cblas.h"
+#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
@@ -72,10 +73,10 @@ class BatchMatMulMkl : public OpKernel {
TensorShape out_shape;
for (int i = 0; i < ndims - 2; ++i) {
OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i),
- errors::InvalidArgument("lhs.dim(", i, ") and rhs.dim(", i,
- ") must be the same: ",
- lhs.shape().DebugString(), " vs ",
- rhs.shape().DebugString()));
+ errors::InvalidArgument(
+ "lhs.dim(", i, ") and rhs.dim(", i,
+ ") must be the same: ", lhs.shape().DebugString(), " vs ",
+ rhs.shape().DebugString()));
out_shape.AddDim(lhs.dim_size(i));
}
auto batch_size = (ndims == 2) ? 1 : out_shape.num_elements();
@@ -109,7 +110,7 @@ class BatchMatMulMkl : public OpKernel {
const uint64 M = lhs_reshaped.dimension(adj_x_ ? 2 : 1);
const uint64 K = lhs_reshaped.dimension(adj_x_ ? 1 : 2);
const uint64 N = rhs_reshaped.dimension(adj_y_ ? 1 : 2);
-
+
std::vector<MKL_INT> m_array(batch_size, M);
std::vector<MKL_INT> n_array(batch_size, N);
std::vector<MKL_INT> k_array(batch_size, K);
@@ -128,7 +129,7 @@ class BatchMatMulMkl : public OpKernel {
b_array.push_back(&rhs_reshaped(i, 0, 0));
c_array.push_back(&out_reshaped(i, 0, 0));
}
-
+
MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, &m_array[0], &n_array[0],
&k_array[0], &a_array[0], &lda_array[0], &b_array[0],
&ldb_array[0], &c_array[0], &ldc_array[0], 1,
diff --git a/tensorflow/core/kernels/multinomial_op.cc b/tensorflow/core/kernels/multinomial_op.cc
index 8c0109f5c8..d086abb247 100644
--- a/tensorflow/core/kernels/multinomial_op.cc
+++ b/tensorflow/core/kernels/multinomial_op.cc
@@ -40,7 +40,7 @@ typedef Eigen::GpuDevice GPUDevice;
namespace functor {
-template <typename Device, typename T>
+template <typename Device, typename T, typename OutputType>
struct MultinomialFunctor {
void operator()(OpKernelContext* ctx, const Device& d,
typename TTypes<T>::ConstMatrix logits,
@@ -49,11 +49,11 @@ struct MultinomialFunctor {
typename TTypes<float>::Flat scratch, int batch_size,
int num_classes, int num_samples,
const random::PhiloxRandom& gen,
- typename TTypes<int64>::Matrix output);
+ typename TTypes<OutputType>::Matrix output);
};
-template <typename T>
-struct MultinomialFunctor<CPUDevice, T> {
+template <typename T, typename OutputType>
+struct MultinomialFunctor<CPUDevice, T, OutputType> {
void operator()(OpKernelContext* ctx, const CPUDevice& d,
typename TTypes<T>::ConstMatrix logits,
typename TTypes<float>::Flat /* noises */,
@@ -61,7 +61,7 @@ struct MultinomialFunctor<CPUDevice, T> {
typename TTypes<float>::Flat /* scratch */, int batch_size,
int num_classes, int num_samples,
const random::PhiloxRandom& gen,
- typename TTypes<int64>::Matrix output) {
+ typename TTypes<OutputType>::Matrix output) {
auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
// The implementation only parallelizes by batch.
@@ -128,7 +128,7 @@ struct MultinomialFunctor<CPUDevice, T> {
} // namespace functor
// Samples from a multinomial distribution.
-template <typename Device, typename T>
+template <typename Device, typename T, typename OutputType>
class MultinomialOp : public OpKernel {
public:
explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) {
@@ -195,11 +195,11 @@ class MultinomialOp : public OpKernel {
if (std::is_same<Device, CPUDevice>::value) num_samples_ceil_4 *= 2;
auto rng =
generator_.ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256);
- functor::MultinomialFunctor<Device, T>()(
+ functor::MultinomialFunctor<Device, T, OutputType>()(
ctx, ctx->eigen_device<Device>(), logits_t.matrix<T>(),
noises.flat<float>(), scores.flat<float>(), scratch.flat<float>(),
batch_size, num_classes, num_samples, rng,
- samples_t->matrix<int64>());
+ samples_t->matrix<OutputType>());
}
}
@@ -209,10 +209,17 @@ class MultinomialOp : public OpKernel {
TF_DISALLOW_COPY_AND_ASSIGN(MultinomialOp);
};
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER( \
- Name("Multinomial").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
- MultinomialOp<CPUDevice, TYPE>);
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("Multinomial") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT32), \
+ MultinomialOp<CPUDevice, TYPE, int32>); \
+ REGISTER_KERNEL_BUILDER(Name("Multinomial") \
+ .Device(DEVICE_CPU) \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT64), \
+ MultinomialOp<CPUDevice, TYPE, int64>);
TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
@@ -220,12 +227,20 @@ TF_CALL_double(REGISTER);
#undef REGISTER
#if GOOGLE_CUDA
-#define REGISTER(TYPE) \
- REGISTER_KERNEL_BUILDER(Name("Multinomial") \
- .Device(DEVICE_GPU) \
- .HostMemory("num_samples") \
- .TypeConstraint<TYPE>("T"), \
- MultinomialOp<GPUDevice, TYPE>)
+#define REGISTER(TYPE) \
+ REGISTER_KERNEL_BUILDER(Name("Multinomial") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("num_samples") \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT32), \
+ MultinomialOp<GPUDevice, TYPE, int32>) \
+ REGISTER_KERNEL_BUILDER(Name("Multinomial") \
+ .Device(DEVICE_GPU) \
+ .HostMemory("num_samples") \
+ .TypeConstraint<TYPE>("T") \
+ .TypeConstraint("output_dtype", DT_INT64), \
+ MultinomialOp<GPUDevice, TYPE, int64>)
+
TF_CALL_half(REGISTER);
TF_CALL_float(REGISTER);
TF_CALL_double(REGISTER);
diff --git a/tensorflow/core/kernels/multinomial_op.h b/tensorflow/core/kernels/multinomial_op.h
index af5e81f219..6e41060aa4 100644
--- a/tensorflow/core/kernels/multinomial_op.h
+++ b/tensorflow/core/kernels/multinomial_op.h
@@ -21,7 +21,7 @@ namespace tensorflow {
namespace functor {
// Generic helper functor for the Multinomial Op.
-template <typename Device, typename T>
+template <typename Device, typename T, typename OutputType>
struct MultinomialFunctor;
} // namespace functor
diff --git a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc
index 19b4f3ca55..5cc5877cce 100644
--- a/tensorflow/core/kernels/multinomial_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/multinomial_op_gpu.cu.cc
@@ -37,20 +37,22 @@ using GPUDevice = Eigen::GpuDevice;
// Kernel for Multinomial op. Data is interpreted to have the following shapes:
// scores: [B, S, C]; maxima: [B, S]; output: [B, S].
+template <typename OutputType>
__global__ void MultinomialKernel(int32 nthreads, const int32 num_classes,
const int32 num_samples, const float* scores,
- const float* maxima, int64* output) {
+ const float* maxima, OutputType* output) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
const int maxima_idx = index / num_classes;
if (ldg(maxima + maxima_idx) == ldg(scores + index)) {
- CudaAtomicMax(reinterpret_cast<uint64*>(output + maxima_idx),
- static_cast<uint64>(index % num_classes));
+ using UnsignedOutputType = typename std::make_unsigned<OutputType>::type;
+ CudaAtomicMax(reinterpret_cast<UnsignedOutputType*>(output + maxima_idx),
+ static_cast<UnsignedOutputType>(index % num_classes));
}
}
}
-template <typename T>
-struct MultinomialFunctor<GPUDevice, T> {
+template <typename T, typename OutputType>
+struct MultinomialFunctor<GPUDevice, T, OutputType> {
void operator()(OpKernelContext* ctx, const GPUDevice& d,
typename TTypes<T>::ConstMatrix logits,
typename TTypes<float>::Flat noises,
@@ -58,7 +60,7 @@ struct MultinomialFunctor<GPUDevice, T> {
typename TTypes<float>::Flat maxima, int batch_size,
int num_classes, int num_samples,
const random::PhiloxRandom& gen,
- typename TTypes<int64>::Matrix output) {
+ typename TTypes<OutputType>::Matrix output) {
// Uniform, [0, 1).
typedef random::UniformDistribution<random::PhiloxRandom, float> Dist;
functor::FillPhiloxRandom<GPUDevice, Dist>()(ctx, d, gen, noises.data(),
@@ -111,11 +113,17 @@ struct MultinomialFunctor<GPUDevice, T> {
};
// Explicit instantiation of the GPU functors.
-template struct MultinomialFunctor<GPUDevice, Eigen::half>;
-template struct MultinomialFunctor<GPUDevice, float>;
-template struct MultinomialFunctor<GPUDevice, double>;
-template struct MultinomialFunctor<GPUDevice, int32>;
-template struct MultinomialFunctor<GPUDevice, int64>;
+template struct MultinomialFunctor<GPUDevice, Eigen::half, int32>;
+template struct MultinomialFunctor<GPUDevice, float, int32>;
+template struct MultinomialFunctor<GPUDevice, double, int32>;
+template struct MultinomialFunctor<GPUDevice, int32, int32>;
+template struct MultinomialFunctor<GPUDevice, int64, int32>;
+
+template struct MultinomialFunctor<GPUDevice, Eigen::half, int64>;
+template struct MultinomialFunctor<GPUDevice, float, int64>;
+template struct MultinomialFunctor<GPUDevice, double, int64>;
+template struct MultinomialFunctor<GPUDevice, int32, int64>;
+template struct MultinomialFunctor<GPUDevice, int64, int64>;
} // namespace functor
} // namespace tensorflow
diff --git a/tensorflow/core/kernels/nn_ops_test.cc b/tensorflow/core/kernels/nn_ops_test.cc
index 0db7c63b8b..a841291ddd 100644
--- a/tensorflow/core/kernels/nn_ops_test.cc
+++ b/tensorflow/core/kernels/nn_ops_test.cc
@@ -653,6 +653,8 @@ BM_ConvFloatDepthwiseFwd(32, 7, 7, 1024, 1, 1024, 3, 3, 1, SAME, conv6);
// Benchmarks with different stride and padding options.
BM_ConvFloatDepthwiseFwd(32, 112, 112, 3, 8, 24, 3, 3, 2, SAME, conv7);
BM_ConvFloatDepthwiseFwd(32, 112, 112, 3, 8, 24, 3, 3, 2, VALID, conv8);
+BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 3, 3, 1, SAME, conv9);
+BM_ConvFloatDepthwiseFwd(1, 100, 100, 72, 1, 72, 5, 5, 1, SAME, conv10);
#define BM_ConvFloatDepthwiseBk(BS, R, C, ID, DM, OD, KR, KC, STR, PAD, LABEL) \
static void BM_ConvFloatDepthwiseBkInCPU1_##LABEL(int iters) { \
diff --git a/tensorflow/core/kernels/padded_batch_dataset_op.cc b/tensorflow/core/kernels/padded_batch_dataset_op.cc
index 7c28d955e1..cef5bde156 100644
--- a/tensorflow/core/kernels/padded_batch_dataset_op.cc
+++ b/tensorflow/core/kernels/padded_batch_dataset_op.cc
@@ -242,7 +242,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
Node* batch_size = nullptr;
TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size));
- std::vector<NodeBuilder::NodeOut> padded_shapes;
+ std::vector<Node*> padded_shapes;
padded_shapes.reserve(padded_shapes_.size());
for (int i = 0; i < padded_shapes_.size(); i++) {
Node* node;
@@ -254,7 +254,7 @@ class PaddedBatchDatasetOp : public UnaryDatasetOpKernel {
padded_shapes.emplace_back(node);
}
- std::vector<NodeBuilder::NodeOut> padding_values;
+ std::vector<Node*> padding_values;
padding_values.reserve(padding_values_.size());
for (const Tensor& t : padding_values_) {
Node* node;
diff --git a/tensorflow/core/kernels/parallel_map_dataset_op.cc b/tensorflow/core/kernels/parallel_map_dataset_op.cc
index 2be87f4bde..b9175fe904 100644
--- a/tensorflow/core/kernels/parallel_map_dataset_op.cc
+++ b/tensorflow/core/kernels/parallel_map_dataset_op.cc
@@ -195,8 +195,8 @@ class ParallelMapDatasetOp : public UnaryDatasetOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = CapturedFunction::generate_step_id();
- ScopedStepContainer* step_container = new ScopedStepContainer(
- opts.step_id, [this, ctx](const string& name) {
+ ScopedStepContainer* step_container =
+ new ScopedStepContainer(opts.step_id, [this](const string& name) {
dataset()
->captured_func_->resource_manager()
->Cleanup(name)
diff --git a/tensorflow/core/kernels/quantized_conv_ops.cc b/tensorflow/core/kernels/quantized_conv_ops.cc
index 3b0764bb9b..f83998e0c1 100644
--- a/tensorflow/core/kernels/quantized_conv_ops.cc
+++ b/tensorflow/core/kernels/quantized_conv_ops.cc
@@ -457,6 +457,19 @@ class QuantizedConv2DOp : public OpKernel {
context, (strides_[0] == 1 && strides_[3] == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
+ std::vector<int32> dilations;
+ OP_REQUIRES_OK(context, context->GetAttr("dilations", &dilations));
+ OP_REQUIRES(context, dilations.size() == 4,
+ errors::InvalidArgument("Dilations field must "
+ "specify 4 dimensions"));
+ OP_REQUIRES(context, dilations[1] == 1 && dilations[2] == 1,
+ errors::InvalidArgument(
+ "Current implementation only supports dilated rate as 1 "
+ "in the row and column dimensions."));
+ OP_REQUIRES(context, (dilations[0] == 1 && dilations[3] == 1),
+ errors::InvalidArgument(
+ "Current implementation does not yet support "
+ "dilations in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
diff --git a/tensorflow/core/kernels/random_dataset_op.cc b/tensorflow/core/kernels/random_dataset_op.cc
new file mode 100644
index 0000000000..03d481a593
--- /dev/null
+++ b/tensorflow/core/kernels/random_dataset_op.cc
@@ -0,0 +1,154 @@
+/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+#include "tensorflow/core/framework/op_kernel.h"
+#include "tensorflow/core/framework/partial_tensor_shape.h"
+#include "tensorflow/core/framework/tensor.h"
+#include "tensorflow/core/kernels/dataset.h"
+#include "tensorflow/core/lib/random/philox_random.h"
+#include "tensorflow/core/lib/random/random.h"
+#include "tensorflow/core/lib/random/random_distributions.h"
+
+namespace tensorflow {
+
+namespace {
+
+// See documentation in ../ops/dataset_ops.cc for a high-level
+// description of the following op.
+
+class RandomDatasetOp : public DatasetOpKernel {
+ public:
+ explicit RandomDatasetOp(OpKernelConstruction* ctx) : DatasetOpKernel(ctx) {}
+
+ void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override {
+ int64 seed;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed", &seed));
+
+ int64 seed2;
+ OP_REQUIRES_OK(ctx, ParseScalarArgument<int64>(ctx, "seed2", &seed2));
+
+ // By TensorFlow convention, passing 0 for both seeds indicates
+ // that the shuffling should be seeded non-deterministically.
+ if (seed == 0 && seed2 == 0) {
+ seed = random::New64();
+ seed2 = random::New64();
+ }
+
+ *output = new Dataset(ctx, seed, seed2);
+ }
+
+ private:
+ class Dataset : public GraphDatasetBase {
+ public:
+ Dataset(OpKernelContext* ctx, int64 seed, int64 seed2)
+ : GraphDatasetBase(ctx), seed_(seed), seed2_(seed2) {}
+
+ std::unique_ptr<IteratorBase> MakeIterator(
+ const string& prefix) const override {
+ return std::unique_ptr<IteratorBase>(
+ new Iterator({this, strings::StrCat(prefix, "::Random")}));
+ }
+
+ const DataTypeVector& output_dtypes() const override {
+ static DataTypeVector* dtypes = new DataTypeVector({DT_INT64});
+ return *dtypes;
+ }
+
+ const std::vector<PartialTensorShape>& output_shapes() const override {
+ static std::vector<PartialTensorShape>* shapes =
+ new std::vector<PartialTensorShape>({{}});
+ return *shapes;
+ }
+
+ string DebugString() override {
+ return strings::StrCat("RandomDatasetOp(", seed_, ", ", seed2_,
+ ")::Dataset");
+ }
+
+ protected:
+ Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* seed = nullptr;
+ Node* seed2 = nullptr;
+ TF_RETURN_IF_ERROR(b->AddScalar(seed_, &seed));
+ TF_RETURN_IF_ERROR(b->AddScalar(seed2_, &seed2));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {seed, seed2}, output));
+ return Status::OK();
+ }
+
+ private:
+ class Iterator : public DatasetIterator<Dataset> {
+ public:
+ explicit Iterator(const Params& params)
+ : DatasetIterator<Dataset>(params),
+ parent_generator_(dataset()->seed_, dataset()->seed2_),
+ generator_(&parent_generator_) {}
+
+ Status GetNextInternal(IteratorContext* ctx,
+ std::vector<Tensor>* out_tensors,
+ bool* end_of_sequence) override {
+ mutex_lock l(mu_);
+ Tensor value_tensor(cpu_allocator(), DT_INT64, {});
+ value_tensor.scalar<int64>()() = Random();
+ out_tensors->emplace_back(std::move(value_tensor));
+ *end_of_sequence = false;
+ return Status::OK();
+ }
+
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"),
+ num_random_samples_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(OpKernelContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(reader->ReadScalar(full_name("num_random_samples"),
+ &num_random_samples_));
+ parent_generator_ =
+ random::PhiloxRandom(dataset()->seed_, dataset()->seed2_);
+ generator_ = random::SingleSampleAdapter<random::PhiloxRandom>(
+ &parent_generator_);
+ generator_.Skip(num_random_samples_);
+ return Status::OK();
+ }
+
+ private:
+ random::SingleSampleAdapter<random::PhiloxRandom>::ResultType Random()
+ EXCLUSIVE_LOCKS_REQUIRED(mu_) {
+ num_random_samples_++;
+ auto out = generator_();
+ return out;
+ }
+ mutex mu_;
+ random::PhiloxRandom parent_generator_ GUARDED_BY(mu_);
+ random::SingleSampleAdapter<random::PhiloxRandom> generator_
+ GUARDED_BY(mu_);
+ int64 num_random_samples_ GUARDED_BY(mu_) = 0;
+ };
+
+ const int64 seed_;
+ const int64 seed2_;
+ };
+};
+
+REGISTER_KERNEL_BUILDER(Name("RandomDataset").Device(DEVICE_CPU),
+ RandomDatasetOp);
+
+} // namespace
+
+} // namespace tensorflow
diff --git a/tensorflow/core/kernels/reduction_ops_min.cc b/tensorflow/core/kernels/reduction_ops_min.cc
index 807ac0a456..5c537c5b9c 100644
--- a/tensorflow/core/kernels/reduction_ops_min.cc
+++ b/tensorflow/core/kernels/reduction_ops_min.cc
@@ -50,6 +50,7 @@ TF_CALL_REAL_NUMBER_TYPES(REGISTER_CPU_KERNELS);
.TypeConstraint<int64>("Tidx") \
.HostMemory("reduction_indices"), \
ReductionOp<GPUDevice, type, int64, Eigen::internal::MinReducer<type>>);
+REGISTER_GPU_KERNELS(Eigen::half);
REGISTER_GPU_KERNELS(float);
REGISTER_GPU_KERNELS(double);
diff --git a/tensorflow/core/kernels/reduction_ops_test.cc b/tensorflow/core/kernels/reduction_ops_test.cc
index 9bbe993a2f..fe8ea59f1b 100644
--- a/tensorflow/core/kernels/reduction_ops_test.cc
+++ b/tensorflow/core/kernels/reduction_ops_test.cc
@@ -174,6 +174,11 @@ static void BM_Min2DToScalarGPU(int iters, int num_x, int num_y) {
}
BENCHMARK(BM_Min2DToScalarGPU)->RangePair(2048, 8192, 2048, 8192);
+static void BM_Min2DToScalarGPUHalf(int iters, int num_x, int num_y) {
+ ReduceToScalar<Eigen::half>(iters, "gpu", "Min", num_x, num_y);
+}
+BENCHMARK(BM_Min2DToScalarGPUHalf)->RangePair(2048, 8192, 2048, 8192);
+
static void BM_Bool2DToScalarGPU(int iters, int num_x, int num_y) {
ReduceToScalar<bool>(iters, "gpu", "All", num_x, num_y);
}
diff --git a/tensorflow/core/kernels/scan_dataset_op.cc b/tensorflow/core/kernels/scan_dataset_op.cc
index 76c219f1ae..bc52322022 100644
--- a/tensorflow/core/kernels/scan_dataset_op.cc
+++ b/tensorflow/core/kernels/scan_dataset_op.cc
@@ -132,7 +132,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
FunctionLibraryRuntime::Options opts;
opts.step_id = CapturedFunction::generate_step_id();
ScopedStepContainer step_container(
- opts.step_id, [this, ctx](const string& name) {
+ opts.step_id, [this](const string& name) {
dataset()
->captured_func_->resource_manager()
->Cleanup(name)
diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc
index 484932ab01..98c0181afb 100644
--- a/tensorflow/core/kernels/scatter_nd_op.cc
+++ b/tensorflow/core/kernels/scatter_nd_op.cc
@@ -21,6 +21,7 @@ limitations under the License.
#endif // GOOGLE_CUDA
#include "tensorflow/core/kernels/scatter_nd_op.h"
+
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
@@ -28,6 +29,8 @@ limitations under the License.
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/dense_update_functor.h"
#include "tensorflow/core/kernels/fill_functor.h"
+#include "tensorflow/core/kernels/training_op_helpers.h"
+#include "tensorflow/core/kernels/variable_ops.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
@@ -83,7 +86,10 @@ class ScatterNdUpdateOp : public OpKernel {
const DataType dt = DataTypeToEnum<T>::v();
const DataType dt_ref = DataTypeToEnum<T>::ref();
const DataType index_t = DataTypeToEnum<Index>::v();
- if (IsRefType(c->input_type(0))) {
+ dtype_ = c->input_type(0);
+ if (c->input_type(0) == DT_RESOURCE) {
+ // TODO(apassos): what to validate here?
+ } else if (IsRefType(c->input_type(0))) {
OP_REQUIRES_OK(c, c->MatchSignature({dt_ref, index_t, dt}, {dt_ref}));
OP_REQUIRES_OK(c, c->GetAttr("use_locking", &use_exclusive_lock_));
} else {
@@ -93,7 +99,16 @@ class ScatterNdUpdateOp : public OpKernel {
}
void Compute(OpKernelContext* c) override {
- if (use_exclusive_lock_) {
+ if (dtype_ == DT_RESOURCE) {
+ if (use_exclusive_lock_) {
+ Var* v;
+ OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
+ mutex_lock m(*v->mu());
+ DoCompute(c);
+ } else {
+ DoCompute(c);
+ }
+ } else if (use_exclusive_lock_) {
// If we're here, it means the input type is a ref.
DCHECK(IsRefType(c->input_dtype(0)));
// Hold mutex while we apply updates
@@ -105,6 +120,7 @@ class ScatterNdUpdateOp : public OpKernel {
}
private:
+ DataType dtype_;
bool use_exclusive_lock_;
void DoCompute(OpKernelContext* c) {
@@ -113,7 +129,20 @@ class ScatterNdUpdateOp : public OpKernel {
Tensor params;
TensorShape params_shape;
- if (IsRefType(c->input_dtype(0))) {
+ if (dtype_ == DT_RESOURCE) {
+ Var* v;
+ OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v));
+ Tensor* t = v->tensor();
+ if (!use_exclusive_lock_) {
+ // We're not holding the lock in the outer scope so need it here.
+ mutex_lock m(*v->mu());
+ OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
+ } else {
+ OP_REQUIRES_OK(c, PrepareToUpdateVariable<Device, T>(c, t));
+ }
+ params = *t;
+ params_shape = params.shape();
+ } else if (IsRefType(c->input_dtype(0))) {
params = c->mutable_input(0, use_exclusive_lock_);
params_shape = params.shape();
c->forward_ref_input_to_ref_output(0, 0);
@@ -159,6 +188,16 @@ class ScatterNdUpdateOp : public OpKernel {
.TypeConstraint<index_type>("Tindices"), \
ScatterNdUpdateOp<dev##Device, type, index_type, op>)
+#define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, index_type, \
+ dev, name, op) \
+ REGISTER_KERNEL_BUILDER( \
+ Name(name) \
+ .Device(DEVICE_##dev) \
+ .TypeConstraint<type>("T") \
+ .TypeConstraint<index_type>("Tindices") \
+ .HostMemory("ref"), \
+ ScatterNdUpdateOp<dev##Device, type, index_type, op>)
+
#define REGISTER_SCATTER_ND_KERNEL(type, dev, name) \
REGISTER_SCATTER_ND_KERNEL_INDEX(type, int32, dev, name); \
REGISTER_SCATTER_ND_KERNEL_INDEX(type, int64, dev, name)
@@ -167,6 +206,11 @@ class ScatterNdUpdateOp : public OpKernel {
REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, op); \
REGISTER_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
+#define REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL(type, dev, name, op) \
+ REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int32, dev, name, \
+ op); \
+ REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL_INDEX(type, int64, dev, name, op)
+
#define REGISTER_SCATTER_ND_ADD_SUB(type, dev) \
REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdAdd", \
scatter_nd_op::UpdateOp::ADD); \
@@ -178,9 +222,11 @@ class ScatterNdUpdateOp : public OpKernel {
#define REGISTER_SCATTER_ND(type, dev) \
REGISTER_SCATTER_ND_KERNEL(type, dev, "ScatterNd");
-#define REGISTER_SCATTER_ND_UPDATE(type, dev) \
- REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \
- scatter_nd_op::UpdateOp::ASSIGN);
+#define REGISTER_SCATTER_ND_UPDATE(type, dev) \
+ REGISTER_SCATTER_ND_UPDATE_KERNEL(type, dev, "ScatterNdUpdate", \
+ scatter_nd_op::UpdateOp::ASSIGN); \
+ REGISTER_RESOURCE_SCATTER_ND_UPDATE_KERNEL( \
+ type, dev, "ResourceScatterNdUpdate", scatter_nd_op::UpdateOp::ASSIGN);
// Registers CPU kernels.
#define REGISTER_SCATTER_ND_ADD_SUB_CPU(type) \
@@ -281,8 +327,7 @@ Status ValidateUpdateShape(const TensorShape& params_shape,
}
template <typename Index>
-Status PrepareAndValidateInputs(OpKernelContext* c,
- const TensorShape& params_shape,
+Status PrepareAndValidateInputs(const TensorShape& params_shape,
const Tensor& indices, const Tensor& updates,
int64* slice_dim, Index* num_updates,
Index* slice_size) {
@@ -396,7 +441,7 @@ Status DoScatterNd(OpKernelContext* c, const Tensor& indices,
Index num_updates;
Index slice_size;
TF_RETURN_IF_ERROR(PrepareAndValidateInputs<Index>(
- c, shape, indices, updates, &slice_dim, &num_updates, &slice_size));
+ shape, indices, updates, &slice_dim, &num_updates, &slice_size));
IndexFlattener<Device, Index> index_flattener;
auto indices_flat = index_flattener(c, indices);
diff --git a/tensorflow/core/kernels/serialize_sparse_op.cc b/tensorflow/core/kernels/serialize_sparse_op.cc
index cfb86904d5..f4159da229 100644
--- a/tensorflow/core/kernels/serialize_sparse_op.cc
+++ b/tensorflow/core/kernels/serialize_sparse_op.cc
@@ -409,186 +409,11 @@ class DeserializeSparseOp : public OpKernel {
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
-template <typename T>
-class DeserializeManySparseOp : public OpKernel {
- public:
- explicit DeserializeManySparseOp(OpKernelConstruction* context)
- : OpKernel(context) {}
-
- void Compute(OpKernelContext* context) override {
- const Tensor& serialized_sparse = context->input(0);
- OP_REQUIRES(context, TensorShapeUtils::IsMatrix(serialized_sparse.shape()),
- errors::InvalidArgument(
- "Serialized sparse should be a matrix but received shape ",
- serialized_sparse.shape().DebugString()));
- OP_REQUIRES(
- context, serialized_sparse.shape().dim_size(1) == 3,
- errors::InvalidArgument(
- "Serialized sparse should have 3 columns but received shape ",
- serialized_sparse.shape().DebugString()));
-
- int num_sparse_tensors = serialized_sparse.shape().dim_size(0);
-
- OP_REQUIRES(
- context, num_sparse_tensors > 0,
- errors::InvalidArgument("Must have at least 1 serialized SparseTensor, "
- "but input matrix has 0 rows"));
-
- std::vector<Tensor> indices_to_concat;
- std::vector<Tensor> values_to_concat;
- std::vector<TensorShape> shapes_to_concat;
-
- const auto& serialized_sparse_t = serialized_sparse.matrix<string>();
-
- for (int i = 0; i < num_sparse_tensors; ++i) {
- Tensor output_indices(DT_INT64);
- Tensor output_values(DataTypeToEnum<T>::value);
- Tensor output_shape(DT_INT64);
- TensorProto proto_indices;
- TensorProto proto_values;
- TensorProto proto_shape;
-
- OP_REQUIRES(
- context,
- ParseProtoUnlimited(&proto_indices, serialized_sparse_t(i, 0)),
- errors::InvalidArgument("Could not parse serialized_sparse[", i,
- ", 0]"));
- OP_REQUIRES(context,
- ParseProtoUnlimited(&proto_values, serialized_sparse_t(i, 1)),
- errors::InvalidArgument("Could not parse serialized_sparse[",
- i, ", 1]"));
- OP_REQUIRES(context,
- ParseProtoUnlimited(&proto_shape, serialized_sparse_t(i, 2)),
- errors::InvalidArgument("Could not parse serialized_sparse[",
- i, ", 2]"));
-
- OP_REQUIRES(context, output_indices.FromProto(proto_indices),
- errors::InvalidArgument(
- "Could not construct Tensor serialized_sparse[", i,
- ", 0] (indices)"));
- OP_REQUIRES(context, TensorShapeUtils::IsMatrix(output_indices.shape()),
- errors::InvalidArgument(
- "Expected serialized_sparse[", i,
- ", 0] to represent an index matrix but received shape ",
- output_indices.shape().DebugString()));
- OP_REQUIRES(context, output_values.FromProto(proto_values),
- errors::InvalidArgument(
- "Could not construct Tensor serialized_sparse[", i,
- ", 1] (values)"));
- OP_REQUIRES(context, TensorShapeUtils::IsVector(output_values.shape()),
- errors::InvalidArgument(
- "Expected serialized_sparse[", i,
- ", 1] to represent a values vector but received shape ",
- output_values.shape().DebugString()));
- OP_REQUIRES(context, output_shape.FromProto(proto_shape),
- errors::InvalidArgument(
- "Could not construct Tensor serialized_sparse[", i,
- ", 2] (shape)"));
- OP_REQUIRES(
- context, TensorShapeUtils::IsVector(output_shape.shape()),
- errors::InvalidArgument("Expected serialized_sparse[", i,
- ", 1] to be a shape vector but its shape is ",
- output_shape.shape().DebugString()));
-
- OP_REQUIRES(
- context, DataTypeToEnum<T>::value == output_values.dtype(),
- errors::InvalidArgument(
- "Requested SparseTensor of type ",
- DataTypeString(DataTypeToEnum<T>::value), " but SparseTensor[", i,
- "].values.dtype() == ", DataTypeString(output_values.dtype())));
-
- int64 num_entries = output_indices.dim_size(0);
- OP_REQUIRES(context, num_entries == output_values.dim_size(0),
- errors::InvalidArgument(
- "Expected row counts of SparseTensor[", i,
- "].indices and SparseTensor[", i,
- "].values to match but they do not: ", num_entries,
- " vs. ", output_values.dim_size(0)));
- int rank = output_indices.dim_size(1);
- OP_REQUIRES(
- context, rank == output_shape.dim_size(0),
- errors::InvalidArgument("Expected column counts of SparseTensor[", i,
- "].indices to match size of SparseTensor[", i,
- "].shape "
- "but they do not: ",
- rank, " vs. ", output_shape.dim_size(0)));
-
- // Now we expand each SparseTensors' indices and shape by
- // prefixing a dimension
- Tensor expanded_indices(
- DT_INT64, TensorShape({num_entries, 1 + output_indices.dim_size(1)}));
- Tensor expanded_shape(DT_INT64,
- TensorShape({1 + output_shape.dim_size(0)}));
- const auto& output_indices_t = output_indices.matrix<int64>();
- const auto& output_shape_t = output_shape.vec<int64>();
- auto expanded_indices_t = expanded_indices.matrix<int64>();
- auto expanded_shape_t = expanded_shape.vec<int64>();
- expanded_indices_t.chip<1>(0).setZero();
- Eigen::DSizes<Eigen::DenseIndex, 2> indices_start(0, 1);
- Eigen::DSizes<Eigen::DenseIndex, 2> indices_sizes(num_entries, rank);
- expanded_indices_t.slice(indices_start, indices_sizes) = output_indices_t;
- expanded_shape_t(0) = 1;
- std::copy_n(&output_shape_t(0), rank, &expanded_shape_t(1));
-
- TensorShape expanded_tensor_shape(expanded_shape.vec<int64>());
-
- indices_to_concat.push_back(expanded_indices);
- values_to_concat.push_back(output_values);
- shapes_to_concat.push_back(expanded_tensor_shape);
- }
-
- int rank = -1;
- for (int i = 0; i < num_sparse_tensors; ++i) {
- if (rank < 0) rank = shapes_to_concat[i].dims();
- OP_REQUIRES(context, rank == shapes_to_concat[i].dims(),
- errors::InvalidArgument(
- "Inconsistent rank across SparseTensors: rank prior to "
- "SparseTensor[",
- i, "] was: ", rank, " but rank of SparseTensor[", i,
- "] is: ", shapes_to_concat[i].dims()));
- }
-
- // SparseTensor::Concat requires consistent shape for all but the
- // primary order dimension (dimension 0 in this case). So we get
- // the maximum value across all the input SparseTensors for each
- // dimension and use that.
- TensorShape preconcat_shape(shapes_to_concat[0]);
- for (int i = 0; i < num_sparse_tensors; ++i) {
- for (int d = 0; d < rank; ++d) {
- preconcat_shape.set_dim(d, std::max(preconcat_shape.dim_size(d),
- shapes_to_concat[i].dim_size(d)));
- }
- }
-
- // Dimension 0 is the primary dimension.
- gtl::InlinedVector<int64, 8> std_order(rank);
- std::iota(std_order.begin(), std_order.end(), 0);
-
- std::vector<SparseTensor> tensors_to_concat;
- tensors_to_concat.reserve(num_sparse_tensors);
- for (int i = 0; i < num_sparse_tensors; ++i) {
- tensors_to_concat.emplace_back(indices_to_concat[i], values_to_concat[i],
- preconcat_shape, std_order);
- }
-
- SparseTensor output = SparseTensor::Concat<T>(tensors_to_concat);
-
- Tensor final_output_shape(DT_INT64, TensorShape({output.dims()}));
-
- std::copy_n(output.shape().data(), output.dims(),
- final_output_shape.vec<int64>().data());
-
- context->set_output(0, output.indices());
- context->set_output(1, output.values());
- context->set_output(2, final_output_shape);
- }
-};
-
#define REGISTER_KERNELS(type) \
REGISTER_KERNEL_BUILDER(Name("DeserializeManySparse") \
.Device(DEVICE_CPU) \
.TypeConstraint<type>("dtype"), \
- DeserializeManySparseOp<type>)
+ DeserializeSparseOp<type>)
TF_CALL_ALL_TYPES(REGISTER_KERNELS);
#undef REGISTER_KERNELS
diff --git a/tensorflow/core/kernels/softmax_op_functor.h b/tensorflow/core/kernels/softmax_op_functor.h
index 1f38bdce8c..d3a267ed87 100644
--- a/tensorflow/core/kernels/softmax_op_functor.h
+++ b/tensorflow/core/kernels/softmax_op_functor.h
@@ -64,23 +64,21 @@ struct SoftmaxEigenImpl {
one_by_class.set(1, num_classes);
#endif
// shifted_logits = logits - max(logits along classes);
- auto shifted_logits = (logits -
- logits.maximum(along_class)
- .eval()
- .reshape(batch_by_one)
- .broadcast(one_by_class));
+ auto shifted_logits = (logits - logits.maximum(along_class)
+ .eval()
+ .reshape(batch_by_one)
+ .broadcast(one_by_class));
if (log) {
// Calculate the log of the softmax
// softmax = logits - max(logits along classes);
softmax.device(d) = shifted_logits;
// softmax = softmax - log(sum(exp(softmax along classes)));
- softmax.device(d) = (softmax -
- softmax.exp()
- .sum(along_class)
- .eval()
- .reshape(batch_by_one)
- .log()
- .broadcast(one_by_class));
+ softmax.device(d) = (softmax - softmax.exp()
+ .sum(along_class)
+ .log()
+ .eval()
+ .reshape(batch_by_one)
+ .broadcast(one_by_class));
} else {
// NOTE(touts): If you modify this implementation please run
// the BM_ImageNetSoftmaxFwd benchmark in nn_ops_test.cc.
@@ -88,12 +86,11 @@ struct SoftmaxEigenImpl {
// softmax = exp(logits - max(logits along classes));
softmax.device(d) = shifted_logits.exp();
// softmax = softmax * (1 / sum(softmax along classes));
- softmax.device(d) = (softmax *
- softmax.sum(along_class)
- .inverse()
- .eval()
- .reshape(batch_by_one)
- .broadcast(one_by_class));
+ softmax.device(d) = (softmax * softmax.sum(along_class)
+ .inverse()
+ .eval()
+ .reshape(batch_by_one)
+ .broadcast(one_by_class));
}
}
};
diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc
index 8fc40db3cc..73b6d4cf6a 100644
--- a/tensorflow/core/kernels/strided_slice_op.cc
+++ b/tensorflow/core/kernels/strided_slice_op.cc
@@ -427,6 +427,7 @@ REGISTER_STRIDED_SLICE(bfloat16);
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU);
TF_CALL_complex64(REGISTER_GPU);
TF_CALL_complex128(REGISTER_GPU);
+TF_CALL_int64(REGISTER_GPU);
// A special GPU kernel for int32.
// TODO(b/25387198): Also enable int32 in device memory. This kernel
diff --git a/tensorflow/core/kernels/strided_slice_op_gpu.cu.cc b/tensorflow/core/kernels/strided_slice_op_gpu.cu.cc
index a8487f49f4..8ca27e3b92 100644
--- a/tensorflow/core/kernels/strided_slice_op_gpu.cu.cc
+++ b/tensorflow/core/kernels/strided_slice_op_gpu.cu.cc
@@ -53,6 +53,7 @@ typedef Eigen::GpuDevice GPUDevice;
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_KERNELS);
TF_CALL_complex64(DEFINE_GPU_KERNELS);
TF_CALL_complex128(DEFINE_GPU_KERNELS);
+TF_CALL_int64(DEFINE_GPU_KERNELS);
DEFINE_GPU_KERNELS(int32);
#undef DEFINE_GPU_KERNELS
diff --git a/tensorflow/core/kernels/tensor_dataset_op.cc b/tensorflow/core/kernels/tensor_dataset_op.cc
index fe53434d17..5cf9931188 100644
--- a/tensorflow/core/kernels/tensor_dataset_op.cc
+++ b/tensorflow/core/kernels/tensor_dataset_op.cc
@@ -70,7 +70,7 @@ class TensorDatasetOp : public DatasetOpKernel {
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
Node** output) const override {
- std::vector<NodeBuilder::NodeOut> components;
+ std::vector<Node*> components;
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
diff --git a/tensorflow/core/kernels/tensor_slice_dataset_op.cc b/tensorflow/core/kernels/tensor_slice_dataset_op.cc
index e85f59b584..19d4816ff3 100644
--- a/tensorflow/core/kernels/tensor_slice_dataset_op.cc
+++ b/tensorflow/core/kernels/tensor_slice_dataset_op.cc
@@ -86,7 +86,7 @@ class TensorSliceDatasetOp : public DatasetOpKernel {
protected:
Status AsGraphDefInternal(DatasetGraphDefBuilder* b,
Node** output) const override {
- std::vector<NodeBuilder::NodeOut> components;
+ std::vector<Node*> components;
components.reserve(tensors_.size());
for (const Tensor& t : tensors_) {
Node* node;
diff --git a/tensorflow/core/kernels/zip_dataset_op.cc b/tensorflow/core/kernels/zip_dataset_op.cc
index 9381915ae9..31e5737f62 100644
--- a/tensorflow/core/kernels/zip_dataset_op.cc
+++ b/tensorflow/core/kernels/zip_dataset_op.cc
@@ -80,7 +80,7 @@ class ZipDatasetOp : public DatasetOpKernel {
protected:
Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
Node** output) const override {
- std::vector<NodeBuilder::NodeOut> input_graph_nodes;
+ std::vector<Node*> input_graph_nodes;
input_graph_nodes.reserve(inputs_.size());
for (const auto& input : inputs_) {
Node* input_node;
diff --git a/tensorflow/core/lib/core/arena.cc b/tensorflow/core/lib/core/arena.cc
index 2a04f7bd39..55e481d0e6 100644
--- a/tensorflow/core/lib/core/arena.cc
+++ b/tensorflow/core/lib/core/arena.cc
@@ -28,6 +28,7 @@ limitations under the License.
#include <algorithm>
#include <vector>
+#include "tensorflow/core/lib/math/math_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mem.h"
@@ -113,24 +114,11 @@ void Arena::MakeNewBlock(const uint32 alignment) {
CHECK(SatisfyAlignment(alignment));
}
-// The following simple numeric routines also exist in util/math/mathutil.h
-// but we don't want to depend on that library.
-
-// Euclid's algorithm for Greatest Common Denominator.
-static uint32 GCD(uint32 x, uint32 y) {
- while (y != 0) {
- uint32 r = x % y;
- x = y;
- y = r;
- }
- return x;
-}
-
static uint32 LeastCommonMultiple(uint32 a, uint32 b) {
if (a > b) {
- return (a / GCD(a, b)) * b;
+ return (a / MathUtil::GCD<uint32>(a, b)) * b;
} else if (a < b) {
- return (b / GCD(b, a)) * a;
+ return (b / MathUtil::GCD<uint32>(b, a)) * a;
} else {
return a;
}
diff --git a/tensorflow/core/lib/math/math_util.h b/tensorflow/core/lib/math/math_util.h
index 6f279865e7..9e71598622 100644
--- a/tensorflow/core/lib/math/math_util.h
+++ b/tensorflow/core/lib/math/math_util.h
@@ -16,6 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_LIB_MATH_MATH_UTIL_H_
#define TENSORFLOW_LIB_MATH_MATH_UTIL_H_
+#include <type_traits>
+
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
@@ -59,6 +61,9 @@ class MathUtil {
template <typename IntegralType, bool ceil>
static IntegralType CeilOrFloorOfRatio(IntegralType numerator,
IntegralType denominator);
+
+ template <typename IntegralType>
+ static IntegralType GCD(IntegralType x, IntegralType y);
};
// ---- CeilOrFloorOfRatio ----
@@ -107,6 +112,18 @@ IntegralType MathUtil::CeilOrFloorOfRatio(IntegralType numerator,
}
}
+template <typename IntegralType>
+IntegralType MathUtil::GCD(IntegralType a, IntegralType b) {
+ static_assert(std::is_unsigned<IntegralType>::value,
+ "signed GCD not supported!");
+ while (b != 0) {
+ IntegralType r = a % b;
+ a = b;
+ b = r;
+ }
+ return a;
+}
+
} // namespace tensorflow
#endif // TENSORFLOW_LIB_MATH_MATH_UTIL_H_
diff --git a/tensorflow/core/lib/math/math_util_test.cc b/tensorflow/core/lib/math/math_util_test.cc
index eaf8c31a43..a96e5467c3 100644
--- a/tensorflow/core/lib/math/math_util_test.cc
+++ b/tensorflow/core/lib/math/math_util_test.cc
@@ -195,4 +195,33 @@ TEST(MathUtil, CeilOfRatio) {
#endif
}
+struct GCDTestCase {
+ unsigned int x;
+ unsigned int y;
+ unsigned int gcd;
+};
+
+TEST(MathUtil, GCD) {
+ std::vector<GCDTestCase> testcases({
+ {10, 20, 10}, //
+ {27, 8, 1}, //
+ {4, 3, 1}, //
+ {6, 8, 2}, //
+ {5, 0, 5}, //
+ {5, 5, 5}, //
+ {0, 0, 0} //
+ });
+
+ for (const auto& tc : testcases) {
+ EXPECT_EQ(tc.gcd, MathUtil::GCD<uint32>(tc.x, tc.y));
+ EXPECT_EQ(tc.gcd, MathUtil::GCD<uint32>(tc.y, tc.x));
+ EXPECT_EQ(tc.gcd, MathUtil::GCD<uint64>(tc.x, tc.y));
+ EXPECT_EQ(tc.gcd, MathUtil::GCD<uint64>(tc.y, tc.x));
+ }
+
+ const uint64 biggish_prime = 1666666667;
+ EXPECT_EQ(biggish_prime,
+ MathUtil::GCD<uint64>(biggish_prime * 3, biggish_prime * 4));
+}
+
} // namespace tensorflow
diff --git a/tensorflow/core/lib/monitoring/collected_metrics.h b/tensorflow/core/lib/monitoring/collected_metrics.h
index fbef25619f..acdb0d86ed 100644
--- a/tensorflow/core/lib/monitoring/collected_metrics.h
+++ b/tensorflow/core/lib/monitoring/collected_metrics.h
@@ -88,6 +88,7 @@ struct Point {
ValueType value_type;
int64 int64_value;
string string_value;
+ bool bool_value;
HistogramProto histogram_value;
// start_timestamp and end_timestamp indicate the time period over which this
diff --git a/tensorflow/core/lib/monitoring/collection_registry.h b/tensorflow/core/lib/monitoring/collection_registry.h
index 113d37e07d..2c8e250c56 100644
--- a/tensorflow/core/lib/monitoring/collection_registry.h
+++ b/tensorflow/core/lib/monitoring/collection_registry.h
@@ -225,6 +225,12 @@ inline void CollectValue(const string& value, Point* const point) {
}
template <>
+inline void CollectValue(const bool& value, Point* const point) {
+ point->value_type = ValueType::kBool;
+ point->bool_value = value;
+}
+
+template <>
inline void CollectValue(const HistogramProto& value, Point* const point) {
point->value_type = ValueType::kHistogram;
// This is inefficient. If and when we hit snags, we can change the API to do
diff --git a/tensorflow/core/lib/monitoring/gauge.h b/tensorflow/core/lib/monitoring/gauge.h
index 75471cfb22..ec978a9193 100644
--- a/tensorflow/core/lib/monitoring/gauge.h
+++ b/tensorflow/core/lib/monitoring/gauge.h
@@ -86,8 +86,29 @@ class GaugeCell<int64> {
TF_DISALLOW_COPY_AND_ASSIGN(GaugeCell);
};
+// Explicit specialization of GaugeCell<bool>. Compared to the primary
+// template, it uses atomic values as opposed to mutex. This class is
+// thread-safe.
+template <>
+class GaugeCell<bool> {
+ public:
+ explicit GaugeCell(bool value) : value_(value) {}
+ ~GaugeCell() {}
+
+ // Atomically sets the value.
+ void Set(bool value);
+
+ // Retrieves the current value.
+ bool value() const;
+
+ private:
+ std::atomic<bool> value_;
+
+ TF_DISALLOW_COPY_AND_ASSIGN(GaugeCell);
+};
+
// A stateful class for updating a gauge-like metric. Allowed ValueType are
-// int64 and string.
+// int64, string and bool.
//
// This class encapsulates a set of values (or a single value for a label-less
// metric). Each value is identified by a tuple of labels. The class allows the
@@ -117,6 +138,9 @@ class Gauge {
//
// auto* integer_gauge = Gauge<int64, 0>::New("/tensorflow/integer_gauge",
// "Integer gauge")
+ //
+ // auto* bool_gauge = Gauge<bool, 0>::New("/tensorflow/bool_gauge",
+ // "Bool gauge")
template <typename... MetricDefArgs>
static Gauge* New(MetricDefArgs&&... metric_def_args);
@@ -172,12 +196,17 @@ inline void GaugeCell<int64>::Set(int64 value) { value_ = value; }
inline int64 GaugeCell<int64>::value() const { return value_; }
+inline void GaugeCell<bool>::Set(bool value) { value_ = value; }
+
+inline bool GaugeCell<bool>::value() const { return value_; }
+
template <typename ValueType, int NumLabels>
template <typename... MetricDefArgs>
Gauge<ValueType, NumLabels>* Gauge<ValueType, NumLabels>::New(
MetricDefArgs&&... metric_def_args) {
static_assert(std::is_same<ValueType, int64>::value ||
- std::is_same<ValueType, string>::value,
+ std::is_same<ValueType, string>::value ||
+ std::is_same<ValueType, bool>::value,
"Gauge only allows int64 and string types.");
return new Gauge<ValueType, NumLabels>(
MetricDef<MetricKind::kGauge, ValueType, NumLabels>(
diff --git a/tensorflow/core/lib/monitoring/gauge_test.cc b/tensorflow/core/lib/monitoring/gauge_test.cc
index f98cfe2a3b..c8f673db38 100644
--- a/tensorflow/core/lib/monitoring/gauge_test.cc
+++ b/tensorflow/core/lib/monitoring/gauge_test.cc
@@ -87,6 +87,28 @@ TEST(GaugeOfStringValue, GetCell) {
EXPECT_EQ("bar", same_cell->value());
}
+auto* bool_gauge =
+ Gauge<bool, 0>::New("/tensorflow/test/bool_gauge", "Gauge of bool value.");
+
+TEST(GaugeOfBoolValue, InitializedWithFalseValue) {
+ EXPECT_EQ(false, bool_gauge->GetCell()->value());
+}
+
+TEST(GaugeOfBoolValue, GetCell) {
+ auto* cell = bool_gauge->GetCell();
+ EXPECT_EQ(false, cell->value());
+
+ cell->Set(true);
+ EXPECT_EQ(true, cell->value());
+
+ auto* same_cell = bool_gauge->GetCell();
+ EXPECT_EQ(true, cell->value());
+
+ same_cell->Set(false);
+ EXPECT_EQ(false, cell->value());
+ EXPECT_EQ(false, same_cell->value());
+}
+
} // namespace
} // namespace monitoring
} // namespace tensorflow
diff --git a/tensorflow/core/lib/monitoring/metric_def.h b/tensorflow/core/lib/monitoring/metric_def.h
index a7f14f9c94..f046842618 100644
--- a/tensorflow/core/lib/monitoring/metric_def.h
+++ b/tensorflow/core/lib/monitoring/metric_def.h
@@ -28,16 +28,16 @@ namespace monitoring {
// The different metric kinds available.
//
// Gauge indicates that the metric's values are instantaneous measurements of a
-// (typically) continuously varying quantity or a string value. Examples: a
-// process's current heap size, a queue's current length, the name of the binary
-// used by a process.
+// (typically) continuously varying value. Examples: a process's current heap
+// size, a queue's current length, the name of the binary used by a process,
+// whether a task is complete.
//
// Cumulative indicates that the metric's values represent non-negative changes
// over specified time periods. Example: the number of rpc calls to a service.
enum class MetricKind : int { kGauge = 0, kCumulative };
// The type of the metric values.
-enum class ValueType : int { kInt64 = 0, kHistogram, kString };
+enum class ValueType : int { kInt64 = 0, kHistogram, kString, kBool };
// Everything in the internal namespace is implementation details. Do not depend
// on this.
@@ -61,6 +61,11 @@ inline ValueType GetValueType<string>() {
return ValueType::kString;
}
+template <>
+inline ValueType GetValueType<bool>() {
+ return ValueType::kBool;
+}
+
} // namespace internal
// Abstract base class for a metric definition.
diff --git a/tensorflow/core/ops/array_ops.cc b/tensorflow/core/ops/array_ops.cc
index 9fa6423d59..6f4ea09206 100644
--- a/tensorflow/core/ops/array_ops.cc
+++ b/tensorflow/core/ops/array_ops.cc
@@ -724,8 +724,8 @@ REGISTER_OP("OnesLike")
.Input("x: T")
.Output("y: T")
.Attr(
- "T: {float, double, int8, uint8, int16, uint16, int32, int64, "
- "complex64, complex128, bool}")
+ "T: {bfloat16, float, double, int8, uint8, int16, uint16, int32, "
+ "int64, complex64, complex128, bool}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns a tensor of ones with the same shape and type as x.
@@ -738,7 +738,7 @@ y: a tensor of the same shape and type as x but filled with ones.
REGISTER_OP("Diag")
.Input("diagonal: T")
.Output("output: T")
- .Attr("T: {float, double, int32, int64, complex64, complex128}")
+ .Attr("T: {bfloat16, float, double, int32, int64, complex64, complex128}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle in = c->input(0);
TF_RETURN_IF_ERROR(c->WithRankAtLeast(in, 1, &in));
@@ -776,7 +776,7 @@ diagonal: Rank k tensor where k is at most 1.
REGISTER_OP("DiagPart")
.Input("input: T")
.Output("diagonal: T")
- .Attr("T: {float, double, int32, int64, complex64, complex128}")
+ .Attr("T: {bfloat16, float, double, int32, int64, complex64, complex128}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle in = c->input(0);
if (!c->RankKnown(in)) {
@@ -1059,9 +1059,8 @@ REGISTER_OP("Reverse")
.Input("dims: bool")
.Output("output: T")
.Attr(
- "T: {uint8, int8, uint16, int16, int32, int64, bool, half, float, "
- "double, complex64, "
- "complex128, string}")
+ "T: {uint8, int8, uint16, int16, int32, int64, bool, half, "
+ "float, double, complex64, complex128, string}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
ShapeHandle dims;
@@ -1137,9 +1136,8 @@ REGISTER_OP("ReverseV2")
.Output("output: T")
.Attr("Tidx: {int32, int64} = DT_INT32")
.Attr(
- "T: {uint8, int8, uint16, int16, int32, int64, bool, half, float, "
- "double, complex64, "
- "complex128, string}")
+ "T: {uint8, int8, uint16, int16, int32, int64, bool, half, bfloat16, "
+ "float, double, complex64, complex128, string}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
ShapeHandle axis;
@@ -1834,7 +1832,7 @@ this operation.
REGISTER_OP("CheckNumerics")
.Input("tensor: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.Attr("message: string")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
@@ -4565,12 +4563,12 @@ REGISTER_OP("Bitcast")
.Output("output: type")
// All supported dtypes are listed here to include qint16 and quint16.
.Attr(
- "T: {float, double, int64, int32, uint8, uint16, int8, int16,"
+ "T: {bfloat16, float, double, int64, int32, uint8, uint16, int8, int16,"
" complex64, complex128, qint8, quint8, qint16, quint16, qint32,"
" half}")
.Attr(
- "type: {float, double, int64, int32, uint8, uint16, int8, int16,"
- " complex64, complex128, qint8, quint8, qint16, quint16, qint32,"
+ "type: {bfloat16, float, double, int64, int32, uint8, uint16, int8, "
+ "int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32,"
" half}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input = c->input(0);
@@ -4782,7 +4780,7 @@ REGISTER_OP("QuantizeAndDequantize")
.Attr("input_min: float = 0")
.Attr("input_max: float = 0")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Deprecated(22, "Replaced by QuantizeAndDequantizeV2")
.Doc(R"doc(
@@ -4798,7 +4796,7 @@ REGISTER_OP("QuantizeAndDequantizeV2")
.Attr("num_bits: int = 8")
.Attr("range_given: bool = false")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
@@ -4877,7 +4875,7 @@ REGISTER_OP("QuantizeAndDequantizeV3")
.Attr("signed_input: bool = true")
.Attr("range_given: bool = true")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
diff --git a/tensorflow/core/ops/dataset_ops.cc b/tensorflow/core/ops/dataset_ops.cc
index 6bf226e7a5..be41531347 100644
--- a/tensorflow/core/ops/dataset_ops.cc
+++ b/tensorflow/core/ops/dataset_ops.cc
@@ -469,6 +469,24 @@ stop: corresponds to stop in python's xrange().
step: corresponds to step in python's xrange().
)doc");
+REGISTER_OP("RandomDataset")
+ .Input("seed: int64")
+ .Input("seed2: int64")
+ .Output("handle: variant")
+ .Attr("output_types: list(type) >= 1")
+ .Attr("output_shapes: list(shape) >= 1")
+ .SetIsStateful() // TODO(b/65524810): Source dataset ops must be marked
+ // stateful to inhibit constant folding.
+ .SetShapeFn(shape_inference::ScalarShape)
+ .Doc(R"doc(
+Creates a Dataset that returns pseudorandom numbers.
+
+seed: A scalar seed for the random number generator. If either seed or
+ seed2 is set to be non-zero, the random number generator is seeded
+ by the given seed. Otherwise, a random seed is used.
+seed2: A second scalar seed to avoid seed collision.
+)doc");
+
REGISTER_OP("ShuffleDataset")
.Input("input_dataset: variant")
.Input("buffer_size: int64")
diff --git a/tensorflow/core/ops/math_ops.cc b/tensorflow/core/ops/math_ops.cc
index df75caca37..45ebfa203b 100644
--- a/tensorflow/core/ops/math_ops.cc
+++ b/tensorflow/core/ops/math_ops.cc
@@ -85,7 +85,7 @@ REGISTER_OP("BatchMatMul")
.Input("x: T")
.Input("y: T")
.Output("output: T")
- .Attr("T: {half, float, double, int32, complex64, complex128}")
+ .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}")
.Attr("adj_x: bool = false")
.Attr("adj_y: bool = false")
.SetShapeFn([](InferenceContext* c) {
@@ -184,7 +184,7 @@ _HostCast requires its input and produces its output in host memory.
REGISTER_OP("Abs")
.Input("x: T")
.Output("y: T")
- .Attr("T: {half, float, double, int32, int64}")
+ .Attr("T: {half, bfloat16, float, double, int32, int64}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Computes the absolute value of a tensor.
@@ -210,29 +210,31 @@ value is computed as \\( \sqrt{a^2 + b^2}\\).
)doc");
// Declares cwise unary operations signature: 't -> 't
-#define UNARY() \
- Input("x: T") \
- .Output("y: T") \
- .Attr("T: {half, float, double, int32, int64, complex64, complex128}") \
+#define UNARY() \
+ Input("x: T") \
+ .Output("y: T") \
+ .Attr( \
+ "T: {half, bfloat16, float, double, int32, int64, complex64, " \
+ "complex128}") \
.SetShapeFn(shape_inference::UnchangedShape)
-#define UNARY_REAL() \
- Input("x: T") \
- .Output("y: T") \
- .Attr("T: {half, float, double}") \
+#define UNARY_REAL() \
+ Input("x: T") \
+ .Output("y: T") \
+ .Attr("T: {half, bfloat16, float, double}") \
.SetShapeFn(shape_inference::UnchangedShape)
-#define UNARY_COMPLEX() \
- Input("x: T") \
- .Output("y: T") \
- .Attr("T: {half, float, double, complex64, complex128}") \
+#define UNARY_COMPLEX() \
+ Input("x: T") \
+ .Output("y: T") \
+ .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \
.SetShapeFn(shape_inference::UnchangedShape)
-#define UNARY_GRADIENT_COMPLEX() \
- Input("y: T") \
- .Input("dy: T") \
- .Output("z: T") \
- .Attr("T: {half, float, double, complex64, complex128}") \
+#define UNARY_GRADIENT_COMPLEX() \
+ Input("y: T") \
+ .Input("dy: T") \
+ .Output("z: T") \
+ .Attr("T: {half, bfloat16, float, double, complex64, complex128}") \
.SetShapeFn(shape_inference::UnchangedShape)
REGISTER_OP("Neg")
@@ -481,7 +483,7 @@ Computes atan of x element-wise.
REGISTER_OP("IsNan")
.Input("x: T")
.Output("y: bool")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns which elements of x are NaN.
@@ -494,7 +496,7 @@ Equivalent to np.isnan
REGISTER_OP("IsInf")
.Input("x: T")
.Output("y: bool")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns which elements of x are Inf.
@@ -507,7 +509,7 @@ Equivalent to np.isinf
REGISTER_OP("IsFinite")
.Input("x: T")
.Output("y: bool")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns which elements of x are finite.
@@ -520,7 +522,9 @@ Equivalent to np.isfinite
REGISTER_OP("Sign")
.Input("x: T")
.Output("y: T")
- .Attr("T: {half, float, double, int32, int64, complex64, complex128}")
+ .Attr(
+ "T: {half, bfloat16, float, double, int32, int64, complex64, "
+ "complex128}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns an element-wise indication of the sign of a number.
@@ -533,7 +537,7 @@ For complex numbers, `y = sign(x) = x / |x|` if `x != 0`, otherwise `y = 0`.
REGISTER_OP("Floor")
.Input("x: T")
.Output("y: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns element-wise largest integer not greater than x.
@@ -542,7 +546,7 @@ Returns element-wise largest integer not greater than x.
REGISTER_OP("Ceil")
.Input("x: T")
.Output("y: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns element-wise smallest integer in not less than x.
@@ -551,7 +555,7 @@ Returns element-wise smallest integer in not less than x.
REGISTER_OP("Rint")
.Input("x: T")
.Output("y: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Returns element-wise integer closest to x.
@@ -569,22 +573,23 @@ rint([-1.7, -1.5, -0.2, 0.2, 1.5, 1.7, 2.0]) ==> [-2., -2., -0., 0., 2., 2., 2.]
// Declares cwise binary operations signature: 't, 't -> 't.
-#define BINARY_MORE() \
- Input("x: T").Input("y: T").Output("z: T").Attr( \
- "T: {half, float, double, uint8, int8, uint16, int16, int32, int64, " \
- "complex64, complex128}")
+#define BINARY_MORE() \
+ Input("x: T").Input("y: T").Output("z: T").Attr( \
+ "T: {half, bfloat16, float, double, uint8, int8, uint16, int16, int32, " \
+ "int64, complex64, complex128}")
-#define BINARY_FEWER() \
- Input("x: T").Input("y: T").Output("z: T").Attr( \
- "T: {half, float, double, int32, int64, complex64, complex128}")
+#define BINARY_FEWER() \
+ Input("x: T").Input("y: T").Output("z: T").Attr( \
+ "T: {half, bfloat16, float, double, int32, int64, complex64, " \
+ "complex128}")
REGISTER_OP("Add")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr(
- "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
- "complex128, string}")
+ "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, "
+ "complex64, complex128, string}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Returns x + y element-wise.
@@ -600,8 +605,8 @@ REGISTER_OP("AddV2")
.Input("y: T")
.Output("z: T")
.Attr(
- "T: {half, float, double, uint8, int8, int16, int32, int64, complex64, "
- "complex128}")
+ "T: {half, bfloat16, float, double, uint8, int8, int16, int32, int64, "
+ "complex64, complex128}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.SetIsAggregate()
.SetIsCommutative()
@@ -757,7 +762,7 @@ REGISTER_OP("Maximum")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {half, float, double, int32, int64}")
+ .Attr("T: {half, bfloat16, float, double, int32, int64}")
.SetIsCommutative()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
@@ -788,7 +793,7 @@ REGISTER_OP("Minimum")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {half, float, double, int32, int64}")
+ .Attr("T: {half, bfloat16, float, double, int32, int64}")
.SetIsCommutative()
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
@@ -802,7 +807,7 @@ REGISTER_OP("Mod")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {int32, int64, float, double}")
+ .Attr("T: {int32, int64, bfloat16, float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Returns element-wise remainder of division. This emulates C semantics in that
@@ -817,7 +822,7 @@ REGISTER_OP("FloorMod")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {int32, int64, float, double}")
+ .Attr("T: {int32, int64, bfloat16, float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Returns element-wise remainder of division. When `x < 0` xor `y < 0` is
@@ -832,7 +837,7 @@ REGISTER_OP("TruncateMod")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {int32, int64, float, double}")
+ .Attr("T: {int32, int64, bfloat16, float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Returns element-wise remainder of division. This emulates C semantics in that
@@ -847,7 +852,9 @@ REGISTER_OP("Pow")
.Input("x: T")
.Input("y: T")
.Output("z: T")
- .Attr("T: {half, float, double, int32, int64, complex64, complex128}")
+ .Attr(
+ "T: {half, bfloat16, float, double, int32, int64, complex64, "
+ "complex128}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Computes the power of one value to another.
@@ -946,7 +953,7 @@ REGISTER_OP("Atan2")
.Input("y: T")
.Input("x: T")
.Output("z: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
.Doc(R"doc(
Computes arctangent of `y/x` element-wise, respecting signs of the arguments.
@@ -1064,15 +1071,15 @@ Returns the truth value of (x >= y) element-wise.
// --------------------------------------------------------------------------
-#define EQUALITY_COMPARISON() \
- Input("x: T") \
- .Input("y: T") \
- .Output("z: bool") \
- .SetIsCommutative() \
- .Attr( \
- "T: {half, float, double, uint8, int8, int16, int32, int64, " \
- "complex64, " \
- "quint8, qint8, qint32, string, bool, complex128}") \
+#define EQUALITY_COMPARISON() \
+ Input("x: T") \
+ .Input("y: T") \
+ .Output("z: bool") \
+ .SetIsCommutative() \
+ .Attr( \
+ "T: {half, bfloat16, float, double, uint8, int8, int16, int32, " \
+ "int64, complex64, quint8, qint8, qint32, string, bool, " \
+ "complex128}") \
.SetShapeFn(shape_inference::BroadcastBinaryOpShapeFn)
REGISTER_OP("Equal")
@@ -1291,7 +1298,7 @@ REGISTER_OP("MatMul")
.Output("product: T")
.Attr("transpose_a: bool = false")
.Attr("transpose_b: bool = false")
- .Attr("T: {half, float, double, int32, complex64, complex128}")
+ .Attr("T: {half, bfloat16, float, double, int32, complex64, complex128}")
.SetShapeFn(shape_inference::MatMulShape)
.Doc(R"doc(
Multiply the matrix "a" by the matrix "b".
@@ -1811,10 +1818,11 @@ output: Has same shape as data, except for dimension 0 which
REGISTER_OP("UnsortedSegmentSum")
.Input("data: T")
.Input("segment_ids: Tindices")
- .Input("num_segments: int32")
+ .Input("num_segments: Tnumsegments")
.Output("output: T")
.Attr("T: numbertype")
.Attr("Tindices: {int32,int64}")
+ .Attr("Tnumsegments: {int32,int64} = DT_INT32")
.SetShapeFn(UnsortedSegmentReductionShapeFn)
.Doc(R"doc(
Computes the sum along segments of a tensor.
@@ -1849,10 +1857,11 @@ output: Has same shape as data, except for the first `segment_ids.rank`
REGISTER_OP("UnsortedSegmentMax")
.Input("data: T")
.Input("segment_ids: Tindices")
- .Input("num_segments: int32")
+ .Input("num_segments: Tnumsegments")
.Output("output: T")
.Attr("T: realnumbertype")
.Attr("Tindices: {int32,int64}")
+ .Attr("Tnumsegments: {int32,int64} = DT_INT32")
.SetShapeFn(UnsortedSegmentReductionShapeFn)
.Doc(R"doc(
Computes the Max along segments of a tensor.
@@ -2103,7 +2112,7 @@ REGISTER_OP("Range")
.Input("limit: Tidx")
.Input("delta: Tidx")
.Output("output: Tidx")
- .Attr("Tidx: {float, double, int32, int64} = DT_INT32")
+ .Attr("Tidx: {bfloat16, float, double, int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
TF_RETURN_WITH_CONTEXT_IF_ERROR(c->WithRank(c->input(0), 0, &unused),
@@ -2158,7 +2167,7 @@ REGISTER_OP("LinSpace")
.Input("stop: T")
.Input("num: Tidx")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.Attr("Tidx: {int32, int64} = DT_INT32")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle unused;
diff --git a/tensorflow/core/ops/nn_ops.cc b/tensorflow/core/ops/nn_ops.cc
index 654e890b57..102de94787 100644
--- a/tensorflow/core/ops/nn_ops.cc
+++ b/tensorflow/core/ops/nn_ops.cc
@@ -73,7 +73,7 @@ REGISTER_OP("AvgPool")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::AvgPoolShape)
.Doc(R"doc(
Performs average pooling on the input.
@@ -101,7 +101,7 @@ REGISTER_OP("AvgPoolGrad")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
@@ -300,7 +300,7 @@ REGISTER_OP("FusedBatchNormV2")
.Output("batch_variance: U")
.Output("reserve_space_1: U")
.Output("reserve_space_2: U")
- .Attr("T: {half, float}")
+ .Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
.Attr("data_format: string = 'NHWC'")
@@ -393,7 +393,7 @@ REGISTER_OP("FusedBatchNormGradV2")
.Output("offset_backprop: U")
.Output("reserve_space_3: U")
.Output("reserve_space_4: U")
- .Attr("T: {half, float}")
+ .Attr("T: {half, bfloat16, float}")
.Attr("U: {float}")
.Attr("epsilon: float = 0.0001")
.Attr("data_format: string = 'NHWC'")
@@ -508,11 +508,12 @@ REGISTER_OP("Conv2D")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
- .Attr("T: {half, float}")
+ .Attr("T: {half, bfloat16, float}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn(shape_inference::Conv2DShape)
.Doc(R"doc(
Computes a 2-D convolution given 4-D `input` and `filter` tensors.
@@ -546,7 +547,7 @@ filter: A 4-D tensor of shape
output: A 4-D tensor. The dimension order is determined by the value of
`data_format`, see below for details.
strides: 1-D tensor of length 4. The stride of the sliding window for each
- dimension of `input`. The dimension order is determined by the value of
+ dimension of `input`. The dimension order is determined by the value of
`data_format`, see below for details.
padding: The type of padding algorithm to use.
data_format: Specify the data format of the input and output data. With the
@@ -554,6 +555,11 @@ data_format: Specify the data format of the input and output data. With the
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each
+ filter element on that dimension. The dimension order is determined by the
+ value of `data_format`, see above for details. Dilations in the batch and
+ depth dimensions must be 1.
)doc");
REGISTER_OP("Conv2DBackpropInput")
@@ -561,11 +567,12 @@ REGISTER_OP("Conv2DBackpropInput")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {half, float}")
+ .Attr("T: {half, bfloat16, float}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
@@ -589,10 +596,15 @@ padding: The type of padding algorithm to use.
output: 4-D with shape `[batch, in_height, in_width, in_channels]`. Gradient
w.r.t. the input of the convolution.
data_format: Specify the data format of the input and output data. With the
- default format "NHWC", the data is stored in the order of:
- [batch, in_height, in_width, in_channels].
- Alternatively, the format could be "NCHW", the data storage order of:
- [batch, in_channels, in_height, in_width].
+ default format "NHWC", the data is stored in the order of:
+ [batch, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, in_channels, in_height, in_width].
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+ element on that dimension. The dimension order is determined by the value of
+ `data_format`, see above for details. Dilations in the batch and depth
+ dimensions must be 1.
)doc");
// TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a
@@ -603,11 +615,12 @@ REGISTER_OP("Conv2DBackpropFilter")
.Input("filter_sizes: int32")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {half, float}")
+ .Attr("T: {half, bfloat16, float}")
.Attr("strides: list(int)")
.Attr("use_cudnn_on_gpu: bool = true")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
@@ -632,10 +645,15 @@ output: 4-D with shape
`[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t.
the `filter` input of the convolution.
data_format: Specify the data format of the input and output data. With the
- default format "NHWC", the data is stored in the order of:
- [batch, in_height, in_width, in_channels].
- Alternatively, the format could be "NCHW", the data storage order of:
- [batch, in_channels, in_height, in_width].
+ default format "NHWC", the data is stored in the order of:
+ [batch, in_height, in_width, in_channels].
+ Alternatively, the format could be "NCHW", the data storage order of:
+ [batch, in_channels, in_height, in_width].
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+ element on that dimension. The dimension order is determined by the value of
+ `data_format`, see above for details. Dilations in the batch and depth
+ dimensions must be 1.
)doc");
namespace {
@@ -819,10 +837,11 @@ REGISTER_OP("DepthwiseConv2dNative")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn(shape_inference::DepthwiseConv2DNativeShape)
.Doc(R"doc(
Computes a 2-D depthwise convolution given 4-D `input` and `filter` tensors.
@@ -845,7 +864,6 @@ for k in 0..in_channels-1
Must have `strides[0] = strides[3] = 1`. For the most common case of the same
horizontal and vertices strides, `strides = [1, stride, stride, 1]`.
-
strides: 1-D of length 4. The stride of the sliding window for each dimension
of `input`.
padding: The type of padding algorithm to use.
@@ -854,6 +872,11 @@ data_format: Specify the data format of the input and output data. With the
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+ element on that dimension. The dimension order is determined by the value of
+ `data_format`, see above for details. Dilations in the batch and depth
+ dimensions must be 1.
)doc");
REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
@@ -861,10 +884,11 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
@@ -892,6 +916,11 @@ data_format: Specify the data format of the input and output data. With the
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+ element on that dimension. The dimension order is determined by the value of
+ `data_format`, see above for details. Dilations in the batch and depth
+ dimensions must be 1.
output: 4-D with shape according to `data_format`. For example, if
`data_format` is 'NHWC', output shape is `[batch, in_height,
in_width, in_channels]`. Gradient w.r.t. the input of the
@@ -903,10 +932,11 @@ REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
.Input("filter_sizes: int32")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
.Attr(GetConvnetDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
@@ -935,6 +965,11 @@ data_format: Specify the data format of the input and output data. With the
[batch, height, width, channels].
Alternatively, the format could be "NCHW", the data storage order of:
[batch, channels, height, width].
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each filter
+ element on that dimension. The dimension order is determined by the value of
+ `data_format`, see above for details. Dilations in the batch and depth
+ dimensions must be 1.
output: 4-D with shape
`[filter_height, filter_width, in_channels, out_channels]`. Gradient w.r.t.
the `filter` input of the convolution.
@@ -945,10 +980,11 @@ REGISTER_OP("Conv3D")
.Input("input: T")
.Input("filter: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
.SetShapeFn(shape_inference::Conv3DShape)
.Doc(R"doc(
Computes a 3-D convolution given 5-D `input` and `filter` tensors.
@@ -970,6 +1006,11 @@ data_format: The data format of the input and output data. With the
[batch, in_depth, in_height, in_width, in_channels].
Alternatively, the format could be "NCDHW", the data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
+dilations: 1-D tensor of length 5. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each
+ filter element on that dimension. The dimension order is determined by the
+ value of `data_format`, see above for details. Dilations in the batch and
+ depth dimensions must be 1.
)doc");
REGISTER_OP("Conv3DBackpropInput")
@@ -1032,10 +1073,11 @@ REGISTER_OP("Conv3DBackpropInputV2")
.Input("filter: T")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
@@ -1061,6 +1103,11 @@ data_format: The data format of the input and output data. With the
[batch, in_depth, in_height, in_width, in_channels].
Alternatively, the format could be "NCDHW", the data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
+dilations: 1-D tensor of length 5. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each
+ filter element on that dimension. The dimension order is determined by the
+ value of `data_format`, see above for details. Dilations in the batch and
+ depth dimensions must be 1.
)doc");
@@ -1069,10 +1116,11 @@ REGISTER_OP("Conv3DBackpropFilterV2")
.Input("filter_sizes: int32")
.Input("out_backprop: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
@@ -1098,6 +1146,11 @@ data_format: The data format of the input and output data. With the
[batch, in_depth, in_height, in_width, in_channels].
Alternatively, the format could be "NCDHW", the data storage order is:
[batch, in_channels, in_depth, in_height, in_width].
+dilations: 1-D tensor of length 5. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each
+ filter element on that dimension. The dimension order is determined by the
+ value of `data_format`, see above for details. Dilations in the batch and
+ depth dimensions must be 1.
)doc");
@@ -1110,7 +1163,7 @@ REGISTER_OP("AvgPool3D")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn(shape_inference::Pool3DShape)
.Doc(R"doc(
Performs 3D average pooling on the input.
@@ -1137,7 +1190,7 @@ REGISTER_OP("AvgPool3DGrad")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: {float, double}")
+ .Attr("T: {bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
@@ -1172,7 +1225,7 @@ REGISTER_OP("MaxPool3D")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: {float}")
+ .Attr("T: {bfloat16, float}")
.SetShapeFn(shape_inference::Pool3DShape)
.Doc(R"doc(
Performs 3D max pooling on the input.
@@ -1200,8 +1253,8 @@ REGISTER_OP("MaxPool3DGrad")
.Attr("strides: list(int) >= 5")
.Attr(GetPaddingAttrString())
.Attr(GetConvnet3dDataFormatAttrString())
- .Attr("T: {float} = DT_FLOAT")
- .Attr("TInput: {float} = DT_FLOAT")
+ .Attr("T: {bfloat16, float} = DT_FLOAT")
+ .Attr("TInput: {bfloat16, float} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 5);
})
@@ -1266,7 +1319,7 @@ data_format: The data format of the input and output data. With the
REGISTER_OP("L2Loss")
.Input("t: T")
.Output("output: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::ScalarShape)
.Doc(R"doc(
L2 Loss.
@@ -1288,7 +1341,7 @@ REGISTER_OP("LRN")
.Attr("bias: float = 1.0")
.Attr("alpha: float = 1.0")
.Attr("beta: float = 0.5")
- .Attr("T: {float, half} = DT_FLOAT")
+ .Attr("T: {half, bfloat16, float} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
return UnchangedShapeWithRank(c, 4);
})
@@ -1323,7 +1376,7 @@ REGISTER_OP("LRNGrad")
.Attr("bias: float = 1.0")
.Attr("alpha: float = 1.0")
.Attr("beta: float = 0.5")
- .Attr("T: {float, half} = DT_FLOAT")
+ .Attr("T: {half, bfloat16, float} = DT_FLOAT")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle s;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads
@@ -1349,8 +1402,8 @@ output: The gradients for LRN.
REGISTER_OP("MaxPool")
.Attr(
- "T: {float, double, int32, int64, uint8, int16, int8, uint16, "
- "half, qint8} = DT_FLOAT")
+ "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
+ "uint16, qint8} = DT_FLOAT")
.Attr("ksize: list(int) >= 4")
.Attr("strides: list(int) >= 4")
.Attr(GetPaddingAttrString())
@@ -1376,8 +1429,8 @@ output: The max pooled output tensor.
REGISTER_OP("MaxPoolV2")
.Attr(
- "T: {float, double, int32, int64, uint8, int16, int8, uint16, "
- "half, qint8} = DT_FLOAT")
+ "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
+ "uint16, qint8} = DT_FLOAT")
.Attr(GetPaddingAttrString())
.Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
.Input("input: T")
@@ -1860,7 +1913,7 @@ backprops: The gradients:
REGISTER_OP("Elu")
.Input("features: T")
.Output("activations: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Computes exponential linear: `exp(features) - 1` if < 0, `features` otherwise.
@@ -1873,7 +1926,7 @@ REGISTER_OP("EluGrad")
.Input("gradients: T")
.Input("outputs: T")
.Output("backprops: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn)
.Doc(R"doc(
Computes gradients for the exponential linear (Elu) operation.
@@ -1887,7 +1940,7 @@ backprops: The gradients: `gradients * (outputs + 1)` if outputs < 0,
REGISTER_OP("Selu")
.Input("features: T")
.Output("activations: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Computes scaled exponential linear: `scale * alpha * (exp(features) - 1)`
@@ -1900,7 +1953,7 @@ REGISTER_OP("SeluGrad")
.Input("gradients: T")
.Input("outputs: T")
.Output("backprops: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn(shape_inference::MergeBothInputsShapeFn)
.Doc(R"doc(
Computes gradients for the scaled exponential linear (Selu) operation.
@@ -1962,7 +2015,7 @@ backprops: The gradients: `gradients / (1 + abs(features)) ** 2`.
REGISTER_OP("Softmax")
.Input("logits: T")
.Output("softmax: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
})
@@ -1982,7 +2035,7 @@ softmax: Same shape as `logits`.
REGISTER_OP("LogSoftmax")
.Input("logits: T")
.Output("logsoftmax: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
})
@@ -2004,7 +2057,7 @@ REGISTER_OP("SoftmaxCrossEntropyWithLogits")
.Input("labels: T")
.Output("loss: T")
.Output("backprop: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle input;
TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &input));
@@ -2033,7 +2086,7 @@ REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits")
.Input("labels: Tlabels")
.Output("loss: T")
.Output("backprop: T")
- .Attr("T: {half, float, double}")
+ .Attr("T: {half, bfloat16, float, double}")
.Attr("Tlabels: {int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle features;
@@ -2613,6 +2666,7 @@ REGISTER_OP("QuantizedConv2D")
.Attr("out_type: quantizedtype = DT_QINT32")
.Attr("strides: list(int)")
.Attr(GetPaddingAttrString())
+ .Attr("dilations: list(int) = [1, 1, 1, 1]")
.SetShapeFn([](InferenceContext* c) {
TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
ShapeHandle unused;
@@ -2641,7 +2695,11 @@ min_filter: The float value that the lowest quantized filter value represents.
max_filter: The float value that the highest quantized filter value represents.
min_output: The float value that the lowest quantized output value represents.
max_output: The float value that the highest quantized output value represents.
-
+dilations: 1-D tensor of length 4. The dilation factor for each dimension of
+ `input`. If set to k > 1, there will be k-1 skipped cells between each
+ filter element on that dimension. The dimension order is determined by the
+ value of `data_format`, see above for details. Dilations in the batch and
+ depth dimensions must be 1.
)doc");
REGISTER_OP("QuantizedMaxPool")
diff --git a/tensorflow/core/ops/random_ops.cc b/tensorflow/core/ops/random_ops.cc
index 2429171fa9..31d9c82e53 100644
--- a/tensorflow/core/ops/random_ops.cc
+++ b/tensorflow/core/ops/random_ops.cc
@@ -29,7 +29,7 @@ REGISTER_OP("RandomUniform")
.Output("output: dtype")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
- .Attr("dtype: {half,float,double}")
+ .Attr("dtype: {half,bfloat16,float,double}")
.Attr("T: {int32, int64}")
.SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
@@ -87,7 +87,7 @@ REGISTER_OP("RandomStandardNormal")
.Output("output: dtype")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
- .Attr("dtype: {half,float,double}")
+ .Attr("dtype: {half,bfloat16,float,double}")
.Attr("T: {int32, int64}")
.SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
@@ -115,7 +115,7 @@ REGISTER_OP("ParameterizedTruncatedNormal")
.Output("output: dtype")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
- .Attr("dtype: {half,float,double}")
+ .Attr("dtype: {half,bfloat16,float,double}")
.Attr("T: {int32, int64}")
.SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
@@ -145,7 +145,7 @@ REGISTER_OP("TruncatedNormal")
.Output("output: dtype")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
- .Attr("dtype: {half,float,double}")
+ .Attr("dtype: {half,bfloat16,float,double}")
.Attr("T: {int32, int64}")
.SetShapeFn(shape_inference::RandomShape)
.Doc(R"doc(
@@ -201,10 +201,11 @@ REGISTER_OP("Multinomial")
.SetIsStateful()
.Input("logits: T")
.Input("num_samples: int32")
- .Output("output: int64")
+ .Output("output: output_dtype")
.Attr("seed: int = 0")
.Attr("seed2: int = 0")
.Attr("T: realnumbertype")
+ .Attr("output_dtype: {int32, int64} = DT_INT64")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle logits_shape;
ShapeHandle unused;
diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc
index cdfbec85cf..bf9e673e8e 100644
--- a/tensorflow/core/ops/resource_variable_ops.cc
+++ b/tensorflow/core/ops/resource_variable_ops.cc
@@ -204,7 +204,10 @@ Status VariableShapeShapeFn(InferenceContext* c) {
if (handle_data == nullptr || handle_data->empty()) {
return errors::InvalidArgument("Handle doesn't have shape information.");
}
- c->set_output(0, (*handle_data)[0].shape);
+ ShapeHandle var_shape = (*handle_data)[0].shape;
+ int64 rank = c->RankKnown(var_shape) ? c->Rank(var_shape)
+ : InferenceContext::kUnknownDim;
+ c->set_output(0, c->Vector(rank));
return Status::OK();
}
diff --git a/tensorflow/core/ops/sparse_ops.cc b/tensorflow/core/ops/sparse_ops.cc
index 8414519f0b..772e2531dc 100644
--- a/tensorflow/core/ops/sparse_ops.cc
+++ b/tensorflow/core/ops/sparse_ops.cc
@@ -256,6 +256,48 @@ REGISTER_OP("DeserializeSparse")
.Doc(R"doc(
Deserialize `SparseTensor` objects.
+The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where
+the last dimension stores serialized `SparseTensor` objects and the other N
+dimensions (N >= 0) correspond to a batch. The ranks of the original
+`SparseTensor` objects must all match. When the final `SparseTensor` is
+created, its rank is the rank of the incoming `SparseTensor` objects plus N;
+the sparse tensors have been concatenated along new dimensions, one for each
+batch.
+
+The output `SparseTensor` object's shape values for the original dimensions
+are the max across the input `SparseTensor` objects' shape values for the
+corresponding dimensions. The new dimensions match the size of the batch.
+
+The input `SparseTensor` objects' indices are assumed ordered in
+standard lexicographic order. If this is not the case, after this
+step run `SparseReorder` to restore index ordering.
+
+For example, if the serialized input is a `[2 x 3]` matrix representing two
+original `SparseTensor` objects:
+
+ index = [ 0]
+ [10]
+ [20]
+ values = [1, 2, 3]
+ shape = [50]
+
+and
+
+ index = [ 2]
+ [10]
+ values = [4, 5]
+ shape = [30]
+
+then the final deserialized `SparseTensor` will be:
+
+ index = [0 0]
+ [0 10]
+ [0 20]
+ [1 2]
+ [1 10]
+ values = [1, 2, 3, 4, 5]
+ shape = [2 50]
+
serialized_sparse: The serialized `SparseTensor` objects. The last dimension
must have 3 columns.
dtype: The `dtype` of the serialized `SparseTensor` objects.
diff --git a/tensorflow/core/ops/state_ops.cc b/tensorflow/core/ops/state_ops.cc
index da5f091e9f..5b1f5d2477 100644
--- a/tensorflow/core/ops/state_ops.cc
+++ b/tensorflow/core/ops/state_ops.cc
@@ -513,6 +513,62 @@ output_ref: Same as ref. Returned as a convenience for operations that want to
use the updated values after the update is done.
)doc");
+REGISTER_OP("ResourceScatterNdUpdate")
+ .Input("ref: resource")
+ .Input("indices: Tindices")
+ .Input("updates: T")
+ .Attr("T: type")
+ .Attr("Tindices: {int32, int64}")
+ .Attr("use_locking: bool = true")
+ .SetShapeFn(shape_inference::ScatterNdUpdateShape)
+ .Doc(R"doc(
+Applies sparse `updates` to individual values or slices within a given
+variable according to `indices`.
+
+`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+`indices` must be integer tensor, containing indices into `ref`.
+It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+The innermost dimension of `indices` (with length `K`) corresponds to
+indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+dimension of `ref`.
+
+`updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+```
+[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+```
+
+For example, say we want to update 4 scattered elements to a rank-1 tensor to
+8 elements. In Python, that update would look like this:
+
+```python
+ ref = tfe.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ update = tf.scatter_nd_update(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(update)
+```
+
+The resulting update to ref would look like this:
+
+ [1, 11, 3, 10, 9, 6, 7, 12]
+
+See @{tf.scatter_nd} for more details about how to make updates to
+slices.
+
+ref: A resource handle. Must be from a VarHandleOp.
+indices: A Tensor. Must be one of the following types: int32, int64.
+ A tensor of indices into ref.
+updates: A Tensor. Must have the same type as ref. A tensor of updated
+ values to add to ref.
+use_locking: An optional bool. Defaults to True. If True, the assignment will
+ be protected by a lock; otherwise the behavior is undefined,
+ but may exhibit less contention.
+)doc");
+
REGISTER_OP("ScatterNdAdd")
.Input("ref: Ref(T)")
.Input("indices: Tindices")
diff --git a/tensorflow/core/platform/cloud/curl_http_request_test.cc b/tensorflow/core/platform/cloud/curl_http_request_test.cc
index 6c0f081852..d476a1a4db 100644
--- a/tensorflow/core/platform/cloud/curl_http_request_test.cc
+++ b/tensorflow/core/platform/cloud/curl_http_request_test.cc
@@ -263,7 +263,6 @@ TEST(CurlHttpRequestTest, GetRequest) {
std::vector<char> scratch;
scratch.insert(scratch.begin(), kTestContent.begin(), kTestContent.end());
- StringPiece result;
scratch.reserve(100);
TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
@@ -594,7 +593,6 @@ TEST(CurlHttpRequestTest, ErrorReturnsNoResponse) {
std::vector<char> scratch;
scratch.insert(scratch.begin(), kTestContent.begin(), kTestContent.end());
- StringPiece result;
scratch.reserve(100);
TF_EXPECT_OK(http_request.SetUri("http://www.testuri.com"));
diff --git a/tensorflow/core/platform/cloud/file_block_cache.cc b/tensorflow/core/platform/cloud/file_block_cache.cc
index a472ae52fc..e1afc7b308 100644
--- a/tensorflow/core/platform/cloud/file_block_cache.cc
+++ b/tensorflow/core/platform/cloud/file_block_cache.cc
@@ -181,7 +181,9 @@ Status FileBlockCache::Read(const string& filename, size_t offset, size_t n,
// The requested offset is at or beyond the end of the file. This can
// happen if `offset` is not block-aligned, and the read returns the last
// block in the file, which does not extend all the way out to `offset`.
- return errors::OutOfRange("EOF at offset ", offset);
+ return errors::OutOfRange("EOF at offset ", offset, " in file ", filename,
+ " at position ", pos, "with data size ",
+ data.size());
}
auto begin = data.begin();
if (offset > pos) {
diff --git a/tensorflow/core/platform/cloud/gcs_file_system.cc b/tensorflow/core/platform/cloud/gcs_file_system.cc
index 54d38fe962..45e9b05092 100644
--- a/tensorflow/core/platform/cloud/gcs_file_system.cc
+++ b/tensorflow/core/platform/cloud/gcs_file_system.cc
@@ -697,6 +697,9 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
TF_RETURN_WITH_CONTEXT_IF_ERROR(request->Send(), " when reading gs://",
bucket, "/", object);
+ VLOG(1) << "Successful read of gs://" << bucket << "/" << object << " @ "
+ << offset << " of size: " << out->size();
+
if (out->size() < block_size()) {
// Check stat cache to see if we encountered an interrupted read.
FileStatistics stat;
@@ -706,6 +709,8 @@ Status GcsFileSystem::LoadBufferFromGCS(const string& filename, size_t offset,
"File contents are inconsistent for file: %s @ %lu.",
filename.c_str(), offset));
}
+ VLOG(2) << "Successful integrity check for: gs://" << bucket << "/"
+ << object << " @ " << offset;
}
}
@@ -868,6 +873,11 @@ Status GcsFileSystem::StatForObject(const string& fname, const string& bucket,
TF_RETURN_IF_ERROR(GetStringValue(root, "updated", &updated));
TF_RETURN_IF_ERROR(ParseRfc3339Time(updated, &(stat->mtime_nsec)));
+ VLOG(1) << "Stat of: gs://" << bucket << "/" << object << " -- "
+ << " length: " << stat->length
+ << "; mtime_nsec: " << stat->mtime_nsec
+ << "; updated: " << updated;
+
stat->is_directory = false;
return Status::OK();
};
diff --git a/tensorflow/core/profiler/g3doc/options.md b/tensorflow/core/profiler/g3doc/options.md
index 4c73e372e3..dd12f76d6f 100644
--- a/tensorflow/core/profiler/g3doc/options.md
+++ b/tensorflow/core/profiler/g3doc/options.md
@@ -60,11 +60,14 @@ Currently, profiler only tracks the allocation of memory. As a result, the
accumulated memory request is uaually larger than the peak memory of the overall
model.
-bytes: The memory allocations requested by the operation.
-peak_bytes: The peak requested memory (not de-allocated) by the operation.
-residual_bytes: The memory requested by the operation and not de-allocated
+It's recommended to generate timeline to see the allocator memory usage over
+time.
+
+`bytes`: The memory allocations requested by the operation.
+`peak_bytes`: The peak requested memory (not de-allocated) by the operation.
+`residual_bytes`: The memory requested by the operation and not de-allocated
when Compute finishes.
-output_bytes: The memory output by the operation. It's not necessarily requested
+`output_bytes`: The memory output by the operation. It's not necessarily requested
by the current operation. For example, it can be a tensor
forwarded from input to output, with in-place mutation.
diff --git a/tensorflow/core/profiler/internal/tfprof_node.cc b/tensorflow/core/profiler/internal/tfprof_node.cc
index 671b65d708..5cd1050bcc 100644
--- a/tensorflow/core/profiler/internal/tfprof_node.cc
+++ b/tensorflow/core/profiler/internal/tfprof_node.cc
@@ -139,6 +139,25 @@ void ExecStep::AddMemoryStats(const string& dev,
exec_.accelerator_persistent_bytes() +
step_stat.memory_stats().device_persistent_memory_size());
}
+
+ // TODO(xpan): Make this more accurate:
+ // High level: Memory tracking is suspicous and requires large scale
+ // clean up.
+ // Investigte the memory usage difference between CPU/GPU with OpViewTest.
+ //
+ // 1. OpKernelConstruction::allocate_xxx is not traced. Below, we only
+ // discuss OpKernelContext-related allocations.
+ // 2. allocate_output calls allocate_tensor, which is properly tracked in
+ // 'NodeExecStats.memory'.
+ // 3. allocate_temp is only tracked through record_xxx_temp. It appears
+ // in 'NodeExecStats.memory_stats'.
+ // 4. allocate_persistent calls allocate_tensor, which is properly tracked
+ // in 'NodeExecStats.memory'. However, there is no way to count it as
+ // persistent now.
+ // 5. record_xxx_persistent is called when allocate_persistent
+ // is not used and hence tracks some complementary bytes. It appears in
+ // 'NodeExecStats.memory_stats'. It's suspicious. But we should
+ // use it now since it covers constant op.
int64 residual_bytes = 0;
int64 requested_bytes = 0;
int64 peak_bytes = 0;
@@ -147,6 +166,15 @@ void ExecStep::AddMemoryStats(const string& dev,
requested_bytes += mem.total_bytes();
peak_bytes += mem.peak_bytes();
}
+ residual_bytes +=
+ exec_.host_persistent_bytes() + exec_.accelerator_persistent_bytes();
+ requested_bytes += exec_.host_persistent_bytes() +
+ exec_.accelerator_persistent_bytes() +
+ exec_.host_temp_bytes() + exec_.accelerator_temp_bytes();
+ peak_bytes += exec_.host_persistent_bytes() +
+ exec_.accelerator_persistent_bytes() + exec_.host_temp_bytes() +
+ exec_.accelerator_temp_bytes();
+
exec_.set_requested_bytes(requested_bytes);
exec_.set_residual_bytes(residual_bytes);
exec_.set_peak_bytes(peak_bytes);
diff --git a/tensorflow/core/profiler/internal/tfprof_node.h b/tensorflow/core/profiler/internal/tfprof_node.h
index e2d0563a07..77c14cb792 100644
--- a/tensorflow/core/profiler/internal/tfprof_node.h
+++ b/tensorflow/core/profiler/internal/tfprof_node.h
@@ -593,17 +593,11 @@ class TFGraphNode {
int64 accelerator_persistent_bytes() const {
int64 persistent_bytes = 0;
for (const auto& exec : execs_) {
- persistent_bytes += exec.second.accelerator_persistent_bytes();
+ persistent_bytes = std::max(persistent_bytes,
+ exec.second.accelerator_persistent_bytes());
}
return persistent_bytes;
}
- int64 host_persistent_bytes(int64 step) const {
- auto exec = execs_.find(step);
- if (exec == execs_.end()) {
- return 0;
- }
- return exec->second.host_persistent_bytes();
- }
const std::map<int32, std::pair<int64, uint64>>& output_memory(
int64 step) const {
auto exec = execs_.find(step);
diff --git a/tensorflow/core/profiler/internal/tfprof_show_test.cc b/tensorflow/core/profiler/internal/tfprof_show_test.cc
index 1f19f8c322..98773ae19e 100644
--- a/tensorflow/core/profiler/internal/tfprof_show_test.cc
+++ b/tensorflow/core/profiler/internal/tfprof_show_test.cc
@@ -105,12 +105,13 @@ TEST_F(TFProfShowTest, DumpScopeMode) {
"node name | # parameters | # float_ops | requested bytes | peak bytes | "
"residual bytes | output bytes | total execution time | accelerator "
"execution time | cpu execution time\n_TFProfRoot (--/451 params, --/0 "
- "flops, --/0B, --/0B, --/0B, --/2.56KB, --/13us, --/0us, --/13us)\n DW "
- "(3x3x3x6, 162/162 params, 0/0 flops, 0B/0B, 0B/0B, 0B/0B, "
- "1.28KB/1.28KB, 2us/2us, 0us/0us, 2us/2us)\n DW2 (2x2x6x12, 288/288 "
- "params, 0/0 flops, 0B/0B, 0B/0B, 0B/0B, 1.28KB/1.28KB, 11us/11us, "
- "0us/0us, 11us/11us)\n ScalarW (1, 1/1 params, 0/0 flops, 0B/0B, 0B/0B, "
- "0B/0B, 0B/0B, 0us/0us, 0us/0us, 0us/0us)\n",
+ "flops, --/2.56KB, --/2.56KB, --/2.56KB, --/2.56KB, --/13us, --/0us, "
+ "--/13us)\n DW (3x3x3x6, 162/162 params, 0/0 flops, 1.28KB/1.28KB, "
+ "1.28KB/1.28KB, 1.28KB/1.28KB, 1.28KB/1.28KB, 2us/2us, 0us/0us, "
+ "2us/2us)\n DW2 (2x2x6x12, 288/288 params, 0/0 flops, 1.28KB/1.28KB, "
+ "1.28KB/1.28KB, 1.28KB/1.28KB, 1.28KB/1.28KB, 11us/11us, 0us/0us, "
+ "11us/11us)\n ScalarW (1, 1/1 params, 0/0 flops, 0B/0B, 0B/0B, 0B/0B, "
+ "0B/0B, 0us/0us, 0us/0us, 0us/0us)\n",
dump_str);
EXPECT_EQ(dump_str, TestToFromProto("scope", opts));
@@ -178,22 +179,22 @@ TEST_F(TFProfShowTest, DumpOpMode) {
EXPECT_EQ(
"nodename|requestedbytes|totalexecutiontime|acceleratorexecutiontime|"
"cpuexecutiontime|#parameters|#float_ops|opoccurrence(run|defined)|"
- "inputshapes\nVariableV20B(0.00%,0.00%),13us(100.00%,0.26%),0us(100.00%,"
- "0.00%),13us(100.00%,0.29%),451params(100.00%,100.00%),0float_ops(100.00%"
- ",0.00%),2|3\n\ninput_type:\t(run*2|defined*3)\texec_time:13us\n\nAdd0B("
- "0.00%,0.00%),0us(99.74%,0.00%),0us(100.00%,0.00%),0us(99.71%,0.00%),"
- "0params(0.00%,0.00%),0float_ops(100.00%,0.00%),0|3\n\ninput_type:0:1,"
- "\t1:1\t(run*0|defined*1)\texec_time:0us\ninput_type:0:2x2x6x12,\t1:1\t("
- "run*0|defined*1)\texec_time:0us\ninput_type:0:3x3x3x6,\t1:1\t(run*0|"
- "defined*1)\texec_time:0us\n\nAssign0B(0.00%,0.00%),0us(99.74%,0.00%),"
- "0us(100.00%,0.00%),0us(99.71%,0.00%),0params(0.00%,0.00%),0float_ops("
- "100.00%,0.00%),0|3\n\ninput_type:0:1,\t1:1\t(run*0|defined*1)\texec_"
+ "inputshapes\nVariableV22.56KB(100.00%,8.40%),13us(100.00%,0.26%),0us("
+ "100.00%,0.00%),13us(100.00%,0.29%),451params(100.00%,100.00%),0float_"
+ "ops(100.00%,0.00%),2|3\n\ninput_type:\t(run*2|defined*3)\texec_time:"
+ "13us\n\nAdd0B(0.00%,0.00%),0us(99.74%,0.00%),0us(100.00%,0.00%),0us(99."
+ "71%,0.00%),0params(0.00%,0.00%),0float_ops(100.00%,0.00%),0|3\n\ninput_"
+ "type:0:1,\t1:1\t(run*0|defined*1)\texec_time:0us\ninput_type:0:2x2x6x12,"
+ "\t1:1\t(run*0|defined*1)\texec_time:0us\ninput_type:0:3x3x3x6,\t1:1\t("
+ "run*0|defined*1)\texec_time:0us\n\nAssign0B(0.00%,0.00%),0us(99.74%,0."
+ "00%),0us(100.00%,0.00%),0us(99.71%,0.00%),0params(0.00%,0.00%),0float_"
+ "ops(100.00%,0.00%),0|3\n\ninput_type:0:1,\t1:1\t(run*0|defined*1)\texec_"
"time:0us\ninput_type:0:2x2x6x12,\t1:2x2x6x12\t(run*0|defined*1)\texec_"
"time:0us\ninput_type:0:3x3x3x6,\t1:3x3x3x6\t(run*0|defined*1)\texec_"
"time:0us\n\nConst0B(0.00%,0.00%),2us(99.74%,0.04%),0us(100.00%,0.00%),"
"2us(99.71%,0.04%),0params(0.00%,0.00%),0float_ops(100.00%,0.00%),1|"
- "10\n\ninput_type:\t(run*1|defined*10)\texec_time:2us\n\nConv2D14.59KB("
- "100.00%,100.00%),4.89ms(99.70%,98.87%),404us(100.00%,100.00%),4.49ms(99."
+ "10\n\ninput_type:\t(run*1|defined*10)\texec_time:2us\n\nConv2D27.90KB("
+ "91.60%,91.60%),4.89ms(99.70%,98.87%),404us(100.00%,100.00%),4.49ms(99."
"67%,98.77%),0params(0.00%,0.00%),10.44kfloat_ops(100.00%,100.00%),2|"
"2\n\ninput_type:0:2x3x3x6,\t1:2x2x6x12\t(run*1|defined*1)\texec_time:"
"597us\ninput_type:0:2x6x6x3,\t1:3x3x3x6\t(run*1|defined*1)\texec_time:4."
diff --git a/tensorflow/core/profiler/internal/tfprof_stats_test.cc b/tensorflow/core/profiler/internal/tfprof_stats_test.cc
index 2f2101d76b..b86a83cb1b 100644
--- a/tensorflow/core/profiler/internal/tfprof_stats_test.cc
+++ b/tensorflow/core/profiler/internal/tfprof_stats_test.cc
@@ -89,21 +89,27 @@ TEST_F(TFProfStatsTest, CustomOpType) {
GraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString(
- "name: \"_TFProfRoot\"\ntotal_exec_micros: 13\ntotal_parameters: "
- "451\nchildren {\n name: \"DW\"\n exec_micros: 2\n parameters: 162\n "
- "total_exec_micros: 2\n total_parameters: 162\n devices: "
+ "name: \"_TFProfRoot\"\ntotal_exec_micros: 13\ntotal_requested_bytes: "
+ "2560\ntotal_parameters: 451\nchildren {\n name: \"DW\"\n exec_micros: "
+ "2\n requested_bytes: 1280\n parameters: 162\n total_exec_micros: 2\n "
+ " total_requested_bytes: 1280\n total_parameters: 162\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n cpu_exec_micros: 2\n "
"total_cpu_exec_micros: 2\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n output_bytes: 1280\n total_output_bytes: "
- "1280\n}\nchildren {\n name: \"DW2\"\n exec_micros: 11\n parameters: "
- "288\n total_exec_micros: 11\n total_parameters: 288\n devices: "
+ "total_definition_count: 1\n peak_bytes: 1280\n residual_bytes: 1280\n "
+ " output_bytes: 1280\n total_peak_bytes: 1280\n total_residual_bytes: "
+ "1280\n total_output_bytes: 1280\n}\nchildren {\n name: \"DW2\"\n "
+ "exec_micros: 11\n requested_bytes: 1280\n parameters: 288\n "
+ "total_exec_micros: 11\n total_requested_bytes: 1280\n "
+ "total_parameters: 288\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n cpu_exec_micros: 11\n "
"total_cpu_exec_micros: 11\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n output_bytes: 1280\n total_output_bytes: "
- "1280\n}\nchildren {\n name: \"ScalarW\"\n parameters: 1\n "
- "total_parameters: 1\n total_definition_count: "
+ "total_definition_count: 1\n peak_bytes: 1280\n residual_bytes: 1280\n "
+ " output_bytes: 1280\n total_peak_bytes: 1280\n total_residual_bytes: "
+ "1280\n total_output_bytes: 1280\n}\nchildren {\n name: \"ScalarW\"\n "
+ "parameters: 1\n total_parameters: 1\n total_definition_count: "
"1\n}\ntotal_cpu_exec_micros: 13\ntotal_run_count: "
- "2\ntotal_definition_count: 3\ntotal_output_bytes: 2560\n",
+ "2\ntotal_definition_count: 3\ntotal_peak_bytes: "
+ "2560\ntotal_residual_bytes: 2560\ntotal_output_bytes: 2560\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
@@ -119,21 +125,27 @@ TEST_F(TFProfStatsTest, CheckPointOpType) {
GraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString(
- "name: \"_TFProfRoot\"\ntotal_exec_micros: 13\ntotal_parameters: "
- "451\nchildren {\n name: \"DW\"\n exec_micros: 2\n parameters: 162\n "
- "total_exec_micros: 2\n total_parameters: 162\n devices: "
+ "name: \"_TFProfRoot\"\ntotal_exec_micros: 13\ntotal_requested_bytes: "
+ "2560\ntotal_parameters: 451\nchildren {\n name: \"DW\"\n exec_micros: "
+ "2\n requested_bytes: 1280\n parameters: 162\n total_exec_micros: 2\n "
+ " total_requested_bytes: 1280\n total_parameters: 162\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n cpu_exec_micros: 2\n "
"total_cpu_exec_micros: 2\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n output_bytes: 1280\n total_output_bytes: "
- "1280\n}\nchildren {\n name: \"DW2\"\n exec_micros: 11\n parameters: "
- "288\n total_exec_micros: 11\n total_parameters: 288\n devices: "
+ "total_definition_count: 1\n peak_bytes: 1280\n residual_bytes: 1280\n "
+ " output_bytes: 1280\n total_peak_bytes: 1280\n total_residual_bytes: "
+ "1280\n total_output_bytes: 1280\n}\nchildren {\n name: \"DW2\"\n "
+ "exec_micros: 11\n requested_bytes: 1280\n parameters: 288\n "
+ "total_exec_micros: 11\n total_requested_bytes: 1280\n "
+ "total_parameters: 288\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n cpu_exec_micros: 11\n "
"total_cpu_exec_micros: 11\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n output_bytes: 1280\n total_output_bytes: "
- "1280\n}\nchildren {\n name: \"ScalarW\"\n parameters: 1\n "
- "total_parameters: 1\n total_definition_count: "
+ "total_definition_count: 1\n peak_bytes: 1280\n residual_bytes: 1280\n "
+ " output_bytes: 1280\n total_peak_bytes: 1280\n total_residual_bytes: "
+ "1280\n total_output_bytes: 1280\n}\nchildren {\n name: \"ScalarW\"\n "
+ "parameters: 1\n total_parameters: 1\n total_definition_count: "
"1\n}\ntotal_cpu_exec_micros: 13\ntotal_run_count: "
- "2\ntotal_definition_count: 3\ntotal_output_bytes: 2560\n",
+ "2\ntotal_definition_count: 3\ntotal_peak_bytes: "
+ "2560\ntotal_residual_bytes: 2560\ntotal_output_bytes: 2560\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
@@ -150,7 +162,7 @@ TEST_F(TFProfStatsTest, TestGraph) {
GraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\ntotal_exec_micros: 4945\ntotal_requested_bytes: "
- "14592\ntotal_parameters: 451\nchildren {\n name: "
+ "30464\ntotal_parameters: 451\nchildren {\n name: "
"\"DW/Initializer/random_normal/mul\"\n children {\n name: "
"\"DW/Initializer/random_normal/RandomStandardNormal\"\n children {\n "
" name: \"DW/Initializer/random_normal/shape\"\n "
@@ -166,7 +178,7 @@ TEST_F(TFProfStatsTest, TestGraph) {
"4\n}\ntotal_float_ops: 10440\ntotal_accelerator_exec_micros: "
"404\ntotal_cpu_exec_micros: 4541\ntotal_run_count: "
"6\ntotal_definition_count: 32\ntotal_peak_bytes: "
- "9984\ntotal_residual_bytes: 1280\ntotal_output_bytes: 4864\n",
+ "25856\ntotal_residual_bytes: 3840\ntotal_output_bytes: 4864\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
@@ -181,9 +193,9 @@ TEST_F(TFProfStatsTest, TestFloatOps) {
GraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\ntotal_exec_micros: 4945\ntotal_requested_bytes: "
- "14592\ntotal_parameters: 451\nchildren {\n name: \"Conv2D\"\n "
- "exec_micros: 4292\n requested_bytes: 9472\n total_exec_micros: 4292\n "
- " total_requested_bytes: 9472\n devices: "
+ "30464\ntotal_parameters: 451\nchildren {\n name: \"Conv2D\"\n "
+ "exec_micros: 4292\n requested_bytes: 18176\n total_exec_micros: "
+ "4292\n total_requested_bytes: 18176\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n float_ops: 5832\n "
"total_float_ops: 5832\n input_shapes {\n key: 0\n value {\n "
"dim {\n size: 2\n }\n dim {\n size: 6\n "
@@ -194,11 +206,11 @@ TEST_F(TFProfStatsTest, TestFloatOps) {
"6\n }\n }\n }\n accelerator_exec_micros: 226\n "
"cpu_exec_micros: 4066\n total_accelerator_exec_micros: 226\n "
"total_cpu_exec_micros: 4066\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n peak_bytes: 5888\n residual_bytes: 768\n "
- "output_bytes: 768\n total_peak_bytes: 5888\n total_residual_bytes: "
+ "total_definition_count: 1\n peak_bytes: 14592\n residual_bytes: 768\n "
+ " output_bytes: 768\n total_peak_bytes: 14592\n total_residual_bytes: "
"768\n total_output_bytes: 768\n}\nchildren {\n name: \"Conv2D_1\"\n "
- "exec_micros: 597\n requested_bytes: 5120\n total_exec_micros: 597\n "
- "total_requested_bytes: 5120\n devices: "
+ "exec_micros: 597\n requested_bytes: 9728\n total_exec_micros: 597\n "
+ "total_requested_bytes: 9728\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n float_ops: 4608\n "
"total_float_ops: 4608\n input_shapes {\n key: 0\n value {\n "
"dim {\n size: 2\n }\n dim {\n size: 3\n "
@@ -209,12 +221,12 @@ TEST_F(TFProfStatsTest, TestFloatOps) {
"12\n }\n }\n }\n accelerator_exec_micros: 178\n "
"cpu_exec_micros: 419\n total_accelerator_exec_micros: 178\n "
"total_cpu_exec_micros: 419\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n peak_bytes: 4096\n residual_bytes: 512\n "
- "output_bytes: 512\n total_peak_bytes: 4096\n total_residual_bytes: "
+ "total_definition_count: 1\n peak_bytes: 8704\n residual_bytes: 512\n "
+ "output_bytes: 512\n total_peak_bytes: 8704\n total_residual_bytes: "
"512\n total_output_bytes: 512\n}\ntotal_float_ops: "
"10440\ntotal_accelerator_exec_micros: 404\ntotal_cpu_exec_micros: "
"4541\ntotal_run_count: 6\ntotal_definition_count: 35\ntotal_peak_bytes: "
- "9984\ntotal_residual_bytes: 1280\ntotal_output_bytes: 4864\n",
+ "25856\ntotal_residual_bytes: 3840\ntotal_output_bytes: 4864\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
@@ -231,9 +243,9 @@ TEST_F(TFProfStatsTest, TestAccountShownNameOnly) {
GraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\ntotal_exec_micros: 597\ntotal_requested_bytes: "
- "5120\nchildren {\n name: \"Conv2D_1\"\n exec_micros: 597\n "
- "requested_bytes: 5120\n total_exec_micros: 597\n "
- "total_requested_bytes: 5120\n devices: "
+ "9728\nchildren {\n name: \"Conv2D_1\"\n exec_micros: 597\n "
+ "requested_bytes: 9728\n total_exec_micros: 597\n "
+ "total_requested_bytes: 9728\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n float_ops: 4608\n "
"total_float_ops: 4608\n input_shapes {\n key: 0\n value {\n "
"dim {\n size: 2\n }\n dim {\n size: 3\n "
@@ -244,12 +256,12 @@ TEST_F(TFProfStatsTest, TestAccountShownNameOnly) {
"12\n }\n }\n }\n accelerator_exec_micros: 178\n "
"cpu_exec_micros: 419\n total_accelerator_exec_micros: 178\n "
"total_cpu_exec_micros: 419\n run_count: 1\n total_run_count: 1\n "
- "total_definition_count: 1\n peak_bytes: 4096\n residual_bytes: 512\n "
- "output_bytes: 512\n total_peak_bytes: 4096\n total_residual_bytes: "
+ "total_definition_count: 1\n peak_bytes: 8704\n residual_bytes: 512\n "
+ "output_bytes: 512\n total_peak_bytes: 8704\n total_residual_bytes: "
"512\n total_output_bytes: 512\n}\ntotal_float_ops: "
"4608\ntotal_accelerator_exec_micros: 178\ntotal_cpu_exec_micros: "
"419\ntotal_run_count: 1\ntotal_definition_count: 2\ntotal_peak_bytes: "
- "4096\ntotal_residual_bytes: 512\ntotal_output_bytes: 512\n",
+ "8704\ntotal_residual_bytes: 512\ntotal_output_bytes: 512\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
@@ -265,8 +277,9 @@ TEST_F(TFProfStatsTest, TestShowTensorValue) {
GraphNodeProto expected;
CHECK(protobuf::TextFormat::ParseFromString(
"name: \"_TFProfRoot\"\ntotal_exec_micros: 4945\ntotal_requested_bytes: "
- "14592\ntotal_parameters: 451\nchildren {\n name: \"DW\"\n "
- "exec_micros: 2\n parameters: 162\n total_exec_micros: 2\n "
+ "30464\ntotal_parameters: 451\nchildren {\n name: \"DW\"\n "
+ "exec_micros: 2\n requested_bytes: 1280\n parameters: 162\n "
+ "total_exec_micros: 2\n total_requested_bytes: 1280\n "
"total_parameters: 162\n devices: "
"\"/job:localhost/replica:0/task:0/gpu:0\"\n tensor_value {\n dtype: "
"DT_FLOAT\n value_double: -0.000534315\n value_double: "
@@ -351,11 +364,13 @@ TEST_F(TFProfStatsTest, TestShowTensorValue) {
"value_double: 0.000374641\n value_double: -0.00149603\n "
"value_double: -0.000317367\n value_double: -0.000417829\n }\n "
"cpu_exec_micros: 2\n total_cpu_exec_micros: 2\n run_count: 1\n "
- "total_run_count: 1\n total_definition_count: 10\n output_bytes: "
- "1280\n total_output_bytes: 1280\n}\ntotal_float_ops: "
- "10440\ntotal_accelerator_exec_micros: 404\ntotal_cpu_exec_micros: "
- "4541\ntotal_run_count: 6\ntotal_definition_count: 35\ntotal_peak_bytes: "
- "9984\ntotal_residual_bytes: 1280\ntotal_output_bytes: 4864\n",
+ "total_run_count: 1\n total_definition_count: 10\n peak_bytes: 1280\n "
+ "residual_bytes: 1280\n output_bytes: 1280\n total_peak_bytes: 1280\n "
+ "total_residual_bytes: 1280\n total_output_bytes: "
+ "1280\n}\ntotal_float_ops: 10440\ntotal_accelerator_exec_micros: "
+ "404\ntotal_cpu_exec_micros: 4541\ntotal_run_count: "
+ "6\ntotal_definition_count: 35\ntotal_peak_bytes: "
+ "25856\ntotal_residual_bytes: 3840\ntotal_output_bytes: 4864\n",
&expected));
EXPECT_EQ(expected.DebugString(), root.DebugString());
}
diff --git a/tensorflow/core/profiler/tfprof_log.proto b/tensorflow/core/profiler/tfprof_log.proto
index f92301133a..b49bdf64ac 100644
--- a/tensorflow/core/profiler/tfprof_log.proto
+++ b/tensorflow/core/profiler/tfprof_log.proto
@@ -124,9 +124,10 @@ message ExecProfile {
int64 residual_bytes = 9;
// Total bytes output by the op (not necessarily requested by the op).
int64 output_bytes = 10;
- // Total temporary bytes allocated and released by the op.
+ // NOTE: Please don't depend on the following 4 fields yet. Due to
+ // TensorFlow internal tracing issues, the numbers can be quite wrong.
+ // TODO(xpan): Fix the TensorFlow internal tracing.
int64 host_temp_bytes = 11;
- // Total persistent bytes (e.g. variable) allocated by the op.
int64 host_persistent_bytes = 12;
int64 accelerator_temp_bytes = 13;
int64 accelerator_persistent_bytes = 14;
diff --git a/tensorflow/docs_src/api_guides/python/reading_data.md b/tensorflow/docs_src/api_guides/python/reading_data.md
index b3ebaa0f0a..4594887349 100644
--- a/tensorflow/docs_src/api_guides/python/reading_data.md
+++ b/tensorflow/docs_src/api_guides/python/reading_data.md
@@ -1,11 +1,11 @@
# Reading data
Note: The preferred way to feed data into a tensorflow program is using the
-@{$datasets$Datasets API}.
+@{$datasets$`tf.data` API}.
There are four methods of getting data into a TensorFlow program:
-* `Dataset` API: Easily construct a complex input pipeline. (preferred method)
+* `tf.data` API: Easily construct a complex input pipeline. (preferred method)
* Feeding: Python code provides the data when running each step.
* `QueueRunner`: a queue-based input pipeline reads the data from files
at the beginning of a TensorFlow graph.
@@ -14,26 +14,27 @@ There are four methods of getting data into a TensorFlow program:
[TOC]
-## Dataset API
+## `tf.data` API
See the @{$datasets$programmer's guide} for an in-depth explanation of
-@{tf.data.Dataset}. The `Dataset` API allows you to extract and preprocess data
-from different input/file formats, and apply transformations such as batch,
-shuffle, and map to the dataset. This is an improved version of the old input
-methods, feeding and `QueueRunner`.
+@{tf.data.Dataset}. The `tf.data` API enables you to extract and preprocess data
+from different input/file formats, and apply transformations such as batching,
+shuffling, and mapping functions over the dataset. This is an improved version
+of the old input methods---feeding and `QueueRunner`---which are described
+below for historical purposes.
## Feeding
+Warning: "Feeding" is the least efficient way to feed data into a TensorFlow
+program and should only be used for small experiments and debugging.
+
TensorFlow's feed mechanism lets you inject data into any Tensor in a
-computation graph. A python computation can thus feed data directly into the
+computation graph. A Python computation can thus feed data directly into the
graph.
Supply feed data through the `feed_dict` argument to a run() or eval() call
that initiates computation.
-Warning: "Feeding" is the least efficient way to feed data into a tensorflow
-program and should only be used for small experiments and debugging.
-
```python
with tf.Session():
input = tf.placeholder(tf.float32)
@@ -55,6 +56,10 @@ and is described in the @{$mechanics$MNIST tutorial}.
## `QueueRunner`
+Warning: This section discusses implementing input pipelines using the
+queue-based APIs which can be cleanly replaced by the @{$datasets$`tf.data`
+API}.
+
A typical queue-based pipeline for reading records from files has the following stages:
1. The list of filenames
@@ -66,9 +71,6 @@ A typical queue-based pipeline for reading records from files has the following
7. *Optional* preprocessing
8. Example queue
-Warning: This section discusses implementing input pipelines using the
-queue-based APIs which can be cleanly replaced by the @{$datasets$Datasets API}.
-
### Filenames, shuffling, and epoch limits
For the list of filenames, use either a constant string Tensor (like
@@ -499,7 +501,7 @@ You can have the train and eval in the same graph in the same process, and share
their trained variables or layers. See @{$variables$the shared variables tutorial}.
To support the single-graph approach
-@{$programmers_guide/datasets$Datasets} also supplies
+@{$programmers_guide/datasets$`tf.data`} also supplies
@{$programmers_guide/datasets#creating_an_iterator$advanced iterator types} that
that allow the user to change the input pipeline without rebuilding the graph or
session.
diff --git a/tensorflow/docs_src/get_started/custom_estimators.md b/tensorflow/docs_src/get_started/custom_estimators.md
new file mode 100644
index 0000000000..e347aa6bd0
--- /dev/null
+++ b/tensorflow/docs_src/get_started/custom_estimators.md
@@ -0,0 +1,576 @@
+
+# Creating Custom Estimators
+This document introduces custom Estimators. In particular, this document
+demonstrates how to create a custom @{tf.estimator.Estimator$Estimator} that
+mimics the behavior of the pre-made Estimator
+@{tf.estimator.DNNClassifier$`DNNClassifier`} in solving the Iris problem. See
+the @{$get_started/estimator$Pre-Made Estimators chapter} for details.
+
+If you are feeling impatient, feel free to compare and contrast the following
+full programs:
+
+* Iris implemented with the [pre-made DNNClassifier Estimator](https://github.com/tensorflow/models/blob/master/samples/core/get_started/premade_estimator.py).
+* Iris implemented with a [custom Estimator](https://github.com/tensorflow/models/blob/master/samples/core/get_started/custom_estimator.py).
+
+## Pre-made vs. custom
+
+As the following figure shows, pre-made Estimators are subclasses of the
+@{tf.estimator.Estimator} base class, while custom Estimators are an instance
+of tf.estimator.Estimator:
+
+<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%"
+ alt="Premade estimators are sub-classes of `Estimator`. Custom Estimators are usually (direct) instances of `Estimator`"
+ src="../images/custom_estimators/estimator_types.png">
+</div>
+<div style="text-align: center">
+Pre-made and custom Estimators are all Estimators.
+</div>
+
+Pre-made Estimators are fully baked. Sometimes though, you need more control
+over an Estimator's behavior. That's where custom Estimators come in. You can
+create a custom Estimator to do just about anything. If you want hidden layers
+connected in some unusual fashion, write a custom Estimator. If you want to
+calculate a unique
+[metric](https://developers.google.com/machine-learning/glossary/#metric)
+for your model, write a custom Estimator. Basically, if you want an Estimator
+optimized for your specific problem, write a custom Estimator.
+
+A model function (or `model_fn`) implements the ML algorithm. The
+only difference between working with pre-made Estimators and custom Estimators
+is:
+
+* With pre-made Estimators, someone already wrote the model function for you.
+* With custom Estimators, you must write the model function.
+
+Your model function could implement a wide range of algorithms, defining all
+sorts of hidden layers and metrics. Like input functions, all model functions
+must accept a standard group of input parameters and return a standard group of
+output values. Just as input functions can leverage the Dataset API, model
+functions can leverage the Layers API and the Metrics API.
+
+Let's see how to solve the Iris problem with a custom Estimator. A quick
+reminder--here's the organization of the Iris model that we're trying to mimic:
+
+<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="height:260px"
+ alt="A diagram of the network architecture: Inputs, 2 hidden layers, and outputs"
+ src="../images/custom_estimators/full_network.png">
+</div>
+<div style="text-align: center">
+Our implementation of Iris contains four features, two hidden layers,
+and a logits output layer.
+</div>
+
+## Write an Input function
+
+In our custom Estimator implementation, we'll reuse the input function we used
+in the pre-made Estimator implementation. Namely:
+
+```python
+def train_input_fn(features, labels, batch_size):
+ """An input function for training"""
+ # Convert the inputs to a Dataset.
+ dataset = tf.data.Dataset.from_tensor_slices((features, labels))
+
+ # Shuffle, repeat, and batch the examples.
+ dataset = dataset.shuffle(1000).repeat().batch(batch_size)
+
+ # Return the read end of the pipeline.
+ return dataset.make_one_shot_iterator().get_next()
+```
+
+This input function builds an input pipeline that yields batches of
+`(features, labels)` pairs, where `features` is a dictionary features.
+
+## Create feature columns
+
+<!-- TODO(markdaoust): link to feature_columns when it exists-->
+As detailed in @{$get_started/estimator$Premade Estimators}, you must define
+your model's feature columns to specify how the model should use each feature.
+Whether working with pre-made Estimators or custom Estimators, you define
+feature columns in the same fashion.
+
+The following code creates a simple `numeric_column` for each input feature,
+indicating that the value of the input feature should be used directly as an
+input to the model:
+
+```python
+# Feature columns describe how to use the input.
+my_feature_columns = []
+for key in train_x.keys():
+ my_feature_columns.append(tf.feature_column.numeric_column(key=key))
+```
+
+## Write a model function
+
+The model function we'll use has the following call signature:
+
+```python
+def my_model_fn(
+ features, # This is batch_features from input_fn
+ labels, # This is batch_labels from input_fn
+ mode, # An instance of tf.estimator.ModeKeys
+ params): # Additional configuration
+```
+
+The first two arguments are the batches of features and labels returned from
+the input function; that is, `features` and `labels` are the handles to the
+data your model will use. The `mode` argument indicates whether the caller is
+requesting training, predicting, or evaluation.
+
+The caller may pass `params` to an Estimator's constructor. The `params` passed
+to the constructor become the `params` passed to `model_fn`.
+
+```python
+ # Build 2 hidden layer DNN with 10, 10 units respectively.
+ classifier = tf.estimator.Estimator(
+ model_fn=my_model,
+ params={
+ 'feature_columns': my_feature_columns,
+ # Two hidden layers of 10 nodes each.
+ 'hidden_units': [10, 10],
+ # The model must choose between 3 classes.
+ 'n_classes': 3,
+ })
+```
+
+To implement a typical model function, you must do the following:
+
+* (Define the model)[#define_the_model].
+* Specify additional calculations for each of
+ the [three different modes](#modes):
+ * [Predict](#predict)
+ * [Evaluate](#evaluate)
+ * [Train](#train)
+
+## Define the model
+
+The basic deep neural network model must define the following three sections:
+
+* An [input layer](https://developers.google.com/machine-learning/glossary/#input_layer)
+* One or more [hidden layers](https://developers.google.com/machine-learning/glossary/#hidden_layer)
+* An [output layer](https://developers.google.com/machine-learning/glossary/#output_layer)
+
+### Define the input layer
+
+Call @{tf.feature_column.input_layer} to convert your feature dictionary and
+feature columns into input for your model. For example:
+
+```python
+ # Use `input_layer` to apply the feature columns.
+ net = tf.feature_column.input_layer(features, params['feature_columns'])
+```
+
+The preceding line applies the transformations defined by your feature columns,
+creating the input layer of our model.
+
+<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="height:260px"
+ alt="A diagram of the input layer, in this case a 1:1 mapping from raw-inputs to features."
+ src="../images/custom_estimators/input_layer.png">
+</div>
+
+
+### Hidden Layers
+
+If you are creating a deep neural network, you must define one or more hidden
+layers. The Layers API provides a rich set of functions to define all types of
+hidden layers, including convolutional, pooling, and dropout layers. For Iris,
+we're simply going to call @{tf.layers.dense} to create hidden layers, with
+dimensions defined by `params['hidden_layers']`. In a `dense` layer each node
+is connected to every node in the preceding layer. Here's the relevant code:
+
+``` python
+ # Build the hidden layers, sized according to the 'hidden_units' param.
+ for units in params['hidden_units']:
+ net = tf.layers.dense(net, units=units, activation=tf.nn.relu)
+```
+* The `units` parameter defines the number of output neurons in a given layer.
+* The `activation` parameter defines the [activation function](https://developers.google.com/machine-learning/glossary/#a) —
+ [Relu](https://developers.google.com/machine-learning/glossary/#ReLU) in this
+ case.
+
+The variable `net` here signifies the current top layer of the network. During
+the first iteration, `net` signifies the input layer. On each loop iteration
+`tf.layers.dense` creates a new layer, which takes the previous layer as its
+input. So, the loop uses `net` to pass the previously created layer as input
+to the layer being created.
+
+After creating two hidden layers, our network looks as follows. For
+simplicity, the figure only shows four hidden units in each layer.
+
+<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="height:260px"
+ alt="The input layer with two hidden layers added."
+ src="../images/custom_estimators/add_hidden_layer.png">
+</div>
+
+Note that @{tf.layers.dense} provides many additional capabilities, including
+the ability to set a multitude of regularization parameters. For the sake of
+simplicity, though, we're going to simply accept the default values of the
+other parameters.
+
+### Output Layer
+
+We'll define the output layer by calling @{tf.layers.dense} yet again, this
+time without an activation function:
+
+```python
+ # Compute logits (1 per class).
+ logits = tf.layers.dense(net, params['n_classes'], activation=None)
+```
+
+Here, `net` signifies the final hidden layer. Therefore, the full set of layers
+is now connected as follows:
+
+<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="height:260px"
+ alt="A logit output layer connected to the top hidden layer"
+ src="../images/custom_estimators/add_logits.png">
+</div>
+<div style="text-align: center">
+The final hidden layer feeds into the output layer.
+</div>
+
+When defining an output layer, the `units` parameter specifies the number of
+outputs. So, by setting `units` to `params['n_classes']`, the model produces
+one output value per class. Each element of the output vector will contains the
+score, or "logit", calculated to the associated class of Iris: Setosa,
+Versicolor, or Virginica, respectively.
+
+Later on, these logits will be transformed into probabilities by the
+@{tf.nn.softmax} function.
+
+## Implement training, evaluation, and prediction {modes}
+
+The final step in creating a model function is to write branching code that
+implements prediction, evaluation, and training.
+
+The model function gets invoked whenever someone calls the Estimator's `train`,
+`evaluate`, or `predict` methods. Recall that the signature for the model
+function looks like this:
+
+``` python
+def my_model_fn(
+ features, # This is batch_features from input_fn
+ labels, # This is batch_labels from input_fn
+ mode): # An instance of tf.estimator.ModeKeys, see below
+```
+
+Focus on that third argument, mode. As the following table shows, when someone
+calls train, evaluate, or predict, the Estimator framework invokes your model
+function with the mode parameter set as follows:
+
+| Estimator method | Estimator Mode |
+|:---------------------------------|:------------------|
+|@{tf.estimator.Estimator.train$`train()`} |@{tf.estimator.ModeKeys.TRAIN$`ModeKeys.TRAIN`} |
+|@{tf.estimator.Estimator.evaluate$`evaluate()`} |@{tf.estimator.ModeKeys.EVAL$`ModeKeys.EVAL`} |
+|@{tf.estimator.Estimator.predict$`predict()`}|@{tf.estimator.ModeKeys.PREDICT$`ModeKeys.PREDICT`} |
+
+For example, suppose you instantiate a custom Estimator to generate an object
+named `classifier`. Then, you make the following call:
+
+``` python
+classifier = tf.estimator.Estimator(...)
+classifier.train(input_fn=lambda: my_input_fn(FILE_TRAIN, True, 500))
+```
+The Estimator framework then calls your model function with mode set to
+`ModeKeys.TRAIN`.
+
+Your model function must provide code to handle all three of the mode values.
+For each mode value, your code must return an instance of
+`tf.estimator.EstimatorSpec`, which contains the information the caller
+requires. Let's examine each mode.
+
+### Predict
+
+When the Estimator's `predict` method is called, the `model_fn` receives
+`mode = ModeKeys.PREDICT`. In this case, the model function must return a
+`tf.estimator.EstimatorSpec` containing the prediction.
+
+The model must have been trained prior to making a prediction. The trained model
+is stored on disk in the `model_dir` directory established when you
+instantiated the Estimator.
+
+The code to generate the prediction for this model looks as follows:
+
+```python
+# Compute predictions.
+predicted_classes = tf.argmax(logits, 1)
+if mode == tf.estimator.ModeKeys.PREDICT:
+ predictions = {
+ 'class_ids': predicted_classes[:, tf.newaxis],
+ 'probabilities': tf.nn.softmax(logits),
+ 'logits': logits,
+ }
+ return tf.estimator.EstimatorSpec(mode, predictions=predictions)
+```
+The prediction dictionary contains everything that your model returns when run
+in prediction mode.
+
+<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="height:260px"
+ alt="Additional outputs added to the output layer."
+ src="../images/custom_estimators/full_network.png">
+</div>
+
+The `predictions` holds the following three key/value pairs:
+
+* `class_ids` holds the class id (0, 1, or 2) representing the model's
+ prediction of the most likely species for this example.
+* `probabilities` holds the three probabilities (in this example, 0.02, 0.95,
+ and 0.03)
+* `logit` holds the raw logit values (in this example, -1.3, 2.6, and -0.9)
+
+We return that dictionary to the caller via the `predictions` parameter of the
+@{tf.estimator.EstimatorSpec}. The Estimator's
+@{tf.estimator.Estimator.predict$`predict`} method will yield these
+dictionaries.
+
+### Calculate the loss
+
+For both [training](#train) and [evaluation](#evaluate) we need to calculate the
+model's loss. This is the
+[objective](https://developers.google.com/machine-learning/glossary/#objective)
+that will be optimized.
+
+Before we calculate loss, we we must first convert the labels from a list of
+indexes `(0, 1, 2)` to a
+[one-hot representation](https://developers.google.com/machine-learning/glossary/#one-hot_encoding)
+by calling @{tf.one_hot}. Then, we can calculate the loss by calling
+@{tf.losses.softmax_cross_entropy}. Here's the complete code:
+
+
+```python
+ # Convert the labels to a one-hot tensor of shape (length of features, 3)
+ # and with a on-value of 1 for each one-hot vector of length 3.
+ onehot_labels = tf.one_hot(labels, 3, 1, 0)
+
+ # Compute loss.
+ loss = tf.losses.softmax_cross_entropy(
+ onehot_labels=onehot_labels, logits=logits)
+```
+
+### Evaluate
+
+When the Estimator's `evaluate` method is called, the `model_fn` receives
+`mode = ModeKeys.EVAL`. In this case, the model function must return a
+`tf.estimator.EstimatorSpec` containing the model's loss and optionally one
+or more metrics.
+
+Although returning metrics is optional, most custom Estimators do return at
+least one metric. TensorFlow provides a Metrics module @{tf.metrics} to
+calculate common metrics. For brevity's sake, we'll only return accuracy. The
+@{tf.metrics.accuracy} function compares our predictions against the
+true values, that is, against the labels provided by the input function. The
+@{tf.metrics.accuracy} function requires the labels and predictions to have the
+same shape. Here's the call to @{tf.metrics.accuracy}:
+
+``` python
+ # Compute evaluation metrics.
+ accuracy = tf.metrics.accuracy(labels=labels,
+ predictions=predicted_classes,
+ name='acc_op')
+```
+
+The @{tf.estimator.EstimatorSpec$`EstimatorSpec`} returned for evaluation
+typically contains the following information:
+
+* `loss`, which is the model's loss
+* `eval_metric_ops`, which is an optional dictionary of metrics.
+
+So, we'll create a dictionary containing our sole metric. If we had calculated
+other metrics, we would have added them as additional key/value pairs to that
+same dictionary. Then, we'll pass that dictionary in the `eval_metric_ops`
+argument of `tf.estimator.EstimatorSpec`. Here's the code:
+
+```python
+ metrics = {'accuracy': accuracy}
+ tf.summary.scalar('accuracy', accuracy[1])
+
+ if mode == tf.estimator.ModeKeys.EVAL:
+ return tf.estimator.EstimatorSpec(
+ mode, loss=loss, eval_metric_ops=metrics)
+```
+
+The @{tf.summary.scalar} will make accuracy available to TensorBoard (more on
+this later).
+
+### Train
+
+When the Estimator's `train` method is called, the `model_fn` is called
+with `mode = ModeKeys.TRAIN`. In this case, the model function must return an
+`EstimatorSpec` that contains the loss and a training operation.
+
+Building the training operation will require an optimizer. We will use
+@{tf.train.AdagradOptimizer} because we're mimicking the `DNNClassifier`, which
+also uses `Adagrad` by default. The `tf.train` package provides many other
+optimizers—feel free to experiment with them.
+
+Here is the code that builds the optimizer:
+
+``` python
+ # Instantiate an optimizer.
+ optimizer = tf.train.AdagradOptimizer(learning_rate=0.1)
+```
+
+Next, we train the model using the optimizer's
+@{tf.train.Optimizer.minimize$`minimize`} method on the loss we calculated
+earlier.
+
+The `minimize` method also takes a `global_step` parameter. TensorFlow uses this
+parameter to count the number of training steps that have been processed
+(to know when to end a training run). Furthermore, the `global_step` is
+essential for TensorBoard graphs to work correctly. Simply call
+@{tf.train.get_global_step} and pass the result to the `global_step`
+argument of `minimize`.
+
+Here's the code to train the model:
+
+``` python
+ # Train the model by establishing an objective, which is to
+ # minimize loss using that optimizer.
+ train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
+```
+
+The @{tf.estimator.EstimatorSpec$`EstimatorSpec`} returned for training
+must have the following fields set:
+
+* `loss`, which contains the value of the loss function.
+* `train_op`, which executes a training step.
+
+Here's our code to call `EstimatorSpec`:
+
+```python
+ # Return training information.
+ return tf.estimator.EstimatorSpec(
+ mode=tf.estimator.ModeKeys.TRAIN,
+ loss=loss,
+ train_op=train_op)
+```
+
+The model function is now complete.
+
+## The custom Estimator
+
+Instantiate the custom Estimator through the Estimator base class as follows:
+
+```python
+ # Build 2 hidden layer DNN with 10, 10 units respectively.
+ classifier = tf.estimator.Estimator(
+ model_fn=my_model,
+ params={
+ 'feature_columns': my_feature_columns,
+ # Two hidden layers of 10 nodes each.
+ 'hidden_units': [10, 10],
+ # The model must choose between 3 classes.
+ 'n_classes': 3,
+ })
+```
+Here the `params` dictionary serves the same purpose as the key-word
+arguments of `DNNClassifier`; that is, the `params` dictionary lets you
+configure your Estimator without modifying the code in the `model_fn`.
+
+The rest of the code to train, evaluate, and generate predictions using our
+Estimator is the same as for the pre-made `DNNClassifier`. For example, the
+following line will train the model:
+
+```python
+ # Train the Model.
+ classifier.train(
+ input_fn=lambda:train_input_fn(train_x, train_y, args.batch_size),
+ steps=args.train_steps)
+```
+
+## TensorBoard
+
+You can view training results for your custom Estimator in TensorBoard. To see
+this reporting, start TensorBoard from your command line as follows:
+
+```bsh
+# Replace PATH with the actual path passed as model_dir
+tensorboard --logdir=PATH
+```
+
+Then, open TensorBoard by browsing to: [http://localhost:6006](http://localhost:6006)
+
+All the pre-made Estimators automatically log a lot of information to
+TensorBoard. With custom Estimators, however, TensorBoard only provides one
+default log (a graph of the loss) plus the information you explicitly tell
+TensorBoard to log. For the custom Estimator you just created, TensorBoard
+generates the following:
+
+<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="height:260px"
+ alt="Accuracy, steps/second, and loss 'scalar' graphs from tensorboard"
+ src="../images/custom_estimators/tensorboard.png">
+</div>
+<div style="text-align: center">
+TensorBoard displays three graphs.
+</div>
+
+In brief, here's what the three graphs tell you:
+
+* global_step/sec: A performance indicator showing how many batches (gradient
+ updates) we processed per second as the model trains.
+
+* loss: The loss reported.
+
+* accuracy: The accuracy is recorded by the following two lines:
+
+ * `eval_metric_ops={'my_accuracy': accuracy})`, during evaluation.
+ * `tf.summary.scalar('accuracy', accuracy[1])`, during training.
+
+These tensorboard graphs are one of the main reasons it's important to pass a
+`global_step` to your optimizer's `minimize` method. The model can't record
+the x-coordinate for these graphs without it.
+
+Note the following in the `my_accuracy` and `loss` graphs:
+
+* The orange line represents training.
+* The blue dot represents evaluation.
+
+During training, summaries (the orange line) are recorded periodically as
+batches are processed, which is why it becomes a graph spanning x-axis range.
+
+By contrast, evaluation produces only a single point on the graph for each call
+to `evaluate`. This point contains the average over the entire evaluation call.
+This has no width on the graph as it is evaluated entirely from the model state
+at a particular training step (from a single checkpoint).
+
+As suggested in the following figure, you may see and also selectively
+disable/enable the reporting using the controls on the left side.
+
+<div style="width:100%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="margin:auto;display:block;"
+ alt="Check-boxes allowing the user to select which runs are shown."
+ src="../images/custom_estimators/select_run.jpg">
+</div>
+<div style="text-align: center">
+Enable or disable reporting.
+</div>
+
+
+## Summary
+
+Although pre-made Estimators can be an effective way to quickly create new
+models, you will often need the additional flexibility that custom Estimators
+provide. Fortunately, pre-made and custom Estimators follow the same
+programming model. The only practical difference is that you must write a model
+function for custom Estimators; everything else is the same.
+
+For more details, be sure to check out:
+
+* The
+[official TensorFlow implementation of MNIST](https://github.com/tensorflow/models/tree/master/official/mnist),
+which uses a custom estimator.
+
+* The TensorFlow
+[official models repository](https://github.com/tensorflow/models/tree/master/official),
+which contains more curated examples using custom estimators.
+
+* This [TensorBoard video](https://youtu.be/eBbEDRsCmv4), which introduces
+TensorBoard.
+
+
diff --git a/tensorflow/docs_src/get_started/feature_columns.md b/tensorflow/docs_src/get_started/feature_columns.md
new file mode 100644
index 0000000000..f9537927b7
--- /dev/null
+++ b/tensorflow/docs_src/get_started/feature_columns.md
@@ -0,0 +1,570 @@
+# Feature Columns
+
+This document details feature columns. Think of **feature columns** as the
+intermediaries between raw data and Estimators. Feature columns are very rich,
+enabling you to transform a diverse range of raw data into formats that
+Estimators can use, allowing easy experimentation.
+
+In @{$get_started/estimator$Premade Estimators}, we used the premade Estimator,
+@{tf.estimator.DNNClassifier$`DNNClassifier`} to train a model to predict
+different types of Iris flowers from four input features. That example created
+only numerical feature columns (of type @{tf.feature_column.numeric_column}).
+Although numerical feature columns model the lengths of petals and sepals
+effectively, real world data sets contain all kinds of features, many of which
+are non-numerical.
+
+<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/feature_columns/feature_cloud.jpg">
+</div>
+<div style="text-align: center">
+Some real-world features (such as, longitude) are numerical, but many are not.
+</div>
+
+## Input to a Deep Neural Network
+
+What kind of data can a deep neural network operate on? The answer
+is, of course, numbers (for example, `tf.float32`). After all, every neuron in
+a neural network performs multiplication and addition operations on weights and
+input data. Real-life input data, however, often contains non-numerical
+(categorical) data. For example, consider a `product_class` feature that can
+contain the following three non-numerical values:
+
+* `kitchenware`
+* `electronics`
+* `sports`
+
+ML models generally represent categorical values as simple vectors in which a
+1 represents the presence of a value and a 0 represents the absence of a value.
+For example, when `product_class` is set to `sports`, an ML model would usually
+represent `product_class` as `[0, 0, 1]`, meaning:
+
+* `0`: `kitchenware` is absent
+* `0`: `electronics` is absent
+* `1`: `sports` is present
+
+So, although raw data can be numerical or categorical, an ML model represents
+all features as numbers.
+
+## Feature Columns
+
+As the following figure suggests, you specify the input to a model through the
+`feature_columns` argument of an Estimator (`DNNClassifier` for Iris).
+Feature Columns bridge input data (as returned by `input_fn`) with your model.
+
+<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/feature_columns/inputs_to_model_bridge.jpg">
+</div>
+<div style="text-align: center">
+Feature columns bridge raw data with the data your model needs.
+</div>
+
+To create feature columns, call functions from the
+@{tf.feature_column} module. This document explains nine of the functions in
+that module. As the following figure shows, all nine functions return either a
+Categorical-Column or a Dense-Column object, except `bucketized_column`, which
+inherits from both classes:
+
+<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/feature_columns/some_constructors.jpg">
+</div>
+<div style="text-align: center">
+Feature column methods fall into two main categories and one hybrid category.
+</div>
+
+Let's look at these functions in more detail.
+
+### Numeric column
+
+The Iris classifier calls the @{tf.feature_column.numeric_column} function for
+all input features:
+
+ * `SepalLength`
+ * `SepalWidth`
+ * `PetalLength`
+ * `PetalWidth`
+
+Although `tf.numeric_column` provides optional arguments, calling
+`tf.numeric_column` without any arguments, as follows, is a fine way to specify
+a numerical value with the default data type (`tf.float32`) as input to your
+model:
+
+```python
+# Defaults to a tf.float32 scalar.
+numeric_feature_column = tf.feature_column.numeric_column(key="SepalLength")
+```
+
+To specify a non-default numerical data type, use the `dtype` argument. For
+example:
+
+``` python
+# Represent a tf.float64 scalar.
+numeric_feature_column = tf.feature_column.numeric_column(key="SepalLength",
+ dtype=tf.float64)
+```
+
+By default, a numeric column creates a single value (scalar). Use the shape
+argument to specify another shape. For example:
+
+<!--TODO(markdaoust) link to full example-->
+```python
+# Represent a 10-element vector in which each cell contains a tf.float32.
+vector_feature_column = tf.feature_column.numeric_column(key="Bowling",
+ shape=10)
+
+# Represent a 10x5 matrix in which each cell contains a tf.float32.
+matrix_feature_column = tf.feature_column.numeric_column(key="MyMatrix",
+ shape=[10,5])
+```
+### Bucketized column
+
+Often, you don't want to feed a number directly into the model, but instead
+split its value into different categories based on numerical ranges. To do so,
+create a @{tf.feature_column.bucketized_column$bucketized column}. For
+example, consider raw data that represents the year a house was built. Instead
+of representing that year as a scalar numeric column, we could split the year
+into the following four buckets:
+
+<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/feature_columns/bucketized_column.jpg">
+</div>
+<div style="text-align: center">
+Dividing year data into four buckets.
+</div>
+
+The model will represent the buckets as follows:
+
+|Date Range |Represented as... |
+|:----------|:-----------------|
+|< 1960 | [1, 0, 0, 0] |
+|>= 1960 but < 1980 | [0, 1, 0, 0] |
+|>= 1980 but < 2000 | [0, 0, 1, 0] |
+|> 2000 | [0, 0, 0, 1] |
+
+Why would you want to split a number—a perfectly valid input to your
+model—into a categorical value? Well, notice that the categorization splits a
+single input number into a four-element vector. Therefore, the model now can
+learn _four individual weights_ rather than just one; four weights creates a
+richer model than one weight. More importantly, bucketizing enables the model
+to clearly distinguish between different year categories since only one of the
+elements is set (1) and the other three elements are cleared (0). When we just
+use a single number (a year) as input, the model can only learn a linear
+relationship. So, bucketing provides the model with additional flexibility that
+the model can use to learn.
+
+The following code demonstrates how to create a bucketized feature:
+
+<!--TODO(markdaoust) link to full example - housing price grid?-->
+```python
+# First, convert the raw input to a numeric column.
+numeric_feature_column = tf.feature_column.numeric_column("Year")
+
+# Then, bucketize the numeric column on the years 1960, 1980, and 2000.
+bucketized_feature_column = tf.feature_column.bucketized_column(
+ source_column = numeric_feature_column,
+ boundaries = [1960, 1980, 2000])
+```
+Note that specifying a _three_-element boundaries vector creates a
+_four_-element bucketized vector.
+
+
+### Categorical identity column
+
+**Categorical identity columns** can be seen as a special case of bucketized
+columns. In traditional bucketized columns, each bucket represents a range of
+values (for example, from 1960 to 1979). In a categorical identity column, each
+bucket represents a single, unique integer. For example, let's say you want to
+represent the integer range `[0, 4)`. That is, you want to represent the
+integers 0, 1, 2, or 3. In this case, the categorical identity mapping looks
+like this:
+
+<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/feature_columns/categorical_column_with_identity.jpg">
+</div>
+<div style="text-align: center">
+A categorical identity column mapping. Note that this is a one-hot
+encoding, not a binary numerical encoding.
+</div>
+
+As with bucketized columns, a model can learn a separate weight for each class
+in a categorical identity column. For example, instead of using a string to
+represent the `product_class`, let's represent each class with a unique integer
+value. That is:
+
+* `0="kitchenware"`
+* `1="electronics"`
+* `2="sport"`
+
+Call @{tf.feature_column.categorical_column_with_identity} to implement a
+categorical identity column. For example:
+
+``` python
+# Create categorical output for an integer feature named "my_feature_b",
+# The values of my_feature_b must be >= 0 and < num_buckets
+identity_feature_column = tf.feature_column.categorical_column_with_identity(
+ key='my_feature_b',
+ num_buckets=4) # Values [0, 4)
+
+# In order for the preceding call to work, the input_fn() must return
+# a dictionary containing 'my_feature_b' as a key. Furthermore, the values
+# assigned to 'my_feature_b' must belong to the set [0, 4).
+def input_fn():
+ ...
+ return ({ 'my_feature_a':[7, 9, 5, 2], 'my_feature_b':[3, 1, 2, 2] },
+ [Label_values])
+```
+
+### Categorical vocabulary column
+
+We cannot input strings directly to a model. Instead, we must first map strings
+to numeric or categorical values. Categorical vocabulary columns provide a good
+way to represent strings as a one-hot vector. For example:
+
+<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/feature_columns/categorical_column_with_vocabulary.jpg">
+</div>
+<div style="text-align: center">
+Mapping string values to vocabulary columns.
+</div>
+
+As you can see, categorical vocabulary columns are kind of an enum version of
+categorical identity columns. TensorFlow provides two different functions to
+create categorical vocabulary columns:
+
+* @{tf.feature_column.categorical_column_with_vocabulary_list}
+* @{tf.feature_column.categorical_column_with_vocabulary_file}
+
+`categorical_column_with_vocabulary_list` maps each string to an integer based
+on an explicit vocabulary list. For example:
+
+```python
+# Given input "feature_name_from_input_fn" which is a string,
+# create a categorical feature by mapping the input to one of
+# the elements in the vocabulary list.
+vocabulary_feature_column =
+ tf.feature_column.categorical_column_with_vocabulary_list(
+ key="a feature returned by input_fn()",
+ vocabulary_list=["kitchenware", "electronics", "sports"])
+```
+
+The preceding function is pretty straightforward, but it has a significant
+drawback. Namely, there's way too much typing when the vocabulary list is long.
+For these cases, call
+`tf.feature_column.categorical_column_with_vocabulary_file` instead, which lets
+you place the vocabulary words in a separate file. For example:
+
+```python
+
+# Given input "feature_name_from_input_fn" which is a string,
+# create a categorical feature to our model by mapping the input to one of
+# the elements in the vocabulary file
+vocabulary_feature_column =
+ tf.feature_column.categorical_column_with_vocabulary_file(
+ key="a feature returned by input_fn()",
+ vocabulary_file="product_class.txt",
+ vocabulary_size=3)
+```
+
+`product_class.txt` should contain one line for each vocabulary element. In our
+case:
+
+```None
+kitchenware
+electronics
+sports
+```
+
+### Hashed Column
+
+So far, we've worked with a naively small number of categories. For example,
+our product_class example has only 3 categories. Often though, the number of
+categories can be so big that it's not possible to have individual categories
+for each vocabulary word or integer because that would consume too much memory.
+For these cases, we can instead turn the question around and ask, "How many
+categories am I willing to have for my input?" In fact, the
+@{tf.feature_column.categorical_column_with_hash_bucket} function enables you
+to specify the number of categories. For this type of feature column the model
+calculates a hash value of the input, then puts it into one of
+the `hash_bucket_size` categories using the modulo operator, as in the following
+pseudocode:
+
+```python
+# pseudocode
+feature_id = hash(raw_feature) % hash_buckets_size
+```
+
+The code to create the `feature_column` might look something like this:
+
+``` python
+hashed_feature_column =
+ tf.feature_column.categorical_column_with_hash_bucket(
+ key = "some_feature",
+ hash_buckets_size = 100) # The number of categories
+```
+At this point, you might rightfully think: "This is crazy!" After all, we are
+forcing the different input values to a smaller set of categories. This means
+that two probably unrelated inputs will be mapped to the same
+category, and consequently mean the same thing to the neural network. The
+following figure illustrates this dilemma, showing that kitchenware and sports
+both get assigned to category (hash bucket) 12:
+
+<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/feature_columns/hashed_column.jpg">
+</div>
+<div style="text-align: center">
+Representing data with hash buckets.
+</div>
+
+As with many counterintuitive phenomena in machine learning, it turns out that
+hashing often works well in practice. That's because hash categories provide
+the model with some separation. The model can use additional features to further
+separate kitchenware from sports.
+
+### Crossed column
+
+Combining features into a single feature, better known as
+[feature crosses](https://developers.google.com/machine-learning/glossary/#feature_cross),
+enables the model to learn separate weights for each combination of
+features.
+
+More concretely, suppose we want our model to calculate real estate prices in
+Atlanta, GA. Real-estate prices within this city vary greatly depending on
+location. Representing latitude and longitude as separate features isn't very
+useful in identifying real-estate location dependencies; however, crossing
+latitude and longitude into a single feature can pinpoint locations. Suppose we
+represent Atlanta as a grid of 100x100 rectangular sections, identifying each
+of the 10,000 sections by a feature cross of latitude and longitude. This
+feature cross enables the model to train on pricing conditions related to each
+individual section, which is a much stronger signal than latitude and longitude
+alone.
+
+The following figure shows our plan, with the latitude & longitude values for
+the corners of the city in red text:
+
+<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/feature_columns/Atlanta.jpg">
+</div>
+<div style="text-align: center">
+Map of Atlanta. Imagine this map divided into 10,000 sections of
+equal size.
+</div>
+
+For the solution, we used a combination of the `bucketized_column` we looked at
+earlier, with the @{tf.feature_column.crossed_column} function.
+
+<!--TODO(markdaoust) link to full example-->
+
+``` python
+def make_dataset(latitude, longitude, labels):
+ assert latitude.shape == longitude.shape == labels.shape
+
+ features = {'latitude': latitude.flatten(),
+ 'longitude': longitude.flatten()}
+ labels=labels.flatten()
+
+ return tf.data.Dataset.from_tensor_slices((features, labels))
+
+
+# Bucketize the latitude and longitude usig the `edges`
+latitude_bucket_fc = tf.feature_column.bucketized_column(
+ tf.feature_column.numeric_column('latitude'),
+ list(atlanta.latitude.edges))
+
+longitude_bucket_fc = tf.feature_column.bucketized_column(
+ tf.feature_column.numeric_column('longitude'),
+ list(atlanta.longitude.edges))
+
+# Cross the bucketized columns, using 5000 hash bins.
+crossed_lat_lon_fc = tf.feature_column.crossed_column(
+ [latitude_bucket_fc, longitude_bucket_fc], 5000)
+
+fc = [
+ latitude_bucket_fc,
+ longitude_bucket_fc,
+ crossed_lat_lon_fc]
+
+# Build and train the Estimator.
+est = tf.estimator.LinearRegressor(fc, ...)
+```
+
+You may create a feature cross from either of the following:
+
+* Feature names; that is, names from the `dict` returned from `input_fn`.
+* Any categorical column, except `categorical_column_with_hash_bucket`
+ (since `crossed_column` hashes the input).
+
+When the feature columns `latitude_bucket_fc` and `longitude_bucket_fc` are
+crossed, TensorFlow will create `(latitude_fc, longitude_fc)` pairs for each
+example. This would produce a full grid of possibilities as follows:
+
+``` None
+ (0,0), (0,1)... (0,99)
+ (1,0), (1,1)... (1,99)
+ ... ... ...
+(99,0), (99,1)...(99, 99)
+```
+
+Except that a full grid would only be tractable for inputs with limited
+vocabularies. Instead of building this, potentially huge, table of inputs,
+the `crossed_column` only builds the number requested by the `hash_bucket_size`
+argument. The feature column assigns an example to a index by running a hash
+function on the tuple of inputs, followed by a modulo operation with
+`hash_bucket_size`.
+
+As discussed earlier, performing the
+hash and modulo function limits the number of categories, but can cause category
+collisions; that is, multiple (latitude, longitude) feature crosses will end
+up in the same hash bucket. In practice though, performing feature crosses
+still adds significant value to the learning capability of your models.
+
+Somewhat counterintuitively, when creating feature crosses, you typically still
+should include the original (uncrossed) features in your model (as in the
+preceding code snippet). The independent latitude and longitude features help the
+model distinguish between examples where a hash collision has occured in the
+crossed feature.
+
+## Indicator and embedding columns
+
+Indicator columns and embedding columns never work on features directly, but
+instead take categorical columns as input.
+
+When using an indicator column, we're telling TensorFlow to do exactly what
+we've seen in our categorical product_class example. That is, an
+**indicator column** treats each category as an element in a one-hot vector,
+where the matching category has value 1 and the rest have 0s:
+
+<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/feature_columns/categorical_column_with_identity.jpg">
+</div>
+<div style="text-align: center">
+Representing data in indicator columns.
+</div>
+
+Here's how you create an indicator column by calling
+@{tf.feature_column.indicator_column}:
+
+``` python
+categorical_column = ... # Create any type of categorical column.
+
+# Represent the categorical column as an indicator column.
+indicator_column = tf.feature_column.indicator_column(categorical_column)
+```
+
+Now, suppose instead of having just three possible classes, we have a million.
+Or maybe a billion. For a number of reasons, as the number of categories grow
+large, it becomes infeasible to train a neural network using indicator columns.
+
+We can use an embedding column to overcome this limitation. Instead of
+representing the data as a one-hot vector of many dimensions, an
+**embedding column** represents that data as a lower-dimensional, ordinary
+vector in which each cell can contain any number, not just 0 or 1. By
+permitting a richer palette of numbers for every cell, an embedding column
+contains far fewer cells than an indicator column.
+
+Let's look at an example comparing indicator and embedding columns. Suppose our
+input examples consists of different words from a limited palette of only 81
+words. Further suppose that the data set provides provides the following input
+words in 4 separate examples:
+
+* `"dog"`
+* `"spoon"`
+* `"scissors"`
+* `"guitar"`
+
+In that case, the following figure illustrates the processing path for
+embedding columns or indicator columns.
+
+<div style="width:80%; margin:auto; margin-bottom:10px; margin-top:20px;">
+<img style="width:100%" src="../images/feature_columns/embedding_vs_indicator.jpg">
+</div>
+<div style="text-align: center">
+An embedding column stores categorical data in a lower-dimensional
+vector than an indicator column. (We just placed random numbers into the
+embedding vectors; training determines the actual numbers.)
+</div>
+
+When an example is processed, one of the `categorical_column_with...` functions
+maps the example string to a numerical categorical value. For example, a
+function maps "spoon" to `[32]`. (The 32 comes from our imagination—the actual
+values depend on the mapping function.) You may then represent these numerical
+categorical values in either of the following two ways:
+
+* As an indicator column. A function converts each numeric categorical value
+ into an 81-element vector (because our palette consists of 81 words), placing
+ a 1 in the index of the categorical value (0, 32, 79, 80) and a 0 in all the
+ other positions.
+
+* As an embedding column. A function uses the numerical categorical values
+ `(0, 32, 79, 80)` as indices to a lookup table. Each slot in that lookup table
+ contains a 3-element vector.
+
+How do the values in the embeddings vectors magically get assigned? Actually,
+the assignments happen during training. That is, the model learns the best way
+to map your input numeric categorical values to the embeddings vector value in
+order to solve your problem. Embedding columns increase your model's
+capabilities, since an embeddings vector learns new relationships between
+categories from the training data.
+
+Why is the embedding vector size 3 in our example? Well, the following "formula"
+provides a general rule of thumb about the number of embedding dimensions:
+
+```python
+embedding_dimensions = number_of_categories**0.25
+```
+
+That is, the embedding vector dimension should be the 4th root of the number of
+categories. Since our vocabulary size in this example is 81, the recommended
+number of dimensions is 3:
+
+``` python
+3 = 81**0.25
+```
+Note that this is just a general guideline; you can set the number of embedding
+dimensions as you please.
+
+Call @{tf.feature_column.embedding_column} to create an `embedding_column` as
+suggested by the following snippet:
+
+``` python
+categorical_column = ... # Create any categorical column
+
+# Represent the categorical column as an embedding column.
+# This means creating a one-hot vector with one element for each category.
+embedding_column = tf.feature_column.embedding_column(
+ categorical_column=categorical_column,
+ dimension=dimension_of_embedding_vector)
+```
+
+@{$programmers_guide/embedding$Embeddings} is a significant topic within machine
+learning. This information was just to get you started using them as feature
+columns.
+
+## Passing feature columns to Estimators
+
+As the following list indicates, not all Estimators permit all types of
+`feature_columns` argument(s):
+
+* @{tf.estimator.LinearClassifier$`LinearClassifier`} and
+ @{tf.estimator.LinearRegressor$`LinearRegressor`}: Accept all types of
+ feature column.
+* @{tf.estimator.DNNClassifier$`DNNClassifier`} and
+ @{tf.estimator.DNNRegressor$`DNNRegressor`}: Only accept dense columns. Other
+ column types must be wrapped in either an `indicator_column` or
+ `embedding_column`.
+* @{tf.estimator.DNNLinearCombinedClassifier$`DNNLinearCombinedClassifier`} and
+ @{tf.estimator.DNNLinearCombinedRegressor$`DNNLinearCombinedRegressor`}:
+ * The `linear_feature_columns` argument accepts any feature column type.
+ * The `dnn_feature_columns` argument only accepts dense columns.
+
+## Other Sources
+
+For more examples on feature columns, view the following:
+
+* The @{$wide_and_deep$Wide & Deep Tutorial}
+* [Examples](https://github.com/tensorflow/models/tree/master/samples/cookbook/regression)
+ of DNNs and linear models that use feature columns.
+
+To learn more about embeddings, see the following:
+
+* [Deep Learning, NLP, and representations](http://colah.github.io/posts/2014-07-NLP-RNNs-Representations/)
+ (Chris Olah's blog)
+* The TensorFlow [Embedding Projector](http://projector.tensorflow.org)
diff --git a/tensorflow/docs_src/performance/xla/operation_semantics.md b/tensorflow/docs_src/performance/xla/operation_semantics.md
index 217f542caa..a49973d550 100644
--- a/tensorflow/docs_src/performance/xla/operation_semantics.md
+++ b/tensorflow/docs_src/performance/xla/operation_semantics.md
@@ -511,6 +511,87 @@ contracted dimensions of `lhs` and `rhs` must be of the same size. In practice,
it can be used to perform dot products between vectors, vector/matrix
multiplications or matrix/matrix multiplications.
+## DotGeneral
+
+See also
+[`ComputationBuilder::DotGeneral`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/computation_builder.h).
+
+<b> `DotGeneral(lhs, rhs, dimension_numbers)` </b>
+
+| Arguments | Type | Semantics
+| --------- | ----------------------- | ---------------
+| `lhs` | `ComputationDataHandle` | array of type T
+| `rhs` | `ComputationDataHandle` | array of type T
+| `dimension_numbers` | `DotDimensionNumbers` | array of type T
+
+As Dot, but allows contracting and batch dimension numbers to be specified for
+both the 'lhs' and 'rhs'.
+
+| DotDimensionNumbers Fields | Type | Semantics
+| --------- | ----------------------- | ---------------
+| 'lhs_contracting_dimensions' | repeated int64 | 'lhs' contracting dimension numbers |
+| 'rhs_contracting_dimensions' | repeated int64 | 'rhs' contracting dimension numbers |
+| 'lhs_batch_dimensions' | repeated int64 | 'lhs' batch dimension numbers |
+| 'rhs_batch_dimensions' | repeated int64 | 'rhs' batch dimension numbers |
+
+DotGeneral performs the sum of products over contracting dimensions specified
+in 'dimension_numbers'.
+
+Associated contracting dimension numbers from the 'lhs' and 'rhs' do not need
+to be the same, but must be listed in the same order in both
+'lhs/rhs_contracting_dimensions' arrays and have the same dimension sizes.
+
+Example with contracting dimension numbers:
+
+```
+lhs = { {1.0, 2.0, 3.0},
+ {4.0, 5.0, 6.0} }
+
+rhs = { {1.0, 1.0, 1.0},
+ {2.0, 2.0, 2.0} }
+
+DotDimensionNumbers dnums;
+dnums.add_lhs_contracting_dimensions(1);
+dnums.add_rhs_contracting_dimensions(1);
+
+DotGeneral(lhs, rhs, dnums) -> { {6.0, 12.0},
+ {15.0, 30.0} }
+```
+
+Associated batch dimension numbers from the 'lhs' and 'rhs' must have the same
+dimension number, must be listed in the same order in both arrays, and must
+have the same dimension sizes.
+
+Example with batch dimension numbers (batch size 2, 2x2 matrices):
+
+```
+lhs = { { {1.0, 2.0},
+ {3.0, 4.0} },
+ { {5.0, 6.0},
+ {7.0, 8.0} } }
+
+rhs = { { {1.0, 0.0},
+ {0.0, 1.0} },
+ { {1.0, 0.0},
+ {0.0, 1.0} } }
+
+DotDimensionNumbers dnums;
+dnums.add_lhs_contracting_dimensions(2);
+dnums.add_rhs_contracting_dimensions(1);
+dnums.add_lhs_batch_dimensions(0);
+dnums.add_rhs_batch_dimensions(0);
+
+DotGeneral(lhs, rhs, dnums) -> { { {1.0, 2.0},
+ {3.0, 4.0} },
+ { {5.0, 6.0},
+ {7.0, 8.0} } }
+```
+
+| Input | Output | Semantics |
+| ----------------------------------- | ----------------- | ---------------- |
+| [b0, m, k] `dot` [b0, k, n] | [b0, m, n] | batch matmul |
+| [b0, b1, m, k] `dot` [b0, b1, k, n] | [b0, b1, m, n] | batch matmul |
+
## Element-wise binary arithmetic operations
See also
diff --git a/tensorflow/docs_src/programmers_guide/datasets.md b/tensorflow/docs_src/programmers_guide/datasets.md
index 9ced56f0f5..c54b399c3a 100644
--- a/tensorflow/docs_src/programmers_guide/datasets.md
+++ b/tensorflow/docs_src/programmers_guide/datasets.md
@@ -1,16 +1,16 @@
# Importing Data
-The @{tf.data.Dataset$`Dataset`} API enables you to build complex input pipelines from
+The `tf.data` API enables you to build complex input pipelines from
simple, reusable pieces. For example, the pipeline for an image model might
aggregate data from files in a distributed file system, apply random
perturbations to each image, and merge randomly selected images into a batch
for training. The pipeline for a text model might involve extracting symbols
from raw text data, converting them to embedding identifiers with a lookup
-table, and batching together sequences of different lengths. The `Dataset` API
+table, and batching together sequences of different lengths. The `tf.data` API
makes it easy to deal with large amounts of data, different data formats, and
complicated transformations.
-The `Dataset` API introduces two new abstractions to TensorFlow:
+The `tf.data` API introduces two new abstractions to TensorFlow:
* A `tf.data.Dataset` represents a sequence of elements, in which
each element contains one or more `Tensor` objects. For example, in an image
@@ -121,7 +121,7 @@ dataset3 = dataset3.filter(lambda x, (y, z): ...)
### Creating an iterator
Once you have built a `Dataset` to represent your input data, the next step is to
-create an `Iterator` to access elements from that dataset. The `Dataset` API
+create an `Iterator` to access elements from that dataset. The `tf.data` API
currently supports the following iterators, in increasing level of
sophistication:
@@ -379,7 +379,7 @@ sess.run(iterator.initializer, feed_dict={features_placeholder: features,
### Consuming TFRecord data
-The `Dataset` API supports a variety of file formats so that you can process
+The `tf.data` API supports a variety of file formats so that you can process
large datasets that do not fit in memory. For example, the TFRecord file format
is a simple record-oriented binary format that many TensorFlow applications use
for training data. The `tf.data.TFRecordDataset` class enables you to
@@ -628,7 +628,7 @@ TODO(mrry): Add this section.
### Processing multiple epochs
-The `Dataset` API offers two main ways to process multiple epochs of the same
+The `tf.data` API offers two main ways to process multiple epochs of the same
data.
The simplest way to iterate over a dataset in multiple epochs is to use the
@@ -693,7 +693,7 @@ dataset = dataset.repeat()
The @{tf.train.MonitoredTrainingSession} API simplifies many aspects of running
TensorFlow in a distributed setting. `MonitoredTrainingSession` uses the
@{tf.errors.OutOfRangeError} to signal that training has completed, so to use it
-with the `Dataset` API, we recommend using
+with the `tf.data` API, we recommend using
`Dataset.make_one_shot_iterator()`. For example:
```python
diff --git a/tensorflow/examples/android/README.md b/tensorflow/examples/android/README.md
index 79202a38d7..881a975e60 100644
--- a/tensorflow/examples/android/README.md
+++ b/tensorflow/examples/android/README.md
@@ -126,6 +126,10 @@ the Android NDK and SDK must be installed on your system.
2. The Android NDK is required to build the native (C/C++) TensorFlow code. The
current recommended version is 14b, which may be found
[here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads).
+
+ * NDK 16, the revision released in November 2017, is **incompatible** with
+ Bazel. See [here](https://github.com/tensorflow/tensorflow/issues/14918).
+
3. The Android SDK and build tools may be obtained
[here](https://developer.android.com/tools/revisions/build-tools.html), or
alternatively as part of [Android
@@ -133,6 +137,10 @@ the Android NDK and SDK must be installed on your system.
23 is required to build the TF Android demo (though it will run on API >= 21
devices).
+ - The Android Studio SDK Manager's NDK installer will install the latest
+ revision of the NDK, which is **incompatible** with Bazel. You'll need
+ to download an older version manually, as (2) suggests.
+
##### Edit WORKSPACE
The Android entries in
diff --git a/tensorflow/examples/how_tos/reading_data/convert_to_records.py b/tensorflow/examples/how_tos/reading_data/convert_to_records.py
index a402eac053..c89e839563 100644
--- a/tensorflow/examples/how_tos/reading_data/convert_to_records.py
+++ b/tensorflow/examples/how_tos/reading_data/convert_to_records.py
@@ -55,12 +55,15 @@ def convert_to(data_set, name):
with tf.python_io.TFRecordWriter(filename) as writer:
for index in range(num_examples):
image_raw = images[index].tostring()
- example = tf.train.Example(features=tf.train.Features(feature={
- 'height': _int64_feature(rows),
- 'width': _int64_feature(cols),
- 'depth': _int64_feature(depth),
- 'label': _int64_feature(int(labels[index])),
- 'image_raw': _bytes_feature(image_raw)}))
+ example = tf.train.Example(
+ features=tf.train.Features(
+ feature={
+ 'height': _int64_feature(rows),
+ 'width': _int64_feature(cols),
+ 'depth': _int64_feature(depth),
+ 'label': _int64_feature(int(labels[index])),
+ 'image_raw': _bytes_feature(image_raw)
+ }))
writer.write(example.SerializeToString())
diff --git a/tensorflow/examples/speech_commands/train.py b/tensorflow/examples/speech_commands/train.py
index f46d5e59b4..f5bf04305a 100644
--- a/tensorflow/examples/speech_commands/train.py
+++ b/tensorflow/examples/speech_commands/train.py
@@ -156,7 +156,8 @@ def main(_):
predicted_indices = tf.argmax(logits, 1)
expected_indices = tf.argmax(ground_truth_input, 1)
correct_prediction = tf.equal(predicted_indices, expected_indices)
- confusion_matrix = tf.confusion_matrix(expected_indices, predicted_indices, num_classes=label_count)
+ confusion_matrix = tf.confusion_matrix(
+ expected_indices, predicted_indices, num_classes=label_count)
evaluation_step = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
tf.summary.scalar('accuracy', evaluation_step)
diff --git a/tensorflow/go/graph.go b/tensorflow/go/graph.go
index 46c600eab1..f200a8e00a 100644
--- a/tensorflow/go/graph.go
+++ b/tensorflow/go/graph.go
@@ -20,6 +20,24 @@ package tensorflow
//
// #include <stdlib.h>
// #include <string.h>
+//
+// void TF_SetAttrShapeList_Helper(TF_OperationDescription* desc,
+// const char* attr_name,
+// const int64_t* flat_dims,
+// const int* num_dims,
+// int num_shapes) {
+// const int64_t** dims =
+// (const int64_t**)malloc(sizeof(const int64_t*) * num_shapes);
+// for (int i = 0; i < num_shapes; i++) {
+// dims[i] = flat_dims;
+// if (num_dims[i] > 0) {
+// // flat_dims will be NULL iff num_shapes is 0 or all elements in num_dims are <= 0.
+// flat_dims += num_dims[i];
+// }
+// }
+// TF_SetAttrShapeList(desc, attr_name, dims, num_dims, num_shapes);
+// free(dims);
+// }
import "C"
import (
@@ -289,41 +307,37 @@ func setAttr(cdesc *C.TF_OperationDescription, status *status, name string, valu
return fmt.Errorf("bad value for attribute %q: %v", name, err)
}
case Shape:
- ndims, dims := cshape(value)
+ ndims := C.int(value.NumDimensions())
var dimsp *C.int64_t
if ndims > 0 {
+ dims := make([]C.int64_t, ndims)
+ for i, d := range value.dims {
+ dims[i] = C.int64_t(d)
+ }
dimsp = &dims[0]
}
C.TF_SetAttrShape(cdesc, cAttrName, dimsp, ndims)
case []Shape:
- ndims := make([]C.int, len(value))
- dims := make([][]C.int64_t, len(value))
- dimsp := make([]*C.int64_t, len(value))
- for i, s := range value {
- ndims[i], dims[i] = cshape(s)
- if ndims[i] > 0 {
- dimsp[i] = &dims[i][0]
- }
- }
- if len(value) > 0 {
- C.TF_SetAttrShapeList(cdesc, cAttrName, &dimsp[0], &ndims[0], C.int(len(value)))
- } else {
+ if len(value) == 0 {
C.TF_SetAttrShapeList(cdesc, cAttrName, nil, nil, 0)
+ } else {
+ var flatDims []C.int64_t
+ ndims := make([]C.int, len(value))
+ for i, s := range value {
+ nd := s.NumDimensions()
+ ndims[i] = C.int(nd)
+ for _, d := range s.dims {
+ flatDims = append(flatDims, C.int64_t(d))
+ }
+ }
+ var flatDimsp *C.int64_t
+ if len(flatDims) > 0 {
+ flatDimsp = &flatDims[0]
+ }
+ C.TF_SetAttrShapeList_Helper(cdesc, cAttrName, flatDimsp, &ndims[0], C.int(len(value)))
}
default:
return fmt.Errorf("attribute %q has a type (%T) which is not valid for operation attributes", name, value)
}
return nil
}
-
-func cshape(s Shape) (C.int, []C.int64_t) {
- ndims := C.int(s.NumDimensions())
- if ndims < 0 {
- return -1, nil
- }
- dims := make([]C.int64_t, ndims)
- for i, s := range s.dims {
- dims[i] = C.int64_t(s)
- }
- return ndims, dims
-}
diff --git a/tensorflow/go/op/op_test.go b/tensorflow/go/op/op_test.go
index 2451ba3606..842dee9ffe 100644
--- a/tensorflow/go/op/op_test.go
+++ b/tensorflow/go/op/op_test.go
@@ -58,3 +58,76 @@ func TestAddOperationFailure(t *testing.T) {
_ = resize.Shape()
t.Errorf("resize.Shape() should have paniced since the underlying Operation was not created")
}
+
+func TestShapeAttribute(t *testing.T) {
+ s := NewScope()
+ x := Placeholder(s.SubScope("x"), tf.Int32, PlaceholderShape(tf.MakeShape(1)))
+ y := Placeholder(s.SubScope("y"), tf.Int32, PlaceholderShape(tf.Shape{}))
+ z := Add(s, x, y)
+ graph, err := s.Finalize()
+ if err != nil {
+ t.Fatal(err)
+ }
+ sess, err := tf.NewSession(graph, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ value, err := tf.NewTensor([]int32{7})
+ if err != nil {
+ t.Fatal(err)
+ }
+ feeds := map[tf.Output]*tf.Tensor{
+ x: value,
+ y: value,
+ }
+ fetched, err := sess.Run(feeds, []tf.Output{z}, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if got, want := len(fetched), 1; got != want {
+ t.Fatalf("Fetched %d tensors, expected %d", got, want)
+ }
+ if got, want := fetched[0].Value().([]int32), []int32{14}; len(got) != len(want) || len(got) != 1 || got[0] != want[0] {
+ t.Fatalf("Got %v, want %v", got, want)
+ }
+}
+
+func TestDataset(t *testing.T) {
+ var (
+ s = NewScope()
+
+ // The use of a non-scalar here is inspired by
+ // https://github.com/tensorflow/tensorflow/issues/14891
+ c = Const(s, []int32{21718, 31415})
+ types = []tf.DataType{c.DataType()}
+ shapes = []tf.Shape{c.Shape()}
+ dataset = TensorDataset(s, []tf.Output{c}, shapes)
+
+ iterator = Iterator(s, "", "", types, shapes)
+ next = IteratorGetNext(s, iterator, types, shapes)
+ init = MakeIterator(s, dataset, iterator)
+ )
+ graph, err := s.Finalize()
+ if err != nil {
+ t.Fatal(err)
+ }
+ sess, err := tf.NewSession(graph, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := sess.Run(nil, nil, []*tf.Operation{init}); err != nil {
+ t.Fatal(err)
+ }
+ results, err := sess.Run(nil, next, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got := results[0].Value().([]int32)
+ if len(got) != 2 || got[0] != 21718 || got[1] != 31415 {
+ t.Errorf("Got %v, want {21718, 31415}", got)
+ }
+ if _, err := sess.Run(nil, next, nil); err == nil {
+ t.Errorf("Expected sess.Run() to fail since the iterator should have reached the end of the dataset")
+ }
+}
diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go
index cd6f4bc1f0..2d25c04dc9 100644
--- a/tensorflow/go/tensor.go
+++ b/tensorflow/go/tensor.go
@@ -270,7 +270,7 @@ func typeOf(dt DataType, shape []int64) reflect.Type {
}
}
if ret == nil {
- panic(bug("DataType %v is not supported", dt))
+ panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt))
}
for range shape {
ret = reflect.SliceOf(ret)
diff --git a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java
index beb3635585..a24150484e 100644
--- a/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java
+++ b/tensorflow/java/src/main/java/org/tensorflow/OperationBuilder.java
@@ -352,7 +352,8 @@ public final class OperationBuilder {
private static native void setAttrShape(long handle, String name, long[] shape, int numDims);
- private static native void setAttrShapeList(long handle, String name, long[] shapes, int[] numDims);
+ private static native void setAttrShapeList(
+ long handle, String name, long[] shapes, int[] numDims);
private static native void setAttrStringList(long handle, String name, Object[] value);
}
diff --git a/tensorflow/java/src/main/native/operation_builder_jni.cc b/tensorflow/java/src/main/native/operation_builder_jni.cc
index 71a451ad13..55d214a7c4 100644
--- a/tensorflow/java/src/main/native/operation_builder_jni.cc
+++ b/tensorflow/java/src/main/native/operation_builder_jni.cc
@@ -275,15 +275,15 @@ JNIEXPORT void JNICALL Java_org_tensorflow_OperationBuilder_setAttrShapeList(
if (num_dims_length > 0) {
const int shapes_length = env->GetArrayLength(shapes);
cshapes.reset(new int64_t[shapes_length]);
- cdims.reset(new int64_t* [num_dims_length]);
+ cdims.reset(new int64_t*[num_dims_length]);
cnum_dims.reset(new int[num_dims_length]);
jlong* shapes_elems =
- (jlong*) env->GetPrimitiveArrayCritical(shapes, nullptr);
+ static_cast<jlong*>(env->GetPrimitiveArrayCritical(shapes, nullptr));
std::memcpy(cshapes.get(), shapes_elems, shapes_length << 3);
env->ReleasePrimitiveArrayCritical(shapes, shapes_elems, JNI_ABORT);
int64_t* cshapes_ptr = cshapes.get();
jint* num_dims_elems =
- (jint*) env->GetPrimitiveArrayCritical(num_dims, nullptr);
+ static_cast<jint*>(env->GetPrimitiveArrayCritical(num_dims, nullptr));
for (int i = 0; i < num_dims_length; ++i) {
cnum_dims[i] = static_cast<int>(num_dims_elems[i]);
cdims[i] = cshapes_ptr;
diff --git a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java b/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java
index 2430816725..0a4a8cf4e3 100644
--- a/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java
+++ b/tensorflow/java/src/test/java/org/tensorflow/OperationBuilderTest.java
@@ -151,10 +151,10 @@ public class OperationBuilderTest {
@Test
public void setAttrShapeList() {
// Those shapes match tensors ones, so no exception is thrown
- testSetAttrShapeList(new Shape[] { Shape.make(2, 2), Shape.make(2, 2, 2) });
+ testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2)});
try {
// Those shapes do not match tensors ones, exception is thrown
- testSetAttrShapeList(new Shape[] { Shape.make(2, 2), Shape.make(2, 2, 2, 2) });
+ testSetAttrShapeList(new Shape[] {Shape.make(2, 2), Shape.make(2, 2, 2, 2)});
fail("Shapes are incompatible and an exception was expected");
} catch (IllegalArgumentException e) {
// expected
@@ -189,20 +189,23 @@ public class OperationBuilderTest {
}
private static void testSetAttrShapeList(Shape[] shapes) {
- try (Graph g = new Graph(); Session s = new Session(g)) {
- int[][] matrix = new int[][] { { 0, 0 }, { 0, 0 } };
- Output<?> queue = g.opBuilder("FIFOQueue", "queue")
- .setAttr("component_types", new DataType[] { DataType.INT32, DataType.INT32 })
- .setAttr("shapes", shapes)
- .build()
- .output(0);
+ try (Graph g = new Graph();
+ Session s = new Session(g)) {
+ int[][] matrix = new int[][] {{0, 0}, {0, 0}};
+ Output<?> queue =
+ g.opBuilder("FIFOQueue", "queue")
+ .setAttr("component_types", new DataType[] {DataType.INT32, DataType.INT32})
+ .setAttr("shapes", shapes)
+ .build()
+ .output(0);
assertTrue(hasNode(g, "queue"));
Output<Integer> c1 = TestUtil.constant(g, "const1", matrix);
- Output<Integer> c2 = TestUtil.constant(g, "const2", new int[][][] { matrix, matrix });
- Operation enqueue = g.opBuilder("QueueEnqueue", "enqueue")
- .addInput(queue)
- .addInputList(new Output<?>[] { c1, c2 })
- .build();
+ Output<Integer> c2 = TestUtil.constant(g, "const2", new int[][][] {matrix, matrix});
+ Operation enqueue =
+ g.opBuilder("QueueEnqueue", "enqueue")
+ .addInput(queue)
+ .addInputList(new Output<?>[] {c1, c2})
+ .build();
assertTrue(hasNode(g, "enqueue"));
s.runner().addTarget(enqueue).run();
diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD
index 23ad9bfa56..12d81c4383 100644
--- a/tensorflow/python/BUILD
+++ b/tensorflow/python/BUILD
@@ -268,6 +268,7 @@ cc_library(
deps = [
":ndarray_tensor_bridge",
":numpy_lib",
+ ":py_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
@@ -309,6 +310,7 @@ cc_library(
hdrs = ["lib/core/py_seq_tensor.h"],
deps = [
":numpy_lib",
+ ":py_util",
":safe_ptr",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
@@ -317,6 +319,17 @@ cc_library(
)
cc_library(
+ name = "py_util",
+ srcs = ["lib/core/py_util.cc"],
+ hdrs = ["lib/core/py_util.h"],
+ deps = [
+ "//tensorflow/core:lib",
+ "//tensorflow/core:script_ops_op_lib",
+ "//util/python:python_headers",
+ ],
+)
+
+cc_library(
name = "py_record_reader_lib",
srcs = ["lib/io/py_record_reader.cc"],
hdrs = ["lib/io/py_record_reader.h"],
diff --git a/tensorflow/python/client/session_test.py b/tensorflow/python/client/session_test.py
index f4b0271195..e4545d287b 100644
--- a/tensorflow/python/client/session_test.py
+++ b/tensorflow/python/client/session_test.py
@@ -28,6 +28,8 @@ import numpy as np
import six
from six.moves import xrange # pylint: disable=redefined-builtin
+from tensorflow.core.framework import attr_value_pb2
+from tensorflow.core.framework import types_pb2
from tensorflow.core.lib.core import error_codes_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
@@ -1742,5 +1744,136 @@ class SessionTest(test_util.TensorFlowTestCase):
self.runTestAddFunctionToSession(server.target)
+class GraphMutationTest(test_util.TensorFlowTestCase):
+
+ def testUpdateInputAfterRunning(self):
+ with ops.Graph().as_default() as g:
+ a = constant_op.constant(1.0)
+ b = constant_op.constant(2.0)
+ c = a + b
+
+ with session.Session(graph=g) as sess:
+ self.assertAllEqual(3.0, sess.run(c))
+ c.op._update_input(1, a) # pylint: disable=protected-access
+ with self.assertRaisesRegexp(
+ errors.FailedPreconditionError,
+ 'add.*was changed by updating input tensor after it was run'):
+ sess.run(c)
+
+ # Check that running the graph with a new session is fine
+ with session.Session(graph=g) as sess2:
+ self.assertAllEqual(2.0, sess2.run(c))
+
+ def testSetDeviceAfterRunning(self):
+ with ops.Graph().as_default() as g:
+ a = constant_op.constant(1.0)
+ b = constant_op.constant(2.0)
+ c = a + b
+
+ with session.Session(graph=g) as sess:
+ self.assertAllEqual(3.0, sess.run(c))
+ c.op._set_device('/cpu:0') # pylint: disable=protected-access
+ with self.assertRaisesRegexp(
+ errors.FailedPreconditionError,
+ 'add.*was changed by setting device after it was run'):
+ sess.run(c)
+
+ def testSetAttrAfterRunning(self):
+ with ops.Graph().as_default() as g:
+ a = constant_op.constant(1.0, dtype=dtypes.float32)
+ b = math_ops.cast(a, dtypes.float64)
+
+ with session.Session(graph=g) as sess:
+ self.assertAllEqual(1.0, sess.run(b))
+ b.op._set_attr('DstT',
+ attr_value_pb2.AttrValue(type=types_pb2.DT_FLOAT))
+ with self.assertRaisesRegexp(
+ errors.FailedPreconditionError,
+ 'Cast.*was changed by setting attribute after it was run'):
+ sess.run(b)
+
+ def testRunModifyRun(self):
+ with ops.Graph().as_default() as g:
+ a = constant_op.constant(1.0)
+ b = constant_op.constant(2.0)
+ c = a + b
+
+ with session.Session(graph=g) as sess:
+ self.assertAllEqual(3.0, sess.run(c))
+
+ d = b + c
+ d.op._update_input(0, a) # pylint: disable=protected-access
+ self.assertAllEqual(3.0, sess.run(c))
+ self.assertAllEqual(4.0, sess.run(d))
+
+ def testRunModifyRunTwoSessions(self):
+ with ops.Graph().as_default() as g:
+ a = constant_op.constant(1.0)
+ b = constant_op.constant(2.0)
+ c = a + b
+
+ with session.Session(graph=g) as sess1:
+ with session.Session(graph=g) as sess2:
+ self.assertAllEqual(3.0, sess1.run(c))
+ self.assertAllEqual(3.0, sess2.run(c))
+
+ d = b + c
+ d.op._update_input(0, a) # pylint: disable=protected-access
+ self.assertAllEqual(3.0, sess2.run(c))
+ self.assertAllEqual(4.0, sess2.run(d))
+
+ d.op._update_input(0, b) # pylint: disable=protected-access
+ self.assertAllEqual(3.0, sess1.run(c))
+ self.assertAllEqual(5.0, sess1.run(d))
+
+ with self.assertRaisesRegexp(
+ errors.FailedPreconditionError,
+ 'add.*was changed by updating input tensor after it was run'):
+ sess2.run(c)
+
+ def testTwoSessionsOneRunBeforeModification(self):
+ with ops.Graph().as_default() as g, ops.device('/cpu:0'):
+ a = constant_op.constant(1.0)
+ b = constant_op.constant(2.0)
+ c = a + b
+
+ with session.Session(graph=g) as sess1:
+ with session.Session(graph=g) as sess2:
+ sess1.run(c)
+
+ c.op._set_device('/cpu:0') # pylint: disable=protected-access
+
+ with self.assertRaisesRegexp(
+ errors.FailedPreconditionError,
+ 'add.*was changed by setting device after it was run'):
+ sess1.run(c)
+
+ # sess2 was not run before modification
+ self.assertAllEqual(3.0, sess2.run(c))
+
+ def testTwoSessionsBothRunBeforeModification(self):
+ with ops.Graph().as_default() as g, ops.device('/cpu:0'):
+ a = constant_op.constant(1.0)
+ b = constant_op.constant(2.0)
+ c = a + b
+
+ with session.Session(graph=g) as sess1:
+ with session.Session(graph=g) as sess2:
+ sess1.run(c)
+ sess2.run(c)
+
+ c.op._set_device('/cpu:0') # pylint: disable=protected-access
+
+ with self.assertRaisesRegexp(
+ errors.FailedPreconditionError,
+ 'add.*was changed by setting device after it was run'):
+ sess1.run(c)
+
+ with self.assertRaisesRegexp(
+ errors.FailedPreconditionError,
+ 'add.*was changed by setting device after it was run'):
+ sess2.run(c)
+
+
if __name__ == '__main__':
googletest.main()
diff --git a/tensorflow/python/client/tf_session.i b/tensorflow/python/client/tf_session.i
index 5fa1a7e8fc..d471a39b69 100644
--- a/tensorflow/python/client/tf_session.i
+++ b/tensorflow/python/client/tf_session.i
@@ -532,6 +532,49 @@ def TF_Reset(target, containers=None, config=None):
%unignore TF_GraphGetTensorShapeHelper;
%ignore TF_GraphGetTensorShape;
+// We use TF_GraphSetTensorShape_wrapper instead of
+// TF_GraphSetTensorShape
+%ignore TF_GraphSetTensorShape;
+%unignore tensorflow;
+%unignore TF_GraphSetTensorShape_wrapper;
+
+// $input is a Python list of ints to a vector<int> for TF_GraphSetTensorShape_wrapper
+%typemap(in) (const std::vector<int64_t>& dims)
+ (std::vector<int64_t> dims_local){
+ if ($input != Py_None) {
+ if (!PyList_Check($input)) {
+ SWIG_exception_fail(SWIG_TypeError, tensorflow::strings::Printf(
+ "$symname: expected list but got %s ", Py_TYPE($input)->tp_name).c_str());
+ }
+ size_t size = PyList_Size($input);
+ for (int i = 0; i < size; ++i) {
+ PyObject* item = PyList_GetItem($input, i);
+ dims_local.push_back(PyInt_AsLong(item));
+ }
+ $1 = &dims_local;
+ } else {
+ $1 = nullptr;
+ }
+}
+
+// We use TF_GraphGetTensorShape_wrapper instead of
+// TF_GraphGetTensorShape
+%ignore TF_GraphGetTensorShape;
+%unignore tensorflow;
+%unignore TF_GraphGetTensorShape_wrapper;
+
+// Build a Python list of ints and return it.
+%typemap(out) std::vector<int64_t> tensorflow::TF_GraphGetTensorShape_wrapper {
+ $result = PyList_New($1.size());
+ if (!$result) {
+ SWIG_exception_fail(SWIG_MemoryError, "$symname: couldn't create list");
+ }
+
+ for (size_t i = 0; i < $1.size(); ++i) {
+ PyList_SET_ITEM($result, i, PyInt_FromLong($1[i]));
+ }
+}
+
%include "tensorflow/python/client/tf_session_helper.h"
%unignoreall
diff --git a/tensorflow/python/client/tf_session_helper.cc b/tensorflow/python/client/tf_session_helper.cc
index ad982e5dd8..e4bf09a0ca 100644
--- a/tensorflow/python/client/tf_session_helper.cc
+++ b/tensorflow/python/client/tf_session_helper.cc
@@ -407,4 +407,23 @@ TF_Function* TF_GraphToFunction_wrapper(
opts, description, out_status);
}
+void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
+ const std::vector<int64_t>& dims,
+ bool unknown_shape, TF_Status* status) {
+ if (unknown_shape) {
+ TF_GraphSetTensorShape(graph, output, nullptr, -1, status);
+ return;
+ }
+ TF_GraphSetTensorShape(graph, output, dims.data(), dims.size(), status);
+}
+
+std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph,
+ TF_Output output,
+ int num_dims,
+ TF_Status* status) {
+ std::vector<int64_t> dims(num_dims);
+ TF_GraphGetTensorShape(graph, output, dims.data(), num_dims, status);
+ return dims;
+}
+
} // namespace tensorflow
diff --git a/tensorflow/python/client/tf_session_helper.h b/tensorflow/python/client/tf_session_helper.h
index 6ed08d3a58..bb7171db31 100644
--- a/tensorflow/python/client/tf_session_helper.h
+++ b/tensorflow/python/client/tf_session_helper.h
@@ -168,6 +168,20 @@ TF_Function* TF_GraphToFunction_wrapper(
const std::vector<TF_Output>& inputs, const std::vector<TF_Output>& outputs,
const NameVector& output_names, const TF_FunctionOptions* opts,
const char* description, TF_Status* out_status);
+
+// Set the shape of output. If unknown is true, `num_dims` must be set to
+// -1 and `dims` is set to nullptr.
+void TF_GraphSetTensorShape_wrapper(TF_Graph* graph, TF_Output output,
+ const std::vector<int64_t>& dims,
+ bool unknown_shape, TF_Status* status);
+
+// Return the shape of output. `num_dims` should be the output of
+// TF_GraphGetTensorNumDims. If `num_dims = -1`, this should not be called.
+std::vector<int64_t> TF_GraphGetTensorShape_wrapper(TF_Graph* graph,
+ TF_Output output,
+ int num_dims,
+ TF_Status* status);
+
} // namespace tensorflow
#endif // TENSORFLOW_PYTHON_CLIENT_TF_SESSION_HELPER_H_
diff --git a/tensorflow/python/data/ops/BUILD b/tensorflow/python/data/ops/BUILD
index 05acfe4de7..695d3ef790 100644
--- a/tensorflow/python/data/ops/BUILD
+++ b/tensorflow/python/data/ops/BUILD
@@ -21,6 +21,7 @@ py_library(
"//tensorflow/python:sparse_tensor",
"//tensorflow/python:tensor_shape",
"//tensorflow/python:tensor_util",
+ "//tensorflow/python:util",
"//tensorflow/python/data/util:nest",
"//tensorflow/python/data/util:sparse",
"//third_party/py/numpy",
diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py
index dbe29c087a..927c6d5c02 100644
--- a/tensorflow/python/data/ops/dataset_ops.py
+++ b/tensorflow/python/data/ops/dataset_ops.py
@@ -41,6 +41,7 @@ from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.ops import sparse_ops
+from tensorflow.python.util import deprecation
class Dataset(object):
@@ -219,6 +220,7 @@ class Dataset(object):
return TensorSliceDataset(tensors)
@staticmethod
+ @deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.")
def from_sparse_tensor_slices(sparse_tensor):
"""Splits each rank-N `tf.SparseTensor` in this dataset row-wise.
@@ -1232,13 +1234,40 @@ class ShuffleDataset(Dataset):
input_dataset,
buffer_size,
seed=None,
- reshuffle_each_iteration=None):
- """See `Dataset.shuffle()` for details."""
+ reshuffle_each_iteration=None,
+ seed2=None):
+ """Randomly shuffles the elements of this dataset.
+
+ Args:
+ input_dataset: The input dataset.
+ buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
+ number of elements from this dataset from which the new
+ dataset will sample.
+ seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
+ random seed that will be used to create the distribution. See
+ @{tf.set_random_seed} for behavior.
+ reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
+ that the dataset should be pseudorandomly reshuffled each time it is
+ iterated over. (Defaults to `True`.)
+ seed2: (Optional.) A `tf.int64` scalar `tf.Tensor` used to avoid seed
+ collision. Users should generally not need to specify this. This is
+ supposed to be used when both the seeds for the Dataset op need to be
+ manually specified. If not None, seed must also be non-None.
+
+ Returns:
+ A `Dataset`.
+
+ Raises:
+ ValueError: if invalid arguments are provided.
+ """
super(ShuffleDataset, self).__init__()
self._input_dataset = input_dataset
self._buffer_size = ops.convert_to_tensor(
buffer_size, dtype=dtypes.int64, name="buffer_size")
- seed, seed2 = random_seed.get_seed(seed)
+ if seed2 is None:
+ seed, seed2 = random_seed.get_seed(seed)
+ elif seed is None:
+ raise ValueError("seed must be non-None if seed2 is non-None.")
if seed is None:
self._seed = constant_op.constant(0, dtype=dtypes.int64, name="seed")
else:
diff --git a/tensorflow/python/data/util/nest.py b/tensorflow/python/data/util/nest.py
index bd7ab3d34f..2455395635 100644
--- a/tensorflow/python/data/util/nest.py
+++ b/tensorflow/python/data/util/nest.py
@@ -379,9 +379,9 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
if check_types and isinstance(shallow_tree, dict):
if set(input_tree) != set(shallow_tree):
raise ValueError(
- "The two structures don't have the same keys. Input "
- "structure has keys %s, while shallow structure has keys %s."
- % (list(_six.iterkeys(input_tree)),
+ "The two structures don't have the same keys. Input "
+ "structure has keys %s, while shallow structure has keys %s." %
+ (list(_six.iterkeys(input_tree)),
list(_six.iterkeys(shallow_tree))))
input_tree = list(_six.iteritems(input_tree))
shallow_tree = list(_six.iteritems(shallow_tree))
diff --git a/tensorflow/python/data/util/nest_test.py b/tensorflow/python/data/util/nest_test.py
index 8c84d9d1df..90dd7dfe77 100644
--- a/tensorflow/python/data/util/nest_test.py
+++ b/tensorflow/python/data/util/nest_test.py
@@ -271,8 +271,9 @@ class NestTest(test.TestCase):
inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
expected_message = (
- "The two structures don't have the same keys. Input "
- "structure has keys \['c'\], while shallow structure has keys \['d'\].")
+ r"The two structures don't have the same keys. Input "
+ r"structure has keys \['c'\], while shallow structure has "
+ r"keys \['d'\].")
with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_shallow_structure(inp_ab2, inp_ab1)
diff --git a/tensorflow/python/eager/backprop.py b/tensorflow/python/eager/backprop.py
index 0144f3b1e5..dc1142705a 100644
--- a/tensorflow/python/eager/backprop.py
+++ b/tensorflow/python/eager/backprop.py
@@ -540,7 +540,7 @@ def _ensure_unique_tensor_objects(parameter_positions, args):
if i in parameter_positions:
tid = ops.tensor_id(t)
if tid in s:
- args[i] = args[i]._dup() # pylint: disable=protected-access
+ args[i] = gen_array_ops.identity(args[i])
else:
s.add(tid)
return args
diff --git a/tensorflow/python/eager/context.py b/tensorflow/python/eager/context.py
index 92f4e15c05..415416cfae 100644
--- a/tensorflow/python/eager/context.py
+++ b/tensorflow/python/eager/context.py
@@ -288,6 +288,21 @@ class Context(object):
self._initialize_handle_and_devices()
return self._num_gpus
+ def add_function(self, fn):
+ """Add a function definition to the context.
+
+ Once added, the function (identified by its name) can be executed like any
+ other operation.
+
+ Args:
+ fn: A wrapped TF_Function (returned from TF_GraphToFunction_wrapper).
+ """
+ with errors.raise_exception_on_not_ok_status() as status:
+ pywrap_tensorflow.TFE_ContextAddFunction(
+ self._handle, # pylint: disable=protected-access
+ fn,
+ status)
+
def add_function_def(self, fdef):
"""Add a function definition to the context.
diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py
index 9bcd9c23c7..cadabb3a24 100644
--- a/tensorflow/python/eager/function.py
+++ b/tensorflow/python/eager/function.py
@@ -25,15 +25,19 @@ import threading
import numpy as np
+from tensorflow.core.framework import function_pb2
+from tensorflow.python import pywrap_tensorflow
from tensorflow.python.eager import context
from tensorflow.python.eager import execute
from tensorflow.python.eager import tape
from tensorflow.python.eager.graph_only_ops import graph_placeholder
+from tensorflow.python.framework import c_api_util
from tensorflow.python.framework import constant_op
-from tensorflow.python.framework import dtypes
-from tensorflow.python.framework import graph_to_function_def
+from tensorflow.python.framework import dtypes as dtypes_module
+from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
@@ -47,26 +51,41 @@ _scoped_captures = threading.local()
_scoped_captures.tensors = None
-def make_function_def(graph, operations, inputs, outputs):
- """Makes function def where accesses to resources are serialized."""
- last_op_using_resource_tensor = {}
-
- # TODO(apassos) probably control flow has to be handled delicately here as in
- # if a resource is accessed inside a control flow context we need the control
- # dependency to point to something outside the context which is guaranteed to
- # happen after the access.
- #
- # TODO(apassos) this should do some form of alias analysis as ops which
- # forward the resources such as Identity and Switch can cause serialization to
- # fail.
- for op in operations:
- for t in op.inputs:
- if t.dtype == dtypes.resource:
- if t.name in last_op_using_resource_tensor:
- op._add_control_input(last_op_using_resource_tensor[t.name]) # pylint: disable=protected-access
- last_op_using_resource_tensor[t.name] = op
- return graph_to_function_def.graph_to_function_def(
- graph, operations, inputs, outputs)
+def make_function_def(name, graph, operations, inputs, outputs):
+ """Makes FunctionDef proto and defined function.
+
+ Args:
+ name: the function name
+ graph: the graph from which to build the function
+ operations: the operations in the function body
+ inputs: tensors to be used as function arguments
+ outputs: tensors to be returned from the function
+
+ Returns:
+ fdef: a FunctionDef protocol buffer for the function
+ fn: a wrapped TF_Function for the function
+ """
+ with errors.raise_exception_on_not_ok_status() as status:
+ fn = pywrap_tensorflow.TF_GraphToFunction_wrapper(
+ graph._c_graph, # pylint: disable=protected-access
+ compat.as_str(name),
+ False,
+ [o._c_op for o in operations], # pylint: disable=protected-access
+ [t._as_tf_output() for t in inputs], # pylint: disable=protected-access
+ [t._as_tf_output() for t in outputs], # pylint: disable=protected-access
+ [],
+ None,
+ compat.as_str(""),
+ status)
+ # TODO(apassos) avoid creating a FunctionDef (specially to grab the signature,
+ # but also in general it's nice not to depend on it.
+ with c_api_util.tf_buffer() as buffer_:
+ with errors.raise_exception_on_not_ok_status() as status:
+ pywrap_tensorflow.TF_FunctionToFunctionDef(fn, buffer_, status)
+ proto_data = pywrap_tensorflow.TF_GetBuffer(buffer_)
+ fdef = function_pb2.FunctionDef()
+ fdef.ParseFromString(compat.as_bytes(proto_data))
+ return fdef, fn
@contextlib.contextmanager
@@ -85,7 +104,7 @@ def capture_value(tensor_map, value, dtype, name):
if captured_value is None:
captured_value = graph_placeholder(
dtype=dtype or value.dtype, shape=value.shape, name=name)
- if captured_value.dtype == dtypes.resource:
+ if captured_value.dtype == dtypes_module.resource:
captured_value._handle_data = value._handle_data # pylint: disable=protected-access
tensor_map[ops.tensor_id(value)] = (value, captured_value)
else:
@@ -120,11 +139,23 @@ def _convert_to_graph_tensor(value, dtype=None, name=None, as_ref=False):
class CapturingGraph(ops.Graph):
+ """Graph used when constructing eager functions."""
def __init__(self, captures):
super(CapturingGraph, self).__init__()
self._building_function = True
self.captures = captures
+ # Map from resource tensor name to last op (in program order) which uses
+ # this tensor. Used to enforce that execution order matches program order
+ # for resource tensors.
+ self._last_op_using_resource_tensor = {}
+
+ # TODO(apassos) remove once the C API is used by default.
+ def _use_c_api_hack(self):
+ return True
+
+ def clear_resource_control_flow_state(self):
+ self._last_op_using_resource_tensor = {}
def create_op(
self,
@@ -137,12 +168,31 @@ class CapturingGraph(ops.Graph):
op_def=None,
compute_shapes=True,
compute_device=True):
+ # TODO(apassos) probably control flow has to be handled delicately here as
+ # in if a resource is accessed inside a control flow context we need the
+ # control dependency to point to something outside the context which is
+ # guaranteed to happen after the access.
+ #
+ # TODO(apassos) this should do some form of alias analysis as ops which
+ # forward the resources such as Identity and Switch can cause serialization
+ # to fail.
+ resource_inputs = set()
+ control_inputs = set()
for i, inp in enumerate(inputs):
if inp.graph is not self:
inputs[i] = capture_value(self.captures, inp, inp.dtype, inp.op.name)
- return super(CapturingGraph, self).create_op(
- op_type, inputs, dtypes, input_types, name, attrs, op_def,
- compute_shapes, compute_device)
+ inp = inputs[i]
+ if inp.dtype == dtypes_module.resource:
+ if inp.name in self._last_op_using_resource_tensor:
+ control_inputs.add(self._last_op_using_resource_tensor[inp.name])
+ resource_inputs.add(inp.name)
+ with self.control_dependencies(list(control_inputs)):
+ op = super(CapturingGraph, self).create_op(
+ op_type, inputs, dtypes, input_types, name, attrs, op_def,
+ compute_shapes, compute_device)
+ for name in resource_inputs:
+ self._last_op_using_resource_tensor[name] = op
+ return op
# TODO(apassos): it'd be really nice if we could scope this registration.
@@ -196,14 +246,20 @@ def _inference_name(n):
return "__inference_%s_%s" % (n, ops.uid())
+# TODO(apassos) get rid of this by splitting framework.function._DefinedFunction
+# so it doesn't have the definition-generating logic and is just a container for
+# an already-defined function.
class _DefinedFunction(object):
"""Mocks the interface of tf _DefinedFunction."""
- def __init__(self, fdef):
+ def __init__(self, fdef, fn):
self.definition = fdef
self.name = fdef.signature.name
+ self.signature = fdef.signature
self.grad_func_name = None
self.python_grad_func = None
+ self._c_func = fn
+ self._grad_func = None
def _map_sequence_obj_to_idx(sequence):
@@ -239,6 +295,7 @@ class GraphModeFunction(object):
input_placeholders,
extra_inputs,
fdef,
+ fn,
graph,
operations,
func_outputs,
@@ -252,7 +309,7 @@ class GraphModeFunction(object):
self._graph = graph
self._has_backprop = False
self._func_name = fdef.signature.name
- self._fdef = _DefinedFunction(fdef)
+ self._fdef = _DefinedFunction(fdef, fn)
self._num_outputs = len(fdef.signature.output_arg)
self._ops = operations
self._func_outputs = func_outputs
@@ -272,38 +329,45 @@ class GraphModeFunction(object):
with self._graph.as_default(), context.graph_mode():
c = _CapturingContext()
with c:
- filtered_outputs = [
- x for x in self._returns if x is not None
- ]
+ filtered_outputs = [x for x in self._returns if x is not None]
self._out_grad_placeholders = [
- graph_placeholder(x.dtype, x.shape) for x in filtered_outputs
- ]
+ graph_placeholder(x.dtype, x.shape) for x in filtered_outputs]
in_gradients = gradients_impl.gradients(
filtered_outputs,
self._input_placeholders,
grad_ys=self._out_grad_placeholders)
- shapes = [x.shape for x in in_gradients if x is not None]
+ shapes = tuple(x.shape for x in in_gradients if x is not None)
captures = list(sorted(c.captured_tensors, key=lambda x: x.name))
- forward_function_def = make_function_def(
- self._graph, self._ops, self._input_placeholders,
+ forward_name = _forward_name(self._func_name)
+ forward_function_def, forward_fn = make_function_def(
+ forward_name, self._graph, self._ops, self._input_placeholders,
filtered_outputs + captures)
- self._forward_fdef = _DefinedFunction(forward_function_def)
- _register_with_name(_forward_name(self._func_name), forward_function_def)
- backward_outputs = [x for x in in_gradients if x is not None]
+ self._forward_fdef = _DefinedFunction(forward_function_def, forward_fn)
+ _register(forward_fn)
+ backward_outputs = tuple(x for x in in_gradients if x is not None)
all_inputs = self._out_grad_placeholders + captures
- backward_function_def = make_function_def(
- self._graph, [x.op for x in self._out_grad_placeholders
- ] + list(sorted(c.known_ops, key=lambda x: x.name)),
+ # Excluding input ops from the body as we do not intend to execute these
+ # operations when the function is executed.
+ all_ignored_ops = frozenset(x.op for x in all_inputs)
+ # Enforce a deterministic order of operations in the generated graph. This
+ # means rerunning the function-defining code will always define the same
+ # function, which is useful if we serialize this etc.
+ fdef_ops = tuple(x for x in sorted(c.known_ops, key=lambda x: x.name)
+ if x not in all_ignored_ops)
+ bname = _backward_name(self._func_name)
+ backward_function_def, backward_fn = make_function_def(
+ bname, self._graph, fdef_ops,
all_inputs, backward_outputs)
- _register_with_name(_backward_name(self._func_name), backward_function_def)
+ _register(backward_fn)
self._backward_function = GraphModeFunction(
- all_inputs, [], backward_function_def, self._graph, c.known_ops,
- in_gradients, _map_sequence_obj_to_idx(backward_outputs), shapes)
+ all_inputs, [], backward_function_def, backward_fn, self._graph,
+ c.known_ops, in_gradients, _map_sequence_obj_to_idx(backward_outputs),
+ shapes)
def _backprop_call(self, args):
"""Calls the wrapped function and records the result on a tape."""
all_args = args + self._extra_inputs
- signature = self._forward_fdef.definition.signature
+ signature = self._forward_fdef.signature
ctx = context.context()
if ctx.in_graph_mode():
g = ops.get_default_graph()
@@ -314,7 +378,7 @@ class GraphModeFunction(object):
return ops.internal_convert_to_tensor(x, ctx=ctx)
op = g.create_op(
signature.name, [make_tensor(x) for x in all_args],
- [dtypes.DType(x.type) for x in signature.output_arg],
+ tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
op_def=signature,
name="FunctionCall",
compute_shapes=False)
@@ -350,11 +414,8 @@ class GraphModeFunction(object):
if v._trainable: # pylint: disable=protected-access
tape.watch_variable(v)
- tensor_inputs = [
- x for x in nest.flatten(args)
- if isinstance(x, ops.Tensor)
- ]
-
+ tensor_inputs = [x for x in nest.flatten(args)
+ if isinstance(x, ops.Tensor)]
if tape.should_record(tensor_inputs) or tape.should_record(
self._extra_inputs):
if not self._has_backprop:
@@ -373,7 +434,7 @@ class GraphModeFunction(object):
args = list(tensor_inputs) + self._extra_inputs
op = g.create_op(
signature.name, [ops.convert_to_tensor(x) for x in args],
- [dtypes.DType(x.type) for x in signature.output_arg],
+ tuple(dtypes_module.DType(x.type) for x in signature.output_arg),
op_def=signature,
name="FunctionCall",
compute_shapes=False)
@@ -458,29 +519,32 @@ def _defun_internal(name, func, args, kwds):
extra_inputs = []
extra_placeholders = []
outputs_list = nest.flatten(func_outputs)
- output_shapes = [x.shape for x in outputs_list if x is not None]
+ output_shapes = tuple(x.shape for x in outputs_list if x is not None)
- flat_inputs = [
- x for x in nest.flatten(func_inputs) if isinstance(x, ops.Tensor)
- ]
+ flat_inputs = [x for x in nest.flatten(func_inputs)
+ if isinstance(x, ops.Tensor)]
all_inputs = flat_inputs + list(extra_placeholders)
-
+ all_ignored_ops = frozenset(x.op for x in all_inputs)
func_def_outputs = [x for x in outputs_list if x is not None]
- inference_function_def = make_function_def(
- tmp_graph, tmp_graph.get_operations(), all_inputs, func_def_outputs)
+ fname = _inference_name(name)
+ operations = tuple(x for x in tmp_graph.get_operations()
+ if x not in all_ignored_ops)
+ inference_function_def, fn = make_function_def(
+ fname, tmp_graph, operations, all_inputs, func_def_outputs)
# Register any other functions defined in the graph
# TODO(ashankar): Oh lord, forgive me for this lint travesty.
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
# TODO(ashankar): What about the gradient registry?
- _register_with_name(f.name, f.definition)
- _register_with_name(_inference_name(name), inference_function_def)
+ _register(f._c_func) # pylint: disable=protected-access
+ _register(fn)
return GraphModeFunction(
all_inputs,
extra_inputs,
inference_function_def,
+ fn,
tmp_graph,
- tmp_graph.get_operations(),
+ operations,
func_outputs,
_map_sequence_obj_to_idx(func_def_outputs),
output_shapes,
@@ -506,10 +570,9 @@ def _cache_key(x):
return x
-def _register_with_name(name, fdef):
- """Registers the function `fdef` with the name `name`."""
- fdef.signature.name = name
- context.context().add_function_def(fdef)
+def _register(fn):
+ """Registers the function `fn`."""
+ context.context().add_function(fn)
# TODO(apassos): better error messages for non-hashable arguments.
diff --git a/tensorflow/python/eager/graph_callable.py b/tensorflow/python/eager/graph_callable.py
index 837a75c808..3da100d800 100644
--- a/tensorflow/python/eager/graph_callable.py
+++ b/tensorflow/python/eager/graph_callable.py
@@ -296,6 +296,7 @@ def _graph_callable_internal(func, shape_and_dtypes):
# Call the function again, now replacing usages of variables with
# placeholders. This assumes the variable capturing scope created above
# knows about all variables.
+ tmp_graph.clear_resource_control_flow_state()
with variable_captures.capturing_scope(), function.capture_tensors(
captures):
captured_outputs = func(*func_inputs)
@@ -317,7 +318,9 @@ def _graph_callable_internal(func, shape_and_dtypes):
placeholder_inputs = flat_inputs+ list(extra_placeholders)
func_def_outputs = [x for x in outputs_list if isinstance(x, tf_ops.Tensor)]
- initializer_function_def = function.make_function_def(
+ initialization_name = function._inference_name(func.__name__) # pylint: disable=protected-access
+ initializer_function_def, initializer_fn = function.make_function_def(
+ initialization_name,
tmp_graph,
initializing_operations,
placeholder_inputs,
@@ -326,13 +329,13 @@ def _graph_callable_internal(func, shape_and_dtypes):
# Also, what about the gradient registry of these functions? Those need to be
# addressed as well.
for f in tmp_graph._functions.values(): # pylint: disable=protected-access
- function._register_with_name(f.name, f.definition) # pylint: disable=protected-access
- function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access
- initializer_function_def)
+ function._register(f._c_func) # pylint: disable=protected-access
+ function._register(initializer_fn) # pylint: disable=protected-access
initializer_function = function.GraphModeFunction(
placeholder_inputs,
extra_inputs,
initializer_function_def,
+ initializer_fn,
tmp_graph,
initializing_operations,
func_outputs,
@@ -341,18 +344,20 @@ def _graph_callable_internal(func, shape_and_dtypes):
capture_func_def_outputs = [
x for x in captured_outlist if isinstance(x, tf_ops.Tensor)]
- captured_function_def = function.make_function_def(
+ captured_function_name = function._inference_name(func.__name__) # pylint: disable=protected-access
+ captured_function_def, capturing_fn = function.make_function_def(
+ captured_function_name,
tmp_graph,
capturing_operations,
placeholder_inputs,
capture_func_def_outputs)
- function._register_with_name(function._inference_name(func.__name__), # pylint: disable=protected-access
- captured_function_def)
+ function._register(capturing_fn) # pylint: disable=protected-access
captured_function = function.GraphModeFunction(
placeholder_inputs,
extra_inputs,
captured_function_def,
+ capturing_fn,
tmp_graph,
capturing_operations,
captured_outputs,
diff --git a/tensorflow/python/eager/graph_callable_test.py b/tensorflow/python/eager/graph_callable_test.py
index 548e16a909..b9e6ca2a93 100644
--- a/tensorflow/python/eager/graph_callable_test.py
+++ b/tensorflow/python/eager/graph_callable_test.py
@@ -152,7 +152,6 @@ class GraphCallableTest(test.TestCase):
self.assertAllEqual(5, f(constant_op.constant(2)))
def testNestedFunction(self):
-
# TensorFlow function (which is what would be used in TensorFlow graph
# construction).
@function.Defun(dtypes.int32, dtypes.int32)
diff --git a/tensorflow/python/eager/pywrap_tfe_src.cc b/tensorflow/python/eager/pywrap_tfe_src.cc
index ce823cb567..b52d71dc6c 100644
--- a/tensorflow/python/eager/pywrap_tfe_src.cc
+++ b/tensorflow/python/eager/pywrap_tfe_src.cc
@@ -531,12 +531,9 @@ static PyTypeObject TFE_Py_Tape_Type = {
// xcode 7 doesn't define thread_local, so for compatibility we implement our
// own. TODO(apassos) remove once we can deprecate xcode 7.
#ifndef __APPLE__
-thread_local std::vector<TFE_Py_Tape*>* tape_stack = nullptr;
std::vector<TFE_Py_Tape*>* GetTapeStack() {
- if (tape_stack == nullptr) {
- tape_stack = new std::vector<TFE_Py_Tape*>;
- }
- return tape_stack;
+ thread_local std::vector<TFE_Py_Tape*> tape_stack;
+ return &tape_stack;
}
#else
static tensorflow::mutex stack_mu(tensorflow::LINKER_INITIALIZED);
diff --git a/tensorflow/python/estimator/BUILD b/tensorflow/python/estimator/BUILD
index 03f386e9cf..e062e1fbfe 100644
--- a/tensorflow/python/estimator/BUILD
+++ b/tensorflow/python/estimator/BUILD
@@ -215,6 +215,7 @@ py_test(
srcs_version = "PY2AND3",
tags = [
"no_pip",
+ "noasan", # test flakily times out in asan mode.
"notsan", # b/67510291
],
deps = [
@@ -433,6 +434,7 @@ py_library(
"//tensorflow/python:summary",
"//tensorflow/python:training",
"//tensorflow/python:util",
+ "//tensorflow/python/data",
"//tensorflow/python/saved_model:builder",
"//tensorflow/python/saved_model:tag_constants",
"//third_party/py/numpy",
diff --git a/tensorflow/python/estimator/estimator.py b/tensorflow/python/estimator/estimator.py
index f267f4a54e..63103ef4c1 100644
--- a/tensorflow/python/estimator/estimator.py
+++ b/tensorflow/python/estimator/estimator.py
@@ -30,6 +30,7 @@ from google.protobuf import message
from tensorflow.core.framework import summary_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.python.client import session as tf_session
+from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.eager import context
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator import run_config
@@ -416,7 +417,7 @@ class Estimator(object):
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
self._create_and_assert_global_step(g)
- features = self._get_features_from_input_fn(
+ features, input_hooks = self._get_features_from_input_fn(
input_fn, model_fn_lib.ModeKeys.PREDICT)
estimator_spec = self._call_model_fn(
features, None, model_fn_lib.ModeKeys.PREDICT, self.config)
@@ -426,7 +427,7 @@ class Estimator(object):
checkpoint_filename_with_path=checkpoint_path,
scaffold=estimator_spec.scaffold,
config=self._session_config),
- hooks=hooks) as mon_sess:
+ hooks=input_hooks + hooks) as mon_sess:
while not mon_sess.should_stop():
preds_evaluated = mon_sess.run(predictions)
if not isinstance(predictions, dict):
@@ -582,6 +583,11 @@ class Estimator(object):
def _get_features_from_input_fn(self, input_fn, mode):
"""Extracts the `features` from return values of `input_fn`."""
result = self._call_input_fn(input_fn, mode)
+ input_hooks = []
+ if isinstance(result, dataset_ops.Dataset):
+ iterator = result.make_initializable_iterator()
+ input_hooks.append(_DatasetInitializerHook(iterator))
+ result = iterator.get_next()
if isinstance(result, (list, tuple)):
# Unconditionally drop the label (the second element of result).
result = result[0]
@@ -590,16 +596,22 @@ class Estimator(object):
logging.warning('Input graph does not use tf.data.Dataset or contain a '
'QueueRunner. That means predict yields forever. '
'This is probably a mistake.')
- return result
+ return result, input_hooks
def _get_features_and_labels_from_input_fn(self, input_fn, mode):
+ """Extracts the `features` and labels from return values of `input_fn`."""
result = self._call_input_fn(input_fn, mode)
+ input_hooks = []
+ if isinstance(result, dataset_ops.Dataset):
+ iterator = result.make_initializable_iterator()
+ input_hooks.append(_DatasetInitializerHook(iterator))
+ result = iterator.get_next()
if isinstance(result, (list, tuple)):
if len(result) != 2:
raise ValueError(
'input_fn should return (feautures, labels) as a len 2 tuple.')
- return result
- return result, None
+ return result[0], result[1], input_hooks
+ return result, None, input_hooks
def _extract_batch_length(self, preds_evaluated):
"""Extracts batch length of predictions."""
@@ -723,8 +735,10 @@ class Estimator(object):
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
training_util._get_or_create_global_step_read() # pylint: disable=protected-access
- features, labels = self._get_features_and_labels_from_input_fn(
- input_fn, model_fn_lib.ModeKeys.TRAIN)
+ features, labels, input_hooks = (
+ self._get_features_and_labels_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.TRAIN))
+ worker_hooks.extend(input_hooks)
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.TRAIN, self.config)
# Check if the user created a loss summary, and add one if they didn't.
@@ -822,8 +836,9 @@ class Estimator(object):
with ops.Graph().as_default() as g:
random_seed.set_random_seed(self._config.tf_random_seed)
global_step_tensor = self._create_and_assert_global_step(g)
- features, labels = self._get_features_and_labels_from_input_fn(
- input_fn, model_fn_lib.ModeKeys.EVAL)
+ features, labels, input_hooks = (
+ self._get_features_and_labels_from_input_fn(
+ input_fn, model_fn_lib.ModeKeys.EVAL))
estimator_spec = self._call_model_fn(
features, labels, model_fn_lib.ModeKeys.EVAL, self.config)
@@ -844,7 +859,8 @@ class Estimator(object):
'already defines a default metric with the same name.')
eval_dict[ops.GraphKeys.GLOBAL_STEP] = global_step_tensor
- all_hooks = list(hooks or [])
+ all_hooks = list(input_hooks)
+ all_hooks.extend(hooks)
all_hooks.extend(list(estimator_spec.evaluation_hooks or []))
eval_results = evaluation._evaluate_once( # pylint: disable=protected-access
@@ -1039,3 +1055,16 @@ def _has_dataset_or_queue_runner(maybe_tensor):
# Now, check queue.
return ops.get_default_graph().get_collection(ops.GraphKeys.QUEUE_RUNNERS)
+
+
+class _DatasetInitializerHook(training.SessionRunHook):
+
+ def __init__(self, iterator):
+ self._iterator = iterator
+
+ def begin(self):
+ self._initializer = self._iterator.initializer
+
+ def after_create_session(self, session, coord):
+ del coord
+ session.run(self._initializer)
diff --git a/tensorflow/python/estimator/estimator_test.py b/tensorflow/python/estimator/estimator_test.py
index c1b773b8c4..db64fbc9cc 100644
--- a/tensorflow/python/estimator/estimator_test.py
+++ b/tensorflow/python/estimator/estimator_test.py
@@ -913,6 +913,80 @@ class EstimatorGetVariablesTest(test.TestCase):
self.assertEqual(3., est.get_variable_value('three'))
+class EstimatorDatasetIntegrationTest(test.TestCase):
+ """Tests dataset integration."""
+
+ def test_returned_by_input_fn(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.from_tensors(([1.], [2.]))
+
+ def _model_fn(features, labels, mode):
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=features + labels, # 1 + 2
+ train_op=state_ops.assign_add(training.get_global_step(), 1))
+
+ est = estimator.Estimator(model_fn=_model_fn)
+ est.train(_input_fn, steps=1)
+ scores = est.evaluate(_input_fn, steps=1)
+ self.assertEqual(3., scores[model_fn_lib.LOSS_METRIC_KEY])
+
+ def test_with_none_labels(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.from_tensors([7.])
+
+ def _model_fn(features, labels, mode):
+ self.assertIsNone(labels)
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ loss=features, # 7
+ train_op=state_ops.assign_add(training.get_global_step(), 1))
+
+ est = estimator.Estimator(model_fn=_model_fn)
+ est.train(_input_fn, steps=1)
+ scores = est.evaluate(_input_fn, steps=1)
+ self.assertEqual(7., scores[model_fn_lib.LOSS_METRIC_KEY])
+
+ def test_with_predict(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.from_tensors([10.])
+
+ def _model_fn(features, labels, mode):
+ _ = labels
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=features, # 10
+ loss=features, # 10
+ train_op=state_ops.assign_add(training.get_global_step(), 1))
+
+ est = estimator.Estimator(model_fn=_model_fn)
+ est.train(_input_fn, steps=1)
+ self.assertEqual([10.], next(est.predict(input_fn=_input_fn)))
+
+ def test_batching(self):
+
+ def _input_fn():
+ return dataset_ops.Dataset.from_tensor_slices(([[1.], [2.]],
+ [[10.], [20.]])).batch(1)
+
+ def _model_fn(features, labels, mode):
+ return model_fn_lib.EstimatorSpec(
+ mode,
+ predictions=features,
+ loss=features + (0 if labels is None else labels), # 11, 22
+ train_op=state_ops.assign_add(training.get_global_step(), 1))
+
+ est = estimator.Estimator(model_fn=_model_fn)
+ est.train(_input_fn)
+ scores = est.evaluate(_input_fn)
+ # (11 + 22)/2 = 16.5
+ self.assertEqual(16.5, scores[model_fn_lib.LOSS_METRIC_KEY])
+ self.assertEqual([1., 2.], list(est.predict(_input_fn)))
+
+
class EstimatorEvaluateTest(test.TestCase):
def test_input_fn_args(self):
diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py
index 29cf223724..366025a0d8 100644
--- a/tensorflow/python/framework/function.py
+++ b/tensorflow/python/framework/function.py
@@ -692,7 +692,10 @@ class _FuncGraph(ops.Graph):
else:
# Substitute with a placeholder.
self.extra_inputs.append(x)
- ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
+ # Hoist the new input placeholder out of any control flow context
+ # we're currently in.
+ with ops.control_dependencies(None):
+ ph = array_ops.placeholder(x.dtype, shape=x.get_shape())
# pylint: disable=protected-access
ph._handle_data = x._handle_data
# pylint: enable=protected-access
diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py
index ba43e9199b..11f343c579 100644
--- a/tensorflow/python/framework/function_test.py
+++ b/tensorflow/python/framework/function_test.py
@@ -724,6 +724,38 @@ class FunctionTest(test.TestCase):
# NOTE: We still do not support capturing control deps.
_ = Foo(x)
+ def testCaptureInWhileLoop(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = constant_op.constant(1)
+
+ @function.Defun()
+ def Foo():
+ return control_flow_ops.while_loop(lambda i: i < 10,
+ lambda i: i + x,
+ [0])
+ y = Foo()
+
+ with self.test_session(graph=g) as sess:
+ self.assertEqual(sess.run(y), 10)
+
+ def testCaptureInCond(self):
+ g = ops.Graph()
+ with g.as_default():
+ x = constant_op.constant(1)
+
+ @function.Defun(dtypes.bool)
+ def Foo(pred):
+ return control_flow_ops.cond(pred,
+ lambda: x,
+ lambda: x + 1)
+ y = Foo(True)
+ z = Foo(False)
+
+ with self.test_session(graph=g) as sess:
+ self.assertEqual(sess.run(y), 1)
+ self.assertEqual(sess.run(z), 2)
+
def testStableName(self):
@function.Defun()
diff --git a/tensorflow/python/framework/importer.py b/tensorflow/python/framework/importer.py
index 434cbda7ad..ada8c30fab 100644
--- a/tensorflow/python/framework/importer.py
+++ b/tensorflow/python/framework/importer.py
@@ -179,12 +179,11 @@ def _ProcessInputMapParam(input_map):
def _ProcessReturnElementsParam(return_elements):
"""Type-checks and possibly canonicalizes `return_elements`."""
- if return_elements is not None:
- return_elements = tuple(return_elements)
- if not all(isinstance(x, compat.bytes_or_text_types)
- for x in return_elements):
- raise TypeError('return_elements must be a list of strings.')
- return return_elements
+ if return_elements is None: return None
+ if not all(isinstance(x, compat.bytes_or_text_types)
+ for x in return_elements):
+ raise TypeError('return_elements must be a list of strings.')
+ return tuple(compat.as_str(x) for x in return_elements)
def _FindAttrInOpDef(attr_name, op_def):
@@ -194,24 +193,125 @@ def _FindAttrInOpDef(attr_name, op_def):
return None
-def _PopulateTFImportGraphDefOptions(options, prefix, return_elements):
+def _ConvertInputMapValues(name, input_map):
+ """Ensures all input map values are tensors.
+
+ This should be called from inside the import name scope.
+
+ Args:
+ name: the `name` argument passed to import_graph_def
+ input_map: the `input_map` argument passed to import_graph_def.
+
+ Returns:
+ An possibly-updated version of `input_map`.
+
+ Raises:
+ ValueError: if input map values cannot be converted due to empty name scope.
+ """
+ if not all(isinstance(v, ops.Tensor) for v in input_map.values()):
+ if name == '': # pylint: disable=g-explicit-bool-comparison
+ raise ValueError(
+ 'tf.import_graph_def() requires a non-empty `name` if `input_map` '
+ 'contains non-Tensor values. Try calling tf.convert_to_tensor() on '
+ '`input_map` values before calling tf.import_graph_def().')
+ with ops.name_scope('_inputs'):
+ input_map = {k: ops.convert_to_tensor(v) for k, v in input_map.items()}
+ return input_map
+
+
+def _PopulateTFImportGraphDefOptions(options, prefix, input_map,
+ return_elements):
"""Populates the TF_ImportGraphDefOptions `options`."""
c_api.TF_ImportGraphDefOptionsSetPrefix(options, prefix)
+ for input_src, input_dst in input_map.items():
+ input_src = compat.as_str(input_src)
+ if input_src.startswith('^'):
+ src_name = compat.as_bytes(input_src[1:])
+ dst_op = input_dst._as_tf_output().oper # pylint: disable=protected-access
+ c_api.TF_ImportGraphDefOptionsRemapControlDependency(options, src_name,
+ dst_op)
+ else:
+ src_name, src_idx = _ParseTensorName(input_src)
+ src_name = compat.as_str(src_name)
+ dst_output = input_dst._as_tf_output() # pylint: disable=protected-access
+ c_api.TF_ImportGraphDefOptionsAddInputMapping(options, src_name,
+ src_idx, dst_output)
for name in return_elements or []:
if ':' in name:
op_name, index = _ParseTensorName(name)
+ op_name = compat.as_str(op_name)
c_api.TF_ImportGraphDefOptionsAddReturnOutput(options, op_name, index)
else:
- c_api.TF_ImportGraphDefOptionsAddReturnOperation(options, name)
+ c_api.TF_ImportGraphDefOptionsAddReturnOperation(options,
+ compat.as_str(name))
+
+ # TODO(skyewm): control dependencies
def _ProcessNewOps(graph):
"""Processes the newly-added TF_Operations in `graph`."""
- for c_op in c_api_util.new_tf_operations(graph):
- graph._create_op_from_tf_operation(c_op) # pylint: disable=protected-access
+ # Maps from a node to the names of the ops it's colocated with, if colocation
+ # is specified in the attributes.
+ colocation_pairs = {}
- # TODO(skyewm): colocation logic
+ for c_op in c_api_util.new_tf_operations(graph):
+ # pylint: disable=protected-access
+ new_op = graph._create_op_from_tf_operation(c_op, compute_device=False)
+ # pylint: enable=protected-access
+
+ colocation_names = _GetColocationNames(new_op)
+ if colocation_names:
+ colocation_pairs[new_op] = colocation_names
+ # Don't apply this op's device function, since colocation constraints
+ # override device functions. Note that this op's device may still be set
+ # by the loop below.
+ else:
+ with _MaybeDevice(new_op.device):
+ graph._apply_device_functions(new_op) # pylint: disable=protected-access
+
+ # The following loop populates the device field of ops that are colocated
+ # with another op. This is implied by the colocation attribute, but we
+ # propagate the device field for completeness.
+ for op, coloc_op_list in colocation_pairs.items():
+ coloc_device = None
+ # Find any device in the list of colocated ops that have a device, if it
+ # exists. We assume that if multiple ops have devices, they refer to the
+ # same device. Otherwise, a runtime error will occur since the colocation
+ # property cannot be guaranteed.
+ #
+ # One possible improvement is to try to check for compatibility of all
+ # devices in this list at import time here, which would require
+ # implementing a compatibility function for device specs in python.
+ for coloc_op_name in coloc_op_list:
+ try:
+ coloc_op = graph._get_operation_by_name_unsafe(coloc_op_name) # pylint: disable=protected-access
+ except KeyError:
+ raise ValueError('Specified colocation to an op that '
+ 'does not exist during import: %s in %s' % (
+ coloc_op_name, op.name))
+ if coloc_op.device:
+ coloc_device = pydev.DeviceSpec.from_string(coloc_op.device)
+ break
+ if coloc_device:
+ op._set_device(coloc_device) # pylint: disable=protected-access
+
+
+def _GetColocationNames(op):
+ """Returns names of the ops that `op` should be colocated with."""
+ colocation_names = []
+ try:
+ class_values = op.get_attr('_class')
+ except ValueError:
+ # No _class attr
+ return
+ for val in class_values:
+ val = compat.as_str(val)
+ if val.startswith('loc:@'):
+ colocation_node_name = val[len('loc:@'):]
+ if colocation_node_name != op.name:
+ colocation_names.append(colocation_node_name)
+ return colocation_names
def _GatherReturnElements(requested_return_elements, graph, results):
@@ -312,17 +412,27 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
else:
prefix = ''
+ # Generate any input map tensors inside name scope
+ input_map = _ConvertInputMapValues(name, input_map)
+
scoped_options = c_api_util.ScopedTFImportGraphDefOptions()
options = scoped_options.options
- _PopulateTFImportGraphDefOptions(options, prefix, return_elements)
+ _PopulateTFImportGraphDefOptions(options, prefix, input_map,
+ return_elements)
with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
- with errors.raise_exception_on_not_ok_status() as status:
- results = c_api.TF_GraphImportGraphDefWithResults(
- graph._c_graph, serialized, options, status) # pylint: disable=protected-access
+ try:
+ with errors.raise_exception_on_not_ok_status() as status:
+ results = c_api.TF_GraphImportGraphDefWithResults(
+ graph._c_graph, serialized, options, status) # pylint: disable=protected-access
+ except errors.InvalidArgumentError as e:
+ # Convert to ValueError for backwards compatibility.
+ raise ValueError(str(e))
_ProcessNewOps(graph)
+ # TODO(skyewm): error if unused input map key
+
if return_elements is None:
return None
else:
@@ -359,16 +469,7 @@ def import_graph_def(graph_def, input_map=None, return_elements=None,
# more nuanced.
g.graph_def_versions.CopyFrom(graph_def.versions)
- if not all(isinstance(v, ops.Tensor) for v in input_map.values()):
- if not scope:
- # The caller must have passed `name=''`.
- raise ValueError(
- 'tf.import_graph_def() requires a non-empty `name` if `input_map`'
- ' contains non-Tensor values. Try calling tf.convert_to_tensor() '
- 'on `input_map` values before calling tf.import_graph_def().')
- with ops.name_scope('_inputs'):
- input_map = {k: ops.convert_to_tensor(v)
- for k, v in input_map.items()}
+ input_map = _ConvertInputMapValues(name, input_map)
# NOTE(mrry): We do this in two passes, because there may be a cycle in
# `graph_def`.
diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py
index 5a6187c8a6..4a215abd2e 100644
--- a/tensorflow/python/framework/importer_test.py
+++ b/tensorflow/python/framework/importer_test.py
@@ -201,8 +201,6 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual(outer_inner_c.name, "outer/inner/c_1")
def testInputMap(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
-
with ops.Graph().as_default():
feed_a_0 = constant_op.constant(0, dtype=dtypes.int32)
feed_b_1 = constant_op.constant(1, dtype=dtypes.int32)
@@ -230,8 +228,6 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual(d.inputs[1], feed_b_1)
def testInputMapBytes(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
-
with ops.Graph().as_default():
feed_a_0 = constant_op.constant(0, dtype=dtypes.int32)
feed_b_1 = constant_op.constant(1, dtype=dtypes.int32)
@@ -259,8 +255,6 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual(d.inputs[1], feed_b_1)
def testInputMapUnicode(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
-
with ops.Graph().as_default():
feed_a_0 = constant_op.constant(0, dtype=dtypes.int32)
feed_b_1 = constant_op.constant(1, dtype=dtypes.int32)
@@ -299,8 +293,6 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual(b.inputs[0], a.outputs[0])
def testInputMapImplicitZerothOutput(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
-
with ops.Graph().as_default():
feed_a_0 = constant_op.constant(0, dtype=dtypes.int32)
b, = importer.import_graph_def(
@@ -453,8 +445,6 @@ class ImportGraphDefTest(test.TestCase):
self.assertTrue("Input tensor 'A:0' not found" in str(e.exception))
def testMissingInputOpInGraphDefButAppearsInInputMap(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
-
with ops.Graph().as_default():
feed_a_0 = constant_op.constant(5.0)
b, = importer.import_graph_def(
@@ -589,19 +579,20 @@ class ImportGraphDefTest(test.TestCase):
self.assertTrue("not found in graph_def: [A:2]" in str(e.exception))
def testInputMapTypeMismatch(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
-
+ if ops._USE_C_API:
+ error_msg = ("Input 0 of node import/B was passed float from Const:0 "
+ "incompatible with expected int32.")
+ else:
+ error_msg = ("Cannot convert a tensor of type float32 to an input of "
+ "type int32.")
with ops.Graph().as_default():
- with self.assertRaises(ValueError) as e:
+ with self.assertRaisesRegexp(ValueError, error_msg):
importer.import_graph_def(
self._MakeGraphDef("""
node { name: 'A' op: 'IntOutput' }
node { name: 'B' op: 'IntInput' input: 'A:0' }
"""),
input_map={"A:0": constant_op.constant(5.0)})
- self.assertTrue(
- "Cannot convert a tensor of type float32 to an input of type int32."
- in str(e.exception))
def testNoReturns(self):
with ops.Graph().as_default() as g:
@@ -651,8 +642,6 @@ class ImportGraphDefTest(test.TestCase):
b.node_def.attr["_class"])
def testColocationWithDeviceFn(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
-
original_graph_def = self._MakeGraphDef("""
node { name: 'A' op: 'None' attr {
key: '_class'
@@ -674,23 +663,17 @@ class ImportGraphDefTest(test.TestCase):
with ops.Graph().as_default():
with ops.device(CustomDeviceFn):
- b, = importer.import_graph_def(
- original_graph_def, return_elements=["B"], name="imported_graph")
-
- self.assertProtoEqualsVersion("""
- node { name: 'imported_graph/A' op: 'None' device: "/device:A:0"
- attr {
- key: '_class' value { list { s: 'loc:@imported_graph/A' } }
- }
- }
- node { name: 'imported_graph/B' op: 'None' device: "/device:A:0"
- attr {
- key: '_class' value { list { s: 'loc:@imported_graph/A' } }
- } }""", b.graph.as_graph_def())
-
- # Test a scenario where 'A' doesn't get a device; 'A' should
- # not have a device, but during runtime will get colocated with
- # 'B' because of the colocation attribute.
+ a, b = importer.import_graph_def(original_graph_def,
+ return_elements=["A", "B"],
+ name="imported_graph")
+ self.assertEqual(a.device, "/device:A:0")
+ self.assertEqual(b.device, "/device:A:0")
+ self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
+ self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"])
+
+ # Test a scenario where 'A' doesn't get a device; 'A' should not have a
+ # device, but during runtime will get colocated with 'B' because of the
+ # colocation attribute. B's device function is still overridden by A.
def BDeviceFn(op):
if "B" in op.name:
return "/device:B:0"
@@ -698,19 +681,13 @@ class ImportGraphDefTest(test.TestCase):
with ops.Graph().as_default():
with ops.device(BDeviceFn):
- b, = importer.import_graph_def(
- original_graph_def, return_elements=["B"], name="imported_graph")
-
- self.assertProtoEqualsVersion("""
- node { name: 'imported_graph/A' op: 'None'
- attr {
- key: '_class' value { list { s: 'loc:@imported_graph/A' } }
- }
- }
- node { name: 'imported_graph/B' op: 'None'
- attr {
- key: '_class' value { list { s: 'loc:@imported_graph/A' } }
- } }""", b.graph.as_graph_def())
+ a, b = importer.import_graph_def(original_graph_def,
+ return_elements=["A", "B"],
+ name="imported_graph")
+ self.assertEqual(a.device, "")
+ self.assertEqual(b.device, "")
+ self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
+ self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"])
# Only A gets a device, so B inherits it implicitly.
def ADeviceFn(op):
@@ -720,23 +697,15 @@ class ImportGraphDefTest(test.TestCase):
with ops.Graph().as_default():
with ops.device(ADeviceFn):
- b, = importer.import_graph_def(
- original_graph_def, return_elements=["B"], name="imported_graph")
-
- self.assertProtoEqualsVersion("""
- node { name: 'imported_graph/A' op: 'None' device: "/device:A:0"
- attr {
- key: '_class' value { list { s: 'loc:@imported_graph/A' } }
- }
- }
- node { name: 'imported_graph/B' op: 'None' device: "/device:A:0"
- attr {
- key: '_class' value { list { s: 'loc:@imported_graph/A' } }
- } }""", b.graph.as_graph_def())
+ a, b = importer.import_graph_def(original_graph_def,
+ return_elements=["A", "B"],
+ name="imported_graph")
+ self.assertEqual(a.device, "/device:A:0")
+ self.assertEqual(b.device, "/device:A:0")
+ self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
+ self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/A"])
def testMultipleColocationWithDeviceFn(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
-
original_graph_def = self._MakeGraphDef("""
node { name: 'A' op: 'None'}
node { name: 'B' op: 'None'}
@@ -757,23 +726,19 @@ class ImportGraphDefTest(test.TestCase):
with ops.Graph().as_default():
with ops.device(CustomDeviceFn):
- c, = importer.import_graph_def(
- original_graph_def, return_elements=["C"], name="imported_graph")
-
- self.assertProtoEqualsVersion("""
- node { name: 'imported_graph/A' op: 'None' }
- node { name: 'imported_graph/B' op: 'None' device: "/device:B:0" }
- node { name: 'imported_graph/C' op: 'None' device: "/device:B:0"
- attr {
- key: '_class' value {
- list { s: 'loc:@imported_graph/A'
- s: 'loc:@imported_graph/B' }
- }
- }
- }""", c.graph.as_graph_def())
+ a, b, c = importer.import_graph_def(original_graph_def,
+ return_elements=["A", "B", "C"],
+ name="imported_graph")
+ self.assertEqual(a.device, "")
+ self.assertEqual(b.device, "/device:B:0")
+ self.assertEqual(c.device, "/device:B:0")
+ self.assertEqual(a.colocation_groups(), [b"loc:@imported_graph/A"])
+ self.assertEqual(b.colocation_groups(), [b"loc:@imported_graph/B"])
+ self.assertEqual(c.colocation_groups(),
+ [b"loc:@imported_graph/A", b"loc:@imported_graph/B"])
def testNamePrefixColocationAttrsMultipleImport(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
+ if ops._USE_C_API: return # TODO(skyewm): set uniquify_names
original_graph_def = self._MakeGraphDef("""
node { name: 'A' op: 'None' }
@@ -800,15 +765,19 @@ class ImportGraphDefTest(test.TestCase):
} }""", b.graph.as_graph_def())
def testNamePrefixColocationAttrsNotFound(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
-
original_graph_def = self._MakeGraphDef("""
node { name: 'B' op: 'None' attr {
key: '_class'
value { list { s: 'loc:@A' } }
} }""")
+
+ if ops._USE_C_API:
+ error_msg = "Node 'B' expects to be colocated with unknown node 'A'"
+ else:
+ error_msg = "does not exist during import"
+
with ops.Graph().as_default():
- with self.assertRaisesRegexp(ValueError, "does not exist during import"):
+ with self.assertRaisesRegexp(ValueError, error_msg):
importer.import_graph_def(
original_graph_def, return_elements=["B"], name="imported_graph")
@@ -825,8 +794,6 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual("graph_def must be a GraphDef proto.", str(e.exception))
def testInvalidInputForInputMap(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
-
with ops.Graph().as_default():
with self.assertRaises(TypeError) as e:
importer.import_graph_def(
@@ -967,7 +934,7 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual(2, len(ops_with_two_inputs))
def testGradient(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
+ if ops._USE_C_API: return # TODO(skyewm): get_shape() doesn't work
with ops.Graph().as_default() as g:
inputs = array_ops.placeholder(
@@ -1226,8 +1193,6 @@ class ImportGraphDefTest(test.TestCase):
self.assertEqual(z_val, -2.0)
def testImportGraphWithFunctionTwice(self):
- if ops._USE_C_API: return # TODO(skyewm): make this work with C API
-
g = ops.Graph()
with g.as_default():
@function.Defun()
diff --git a/tensorflow/python/framework/ops.py b/tensorflow/python/framework/ops.py
index 60df8f82f0..13e6426447 100644
--- a/tensorflow/python/framework/ops.py
+++ b/tensorflow/python/framework/ops.py
@@ -35,6 +35,7 @@ from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.framework import op_def_pb2
from tensorflow.core.framework import versions_pb2
+from tensorflow.core.protobuf import config_pb2
from tensorflow.python import pywrap_tensorflow as c_api
from tensorflow.python.eager import context
from tensorflow.python.eager import core
@@ -373,6 +374,19 @@ class Tensor(_TensorLike):
A `TensorShape` representing the shape of this tensor.
"""
+ if _USE_C_API:
+ graph = self._op._graph._c_graph # pylint: disable=protected-access
+ with errors.raise_exception_on_not_ok_status() as status:
+ num_dims = c_api.TF_GraphGetTensorNumDims(graph, self._as_tf_output(),
+ status)
+ if num_dims == -1:
+ dim_list = None
+ else:
+ with errors.raise_exception_on_not_ok_status() as status:
+ dim_list = c_api.TF_GraphGetTensorShape_wrapper(
+ graph, self._as_tf_output(), num_dims, status)
+ dim_list = [None if i == -1 else i for i in dim_list]
+ return tensor_shape.TensorShape(dim_list)
return self._shape
def __iter__(self):
@@ -392,8 +406,8 @@ class Tensor(_TensorLike):
yield self[i]
def _shape_as_list(self):
- if self._shape.ndims is not None:
- return [dim.value for dim in self._shape.dims]
+ if self.shape.ndims is not None:
+ return [dim.value for dim in self.shape.dims]
else:
return None
@@ -409,7 +423,7 @@ class Tensor(_TensorLike):
Returns:
Integer rank or None
"""
- return self._shape.ndims
+ return self.shape.ndims
def get_shape(self):
"""Alias of Tensor.shape."""
@@ -440,14 +454,35 @@ class Tensor(_TensorLike):
```
Args:
- shape: A `TensorShape` representing the shape of this tensor.
+ shape: A `TensorShape` representing the shape of this tensor, a
+ `TensorShapeProto`, a list, a tuple, or None.
Raises:
ValueError: If `shape` is not compatible with the current shape of
this tensor.
"""
- # TODO(skyewm): call C API
- self._shape = self._shape.merge_with(shape)
+ if not _USE_C_API:
+ self._shape = self._shape.merge_with(shape) # pylint: disable=protected-access
+ return
+ if not isinstance(shape, tensor_shape.TensorShape):
+ shape = tensor_shape.TensorShape(shape)
+ dim_list = []
+ if shape.dims is None:
+ unknown_shape = True
+ else:
+ unknown_shape = False
+ for dim in shape.dims:
+ if dim.value is None:
+ dim_list.append(-1)
+ else:
+ dim_list.append(dim.value)
+ with errors.raise_exception_on_not_ok_status() as status:
+ c_api.TF_GraphSetTensorShape_wrapper(
+ self._op._graph._c_graph, # pylint: disable=protected-access
+ self._as_tf_output(),
+ dim_list,
+ unknown_shape,
+ status)
@property
def value_index(self):
@@ -598,11 +633,6 @@ class Tensor(_TensorLike):
"""
return _eval_using_default_session(self, feed_dict, self.graph, session)
- def _dup(self):
- ret = copy.copy(self)
- ret._id = uid() # pylint: disable=protected-access
- return ret
-
# TODO(agarwal): consider getting rid of this.
class _EagerTensorBase(Tensor):
@@ -728,9 +758,6 @@ class _EagerTensorBase(Tensor):
return new_tensor
# pylint: enable=protected-access
- def _dup(self):
- return self._copy(device_name=self.device)
-
@property
def shape(self):
return tensor_shape.TensorShape(self._shape_tuple())
@@ -1634,8 +1661,6 @@ class Operation(object):
self._id_value = self._graph._next_id() # pylint: disable=protected-access
self._recompute_node_def()
- self._graph._add_op(self) # pylint: disable=protected-access
-
def _reconstruct_sequence_inputs(self, op_def, inputs, attrs):
"""Regroups a flat list of input tensors into scalar and sequence inputs.
@@ -1795,7 +1820,7 @@ class Operation(object):
c_api.SetRequestedDevice(
self._graph._c_graph, # pylint: disable=protected-access
self._c_op, # pylint: disable=protected-access
- _device_string(device))
+ compat.as_str(_device_string(device)))
else:
self._node_def.device = _device_string(device)
@@ -2084,7 +2109,7 @@ class Operation(object):
def _set_attr(self, attr_name, attr_value):
"""Private method used to set an attribute in the node_def."""
- if _USE_C_API:
+ if self._c_op:
buf = c_api.TF_NewBufferFromString(
compat.as_bytes(attr_value.SerializeToString()))
try:
@@ -2653,11 +2678,16 @@ class Graph(object):
# TODO(skyewm): fold as much of the above as possible into the C
# implementation
- if _USE_C_API:
+ if _USE_C_API or self._use_c_api_hack():
self._scoped_c_graph = c_api_util.ScopedTFGraph()
else:
self._scoped_c_graph = None
+ # TODO(apassos) remove once the C API is used by default.
+ def _use_c_api_hack(self):
+ """Temporary hack; can be overridden to force C API usage."""
+ return False
+
def _convert_stack(self, stack, include_func_start_lineno=False):
"""Converts a stack extracted using _extract_stack() to a traceback stack.
@@ -2986,9 +3016,14 @@ class Graph(object):
# Add function to graph
# pylint: disable=protected-access
if self._c_graph:
- assert function._c_func, (
- "Cannot add function created without C API support to graph "
- "created with C API support")
+ # Handle functions created without using the C API. TODO(apassos,skyewm)
+ # remove this when all functions are generated using the C API by default
+ # as this will be unnecessary.
+ if not function._c_func:
+ with errors.raise_exception_on_not_ok_status() as status:
+ serialized = function.definition.SerializeToString()
+ function._c_func = c_api.TF_FunctionImportFunctionDef(
+ serialized, status)
with errors.raise_exception_on_not_ok_status() as status:
gradient = function._grad_func._c_func if function._grad_func else None
c_api.TF_GraphCopyFunction(self._c_graph, function._c_func, gradient,
@@ -3099,12 +3134,11 @@ class Graph(object):
input_types=input_types,
original_op=self._default_original_op,
op_def=op_def)
-
self._create_op_helper(ret, compute_shapes=compute_shapes,
compute_device=compute_device)
return ret
- def _create_op_from_tf_operation(self, c_op):
+ def _create_op_from_tf_operation(self, c_op, compute_device=True):
"""Creates an `Operation` in this graph from the supplied TF_Operation.
This method is like create_op() except the new Operation is constructed
@@ -3114,6 +3148,8 @@ class Graph(object):
Args:
c_op: a wrapped TF_Operation
+ compute_device: (Optional.) If True, device functions will be executed
+ to compute the device property of the Operation.
Returns:
An `Operation` object.
@@ -3124,7 +3160,7 @@ class Graph(object):
for output in tf_outputs)
control_inputs = self._control_dependencies_for_inputs(input_ops)
ret = Operation(c_op, self, control_inputs=control_inputs)
- self._create_op_helper(ret)
+ self._create_op_helper(ret, compute_device=compute_device)
return ret
def _create_op_helper(self, op, compute_shapes=True, compute_device=True):
@@ -3138,6 +3174,8 @@ class Graph(object):
# compute_shapes argument.
if op._c_op or compute_shapes: # pylint: disable=protected-access
set_shapes_for_outputs(op)
+ # TODO(b/XXXX): move to Operation.__init__ once _USE_C_API flag is removed.
+ self._add_op(op)
# Apply any additional attributes requested. Do not overwrite any existing
# attributes.
@@ -4517,15 +4555,11 @@ def control_dependencies(control_inputs):
See @{tf.Graph.control_dependencies}
for more details.
- When eager execution is enabled, any callable object in the `control_inputs`
- list will be called.
-
Args:
control_inputs: A list of `Operation` or `Tensor` objects which
must be executed or computed before running the operations
defined in the context. Can also be `None` to clear the control
- dependencies. If eager execution is enabled, any callable object in the
- `control_inputs` list will be called.
+ dependencies.
Returns:
A context manager that specifies control dependencies for all
@@ -4534,11 +4568,6 @@ def control_dependencies(control_inputs):
if context.in_graph_mode():
return get_default_graph().control_dependencies(control_inputs)
else:
- if control_inputs:
- # Excute any pending callables.
- for control in control_inputs:
- if callable(control):
- control()
return _NullContextmanager()
@@ -4794,6 +4823,16 @@ def enable_eager_execution(config=None, device_policy=None):
or if trying to create a context with nontrivial options which differ
from those of the existing context.
"""
+ if config is not None and not isinstance(config, config_pb2.ConfigProto):
+ raise TypeError(
+ "config must be a tf.ConfigProto, but got %s" % type(config))
+ if device_policy not in (None, context.DEVICE_PLACEMENT_EXPLICIT,
+ context.DEVICE_PLACEMENT_WARN,
+ context.DEVICE_PLACEMENT_SILENT):
+ raise ValueError(
+ "device_policy must be one of None, tfe.DEVICE_PLACEMENT_EXPLICIT, "
+ "tfe.DEVICE_PLACEMENT_WARN, tfe.DEVICE_PLACEMENT_SILENT"
+ )
# pylint: disable=protected-access
if context._default_mode == context.GRAPH_MODE:
graph_mode_has_been_used = (
diff --git a/tensorflow/python/framework/ops_test.py b/tensorflow/python/framework/ops_test.py
index cd296ccdc5..b1ad6ad744 100644
--- a/tensorflow/python/framework/ops_test.py
+++ b/tensorflow/python/framework/ops_test.py
@@ -274,6 +274,7 @@ class OperationTest(test_util.TensorFlowTestCase):
op1 = ops.Operation(
ops._NodeDef("RefOutputFloatOutput", "op1"), g, [],
[dtypes.float32_ref, dtypes.float32])
+ g._add_op(op1)
self.assertProtoEquals("op:'RefOutputFloatOutput' name:'op1'", op1.node_def)
self.assertEquals([], list(op1.inputs))
ref_t, nonref_t = op1.values()
@@ -282,12 +283,14 @@ class OperationTest(test_util.TensorFlowTestCase):
ops._NodeDef("RefInputFloatInput", "op2"),
g, [ref_t, nonref_t], [],
input_types=[dtypes.float32_ref, dtypes.float32])
+ g._add_op(op2)
self.assertProtoEquals(
"op:'RefInputFloatInput' name:'op2' input:'op1' input:'op1:1'",
op2.node_def)
self.assertEquals([ref_t, nonref_t], list(op2.inputs))
op3 = ops.Operation(
ops._NodeDef("TwoFloatInputs", "op3"), g, [ref_t, nonref_t], [])
+ g._add_op(op3)
self.assertProtoEquals(
"op:'TwoFloatInputs' name:'op3' input:'op1' input:'op1:1'",
op3.node_def)
@@ -1537,7 +1540,7 @@ class ControlDependenciesTest(test_util.TensorFlowTestCase):
self.assertEqual(future.calls, 1)
else:
a = constant_op.constant(1.0)
- b = future
+ b = future()
with ops.control_dependencies([a, b]):
c = constant_op.constant(3.0)
self.assertEqual(future.calls, 1)
@@ -1876,6 +1879,24 @@ class GraphTest(test_util.TensorFlowTestCase):
gc.collect()
self.assertIsNone(g_ref())
+ def testRunnableAfterInvalidShape(self):
+ with ops.Graph().as_default():
+ with self.assertRaises(ValueError):
+ math_ops.add([1, 2], [1, 2, 3])
+ a = constant_op.constant(1)
+ with session.Session() as sess:
+ sess.run(a)
+
+ def testRunnableAfterInvalidShapeWithKernelLabelMap(self):
+ g = ops.Graph()
+ with g.as_default():
+ with g._kernel_label_map({"KernelLabelRequired": "overload_1"}):
+ with self.assertRaises(ValueError):
+ test_ops.kernel_label_required(1)
+ a = constant_op.constant(1)
+ with session.Session() as sess:
+ sess.run(a)
+
@test_util.with_c_api
class AttrScopeTest(test_util.TensorFlowTestCase):
@@ -2395,6 +2416,13 @@ class InputTypesTest(test_util.TensorFlowTestCase):
self.assertEqual([dtypes.double, dtypes.double], z.op._input_dtypes)
# pylint: enable=protected-access
+ def testBadArgumentsToEnableEagerExecution(self):
+ with self.assertRaisesRegexp(TypeError, "config must be a tf.ConfigProto"):
+ ops.enable_eager_execution(context.DEVICE_PLACEMENT_SILENT)
+ with self.assertRaisesRegexp(ValueError, "device_policy must be one of"):
+ c = config_pb2.ConfigProto()
+ ops.enable_eager_execution(c, c)
+
if __name__ == "__main__":
googletest.main()
diff --git a/tensorflow/python/framework/test_ops.cc b/tensorflow/python/framework/test_ops.cc
index 25bb7af20c..dbabce0962 100644
--- a/tensorflow/python/framework/test_ops.cc
+++ b/tensorflow/python/framework/test_ops.cc
@@ -26,6 +26,16 @@ REGISTER_OP("KernelLabel")
.Output("result: string")
.SetShapeFn(shape_inference::ScalarShape);
+REGISTER_OP("KernelLabelRequired")
+ .Input("input: int32")
+ .Output("result: string")
+ .SetShapeFn([](shape_inference::InferenceContext* c) {
+ shape_inference::ShapeHandle out;
+ TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &out));
+ c->set_output(0, c->Scalar());
+ return Status::OK();
+ });
+
REGISTER_OP("GraphDefVersion")
.Output("version: int32")
.SetIsStateful()
@@ -104,6 +114,14 @@ REGISTER_KERNEL_BUILDER(Name("KernelLabel")
.Label("overload_2"),
KernelLabelOp<OVERLOAD_2_LABEL>);
+// All "KernelLabelRequired" kernels have labels
+REGISTER_KERNEL_BUILDER(
+ Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_1"),
+ KernelLabelOp<OVERLOAD_1_LABEL>);
+REGISTER_KERNEL_BUILDER(
+ Name("KernelLabelRequired").Device(DEVICE_CPU).Label("overload_2"),
+ KernelLabelOp<OVERLOAD_2_LABEL>);
+
class GraphDefVersionOp : public OpKernel {
public:
explicit GraphDefVersionOp(OpKernelConstruction* ctx)
diff --git a/tensorflow/python/grappler/item.i b/tensorflow/python/grappler/item.i
index 7dd79f7c82..8f72a425c3 100644
--- a/tensorflow/python/grappler/item.i
+++ b/tensorflow/python/grappler/item.i
@@ -120,7 +120,7 @@ static PyObject* TF_GetOpProperties(GItem item) {
Py_RETURN_NONE;
}
tensorflow::grappler::GraphProperties properties(*item);
- tensorflow::Status status = properties.InferStatically();
+ tensorflow::Status status = properties.InferStatically(false);
if (!status.ok()) {
Py_RETURN_NONE;
}
diff --git a/tensorflow/python/grappler/layout_optimizer_test.py b/tensorflow/python/grappler/layout_optimizer_test.py
index 626e0502cb..50735fb567 100644
--- a/tensorflow/python/grappler/layout_optimizer_test.py
+++ b/tensorflow/python/grappler/layout_optimizer_test.py
@@ -190,7 +190,7 @@ class LayoutOptimizerTest(test.TestCase):
self.assertEqual(expected_num_transposes, num_transposes)
self.assertIn('LayoutOptimizerTransposeNHWCToNCHW-Conv2D-Reshape-0',
nodes)
- self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Relu_1-MaxPool_1',
+ self.assertIn('LayoutOptimizerTransposeNCHWToNHWC-Relu_1-MaxPool_1-0',
nodes)
self.assertAllClose(output_val_ref, output_val, atol=1e-3)
diff --git a/tensorflow/python/grappler/model_analyzer.cc b/tensorflow/python/grappler/model_analyzer.cc
index 7d365c3be9..da5b03234e 100644
--- a/tensorflow/python/grappler/model_analyzer.cc
+++ b/tensorflow/python/grappler/model_analyzer.cc
@@ -27,7 +27,7 @@ ModelAnalyzer::ModelAnalyzer(const GrapplerItem& item) : item_(item) {}
Status ModelAnalyzer::GenerateReport(std::ostream& os) {
GraphProperties properties(item_);
- TF_RETURN_IF_ERROR(properties.InferStatically());
+ TF_RETURN_IF_ERROR(properties.InferStatically(false));
for (const auto& node : item_.MainOpsFanin()) {
PrintNodeInfo(node, properties, os);
diff --git a/tensorflow/python/keras/_impl/keras/callbacks_test.py b/tensorflow/python/keras/_impl/keras/callbacks_test.py
index 9c17fbb4a7..79dfcd1bb6 100644
--- a/tensorflow/python/keras/_impl/keras/callbacks_test.py
+++ b/tensorflow/python/keras/_impl/keras/callbacks_test.py
@@ -685,8 +685,8 @@ class KerasCallbacksTest(test.TestCase):
# fit w/o validation data should raise ValueError if histogram_freq > 0
cbs = callbacks_factory(histogram_freq=1)
with self.assertRaises(ValueError):
- model.fit(x_train, y_train, batch_size=BATCH_SIZE,
- callbacks=cbs, epochs=3)
+ model.fit(
+ x_train, y_train, batch_size=BATCH_SIZE, callbacks=cbs, epochs=3)
for cb in cbs:
cb.on_train_end()
@@ -695,8 +695,8 @@ class KerasCallbacksTest(test.TestCase):
# histogram_freq > 0
cbs = callbacks_factory(histogram_freq=1)
with self.assertRaises(ValueError):
- model.fit_generator(data_generator(True), len(x_train), epochs=2,
- callbacks=cbs)
+ model.fit_generator(
+ data_generator(True), len(x_train), epochs=2, callbacks=cbs)
for cb in cbs:
cb.on_train_end()
@@ -705,10 +705,13 @@ class KerasCallbacksTest(test.TestCase):
# histogram_freq > 0
cbs = callbacks_factory(histogram_freq=1)
with self.assertRaises(ValueError):
- model.fit_generator(data_generator(True), len(x_train), epochs=2,
- validation_data=data_generator(False),
- validation_steps=1,
- callbacks=cbs)
+ model.fit_generator(
+ data_generator(True),
+ len(x_train),
+ epochs=2,
+ validation_data=data_generator(False),
+ validation_steps=1,
+ callbacks=cbs)
for cb in cbs:
cb.on_train_end()
diff --git a/tensorflow/python/keras/_impl/keras/utils/io_utils.py b/tensorflow/python/keras/_impl/keras/utils/io_utils.py
index 2003e19a0a..a8fc18c17a 100644
--- a/tensorflow/python/keras/_impl/keras/utils/io_utils.py
+++ b/tensorflow/python/keras/_impl/keras/utils/io_utils.py
@@ -78,7 +78,7 @@ class HDF5Matrix(object):
def __len__(self):
return self.end - self.start
- def __del__(self):
+ def __del__(self):
self._f.close()
def __getitem__(self, key):
diff --git a/tensorflow/python/kernel_tests/BUILD b/tensorflow/python/kernel_tests/BUILD
index 2ec162578c..f6721de32a 100644
--- a/tensorflow/python/kernel_tests/BUILD
+++ b/tensorflow/python/kernel_tests/BUILD
@@ -676,6 +676,7 @@ cuda_py_test(
"//tensorflow/python:gradients",
"//tensorflow/python:state_ops",
"//tensorflow/python:variables",
+ "//tensorflow/python:resource_variable_ops",
],
tags = ["noasan"], # http://b/32635055
)
diff --git a/tensorflow/python/kernel_tests/constant_op_eager_test.py b/tensorflow/python/kernel_tests/constant_op_eager_test.py
index 3b71586b55..8e9d75667d 100644
--- a/tensorflow/python/kernel_tests/constant_op_eager_test.py
+++ b/tensorflow/python/kernel_tests/constant_op_eager_test.py
@@ -237,6 +237,39 @@ class ConstantTest(test.TestCase):
self._testAll((1, x))
self._testAll((x, 1))
+ def testInvalidLength(self):
+
+ class BadList(list):
+
+ def __init__(self):
+ super(BadList, self).__init__([1, 2, 3]) # pylint: disable=invalid-length-returned
+
+ def __len__(self):
+ return -1
+
+ with self.assertRaisesRegexp(ValueError, "should return >= 0"):
+ constant_op.constant([BadList()])
+ with self.assertRaisesRegexp(ValueError, "mixed types"):
+ constant_op.constant([1, 2, BadList()])
+ with self.assertRaisesRegexp(ValueError, "should return >= 0"):
+ constant_op.constant(BadList())
+ with self.assertRaisesRegexp(ValueError, "should return >= 0"):
+ constant_op.constant([[BadList(), 2], 3])
+ with self.assertRaisesRegexp(ValueError, "should return >= 0"):
+ constant_op.constant([BadList(), [1, 2, 3]])
+ with self.assertRaisesRegexp(ValueError, "should return >= 0"):
+ constant_op.constant([BadList(), []])
+
+ # TODO(allenl, josh11b): These cases should return exceptions rather than
+ # working (currently shape checking only checks the first element of each
+ # sequence recursively). Maybe the first one is fine, but the second one
+ # silently truncating is rather bad.
+
+ # with self.assertRaisesRegexp(ValueError, "should return >= 0"):
+ # constant_op.constant([[3, 2, 1], BadList()])
+ # with self.assertRaisesRegexp(ValueError, "should return >= 0"):
+ # constant_op.constant([[], BadList()])
+
def testSparseValuesRaiseErrors(self):
with self.assertRaisesRegexp(ValueError, "non-rectangular Python sequence"):
constant_op.constant([[1, 2], [3]], dtype=dtypes_lib.int32)
diff --git a/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py b/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
index 1679857bd5..be299beee4 100644
--- a/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
+++ b/tensorflow/python/kernel_tests/conv2d_backprop_filter_grad_test.py
@@ -42,17 +42,21 @@ class Conv2DBackpropFilterGradTest(test.TestCase):
filter_shape = [3, 3, 4, 6]
# Make a convolution op with the current settings, just to easily get
# the shape of the output.
- conv_out = nn_ops.conv2d(in_val,
- array_ops.zeros(filter_shape),
- [1, stride, stride, 1], padding)
+ conv_out = nn_ops.conv2d(
+ in_val,
+ array_ops.zeros(filter_shape),
+ strides=[1, stride, stride, 1],
+ padding=padding)
out_backprop_shape = conv_out.get_shape().as_list()
out_backprop_val = constant_op.constant(
2 * np.random.random_sample(out_backprop_shape) - 1,
dtype=dtypes.float32)
- output = nn_ops.conv2d_backprop_filter(in_val, filter_shape,
- out_backprop_val,
- [1, stride, stride, 1],
- padding)
+ output = nn_ops.conv2d_backprop_filter(
+ in_val,
+ filter_shape,
+ out_backprop_val,
+ strides=[1, stride, stride, 1],
+ padding=padding)
err = gradient_checker.compute_gradient_error(
[in_val, out_backprop_val], [in_shape, out_backprop_shape],
output, filter_shape)
@@ -60,6 +64,42 @@ class Conv2DBackpropFilterGradTest(test.TestCase):
err_tolerance = 2e-3
self.assertLess(err, err_tolerance)
+ def testGradientDilatedConv(self):
+ if test.is_gpu_available(cuda_only=True):
+ with self.test_session(use_gpu=True):
+ for padding in ["SAME", "VALID"]:
+ for stride in [1, 2]:
+ np.random.seed(1)
+ in_shape = [5, 8, 6, 4]
+ in_val = constant_op.constant(
+ 2 * np.random.random_sample(in_shape) - 1, dtype=dtypes.float32)
+ filter_shape = [3, 3, 4, 6]
+ # Make a convolution op with the current settings,
+ # just to easily get the shape of the output.
+ conv_out = nn_ops.conv2d(
+ in_val,
+ array_ops.zeros(filter_shape),
+ dilations=[1, 2, 2, 1],
+ strides=[1, stride, stride, 1],
+ padding=padding)
+ out_backprop_shape = conv_out.get_shape().as_list()
+ out_backprop_val = constant_op.constant(
+ 2 * np.random.random_sample(out_backprop_shape) - 1,
+ dtype=dtypes.float32)
+ output = nn_ops.conv2d_backprop_filter(
+ in_val,
+ filter_shape,
+ out_backprop_val,
+ dilations=[1, 2, 2, 1],
+ strides=[1, stride, stride, 1],
+ padding=padding)
+ err = gradient_checker.compute_gradient_error(
+ [in_val, out_backprop_val], [in_shape, out_backprop_shape],
+ output, filter_shape)
+ print("conv2d_backprop_filter gradient err = %g " % err)
+ err_tolerance = 2e-3
+ self.assertLess(err, err_tolerance)
+
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/conv_ops_test.py b/tensorflow/python/kernel_tests/conv_ops_test.py
index 22e5400c37..bf7245a2ae 100644
--- a/tensorflow/python/kernel_tests/conv_ops_test.py
+++ b/tensorflow/python/kernel_tests/conv_ops_test.py
@@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import collections
import os
import time
@@ -32,6 +33,7 @@ from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradient_checker
+from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import nn_impl
from tensorflow.python.ops import nn_ops
from tensorflow.python.ops import random_ops
@@ -240,6 +242,77 @@ class Conv2DTest(test.TestCase):
for i in range(1, len(values)):
self.assertAllClose(values[0], values[i], rtol=1e-5, atol=1e-5)
+ def _ComputeReferenceDilatedConv(self, tensor_in_sizes, filter_in_sizes,
+ stride, dilation, padding, data_format,
+ use_gpu):
+ total_size_1 = 1
+ total_size_2 = 1
+ for s in tensor_in_sizes:
+ total_size_1 *= s
+ for s in filter_in_sizes:
+ total_size_2 *= s
+
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x1 = [f * 1.0 for f in range(1, total_size_1 + 1)]
+ x2 = [f * 1.0 for f in range(1, total_size_2 + 1)]
+ with test_util.device(use_gpu):
+ t1 = constant_op.constant(x1, shape=tensor_in_sizes)
+ t2 = constant_op.constant(x2, shape=filter_in_sizes)
+ if isinstance(stride, collections.Iterable):
+ strides = list(stride)
+ else:
+ strides = [stride, stride]
+ if data_format == "NCHW":
+ t1 = test_util.NHWCToNCHW(t1)
+ full_strides = [1, 1] + strides
+ full_dilation = [1, 1] + dilation
+ else:
+ full_strides = [1] + strides + [1]
+ full_dilation = [1] + dilation + [1]
+ expected = nn_ops.convolution(
+ t1,
+ t2,
+ padding=padding,
+ strides=strides,
+ dilation_rate=dilation,
+ data_format=data_format)
+ computed = nn_ops.conv2d(
+ t1,
+ t2,
+ strides=full_strides,
+ dilations=full_dilation,
+ padding=padding,
+ data_format=data_format)
+ if data_format == "NCHW":
+ expected = test_util.NCHWToNHWC(expected)
+ computed = test_util.NCHWToNHWC(computed)
+ return expected, computed
+
+ def _VerifyDilatedConvValues(self, tensor_in_sizes, filter_in_sizes, strides,
+ padding, dilations):
+ expected_results = []
+ computed_results = []
+ default_dilations = (dilations[0] == 1 and dilations[1] == 1)
+ for data_format, use_gpu in GetTestConfigs():
+ # If any dilation rate is larger than 1, only do test on the GPU
+ # because we currently do not have a CPU implementation for arbitrary
+ # dilation rates.
+ if default_dilations or use_gpu:
+ expected, computed = self._ComputeReferenceDilatedConv(
+ tensor_in_sizes, filter_in_sizes, strides, dilations, padding,
+ data_format, use_gpu)
+ expected_results.append(expected)
+ computed_results.append(computed)
+ tolerance = 1e-2 if use_gpu else 1e-5
+ expected_values = self.evaluate(expected_results)
+ computed_values = self.evaluate(computed_results)
+ for e_value, c_value in zip(expected_values, computed_values):
+ print("expected = ", e_value)
+ print("actual = ", c_value)
+ self.assertAllClose(
+ e_value.flatten(), c_value.flatten(), atol=tolerance, rtol=1e-6)
+
def _VerifyValues(self, tensor_in_sizes, filter_in_sizes, strides, padding,
expected):
tensors = []
@@ -280,6 +353,16 @@ class Conv2DTest(test.TestCase):
expected=expected_output)
@test_util.run_in_graph_and_eager_modes()
+ def testConv2D2x2Filter2x1Dilation(self):
+ if test.is_gpu_available(cuda_only=True):
+ self._VerifyDilatedConvValues(
+ tensor_in_sizes=[1, 4, 4, 1],
+ filter_in_sizes=[2, 2, 1, 1],
+ strides=[1, 1],
+ dilations=[2, 1],
+ padding="VALID")
+
+ @test_util.run_in_graph_and_eager_modes()
def testConv2DEmpty(self):
expected_output = []
self._VerifyValues(
@@ -290,6 +373,16 @@ class Conv2DTest(test.TestCase):
expected=expected_output)
@test_util.run_in_graph_and_eager_modes()
+ def testConv2DEmptyDilation(self):
+ if test.is_gpu_available(cuda_only=True):
+ self._VerifyDilatedConvValues(
+ tensor_in_sizes=[0, 2, 3, 3],
+ filter_in_sizes=[1, 1, 3, 3],
+ strides=[1, 1],
+ dilations=[2, 1],
+ padding="VALID")
+
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D2x2Filter(self):
# The outputs are computed using third_party/py/IPython/notebook.
expected_output = [2271.0, 2367.0, 2463.0, 2901.0, 3033.0, 3165.0]
@@ -301,6 +394,16 @@ class Conv2DTest(test.TestCase):
expected=expected_output)
@test_util.run_in_graph_and_eager_modes()
+ def testConv2D2x2FilterDilation(self):
+ if test.is_gpu_available(cuda_only=True):
+ self._VerifyDilatedConvValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[2, 2, 3, 3],
+ strides=[1, 1],
+ dilations=[1, 2],
+ padding="VALID")
+
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D1x2Filter(self):
# The outputs are computed using third_party/py/IPython/notebook.
expected_output = [
@@ -315,6 +418,16 @@ class Conv2DTest(test.TestCase):
expected=expected_output)
@test_util.run_in_graph_and_eager_modes()
+ def testConv2D1x2FilterDilation(self):
+ if test.is_gpu_available(cuda_only=True):
+ self._VerifyDilatedConvValues(
+ tensor_in_sizes=[1, 2, 3, 3],
+ filter_in_sizes=[1, 2, 3, 3],
+ strides=[1, 1],
+ dilations=[2, 1],
+ padding="VALID")
+
+ @test_util.run_in_graph_and_eager_modes()
def testConv2D2x2FilterStride2(self):
expected_output = [2271.0, 2367.0, 2463.0]
self._VerifyValues(
@@ -386,13 +499,23 @@ class Conv2DTest(test.TestCase):
padding="VALID",
expected=[50, 60])
- # TODO this currently fails.
- # self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
- # filter_in_sizes=[2, 2, 1, 1],
- # strides=[4, 4], padding="SAME",
- # expected=[72, 112, 392, 432])
+ @test_util.run_in_graph_and_eager_modes()
+ def testConv2DKernelSizeMatchesInputSizeDilation(self):
+ if test.is_gpu_available(cuda_only=True):
+ self._VerifyDilatedConvValues(
+ tensor_in_sizes=[1, 3, 3, 1],
+ filter_in_sizes=[2, 2, 1, 2],
+ strides=[1, 1],
+ dilations=[2, 2],
+ padding="VALID")
+
+ # TODO this currently fails.
+ # self._VerifyValues(tensor_in_sizes=[1, 8, 8, 1],
+ # filter_in_sizes=[2, 2, 1, 1],
+ # strides=[4, 4], padding="SAME",
+ # expected=[72, 112, 392, 432])
- # Testing for backprops
+ # Testing for backprops
def _RunAndVerifyBackpropInput(self, input_sizes, filter_sizes, output_sizes,
strides, padding, expected, data_format,
use_gpu, err):
@@ -724,6 +847,255 @@ class Conv2DTest(test.TestCase):
data_format=data_format,
use_gpu=use_gpu)
+ # Testing for backprops
+ def _RunAndVerifyBackpropInputDilation(self, input_sizes, filter_sizes,
+ output_sizes, strides, dilations,
+ padding, data_format, use_gpu, err):
+ total_input_size = 1
+ total_filter_size = 1
+ for s in input_sizes:
+ total_input_size *= s
+ for s in filter_sizes:
+ total_filter_size *= s
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x1 = [f * 1.0 for f in range(1, total_input_size + 1)]
+ x2 = [f * 1.0 for f in range(1, total_filter_size + 1)]
+ default_dilations = (dilations[0] == 1 and dilations[1] == 1)
+ if default_dilations or use_gpu:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ if data_format == "NCHW":
+ input_sizes = test_util.NHWCToNCHW(input_sizes)
+ t1 = constant_op.constant(x1, shape=input_sizes)
+ t2 = constant_op.constant(x2, shape=filter_sizes)
+ full_strides = [1] + strides + [1]
+ full_dilations = [1] + dilations + [1]
+ if data_format == "NCHW":
+ full_strides = test_util.NHWCToNCHW(full_strides)
+ full_dilations = test_util.NHWCToNCHW(full_dilations)
+ conv_forward = nn_ops.conv2d(
+ t1,
+ t2,
+ strides=full_strides,
+ dilations=full_dilations,
+ padding=padding,
+ data_format=data_format)
+ conv_forward_2 = nn_ops.convolution(
+ t1,
+ t2,
+ padding=padding,
+ strides=strides,
+ dilation_rate=dilations,
+ data_format=data_format)
+ if data_format == "NCHW":
+ conv_forward = test_util.NCHWToNHWC(conv_forward)
+ conv_forward_2 = test_util.NCHWToNHWC(conv_forward_2)
+ conv = gradients_impl.gradients(conv_forward, t1)[0]
+ conv_2 = gradients_impl.gradients(conv_forward_2, t1)[0]
+ # "values" consists of two tensors for two backprops
+ value = sess.run(conv)
+ value_2 = sess.run(conv_2)
+ self.assertShapeEqual(value, conv)
+ self.assertShapeEqual(value_2, conv_2)
+ print("expected = ", value_2)
+ print("actual = ", value)
+ self.assertArrayNear(value_2.flatten(), value.flatten(), err)
+
+ # Testing for backprops
+ def _RunAndVerifyBackpropFilterDilation(self, input_sizes, filter_sizes,
+ output_sizes, strides, dilations,
+ padding, data_format, use_gpu, err):
+ total_input_size = 1
+ total_filter_size = 1
+ for s in input_sizes:
+ total_input_size *= s
+ for s in filter_sizes:
+ total_filter_size *= s
+ # Initializes the input tensor with array containing incrementing
+ # numbers from 1.
+ x1 = [f * 1.0 for f in range(1, total_input_size + 1)]
+ x2 = [f * 1.0 for f in range(1, total_filter_size + 1)]
+ default_dilations = (dilations[0] == 1 and dilations[1] == 1)
+ if default_dilations or use_gpu:
+ with self.test_session(use_gpu=use_gpu) as sess:
+ if data_format == "NCHW":
+ input_sizes = test_util.NHWCToNCHW(input_sizes)
+ t1 = constant_op.constant(x1, shape=input_sizes)
+ t2 = constant_op.constant(x2, shape=filter_sizes)
+ full_strides = [1] + strides + [1]
+ full_dilations = [1] + dilations + [1]
+ if data_format == "NCHW":
+ full_strides = test_util.NHWCToNCHW(full_strides)
+ full_dilations = test_util.NHWCToNCHW(full_dilations)
+ conv_forward = nn_ops.conv2d(
+ t1,
+ t2,
+ strides=full_strides,
+ dilations=full_dilations,
+ padding=padding,
+ data_format=data_format)
+ conv_forward_2 = nn_ops.convolution(
+ t1,
+ t2,
+ padding=padding,
+ strides=strides,
+ dilation_rate=dilations,
+ data_format=data_format)
+ if data_format == "NCHW":
+ conv_forward = test_util.NCHWToNHWC(conv_forward)
+ conv_forward_2 = test_util.NCHWToNHWC(conv_forward_2)
+ conv = gradients_impl.gradients(conv_forward, t2)[0]
+ conv_2 = gradients_impl.gradients(conv_forward, t2)[0]
+ value = sess.run(conv)
+ value_2 = sess.run(conv_2)
+ self.assertShapeEqual(value, conv)
+ self.assertShapeEqual(value_2, conv_2)
+ print("expected = ", value_2)
+ print("actual = ", value)
+ self.assertArrayNear(value_2.flatten(), value.flatten(), err)
+
+ def testConv2D2x2Depth3ValidBackpropFilterStride1x1Dilation2x1(self):
+ if test.is_gpu_available(cuda_only=True):
+ for (data_format, use_gpu) in GetTestConfigs():
+ self._RunAndVerifyBackpropFilterDilation(
+ input_sizes=[1, 3, 6, 1],
+ filter_sizes=[2, 2, 1, 1],
+ output_sizes=[1, 1, 5, 1],
+ strides=[1, 1],
+ dilations=[2, 1],
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu,
+ err=1e-5)
+
+ def testConv2D2x2Depth1ValidBackpropFilterDilation1x2(self):
+ if test.is_gpu_available(cuda_only=True):
+ for (data_format, use_gpu) in GetTestConfigs():
+ self._RunAndVerifyBackpropFilterDilation(
+ input_sizes=[1, 2, 3, 1],
+ filter_sizes=[2, 2, 1, 1],
+ output_sizes=[1, 1, 2, 1],
+ strides=[1, 1],
+ dilations=[1, 2],
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu,
+ err=1e-5)
+
+ def testConv2DEmptyBackpropFilterDilation1x2(self):
+ if test.is_gpu_available(cuda_only=True):
+ for (data_format, use_gpu) in GetTestConfigs():
+ self._RunAndVerifyBackpropFilterDilation(
+ input_sizes=[1, 2, 3, 1],
+ filter_sizes=[2, 2, 1, 0],
+ output_sizes=[1, 1, 2, 0],
+ strides=[1, 1],
+ dilations=[1, 2],
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu,
+ err=1e-5)
+
+ def testConv2D2x2Depth3ValidBackpropFilterDilation2x2(self):
+ if test.is_gpu_available(cuda_only=True):
+ for (data_format, use_gpu) in GetTestConfigs():
+ self._RunAndVerifyBackpropFilterDilation(
+ input_sizes=[1, 3, 4, 3],
+ filter_sizes=[2, 2, 3, 3],
+ output_sizes=[1, 1, 2, 3],
+ strides=[1, 1],
+ dilations=[2, 2],
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu,
+ err=1e-5)
+
+ def testConv2DKernelSizeMatchesInputSizeBackpropFilterDilation2x2(self):
+ if test.is_gpu_available(cuda_only=True):
+ for (data_format, use_gpu) in GetTestConfigs():
+ self._RunAndVerifyBackpropFilterDilation(
+ input_sizes=[1, 3, 3, 1],
+ filter_sizes=[2, 2, 1, 2],
+ output_sizes=[1, 1, 1, 2],
+ strides=[1, 1],
+ dilations=[2, 2],
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu,
+ err=1e-5)
+
+ def testConv2D2x2Depth3ValidBackpropInputStride1x1Dilation2x1(self):
+ if test.is_gpu_available(cuda_only=True):
+ for (data_format, use_gpu) in GetTestConfigs():
+ self._RunAndVerifyBackpropInputDilation(
+ input_sizes=[1, 3, 6, 1],
+ filter_sizes=[2, 2, 1, 1],
+ output_sizes=[1, 1, 5, 1],
+ strides=[1, 1],
+ dilations=[2, 1],
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu,
+ err=1e-5)
+
+ def testConv2D2x2Depth1ValidBackpropInputDilation1x2(self):
+ if test.is_gpu_available(cuda_only=True):
+ for (data_format, use_gpu) in GetTestConfigs():
+ self._RunAndVerifyBackpropInputDilation(
+ input_sizes=[1, 2, 3, 1],
+ filter_sizes=[2, 2, 1, 1],
+ output_sizes=[1, 1, 2, 1],
+ strides=[1, 1],
+ dilations=[1, 2],
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu,
+ err=1e-5)
+
+ def testConv2DEmptyBackpropInputDilation1x2(self):
+ if test.is_gpu_available(cuda_only=True):
+ for (data_format, use_gpu) in GetTestConfigs():
+ self._RunAndVerifyBackpropInputDilation(
+ input_sizes=[0, 2, 3, 1],
+ filter_sizes=[2, 2, 1, 1],
+ output_sizes=[0, 1, 2, 1],
+ strides=[1, 1],
+ dilations=[1, 2],
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu,
+ err=1e-5)
+
+ def testConv2D2x2Depth3ValidBackpropInputDilation2x1(self):
+ if test.is_gpu_available(cuda_only=True):
+ for (data_format, use_gpu) in GetTestConfigs():
+ # The GPU version of this test is not very stable. So adjusting the
+ # error threshold to 1e-4.
+ self._RunAndVerifyBackpropInputDilation(
+ input_sizes=[1, 3, 2, 3],
+ filter_sizes=[2, 2, 3, 3],
+ output_sizes=[1, 1, 2, 3],
+ strides=[1, 1],
+ dilations=[2, 1],
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu,
+ err=1e-4)
+
+ def testConv2DKernelSizeMatchesInputSizeBackpropInputDilation2x2(self):
+ if test.is_gpu_available(cuda_only=True):
+ for (data_format, use_gpu) in GetTestConfigs():
+ self._RunAndVerifyBackpropInputDilation(
+ input_sizes=[1, 3, 3, 1],
+ filter_sizes=[2, 2, 1, 2],
+ output_sizes=[1, 1, 1, 2],
+ strides=[1, 1],
+ dilations=[2, 2],
+ padding="VALID",
+ data_format=data_format,
+ use_gpu=use_gpu,
+ err=1e-5)
+
# Gradient checkers
def ConstructAndTestGradient(self, batch, input_rows, input_cols, filter_rows,
filter_cols, in_depth, out_depth, stride_rows,
@@ -1457,6 +1829,22 @@ def GetInceptionFwdTest(input_size, filter_size, stride, padding,
return Test
+def GetInceptionFwdDilatedConvTest(input_size, filter_size, stride, padding):
+
+ def Test(self):
+ if test.is_gpu_available(cuda_only=True) and stride == 1:
+ tf_logging.info("Testing InceptionFwd with dilations %s",
+ (input_size, filter_size, stride, padding))
+ self._VerifyDilatedConvValues(
+ tensor_in_sizes=input_size,
+ filter_in_sizes=filter_size,
+ strides=[stride, stride],
+ dilations=[2, 2],
+ padding=padding)
+
+ return Test
+
+
def GetInceptionBackInputTest(input_size, filter_size, output_size, stride,
padding,
gpu_only=False):
@@ -1497,6 +1885,10 @@ if __name__ == "__main__":
test_util.run_in_graph_and_eager_modes()(
GetInceptionFwdTest(input_size_, filter_size_, stride_,
padding_)))
+ setattr(
+ Conv2DTest, "testInceptionFwdDilatedConv_" + str(index),
+ test_util.run_in_graph_and_eager_modes()(GetInceptionFwdDilatedConvTest(
+ input_size_, filter_size_, stride_, padding_)))
setattr(Conv2DTest, "testInceptionBackInput_" + str(index),
test_util.run_in_graph_and_eager_modes()(
GetInceptionBackInputTest(input_size_, filter_size_,
@@ -1519,6 +1911,9 @@ if __name__ == "__main__":
setattr(Conv2DTest, "testInceptionFwd_No_Winograd_Nonfused",
test_util.run_in_graph_and_eager_modes()(
GetInceptionFwdTest(ishape, fshape, 1, "SAME", gpu_only=True)))
+ setattr(Conv2DTest, "testInceptionFwdDilatedConv_No_Winograd_Nonfused",
+ test_util.run_in_graph_and_eager_modes()(
+ GetInceptionFwdDilatedConvTest(ishape, fshape, 1, "SAME")))
setattr(Conv2DTest, "testInceptionBackInput_No_Winograd_Nonfused",
test_util.run_in_graph_and_eager_modes()(
GetInceptionBackInputTest(ishape, fshape, oshape, 1, "SAME",
diff --git a/tensorflow/python/kernel_tests/decode_bmp_op_test.py b/tensorflow/python/kernel_tests/decode_bmp_op_test.py
index c086f46170..c67c26b7be 100644
--- a/tensorflow/python/kernel_tests/decode_bmp_op_test.py
+++ b/tensorflow/python/kernel_tests/decode_bmp_op_test.py
@@ -68,28 +68,68 @@ class DecodeBmpOpTest(test.TestCase):
def testGrayscale(self):
img_bytes = [[[255], [0]], [[255], [0]]]
encoded_bytes = [
- 0x42, 0x40,
- 0x3d, 0, 0, 0,
- 0, 0,
- 0, 0,
- 0x36, 0, 0, 0,
- 0x28, 0, 0, 0,
- 0x2, 0, 0, 0,
- 0x2, 0, 0, 0,
- 0x1, 0,
- 0x8, 0,
- 0, 0, 0, 0,
- 0x10, 0, 0, 0,
- 0x13, 0xb, 0, 0,
- 0x13, 0xb, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
+ 0x42,
+ 0x40,
+ 0x3d,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0x36,
+ 0,
+ 0,
+ 0,
+ 0x28,
+ 0,
+ 0,
+ 0,
+ 0x2,
+ 0,
+ 0,
+ 0,
+ 0x2,
+ 0,
+ 0,
+ 0,
+ 0x1,
+ 0,
+ 0x8,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0x10,
+ 0,
+ 0,
+ 0,
+ 0x13,
+ 0xb,
+ 0,
+ 0,
+ 0x13,
+ 0xb,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
+ 0,
0xff,
0,
- 0, 0,
+ 0,
+ 0,
0xff,
0,
- 0, 0,
+ 0,
+ 0,
]
byte_string = bytes(bytearray(encoded_bytes))
@@ -100,54 +140,6 @@ class DecodeBmpOpTest(test.TestCase):
decoded = decode.eval()
self.assertAllEqual(decoded, img_bytes)
- def testIncompleteHeader(self):
- # Encoded BMP bytes from Wikipedia
- encoded_bytes = [
- 0x42, 0x40,
- 0x46, 0, 0, 0,
- ]
-
- byte_string = bytes(bytearray(encoded_bytes))
- img_in = constant_op.constant(byte_string, dtype=dtypes.string)
- decode = array_ops.squeeze(image_ops.decode_bmp(img_in))
-
- with self.test_session():
- with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- "requires at least 32 bytes to find the header"):
- decoded = decode.eval()
-
- def testIncompleteBody(self):
- # Encoded BMP bytes from Wikipedia
- encoded_bytes = [
- 0x42, 0x40,
- 0x46, 0, 0, 0,
- 0, 0,
- 0, 0,
- 0x36, 0, 0, 0,
- 0x28, 0, 0, 0,
- 0x2, 0, 0, 0,
- 0x2, 0, 0, 0,
- 0x1, 0,
- 0x18, 0,
- 0, 0, 0, 0,
- 0x10, 0, 0, 0,
- 0x13, 0xb, 0, 0,
- 0x13, 0xb, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0, 0,
- 0, 0, 0xff,
- 0xff, 0xff, 0xff,
- 0, 0,
- ]
-
- byte_string = bytes(bytearray(encoded_bytes))
- img_in = constant_op.constant(byte_string, dtype=dtypes.string)
- decode = array_ops.squeeze(image_ops.decode_bmp(img_in))
-
- with self.test_session():
- with self.assertRaisesRegexp(errors_impl.InvalidArgumentError,
- "requires at least 68 bytes, got 62 bytes"):
- decoded = decode.eval()
if __name__ == "__main__":
test.main()
diff --git a/tensorflow/python/kernel_tests/prefetch_dataset_op_test.py b/tensorflow/python/kernel_tests/prefetch_dataset_op_test.py
index edea9c9027..646324cb95 100644
--- a/tensorflow/python/kernel_tests/prefetch_dataset_op_test.py
+++ b/tensorflow/python/kernel_tests/prefetch_dataset_op_test.py
@@ -25,10 +25,11 @@ from tensorflow.python.platform import test
class PrefetchDatasetTest(test.TestCase):
+
def testBufferSize(self):
buffer_size = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(10).prefetch(
- buffer_size=buffer_size).make_initializable_iterator()
+ buffer_size=buffer_size).make_initializable_iterator()
init_op = iterator.initializer
get_next = iterator.get_next()
@@ -42,7 +43,7 @@ class PrefetchDatasetTest(test.TestCase):
def testInvalidBufferSize(self):
buffer_size = array_ops.placeholder(dtypes.int64, shape=[])
iterator = dataset_ops.Dataset.range(10).prefetch(
- buffer_size=buffer_size).make_initializable_iterator()
+ buffer_size=buffer_size).make_initializable_iterator()
init_op = iterator.initializer
with self.assertRaisesRegexp(errors.InvalidArgumentError, "buffer_size"):
diff --git a/tensorflow/python/kernel_tests/random/multinomial_op_test.py b/tensorflow/python/kernel_tests/random/multinomial_op_test.py
index ca48ba6cad..a9dc7b7de0 100644
--- a/tensorflow/python/kernel_tests/random/multinomial_op_test.py
+++ b/tensorflow/python/kernel_tests/random/multinomial_op_test.py
@@ -57,12 +57,14 @@ class MultinomialTest(test.TestCase):
@test_util.run_in_graph_and_eager_modes()
def testSmallEntropy(self):
random_seed.set_random_seed(1618)
- with test_util.device(use_gpu=True):
- # A logit value of -10 corresponds to a probability of ~5e-5.
- logits = constant_op.constant([[-10., 10., -10.], [-10., -10., 10.]])
- num_samples = 1000
- samples = self.evaluate(random_ops.multinomial(logits, num_samples))
- self.assertAllEqual([[1] * num_samples, [2] * num_samples], samples)
+ for output_dtype in [np.int32, np.int64]:
+ with test_util.device(use_gpu=True):
+ # A logit value of -10 corresponds to a probability of ~5e-5.
+ logits = constant_op.constant([[-10., 10., -10.], [-10., -10., 10.]])
+ num_samples = 1000
+ samples = self.evaluate(random_ops.multinomial(
+ logits, num_samples, output_dtype=output_dtype))
+ self.assertAllEqual([[1] * num_samples, [2] * num_samples], samples)
def testOneOpMultipleStepsIndependent(self):
with self.test_session(use_gpu=True) as sess:
diff --git a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
index a79d66e988..d7bde04230 100644
--- a/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
+++ b/tensorflow/python/kernel_tests/scatter_nd_ops_test.py
@@ -27,6 +27,7 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gradients_impl
+from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
@@ -157,6 +158,20 @@ class StatefulScatterNdTest(test.TestCase):
result = sess.run(scatter)
self.assertAllClose(result, expected)
+ def testSimpleResource(self):
+ indices = constant_op.constant([[4], [3], [1], [7]], dtype=dtypes.int32)
+ updates = constant_op.constant([9, 10, 11, 12], dtype=dtypes.float32)
+ ref = resource_variable_ops.ResourceVariable(
+ [0, 0, 0, 0, 0, 0, 0, 0], dtype=dtypes.float32)
+ expected = np.array([0, 11, 0, 10, 9, 0, 0, 12])
+ scatter = state_ops.scatter_nd_update(ref, indices, updates)
+ init = variables.global_variables_initializer()
+
+ with self.test_session(use_gpu=True) as sess:
+ sess.run(init)
+ sess.run(scatter)
+ self.assertAllClose(ref.eval(), expected)
+
def testSimple2(self):
indices = constant_op.constant([[1, 0], [1, 1]], dtype=dtypes.int32)
updates = constant_op.constant([11., 12.], dtype=dtypes.float32)
diff --git a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
index 99f9f09690..fd58cdb170 100644
--- a/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
+++ b/tensorflow/python/kernel_tests/segment_reduction_ops_test.py
@@ -266,6 +266,27 @@ class UnsortedSegmentSumTest(SegmentReductionHelper):
self.assertAllClose(np_ans, tf_ans)
self.assertShapeEqual(np_ans, s)
+ def testNumSegmentsTypes(self):
+ dtypes = [dtypes_lib.int32, dtypes_lib.int64]
+ indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
+ num_segments = 12
+ for indices in indices_flat, indices_flat.reshape(5, 2):
+ shape = indices.shape + (2,)
+ for dtype in dtypes:
+ with self.test_session(use_gpu=True):
+ tf_x, np_x = self._input(shape)
+ num_segments_constant = constant_op.constant(
+ num_segments, dtype=dtype)
+ np_ans = self._segmentReduce(
+ indices, np_x, np.add, op2=None, num_out_rows=num_segments)
+ s = math_ops.unsorted_segment_sum(
+ data=tf_x,
+ segment_ids=indices,
+ num_segments=num_segments_constant)
+ tf_ans = s.eval()
+ self.assertAllClose(np_ans, tf_ans)
+ self.assertShapeEqual(np_ans, s)
+
def testGradientSegmentSum(self):
num_cols = 2
indices_flat = np.array([0, 4, 0, 8, 3, 8, 4, 7, 7, 3])
diff --git a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
index 78c113f514..d1a90952c7 100644
--- a/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
+++ b/tensorflow/python/kernel_tests/sparse_serialization_ops_test.py
@@ -254,8 +254,8 @@ class SerializeSparseTest(test.TestCase):
serialized_concat, dtype=dtypes.int32)
with self.assertRaisesOpError(
- r"Inconsistent rank across SparseTensors: rank prior to "
- r"SparseTensor\[1\] was: 3 but rank of SparseTensor\[1\] is: 4"):
+ r"Inconsistent shape across SparseTensors: rank prior to "
+ r"SparseTensor\[1\] was: 2 but rank of SparseTensor\[1\] is: 3"):
sess.run(sp_deserialized,
{sp_input0: input0_val,
sp_input1: input1_val})
diff --git a/tensorflow/python/kernel_tests/template_test.py b/tensorflow/python/kernel_tests/template_test.py
index 40c0ade62a..f0354374ac 100644
--- a/tensorflow/python/kernel_tests/template_test.py
+++ b/tensorflow/python/kernel_tests/template_test.py
@@ -34,9 +34,10 @@ from tensorflow.python.platform import test
from tensorflow.python.training import gradient_descent
-def variable_scoped_function():
+def variable_scoped_function(trainable=True):
return variable_scope.get_variable(
- "dummy", shape=[1], initializer=init_ops.zeros_initializer())
+ "dummy", shape=[1], trainable=trainable,
+ initializer=init_ops.zeros_initializer())
def internally_variable_scoped_function(scope_name):
@@ -413,7 +414,7 @@ class TemplateTest(test.TestCase):
self.assertEqual(custom_getter_count[0], 2)
# Test that custom getter is called when the variable scope is created
- # during construction
+ # during construction
custom_getter_count[0] = 0
tmpl2 = template.make_template(
"s2",
@@ -539,6 +540,36 @@ class TemplateTest(test.TestCase):
# Ensure we can get the scopes before either template is actually called.
self.assertEqual(1, len(ta.trainable_variables))
self.assertEqual(1, len(tb.trainable_variables))
+ # None non-trainable variable was created.
+ self.assertEqual([], list(ta.non_trainable_variables))
+ self.assertEqual([], list(tb.non_trainable_variables))
+ # Ensure variables returns all the variables.
+ self.assertEqual(1, len(ta.variables))
+ self.assertEqual(1, len(tb.variables))
+
+ @test_util.run_in_graph_and_eager_modes()
+ def test_non_trainable_variables(self):
+ # Make sure non_trainable_variables are created.
+ with variable_scope.variable_scope("foo2"):
+ ta = template.make_template("a", variable_scoped_function,
+ trainable=True)
+ tb = template.make_template("b", variable_scoped_function,
+ trainable=False)
+ # Initially there are not variables created.
+ self.assertEqual([], list(ta.variables))
+ self.assertEqual([], list(tb.variables))
+ # After calling there are variables created.
+ ta()
+ tb()
+ # Check the trainable and non_trainable variables.
+ self.assertEqual(1, len(ta.trainable_variables))
+ self.assertEqual([], list(ta.non_trainable_variables))
+
+ self.assertEqual([], list(tb.trainable_variables))
+ self.assertEqual(1, len(tb.non_trainable_variables))
+ # Ensure variables returns all the variables.
+ self.assertEqual(1, len(ta.variables))
+ self.assertEqual(1, len(tb.variables))
# TODO(apassos) handle local variables in Eager
def test_local_variables(self):
diff --git a/tensorflow/python/layers/base.py b/tensorflow/python/layers/base.py
index 6be2bc3e76..c083f8a5d2 100644
--- a/tensorflow/python/layers/base.py
+++ b/tensorflow/python/layers/base.py
@@ -103,10 +103,16 @@ class Layer(object):
self.built = False
self.input_spec = None
+ if activity_regularizer and context.in_eager_mode():
+ raise ValueError(
+ ('Activity regularization is not supported when executing eagerly. '
+ 'Got activity_regularizer=%s') % (activity_regularizer,))
self._activity_regularizer = activity_regularizer
self._trainable_weights = []
self._non_trainable_weights = []
self._updates = []
+ # When executing eagerly, _losses is a list of zero-argument lambdas which
+ # return tensors. When using graph execution, _losses is a list of ops.
self._losses = []
self._reuse = kwargs.get('_reuse')
self._graph = ops.get_default_graph()
@@ -287,9 +293,22 @@ class Layer(object):
@property
def losses(self):
+ """Losses which are associated with this `Layer`.
+
+ Note that when executing eagerly, getting this property evaluates
+ regularizers. When using graph execution, variable regularization ops have
+ already been created and are simply returned here.
+
+ Returns:
+ A list of tensors.
+ """
if context.in_eager_mode():
- raise RuntimeError('Layer.losses not supported in Eager mode.')
- return self._losses
+ # _losses may only contain variable regularization losses when executing
+ # eagerly, and they have been saved as lambdas to be executed when
+ # requested.
+ return [regularizer() for regularizer in self._losses]
+ else:
+ return self._losses
def add_loss(self, losses, inputs=None):
"""Add loss tensor(s), potentially dependent on layer inputs.
@@ -303,6 +322,11 @@ class Layer(object):
The `get_losses_for` method allows to retrieve the losses relevant to a
specific set of inputs.
+ Note that `add_loss` is not supported when executing eagerly. Instead,
+ variable regularizers may be added through `add_variable`. Activity
+ regularization is not supported directly (but such losses may be returned
+ from `Layer.call()`).
+
Arguments:
losses: Loss tensor, or list/tuple of tensors.
inputs: Optional input tensor(s) that the loss(es) depend on. Must
@@ -462,16 +486,8 @@ class Layer(object):
Raises:
RuntimeError: If called in Eager mode with regularizers.
"""
- # Note that we currently don't support variable regularization in Eager
- # mode. An alternative is for users to directly compute these losses before
- # performing a backward pass.
if context.in_graph_mode():
existing_variables = set(tf_variables.global_variables())
- else:
- existing_variables = []
- if regularizer is not None:
- raise RuntimeError('Variable regularization not supported in Eager '
- 'mode.')
if dtype is None:
dtype = self.dtype or dtypes.float32
@@ -486,28 +502,39 @@ class Layer(object):
constraint=constraint,
trainable=trainable and self.trainable,
partitioner=partitioner)
- if (context.in_graph_mode() and trainable and self.trainable
- and variable not in tf_variables.trainable_variables()):
- # A custom getter / variable scope overrode the trainable flag.
- trainable = False
- if variable in existing_variables:
- return variable
- if regularizer:
- # To match the behavior of tf.get_variable(), we only
- # apply regularization if the variable is newly created.
- if isinstance(variable, tf_variables.PartitionedVariable):
- for v in variable:
- with ops.colocate_with(v.op):
+ if context.in_graph_mode():
+ if (trainable and self.trainable
+ and variable not in tf_variables.trainable_variables()):
+ # A custom getter / variable scope overrode the trainable flag.
+ trainable = False
+ if variable in existing_variables:
+ return variable
+ if regularizer:
+ # To match the behavior of tf.get_variable(), we only
+ # apply regularization if the variable is newly created.
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ for v in variable:
+ with ops.colocate_with(v.op):
+ with ops.name_scope(name + '/Regularizer'):
+ regularization = regularizer(v)
+ if regularization is not None:
+ self.add_loss(regularization)
+ else:
+ with ops.colocate_with(variable.op):
with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(v)
+ regularization = regularizer(variable)
if regularization is not None:
self.add_loss(regularization)
- else:
- with ops.colocate_with(variable.op):
- with ops.name_scope(name + '/Regularizer'):
- regularization = regularizer(variable)
- if regularization is not None:
- self.add_loss(regularization)
+ elif regularizer:
+ if isinstance(variable, tf_variables.PartitionedVariable):
+ raise RuntimeError(
+ 'Partitioned variable regularization is not yet supported when '
+ 'executing eagerly. File a feature request is this is '
+ 'important to you.')
+ # Save a zero-argument lambda which runs the regularizer on the
+ # variable, to be executed when `Layer.losses` is requested. This
+ # makes losses responsive to variable updates when executing eagerly.
+ self._losses.append(lambda: regularizer(variable))
if trainable:
self._trainable_weights.append(variable)
else:
diff --git a/tensorflow/python/layers/base_test.py b/tensorflow/python/layers/base_test.py
index 1eea20deef..3e5a51eb62 100644
--- a/tensorflow/python/layers/base_test.py
+++ b/tensorflow/python/layers/base_test.py
@@ -88,6 +88,11 @@ class BaseLayerTest(test.TestCase):
regularizer=regularizer)
self.assertEqual(len(layer.losses), 1)
+ def testNoEagerActivityRegularizer(self):
+ with context.eager_mode():
+ with self.assertRaisesRegexp(ValueError, 'activity_regularizer'):
+ core_layers.Dense(1, activity_regularizer=lambda *args, **kwargs: 0.)
+
def testGetVariable(self):
with self.test_session():
diff --git a/tensorflow/python/layers/convolutional.py b/tensorflow/python/layers/convolutional.py
index 7213fa1db8..fbb13bb72c 100644
--- a/tensorflow/python/layers/convolutional.py
+++ b/tensorflow/python/layers/convolutional.py
@@ -1232,7 +1232,8 @@ class Conv2DTranspose(Conv2D):
def build(self, input_shape):
if len(input_shape) != 4:
- raise ValueError('Inputs should have rank 4. Received input shape: ' + str(input_shape))
+ raise ValueError('Inputs should have rank 4. Received input shape: ' +
+ str(input_shape))
if self.data_format == 'channels_first':
channel_axis = 1
else:
diff --git a/tensorflow/python/lib/core/py_func.cc b/tensorflow/python/lib/core/py_func.cc
index 8bf831f8ba..a42282b055 100644
--- a/tensorflow/python/lib/core/py_func.cc
+++ b/tensorflow/python/lib/core/py_func.cc
@@ -22,11 +22,11 @@ limitations under the License.
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/threadpool.h"
-#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/lib/core/ndarray_tensor_bridge.h"
+#include "tensorflow/python/lib/core/py_util.h"
#include <Python.h>
namespace tensorflow {
@@ -133,48 +133,6 @@ bool IsSingleNone(PyObject* obj) {
return item == Py_None;
}
-// py.__class__.__name__
-const char* ClassName(PyObject* py) {
-/* PyPy doesn't have a separate C API for old-style classes. */
-#if PY_MAJOR_VERSION < 3 && !defined(PYPY_VERSION)
- if (PyClass_Check(py))
- return PyString_AS_STRING(
- CHECK_NOTNULL(reinterpret_cast<PyClassObject*>(py)->cl_name));
- if (PyInstance_Check(py))
- return PyString_AS_STRING(CHECK_NOTNULL(
- reinterpret_cast<PyInstanceObject*>(py)->in_class->cl_name));
-#endif
- if (Py_TYPE(py) == &PyType_Type) {
- return reinterpret_cast<PyTypeObject*>(py)->tp_name;
- }
- return Py_TYPE(py)->tp_name;
-}
-
-string PyExcFetch() {
- CHECK(PyErr_Occurred()) << "Must only call PyExcFetch after an exception.";
- PyObject* ptype;
- PyObject* pvalue;
- PyObject* ptraceback;
- PyErr_Fetch(&ptype, &pvalue, &ptraceback);
- PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
- string err = ClassName(ptype);
- if (pvalue) {
- PyObject* str = PyObject_Str(pvalue);
- if (str) {
-#if PY_MAJOR_VERSION < 3
- strings::StrAppend(&err, ": ", PyString_AS_STRING(str));
-#else
- strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str));
-#endif
- Py_DECREF(str);
- }
- Py_DECREF(pvalue);
- }
- Py_DECREF(ptype);
- Py_XDECREF(ptraceback);
- return err;
-}
-
// Calls the registered py function through the trampoline.
Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
*out_log_on_error = true;
@@ -195,18 +153,18 @@ Status DoCallPyFunc(PyCall* call, bool* out_log_on_error) {
if (PyErr_Occurred()) {
if (PyErr_ExceptionMatches(PyExc_ValueError) ||
PyErr_ExceptionMatches(PyExc_TypeError)) {
- return errors::InvalidArgument(PyExcFetch());
+ return errors::InvalidArgument(PyExceptionFetch());
} else if (PyErr_ExceptionMatches(PyExc_StopIteration)) {
*out_log_on_error = false;
- return errors::OutOfRange(PyExcFetch());
+ return errors::OutOfRange(PyExceptionFetch());
} else if (PyErr_ExceptionMatches(PyExc_MemoryError)) {
- return errors::ResourceExhausted(PyExcFetch());
+ return errors::ResourceExhausted(PyExceptionFetch());
} else if (PyErr_ExceptionMatches(PyExc_NotImplementedError)) {
- return errors::Unimplemented(PyExcFetch());
+ return errors::Unimplemented(PyExceptionFetch());
} else {
// TODO(ebrevdo): Check if exception is an OpError and use the
// OpError.error_code property to map it back in the Status.
- return errors::Unknown(PyExcFetch());
+ return errors::Unknown(PyExceptionFetch());
}
} else {
return errors::Internal("Failed to run py callback ", call->token,
diff --git a/tensorflow/python/lib/core/py_seq_tensor.cc b/tensorflow/python/lib/core/py_seq_tensor.cc
index 71cb38f8fd..317bdc2e14 100644
--- a/tensorflow/python/lib/core/py_seq_tensor.cc
+++ b/tensorflow/python/lib/core/py_seq_tensor.cc
@@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/python/lib/core/numpy.h"
+#include "tensorflow/python/lib/core/py_util.h"
#include "tensorflow/python/lib/core/safe_ptr.h"
namespace tensorflow {
@@ -89,12 +90,25 @@ Status InferShapeAndType(PyObject* obj, TensorShape* shape, DataType* dtype) {
*dtype = DT_STRING;
} else if (PySequence_Check(obj)) {
auto length = PySequence_Length(obj);
- shape->AddDim(length);
if (length > 0) {
+ shape->AddDim(length);
obj = PySequence_GetItem(obj, 0);
continue;
- } else {
+ } else if (length == 0) {
+ shape->AddDim(length);
*dtype = DT_INVALID; // Invalid dtype for empty tensors.
+ } else {
+ // The sequence does not have a valid length (PySequence_Length < 0).
+ if (PyErr_Occurred()) {
+ // PySequence_Length failed and set an exception. Fetch the message
+ // and convert it to a failed status.
+ return errors::InvalidArgument(PyExceptionFetch());
+ } else {
+ // This is almost certainly dead code: PySequence_Length failed but
+ // did not set an exception.
+ return errors::InvalidArgument(
+ "Attempted to convert an invalid sequence to a Tensor.");
+ }
}
} else if (IsPyFloat(obj)) {
*dtype = DT_DOUBLE;
diff --git a/tensorflow/python/lib/core/py_util.cc b/tensorflow/python/lib/core/py_util.cc
new file mode 100644
index 0000000000..2635694e23
--- /dev/null
+++ b/tensorflow/python/lib/core/py_util.cc
@@ -0,0 +1,70 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/python/lib/core/py_util.h"
+
+#include "tensorflow/core/lib/core/errors.h"
+#include "tensorflow/core/lib/strings/strcat.h"
+#include <Python.h>
+
+namespace tensorflow {
+namespace {
+
+// py.__class__.__name__
+const char* ClassName(PyObject* py) {
+/* PyPy doesn't have a separate C API for old-style classes. */
+#if PY_MAJOR_VERSION < 3 && !defined(PYPY_VERSION)
+ if (PyClass_Check(py))
+ return PyString_AS_STRING(
+ CHECK_NOTNULL(reinterpret_cast<PyClassObject*>(py)->cl_name));
+ if (PyInstance_Check(py))
+ return PyString_AS_STRING(CHECK_NOTNULL(
+ reinterpret_cast<PyInstanceObject*>(py)->in_class->cl_name));
+#endif
+ if (Py_TYPE(py) == &PyType_Type) {
+ return reinterpret_cast<PyTypeObject*>(py)->tp_name;
+ }
+ return Py_TYPE(py)->tp_name;
+}
+
+} // end namespace
+
+string PyExceptionFetch() {
+ CHECK(PyErr_Occurred())
+ << "Must only call PyExceptionFetch after an exception.";
+ PyObject* ptype;
+ PyObject* pvalue;
+ PyObject* ptraceback;
+ PyErr_Fetch(&ptype, &pvalue, &ptraceback);
+ PyErr_NormalizeException(&ptype, &pvalue, &ptraceback);
+ string err = ClassName(ptype);
+ if (pvalue) {
+ PyObject* str = PyObject_Str(pvalue);
+ if (str) {
+#if PY_MAJOR_VERSION < 3
+ strings::StrAppend(&err, ": ", PyString_AS_STRING(str));
+#else
+ strings::StrAppend(&err, ": ", PyUnicode_AsUTF8(str));
+#endif
+ Py_DECREF(str);
+ }
+ Py_DECREF(pvalue);
+ }
+ Py_DECREF(ptype);
+ Py_XDECREF(ptraceback);
+ return err;
+}
+
+} // end namespace tensorflow
diff --git a/tensorflow/python/lib/core/py_util.h b/tensorflow/python/lib/core/py_util.h
new file mode 100644
index 0000000000..44dfe7ba21
--- /dev/null
+++ b/tensorflow/python/lib/core/py_util.h
@@ -0,0 +1,27 @@
+/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
+#define TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
+
+#include "tensorflow/core/platform/types.h"
+
+namespace tensorflow {
+// Fetch the exception message as a string. An exception must be set
+// (PyErr_Occurred() must be true).
+string PyExceptionFetch();
+} // end namespace tensorflow
+
+#endif // TENSORFLOW_PYTHON_LIB_CORE_UTIL_H_
diff --git a/tensorflow/python/lib/core/safe_ptr.cc b/tensorflow/python/lib/core/safe_ptr.cc
index 456ea3348b..ce34b6d004 100644
--- a/tensorflow/python/lib/core/safe_ptr.cc
+++ b/tensorflow/python/lib/core/safe_ptr.cc
@@ -16,25 +16,21 @@ limitations under the License.
#include "tensorflow/python/lib/core/safe_ptr.h"
namespace tensorflow {
-namespace {
-inline void Py_DECREF_wrapper(PyObject* o) { Py_DECREF(o); }
-
-} // namespace
-
-Safe_PyObjectPtr make_safe(PyObject* o) {
- return Safe_PyObjectPtr(o, Py_DECREF_wrapper);
+Safe_PyObjectPtr make_safe(PyObject* object) {
+ return Safe_PyObjectPtr(object);
}
Safe_TF_TensorPtr make_safe(TF_Tensor* tensor) {
- return Safe_TF_TensorPtr(tensor, TF_DeleteTensor);
+ return Safe_TF_TensorPtr(tensor);
}
Safe_TFE_TensorHandlePtr make_safe(TFE_TensorHandle* handle) {
- return Safe_TFE_TensorHandlePtr(handle, TFE_DeleteTensorHandle);
+ return Safe_TFE_TensorHandlePtr(handle);
}
Safe_TF_StatusPtr make_safe(TF_Status* status) {
- return Safe_TF_StatusPtr(status, TF_DeleteStatus);
+ return Safe_TF_StatusPtr(status);
}
+
} // namespace tensorflow
diff --git a/tensorflow/python/lib/core/safe_ptr.h b/tensorflow/python/lib/core/safe_ptr.h
index 70cd2fdf6c..80db840aeb 100644
--- a/tensorflow/python/lib/core/safe_ptr.h
+++ b/tensorflow/python/lib/core/safe_ptr.h
@@ -17,39 +17,51 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_PYTHON_LIB_CORE_SAFE_PTR_H_
#include <memory>
-#include <Python.h>
+#include <Python.h>
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
namespace tensorflow {
+namespace detail {
+
+struct PyDecrefDeleter {
+ void operator()(PyObject* p) const { Py_DECREF(p); }
+};
+
+struct TFTensorDeleter {
+ void operator()(TF_Tensor* p) const { TF_DeleteTensor(p); }
+};
+
+struct TFETensorHandleDeleter {
+ void operator()(TFE_TensorHandle* p) const { TFE_DeleteTensorHandle(p); }
+};
+
+struct TFStatusDeleter {
+ void operator()(TF_Status* p) const { TF_DeleteStatus(p); }
+};
+
+} // namespace detail
// Safe container for an owned PyObject. On destruction, the reference count of
// the contained object will be decremented.
-typedef void (*Py_DECREF_wrapper_type)(PyObject*);
-typedef std::unique_ptr<PyObject, Py_DECREF_wrapper_type> Safe_PyObjectPtr;
+using Safe_PyObjectPtr = std::unique_ptr<PyObject, detail::PyDecrefDeleter>;
Safe_PyObjectPtr make_safe(PyObject* o);
// Safe containers for an owned TF_Tensor. On destruction, the tensor will be
// deleted by TF_DeleteTensor.
-// Note: can't use decltype(&TF_DeleteTensor) due to SWIG
-typedef void (*TF_DeleteTensor_type)(TF_Tensor*);
-typedef std::unique_ptr<TF_Tensor, TF_DeleteTensor_type> Safe_TF_TensorPtr;
+using Safe_TF_TensorPtr = std::unique_ptr<TF_Tensor, detail::TFTensorDeleter>;
Safe_TF_TensorPtr make_safe(TF_Tensor* tensor);
// Safe containers for an owned TFE_TensorHandle. On destruction, the handle
-// will be deleted by TFE_DeleteTensorHandle. Note: can't use
-// decltype(&TFE_DeleteTensorHandle) due to SWIG
-typedef void (*TFE_DeleteTensorHandle_type)(TFE_TensorHandle*);
-typedef std::unique_ptr<TFE_TensorHandle, TFE_DeleteTensorHandle_type>
- Safe_TFE_TensorHandlePtr;
+// will be deleted by TFE_DeleteTensorHandle.
+using Safe_TFE_TensorHandlePtr =
+ std::unique_ptr<TFE_TensorHandle, detail::TFETensorHandleDeleter>;
Safe_TFE_TensorHandlePtr make_safe(TFE_TensorHandle* handle);
// Safe containers for an owned TF_Status. On destruction, the handle
-// will be deleted by TF_DeleteStatus. Note: can't use
-// decltype(&TF_DeleteStatus) due to SWIG
-typedef void (*TF_DeleteStatus_type)(TF_Status*);
-typedef std::unique_ptr<TF_Status, TF_DeleteStatus_type> Safe_TF_StatusPtr;
+// will be deleted by TF_DeleteStatus.
+using Safe_TF_StatusPtr = std::unique_ptr<TF_Status, detail::TFStatusDeleter>;
Safe_TF_StatusPtr make_safe(TF_Status* status);
} // namespace tensorflow
diff --git a/tensorflow/python/ops/nn_grad.py b/tensorflow/python/ops/nn_grad.py
index 4b406ba840..8cd535aa0b 100644
--- a/tensorflow/python/ops/nn_grad.py
+++ b/tensorflow/python/ops/nn_grad.py
@@ -41,33 +41,48 @@ def _Conv2DBackpropInputGrad(op, grad):
Returns:
the gradients w.r.t. the input and the filter
"""
- return [None,
- nn_ops.conv2d_backprop_filter(grad, array_ops.shape(op.inputs[1]),
- op.inputs[2], op.get_attr("strides"),
- op.get_attr("padding"),
- op.get_attr("use_cudnn_on_gpu"),
- op.get_attr("data_format")),
- nn_ops.conv2d(grad, op.inputs[1], op.get_attr("strides"),
- op.get_attr("padding"), op.get_attr("use_cudnn_on_gpu"),
- op.get_attr("data_format"))]
+ return [
+ None,
+ nn_ops.conv2d_backprop_filter(
+ grad,
+ array_ops.shape(op.inputs[1]),
+ op.inputs[2],
+ dilations=op.get_attr("dilations"),
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
+ data_format=op.get_attr("data_format")),
+ nn_ops.conv2d(
+ grad,
+ op.inputs[1],
+ dilations=op.get_attr("dilations"),
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
+ data_format=op.get_attr("data_format"))
+ ]
@ops.RegisterGradient("Conv2DBackpropFilter")
def _Conv2DBackpropFilterGrad(op, grad):
return [
nn_ops.conv2d_backprop_input(
- array_ops.shape(op.inputs[0]), grad, op.inputs[2],
- op.get_attr("strides"),
- op.get_attr("padding"),
- op.get_attr("use_cudnn_on_gpu"),
- op.get_attr("data_format")),
- None,
+ array_ops.shape(op.inputs[0]),
+ grad,
+ op.inputs[2],
+ dilations=op.get_attr("dilations"),
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
+ data_format=op.get_attr("data_format")), None,
nn_ops.conv2d(
- op.inputs[0], grad,
- op.get_attr("strides"),
- op.get_attr("padding"),
- op.get_attr("use_cudnn_on_gpu"),
- op.get_attr("data_format"))
+ op.inputs[0],
+ grad,
+ dilations=op.get_attr("dilations"),
+ strides=op.get_attr("strides"),
+ padding=op.get_attr("padding"),
+ use_cudnn_on_gpu=op.get_attr("use_cudnn_on_gpu"),
+ data_format=op.get_attr("data_format"))
]
@@ -466,25 +481,32 @@ def _SparseSoftmaxCrossEntropyWithLogitsGrad(op, grad_0, _):
@ops.RegisterGradient("Conv2D")
def _Conv2DGrad(op, grad):
+ dilations = op.get_attr("dilations")
strides = op.get_attr("strides")
padding = op.get_attr("padding")
use_cudnn_on_gpu = op.get_attr("use_cudnn_on_gpu")
data_format = op.get_attr("data_format")
shape_0, shape_1 = array_ops.shape_n([op.inputs[0], op.inputs[1]])
- return [nn_ops.conv2d_backprop_input(shape_0,
- op.inputs[1],
- grad,
- strides,
- padding,
- use_cudnn_on_gpu,
- data_format),
- nn_ops.conv2d_backprop_filter(op.inputs[0],
- shape_1,
- grad,
- strides,
- padding,
- use_cudnn_on_gpu,
- data_format)]
+ return [
+ nn_ops.conv2d_backprop_input(
+ shape_0,
+ op.inputs[1],
+ grad,
+ dilations=dilations,
+ strides=strides,
+ padding=padding,
+ use_cudnn_on_gpu=use_cudnn_on_gpu,
+ data_format=data_format),
+ nn_ops.conv2d_backprop_filter(
+ op.inputs[0],
+ shape_1,
+ grad,
+ dilations=dilations,
+ strides=strides,
+ padding=padding,
+ use_cudnn_on_gpu=use_cudnn_on_gpu,
+ data_format=data_format)
+ ]
@ops.RegisterGradient("DepthwiseConv2dNative")
diff --git a/tensorflow/python/ops/nn_ops.py b/tensorflow/python/ops/nn_ops.py
index ec7b9372ca..b3c0a22efc 100644
--- a/tensorflow/python/ops/nn_ops.py
+++ b/tensorflow/python/ops/nn_ops.py
@@ -1205,13 +1205,14 @@ def conv2d_transpose(value,
raise ValueError("padding must be either VALID or SAME:"
" {}".format(padding))
- return gen_nn_ops.conv2d_backprop_input(input_sizes=output_shape_,
- filter=filter,
- out_backprop=value,
- strides=strides,
- padding=padding,
- data_format=data_format,
- name=name)
+ return gen_nn_ops.conv2d_backprop_input(
+ input_sizes=output_shape_,
+ filter=filter,
+ out_backprop=value,
+ strides=strides,
+ padding=padding,
+ data_format=data_format,
+ name=name)
def atrous_conv2d_transpose(value,
@@ -1343,12 +1344,13 @@ def atrous_conv2d_transpose(value,
(in_width + pad_right_extra) // rate,
output_shape[3]]
- value = gen_nn_ops.conv2d_backprop_input(input_sizes=input_sizes,
- filter=filters,
- out_backprop=value,
- strides=[1, 1, 1, 1],
- padding="VALID",
- data_format="NHWC")
+ value = gen_nn_ops.conv2d_backprop_input(
+ input_sizes=input_sizes,
+ filter=filters,
+ out_backprop=value,
+ strides=[1, 1, 1, 1],
+ padding="VALID",
+ data_format="NHWC")
# The crops argument to batch_to_space includes both padding components.
batch_to_space_crop = [[pad_top, pad_bottom + pad_bottom_extra],
diff --git a/tensorflow/python/ops/random_ops.py b/tensorflow/python/ops/random_ops.py
index 52fb5131cf..afaff8ca41 100644
--- a/tensorflow/python/ops/random_ops.py
+++ b/tensorflow/python/ops/random_ops.py
@@ -316,7 +316,7 @@ def random_crop(value, size, seed=None, name=None):
return array_ops.slice(value, offset, size, name=name)
-def multinomial(logits, num_samples, seed=None, name=None):
+def multinomial(logits, num_samples, seed=None, name=None, output_dtype=None):
"""Draws samples from a multinomial distribution.
Example:
@@ -336,6 +336,7 @@ def multinomial(logits, num_samples, seed=None, name=None):
@{tf.set_random_seed}
for behavior.
name: Optional name for the operation.
+ output_dtype: integer type to use for the output. Defaults to int64.
Returns:
The drawn samples of shape `[batch_size, num_samples]`.
@@ -344,7 +345,7 @@ def multinomial(logits, num_samples, seed=None, name=None):
logits = ops.convert_to_tensor(logits, name="logits")
seed1, seed2 = random_seed.get_seed(seed)
return gen_random_ops.multinomial(
- logits, num_samples, seed=seed1, seed2=seed2)
+ logits, num_samples, seed=seed1, seed2=seed2, output_dtype=output_dtype)
ops.NotDifferentiable("Multinomial")
diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py
index 343e38f960..652bfa1ebc 100644
--- a/tensorflow/python/ops/resource_variable_ops.py
+++ b/tensorflow/python/ops/resource_variable_ops.py
@@ -887,26 +887,19 @@ def _ReadGrad(_, grad):
def _GatherGrad(op, grad):
"""Gradient for gather op."""
# Build appropriately shaped IndexedSlices
- # Walk graph back until the original handle is found.
- # TODO(apassos): more robust way of getting the shape.
- # TODO(apassos): implement this for EAGER mode.
- if context.in_eager_mode():
- dense_shape = gen_resource_variable_ops.variable_shape(op.inputs[0])
- return (ops.IndexedSlices(grad,
- op.inputs[1],
- dense_shape=dense_shape),
- None)
handle = op.inputs[0]
- while handle.op.type != "VarHandleOp":
- handle = handle.op.inputs[0]
- params_shape = ops.convert_to_tensor(
- tensor_shape.TensorShape(handle.op.get_attr("shape")))
indices = op.inputs[1]
+ if context.in_graph_mode():
+ # Walk graph back until the original handle is found.
+ # TODO(apassos): implement this for EAGER mode.
+ while handle.op.type != "VarHandleOp":
+ handle = handle.op.inputs[0]
+ params_shape = gen_resource_variable_ops.variable_shape(handle)
size = array_ops.expand_dims(array_ops.size(indices), 0)
values_shape = array_ops.concat([size, params_shape[1:]], 0)
values = array_ops.reshape(grad, values_shape)
indices = array_ops.reshape(indices, size)
- return [ops.IndexedSlices(values, indices, params_shape), None]
+ return (ops.IndexedSlices(values, indices, params_shape), None)
def _to_proto_fn(v, export_scope=None):
diff --git a/tensorflow/python/ops/sparse_ops.py b/tensorflow/python/ops/sparse_ops.py
index cdfe9e1c1e..9bdc124c83 100644
--- a/tensorflow/python/ops/sparse_ops.py
+++ b/tensorflow/python/ops/sparse_ops.py
@@ -1437,10 +1437,47 @@ def serialize_many_sparse(sp_input, name=None):
def deserialize_sparse(serialized_sparse, dtype, rank=None, name=None):
"""Deserialize `SparseTensor` objects.
- The input is expected to have shape [d_1, ..., d_m, 3], where the last
- dimension stores a serialized `SparseTensor`. The method deserializes
- all input `SparseTensor`s, concatenates them into a single tensor, and
- reshapes the sparse tensor to preserve the structure of the input.
+ The input `serialized_sparse` must have the shape `[?, ?, ..., ?, 3]` where
+ the last dimension stores serialized `SparseTensor` objects and the other N
+ dimensions (N >= 0) correspond to a batch. The ranks of the original
+ `SparseTensor` objects must all match. When the final `SparseTensor` is
+ created, its rank is the rank of the incoming `SparseTensor` objects plus N;
+ the sparse tensors have been concatenated along new dimensions, one for each
+ batch.
+
+ The output `SparseTensor` object's shape values for the original dimensions
+ are the max across the input `SparseTensor` objects' shape values for the
+ corresponding dimensions. The new dimensions match the size of the batch.
+
+ The input `SparseTensor` objects' indices are assumed ordered in
+ standard lexicographic order. If this is not the case, after this
+ step run `SparseReorder` to restore index ordering.
+
+ For example, if the serialized input is a `[2 x 3]` matrix representing two
+ original `SparseTensor` objects:
+
+ index = [ 0]
+ [10]
+ [20]
+ values = [1, 2, 3]
+ shape = [50]
+
+ and
+
+ index = [ 2]
+ [10]
+ values = [4, 5]
+ shape = [30]
+
+ then the final deserialized `SparseTensor` will be:
+
+ index = [0 0]
+ [0 10]
+ [0 20]
+ [1 2]
+ [1 10]
+ values = [1, 2, 3, 4, 5]
+ shape = [2 50]
Args:
serialized_sparse: The serialized `SparseTensor` objects.
diff --git a/tensorflow/python/ops/state_ops.py b/tensorflow/python/ops/state_ops.py
index dfc657893c..dee495f78f 100644
--- a/tensorflow/python/ops/state_ops.py
+++ b/tensorflow/python/ops/state_ops.py
@@ -347,5 +347,71 @@ def scatter_update(ref, indices, updates, use_locking=True, name=None):
if ref.dtype._is_ref_dtype:
return gen_state_ops.scatter_update(ref, indices, updates,
use_locking=use_locking, name=name)
- return gen_resource_variable_ops.resource_scatter_update(
- ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype), name=name)
+ with ops.control_dependencies(
+ [gen_resource_variable_ops.resource_scatter_update(
+ ref.handle, indices, ops.convert_to_tensor(updates, ref.dtype),
+ name=name)]):
+ return ref.read_value()
+
+
+def scatter_nd_update(ref, indices, updates, use_locking=True, name=None):
+ r"""Applies sparse `updates` to individual values or slices in a Variable.
+
+ `ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
+
+ `indices` must be integer tensor, containing indices into `ref`.
+ It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
+
+ The innermost dimension of `indices` (with length `K`) corresponds to
+ indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
+ dimension of `ref`.
+
+ `updates` is `Tensor` of rank `Q-1+P-K` with shape:
+
+ ```
+ [d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
+ ```
+
+ For example, say we want to update 4 scattered elements to a rank-1 tensor to
+ 8 elements. In Python, that update would look like this:
+
+ ```python
+ ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
+ indices = tf.constant([[4], [3], [1] ,[7]])
+ updates = tf.constant([9, 10, 11, 12])
+ update = tf.scatter_nd_update(ref, indices, updates)
+ with tf.Session() as sess:
+ print sess.run(update)
+ ```
+
+ The resulting update to ref would look like this:
+
+ [1, 11, 3, 10, 9, 6, 7, 12]
+
+ See @{tf.scatter_nd} for more details about how to make updates to
+ slices.
+
+ Args:
+ ref: A Variable.
+ indices: A `Tensor`. Must be one of the following types: `int32`, `int64`.
+ A Tensor. Must be one of the following types: int32, int64.
+ A tensor of indices into ref.
+ updates: A `Tensor`. Must have the same type as `ref`.
+ A Tensor. Must have the same type as ref. A tensor of updated
+ values to add to ref.
+ use_locking: An optional `bool`. Defaults to `True`.
+ An optional bool. Defaults to True. If True, the assignment will
+ be protected by a lock; otherwise the behavior is undefined,
+ but may exhibit less contention.
+ name: A name for the operation (optional).
+
+ Returns:
+ The value of the variable after the update.
+ """
+ if ref.dtype._is_ref_dtype:
+ return gen_state_ops.scatter_nd_update(
+ ref, indices, updates, use_locking, name)
+ with ops.control_dependencies([gen_state_ops.resource_scatter_nd_update(
+ ref.handle, indices, ops.convert_to_tensor(updates, dtype=ref.dtype),
+ use_locking, name)]):
+ return ref.read_value()
diff --git a/tensorflow/python/ops/template.py b/tensorflow/python/ops/template.py
index 98578b799a..07796b28d9 100644
--- a/tensorflow/python/ops/template.py
+++ b/tensorflow/python/ops/template.py
@@ -308,6 +308,12 @@ class Template(object):
return name if name[-1] == "/" else name + "/"
@property
+ def variables(self):
+ """Returns the list of global and local variables created by the Template.
+ """
+ return self.global_variables + self.local_variables
+
+ @property
def trainable_variables(self):
"""Returns the list of trainable variables created by the Template."""
if self._variables_created:
@@ -317,6 +323,14 @@ class Template(object):
return []
@property
+ def non_trainable_variables(self):
+ """Returns the list of non-trainable variables created by the Template."""
+ # TODO(apassos) Make sure it matches Eager when using local variables.
+ global_variables = self.global_variables
+ trainable_variables = set(self.trainable_variables)
+ return [x for x in global_variables if x not in trainable_variables]
+
+ @property
def global_variables(self):
"""Returns the list of global variables created by the Template."""
if self._variables_created:
@@ -335,6 +349,21 @@ class Template(object):
return []
@property
+ def weights(self):
+ """List of weights/variables created by the Template."""
+ return self.variables
+
+ @property
+ def trainable_weights(self):
+ """List of trainable weights/variables created by the Template."""
+ return self.trainable_variables
+
+ @property
+ def non_trainable_weights(self):
+ """List of non-trainable weights/variables created by the Template."""
+ return self.non_trainable_variables
+
+ @property
@deprecated(
"2017-02-21", "The .var_scope property is deprecated. Please change your "
"code to use the .variable_scope property")
@@ -501,7 +530,7 @@ class EagerTemplate(Template):
@property
def variables(self):
- """Returns the list of trainable variables created by the Template."""
+ """Returns the list of variables created by the Template."""
# Currently there is no local variable in Eager mode.
return self._eager_variable_store.variables()
@@ -512,6 +541,12 @@ class EagerTemplate(Template):
return self._eager_variable_store.trainable_variables()
@property
+ def non_trainable_variables(self):
+ """Returns the list of non-trainable variables created by the Template."""
+ # Currently there is no local variable in Eager mode.
+ return self._eager_variable_store.non_trainable_variables()
+
+ @property
def global_variables(self):
"""Returns the list of global variables created by the Template."""
# Currently there is no local variable in Eager mode.
diff --git a/tensorflow/python/platform/flags.py b/tensorflow/python/platform/flags.py
index e9a36ae75d..abd6f3d855 100644
--- a/tensorflow/python/platform/flags.py
+++ b/tensorflow/python/platform/flags.py
@@ -18,5 +18,53 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
+import logging as _logging
+
# go/tf-wildcard-import
from absl.flags import * # pylint: disable=wildcard-import
+import six as _six
+
+from tensorflow.python.util import tf_decorator
+
+
+# Since we wrap absl.flags DEFINE functions, we need to declare this module
+# does not affect key flags.
+disclaim_key_flags() # pylint: disable=undefined-variable
+
+
+_RENAMED_ARGUMENTS = {
+ 'flag_name': 'name',
+ 'default_value': 'default',
+ 'docstring': 'help',
+}
+
+
+def _wrap_define_function(original_function):
+ """Wraps absl.flags's define functions so tf.flags accepts old names."""
+
+ def wrapper(*args, **kwargs):
+ """Wrapper function that turns old keyword names to new ones."""
+ has_old_names = False
+ for old_name, new_name in _six.iteritems(_RENAMED_ARGUMENTS):
+ if old_name in kwargs:
+ has_old_names = True
+ value = kwargs.pop(old_name)
+ kwargs[new_name] = value
+ if has_old_names:
+ _logging.warning(
+ 'Use of the keyword argument names (flag_name, default_value, '
+ 'docstring) is deprecated, please use (name, default, help) instead.')
+ return original_function(*args, **kwargs)
+
+ return tf_decorator.make_decorator(original_function, wrapper)
+
+
+# pylint: disable=invalid-name,used-before-assignment
+# absl.flags APIs use `default` as the name of the default value argument.
+# Allow the following functions continue to accept `default_value`.
+DEFINE_string = _wrap_define_function(DEFINE_string)
+DEFINE_boolean = _wrap_define_function(DEFINE_boolean)
+DEFINE_bool = DEFINE_boolean
+DEFINE_float = _wrap_define_function(DEFINE_float)
+DEFINE_integer = _wrap_define_function(DEFINE_integer)
+# pylint: enable=invalid-name,used-before-assignment
diff --git a/tensorflow/python/platform/flags_test.py b/tensorflow/python/platform/flags_test.py
index 23060e17d2..e8200142dd 100644
--- a/tensorflow/python/platform/flags_test.py
+++ b/tensorflow/python/platform/flags_test.py
@@ -24,11 +24,50 @@ from absl import flags as absl_flags
from tensorflow.python.platform import flags
+flags.DEFINE_string(
+ flag_name='old_string', default_value='default', docstring='docstring')
+flags.DEFINE_string(
+ name='new_string', default='default', help='docstring')
+flags.DEFINE_integer(
+ flag_name='old_integer', default_value=1, docstring='docstring')
+flags.DEFINE_integer(
+ name='new_integer', default=1, help='docstring')
+flags.DEFINE_float(
+ flag_name='old_float', default_value=1.5, docstring='docstring')
+flags.DEFINE_float(
+ name='new_float', default=1.5, help='docstring')
+flags.DEFINE_bool(
+ flag_name='old_bool', default_value=True, docstring='docstring')
+flags.DEFINE_bool(
+ name='new_bool', default=True, help='docstring')
+flags.DEFINE_boolean(
+ flag_name='old_boolean', default_value=False, docstring='docstring')
+flags.DEFINE_boolean(
+ name='new_boolean', default=False, help='docstring')
+
+
class FlagsTest(unittest.TestCase):
def test_global_flags_object(self):
self.assertIs(flags.FLAGS, absl_flags.FLAGS)
+ def test_keyword_arguments(self):
+ test_cases = (
+ ('old_string', 'default'),
+ ('new_string', 'default'),
+ ('old_integer', 1),
+ ('new_integer', 1),
+ ('old_float', 1.5),
+ ('new_float', 1.5),
+ ('old_bool', True),
+ ('new_bool', True),
+ ('old_boolean', False),
+ ('new_boolean', False),
+ )
+ for flag_name, default_value in test_cases:
+ self.assertEqual(default_value, absl_flags.FLAGS[flag_name].default)
+ self.assertEqual('docstring', absl_flags.FLAGS[flag_name].help)
+
-if __name__ == "__main__":
+if __name__ == '__main__':
unittest.main()
diff --git a/tensorflow/python/profiler/model_analyzer_test.py b/tensorflow/python/profiler/model_analyzer_test.py
index 26fb99efe6..ccfb9aac53 100644
--- a/tensorflow/python/profiler/model_analyzer_test.py
+++ b/tensorflow/python/profiler/model_analyzer_test.py
@@ -23,12 +23,15 @@ import os
import random
import re
+import numpy as np
+
from tensorflow.core.profiler import profile_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
+from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
@@ -346,8 +349,8 @@ class PrintModelAnalysisTest(test.TestCase):
with gfile.Open(outfile, 'r') as f:
# pylint: disable=line-too-long
self.assertEqual(
- 'nodename|requestedbytes|peakbytes|residualbytes|outputbytes|totalexecutiontime|acceleratorexecutiontime|cpuexecutiontime|#parameters|opoccurrence(run|defined)|inputshapes\nConst0B(0',
- f.read().replace('\t', '').replace(' ', '')[0:180])
+ 'nodename|requestedbytes|peakbytes|residualbytes|outputbytes|totalexecutiontime|acceleratorexecutiontime|cpuexecutiontime|#parameters|opoccurrence(run|defined)|inputshapes',
+ f.read().replace('\t', '').replace(' ', '')[0:170])
# pylint: enable=line-too-long
total_children = 0
@@ -694,6 +697,39 @@ class PrintModelAnalysisTest(test.TestCase):
exception_str)
self.assertTrue(mat is None)
+ def testTrackPersistentBytes(self):
+ ops.reset_default_graph()
+ a = array_ops.constant(np.ones((100, 100)))
+ b = array_ops.constant(np.ones((100, 100)))
+ c = a * b
+
+ with session.Session() as sess:
+ run_options = config_pb2.RunOptions(
+ trace_level=config_pb2.RunOptions.FULL_TRACE)
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(c, options=run_options, run_metadata=run_metadata)
+
+ options = option_builder.ProfileOptionBuilder.time_and_memory()
+ options['min_bytes'] = 0
+ options['select'] = ('bytes', 'peak_bytes', 'output_bytes',
+ 'residual_bytes')
+ ret = model_analyzer.profile(
+ sess.graph, run_meta=run_metadata, cmd='scope', options=options)
+
+ run_metadata = config_pb2.RunMetadata()
+ sess.run(c, options=run_options, run_metadata=run_metadata)
+ ret2 = model_analyzer.profile(
+ sess.graph, run_meta=run_metadata, cmd='scope', options=options)
+
+ n = lib.SearchTFProfNode(ret, 'mul')
+ n2 = lib.SearchTFProfNode(ret2, 'mul')
+ self.assertGreater(n.peak_bytes, 0)
+ self.assertGreater(n.output_bytes, 0)
+ self.assertGreater(n.residual_bytes, 0)
+ self.assertEqual(n.peak_bytes, n2.peak_bytes)
+ self.assertEqual(n.output_bytes, n2.output_bytes)
+ self.assertEqual(n.residual_bytes, n2.residual_bytes)
+
if __name__ == '__main__':
test.main()
diff --git a/tensorflow/python/pywrap_tfe.i b/tensorflow/python/pywrap_tfe.i
index 82b154164e..82750e9e49 100644
--- a/tensorflow/python/pywrap_tfe.i
+++ b/tensorflow/python/pywrap_tfe.i
@@ -18,6 +18,7 @@ limitations under the License.
%rename("%s") TFE_NewContext;
%rename("%s") TFE_DeleteContext;
%rename("%s") TFE_ContextListDevices;
+%rename("%s") TFE_ContextAddFunction;
%rename("%s") TFE_ContextAddFunctionDef;
%rename("%s") TFE_OpNameGetAttrType;
%rename("%s") TFE_Py_InitEagerTensor;
@@ -149,7 +150,7 @@ limitations under the License.
}
$1 = &temp;
$1->resize(PyInt_AsLong($input), nullptr);
-}
+}
// Create new Status object.
%typemap(in, numinputs=0) TF_Status *out_status {
diff --git a/tensorflow/python/training/momentum_test.py b/tensorflow/python/training/momentum_test.py
index 7268b3abc9..6865513b0e 100644
--- a/tensorflow/python/training/momentum_test.py
+++ b/tensorflow/python/training/momentum_test.py
@@ -234,23 +234,38 @@ class MomentumOptimizerTest(test.TestCase):
self.assertAllClose(var0_np, var0.eval())
self.assertAllClose(var1_np, var1.eval())
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
def testMinimizeSparseResourceVariable(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
- with self.test_session():
- var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+ var0 = resource_variable_ops.ResourceVariable([[1.0, 2.0]], dtype=dtype)
+
+ # pylint: disable=cell-var-from-loop
+ def loss():
x = constant_op.constant([[4.0], [5.0]], dtype=dtype)
pred = math_ops.matmul(embedding_ops.embedding_lookup([var0], [0]), x)
- loss = pred * pred
- sgd_op = momentum_lib.MomentumOptimizer(
- learning_rate=1.0, momentum=0.0).minimize(loss)
- variables.global_variables_initializer().run()
- # Fetch params to validate initial values
- self.assertAllCloseAccordingToType([[1.0, 2.0]], var0.eval())
- # Run 1 step of sgd
- sgd_op.run()
- # Validate updated params
- self.assertAllCloseAccordingToType(
- [[-111, -138]], var0.eval())
+ return pred * pred
+ # pylint: enable=cell-var-from-loop
+
+ opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0)
+ sgd_op = opt.minimize(loss if context.in_eager_mode() else loss())
+ self.evaluate(variables.global_variables_initializer())
+ # Run 1 step of sgd
+ self.evaluate(sgd_op)
+ # Validate updated params
+ self.assertAllCloseAccordingToType([[-111, -138]], self.evaluate(var0))
+
+ @test_util.run_in_graph_and_eager_modes(reset_test=True)
+ def testMinimizeWith2DIndiciesForEmbeddingLookup(self):
+ var0 = resource_variable_ops.ResourceVariable(array_ops.ones([2, 2]))
+
+ def loss():
+ return math_ops.reduce_sum(embedding_ops.embedding_lookup(var0, [[1]]))
+
+ opt = momentum_lib.MomentumOptimizer(learning_rate=1.0, momentum=0.0)
+ sgd_op = opt.minimize(loss if context.in_eager_mode() else loss())
+ self.evaluate(variables.global_variables_initializer())
+ self.evaluate(sgd_op)
+ self.assertAllCloseAccordingToType([[1, 1], [0, 0]], self.evaluate(var0))
def testTensorLearningRateAndMomentum(self):
for dtype in [dtypes.half, dtypes.float32, dtypes.float64]:
diff --git a/tensorflow/python/training/monitored_session.py b/tensorflow/python/training/monitored_session.py
index e931555470..f1cb81981a 100644
--- a/tensorflow/python/training/monitored_session.py
+++ b/tensorflow/python/training/monitored_session.py
@@ -52,7 +52,6 @@ _PREEMPTION_ERRORS = (errors.AbortedError, errors.UnavailableError)
USE_DEFAULT = object()
-# TODO(touts): Share that with the Supervisor.
class Scaffold(object):
"""Structure to create or gather pieces commonly needed to train a model.
diff --git a/tensorflow/python/training/saver_test.py b/tensorflow/python/training/saver_test.py
index b7f1297b8f..74ee1e5fa8 100644
--- a/tensorflow/python/training/saver_test.py
+++ b/tensorflow/python/training/saver_test.py
@@ -774,9 +774,13 @@ class SaveRestoreShardedTest(test.TestCase):
with sess.graph.device("/cpu:0"):
v0 = variables.Variable(111, name="v0")
t0 = saver_test_utils.CheckpointedOp(name="t0")
- save = saver_module.Saver({"v0": v0, "t0": t0.saveable},
- write_version=self._WRITE_VERSION,
- sharded=True)
+ save = saver_module.Saver(
+ {
+ "v0": v0,
+ "t0": t0.saveable
+ },
+ write_version=self._WRITE_VERSION,
+ sharded=True)
variables.global_variables_initializer().run()
t0.insert("k11", 33.0).run()
self.assertEqual(111, v0.eval())
@@ -794,9 +798,13 @@ class SaveRestoreShardedTest(test.TestCase):
with sess.graph.device("/cpu:0"):
v1 = variables.Variable(222)
t1 = saver_test_utils.CheckpointedOp(name="t1")
- save = saver_module.Saver({"v1": v1, "t1": t1.saveable},
- write_version=self._WRITE_VERSION,
- sharded=True)
+ save = saver_module.Saver(
+ {
+ "v1": v1,
+ "t1": t1.saveable
+ },
+ write_version=self._WRITE_VERSION,
+ sharded=True)
variables.global_variables_initializer().run()
t1.insert("k22", 44.0).run()
self.assertEqual(222, v1.eval())
diff --git a/tensorflow/python/training/supervisor.py b/tensorflow/python/training/supervisor.py
index a634a842b6..e4514aaea2 100644
--- a/tensorflow/python/training/supervisor.py
+++ b/tensorflow/python/training/supervisor.py
@@ -36,11 +36,15 @@ from tensorflow.python.training import coordinator
from tensorflow.python.training import saver as saver_mod
from tensorflow.python.training import session_manager as session_manager_mod
from tensorflow.python.training import training_util
+from tensorflow.python.util import deprecation
class Supervisor(object):
"""A training helper that checkpoints models and computes summaries.
+ This class is deprecated. Please use
+ ${tf.train.MonitoredTrainingSession} instead.
+
The Supervisor is a small wrapper around a `Coordinator`, a `Saver`,
and a `SessionManager` that takes care of common needs of TensorFlow
training programs.
@@ -198,6 +202,8 @@ class Supervisor(object):
# the default behavior should be used.
USE_DEFAULT = 0
+ @deprecation.deprecated(None,
+ "Please switch to tf.train.MonitoredTrainingSession")
def __init__(self,
graph=None,
ready_op=USE_DEFAULT,
diff --git a/tensorflow/python/util/nest.py b/tensorflow/python/util/nest.py
index f5802d9359..5c066e2bef 100644
--- a/tensorflow/python/util/nest.py
+++ b/tensorflow/python/util/nest.py
@@ -456,9 +456,9 @@ def assert_shallow_structure(shallow_tree, input_tree, check_types=True):
if set(input_tree) != set(shallow_tree):
raise ValueError(
"The two structures don't have the same keys. Input "
- "structure has keys %s, while shallow structure has keys %s."
- % (list(_six.iterkeys(input_tree)),
- list(_six.iterkeys(shallow_tree))))
+ "structure has keys %s, while shallow structure has keys %s." %
+ (list(_six.iterkeys(input_tree)),
+ list(_six.iterkeys(shallow_tree))))
input_tree = list(_six.iteritems(input_tree))
shallow_tree = list(_six.iteritems(shallow_tree))
diff --git a/tensorflow/python/util/nest_test.py b/tensorflow/python/util/nest_test.py
index 26aeaeec19..3d9e9f9684 100644
--- a/tensorflow/python/util/nest_test.py
+++ b/tensorflow/python/util/nest_test.py
@@ -388,8 +388,9 @@ class NestTest(test.TestCase):
inp_ab1 = {"a": (1, 1), "b": {"c": (2, 2)}}
inp_ab2 = {"a": (1, 1), "b": {"d": (2, 2)}}
expected_message = (
- "The two structures don't have the same keys. Input "
- "structure has keys \['c'\], while shallow structure has keys \['d'\].")
+ r"The two structures don't have the same keys. Input "
+ r"structure has keys \['c'\], while shallow structure has "
+ r"keys \['d'\].")
with self.assertRaisesRegexp(ValueError, expected_message):
nest.assert_shallow_structure(inp_ab2, inp_ab1)
@@ -438,8 +439,7 @@ class NestTest(test.TestCase):
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree, [0, 1, 2, 3, 4])
- shallow_tree = collections.OrderedDict([("a", 0),
- ("c", {"d": 3, "e": 1})])
+ shallow_tree = collections.OrderedDict([("a", 0), ("c", {"d": 3, "e": 1})])
input_tree_flattened_as_shallow_tree = nest.flatten_up_to(shallow_tree,
input_tree)
self.assertEqual(input_tree_flattened_as_shallow_tree,
diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl
index 8d392fb36d..76ef59484f 100644
--- a/tensorflow/tensorflow.bzl
+++ b/tensorflow/tensorflow.bzl
@@ -167,7 +167,19 @@ WIN_COPTS = [
]
# LINT.IfChange
-def tf_copts():
+def tf_copts(android_optimization_level_override="-O2"):
+ # For compatibility reasons, android_optimization_level_override
+ # is currently only being set for Android.
+ # To clear this value, and allow the CROSSTOOL default
+ # to be used, pass android_optimization_level_override=None
+ android_copts = [
+ "-std=c++11",
+ "-DTF_LEAN_BINARY",
+ "-Wno-narrowing",
+ "-fomit-frame-pointer",
+ ]
+ if android_optimization_level_override:
+ android_copts.append(android_optimization_level_override)
return (
if_not_windows([
"-DEIGEN_AVOID_STL_ARRAY",
@@ -180,13 +192,7 @@ def tf_copts():
+ if_android_arm(["-mfpu=neon"])
+ if_linux_x86_64(["-msse3"])
+ select({
- clean_dep("//tensorflow:android"): [
- "-std=c++11",
- "-DTF_LEAN_BINARY",
- "-O2",
- "-Wno-narrowing",
- "-fomit-frame-pointer",
- ],
+ clean_dep("//tensorflow:android"): android_copts,
clean_dep("//tensorflow:darwin"): [],
clean_dep("//tensorflow:windows"): WIN_COPTS,
clean_dep("//tensorflow:windows_msvc"): WIN_COPTS,
diff --git a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
index ebd9c079b5..d920fef770 100644
--- a/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.nn.pbtxt
@@ -54,15 +54,15 @@ tf_module {
}
member_method {
name: "conv2d"
- argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'None\'], "
+ argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "conv2d_backprop_filter"
- argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'None\'], "
+ argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "conv2d_backprop_input"
- argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'None\'], "
+ argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'use_cudnn_on_gpu\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "conv2d_transpose"
@@ -70,11 +70,11 @@ tf_module {
}
member_method {
name: "conv3d"
- argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NDHWC\', \'None\'], "
+ argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NDHWC\', \'[1, 1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "conv3d_backprop_filter_v2"
- argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NDHWC\', \'None\'], "
+ argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NDHWC\', \'[1, 1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "conv3d_transpose"
@@ -106,15 +106,15 @@ tf_module {
}
member_method {
name: "depthwise_conv2d_native"
- argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
+ argspec: "args=[\'input\', \'filter\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "depthwise_conv2d_native_backprop_filter"
- argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
+ argspec: "args=[\'input\', \'filter_sizes\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "depthwise_conv2d_native_backprop_input"
- argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'None\'], "
+ argspec: "args=[\'input_sizes\', \'filter\', \'out_backprop\', \'strides\', \'padding\', \'data_format\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'[1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "dilation2d"
@@ -234,7 +234,7 @@ tf_module {
}
member_method {
name: "quantized_conv2d"
- argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'strides\', \'padding\', \'out_type\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'qint32\'>\", \'None\'], "
+ argspec: "args=[\'input\', \'filter\', \'min_input\', \'max_input\', \'min_filter\', \'max_filter\', \'strides\', \'padding\', \'out_type\', \'dilations\', \'name\'], varargs=None, keywords=None, defaults=[\"<dtype: \'qint32\'>\", \'[1, 1, 1, 1]\', \'None\'], "
}
member_method {
name: "quantized_max_pool"
diff --git a/tensorflow/tools/api/golden/tensorflow.pbtxt b/tensorflow/tools/api/golden/tensorflow.pbtxt
index 0edd4153d7..57573d5024 100644
--- a/tensorflow/tools/api/golden/tensorflow.pbtxt
+++ b/tensorflow/tools/api/golden/tensorflow.pbtxt
@@ -1394,7 +1394,7 @@ tf_module {
}
member_method {
name: "multinomial"
- argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'name\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], "
+ argspec: "args=[\'logits\', \'num_samples\', \'seed\', \'name\', \'output_dtype\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], "
}
member_method {
name: "multiply"
diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh
index 404a9a6b62..4021d794b6 100755
--- a/tensorflow/tools/ci_build/ci_sanity.sh
+++ b/tensorflow/tools/ci_build/ci_sanity.sh
@@ -99,7 +99,8 @@ do_pylint() {
"^tensorflow/contrib/eager/python/metrics_impl\.py.*\[E0202.*method-hidden "\
"^tensorflow/python/platform/gfile\.py.*\[E0301.*non-iterator "\
"^tensorflow/python/keras/_impl/keras/callbacks\.py.*\[E1133.*not-an-iterable "\
-"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition"
+"^tensorflow/python/keras/_impl/keras/layers/recurrent\.py.*\[E0203.*access-member-before-definition "\
+"^tensorflow/python/kernel_tests/constant_op_eager_test.py.*\[E0303.*invalid-length-returned"
echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\""
diff --git a/tensorflow/tools/dist_test/python/census_widendeep.py b/tensorflow/tools/dist_test/python/census_widendeep.py
index 6f578d6f67..8feb5386e9 100644
--- a/tensorflow/tools/dist_test/python/census_widendeep.py
+++ b/tensorflow/tools/dist_test/python/census_widendeep.py
@@ -263,8 +263,7 @@ if __name__ == "__main__":
"--data_dir",
type=str,
default="/tmp/census-data",
- help="Directory for storing the census data"
- )
+ help="Directory for storing the census data")
parser.add_argument(
"--model_dir",
type=str,
diff --git a/tensorflow/tools/pip_package/setup.py b/tensorflow/tools/pip_package/setup.py
index c18f20910a..3852b251d9 100644
--- a/tensorflow/tools/pip_package/setup.py
+++ b/tensorflow/tools/pip_package/setup.py
@@ -33,7 +33,12 @@ _VERSION = '1.4.0'
REQUIRED_PACKAGES = [
'absl-py',
- 'enum34 >= 1.1.6',
+ # weakref.finalize introduced in Python 3.4
+ 'backports.weakref >= 1.0rc1; python_version < "3.4"',
+ # enum module introduced in Python 3.4
+ 'enum34 >= 1.1.6; python_version < "3.4"',
+ # Needed for unittest.mock in Python 2
+ 'mock >= 2.0.0; python_version < "3.0"',
'numpy >= 1.12.1',
'six >= 1.10.0',
'protobuf >= 3.4.0',
@@ -52,8 +57,6 @@ if sys.version_info.major == 3:
REQUIRED_PACKAGES.append('wheel >= 0.26')
else:
REQUIRED_PACKAGES.append('wheel')
- # mock comes with unittest.mock for python3, need to install for python2
- REQUIRED_PACKAGES.append('mock >= 2.0.0')
# tf-nightly should depend on tb-nightly
if 'tf_nightly' in project_name:
@@ -62,10 +65,6 @@ if 'tf_nightly' in project_name:
REQUIRED_PACKAGES[i] = 'tb-nightly >= 1.5.0a0, < 1.6.0a0'
break
-# weakref.finalize was introduced in Python 3.4
-if sys.version_info < (3, 4):
- REQUIRED_PACKAGES.append('backports.weakref >= 1.0rc1')
-
# pylint: disable=line-too-long
CONSOLE_SCRIPTS = [
'freeze_graph = tensorflow.python.tools.freeze_graph:main',
diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl
index 6b13271002..c2256b6313 100644
--- a/tensorflow/workspace.bzl
+++ b/tensorflow/workspace.bzl
@@ -57,33 +57,6 @@ def check_version(bazel_version):
fail("\nCurrent Bazel version is {}, expected at least {}\n".format(
native.bazel_version, bazel_version))
-def _repos_are_siblings():
- return Label("@foo//bar").workspace_root.startswith("../")
-
-# Temporary workaround to support including TensorFlow as a submodule until this
-# use-case is supported in the next Bazel release.
-def _temp_workaround_http_archive_impl(repo_ctx):
- repo_ctx.template("BUILD", repo_ctx.attr.build_file, {
- "%prefix%": ".." if _repos_are_siblings() else "external",
- "%ws%": repo_ctx.attr.repository
- }, False)
- repo_ctx.download_and_extract(repo_ctx.attr.urls, "", repo_ctx.attr.sha256,
- "", repo_ctx.attr.strip_prefix)
- if repo_ctx.attr.patch_file != None:
- _apply_patch(repo_ctx, repo_ctx.attr.patch_file)
-
-temp_workaround_http_archive = repository_rule(
- attrs = {
- "build_file": attr.label(),
- "repository": attr.string(),
- "patch_file": attr.label(default = None),
- "urls": attr.string_list(default = []),
- "sha256": attr.string(default = ""),
- "strip_prefix": attr.string(default = ""),
- },
- implementation = _temp_workaround_http_archive_impl,
-)
-
# Executes specified command with arguments and calls 'fail' if it exited with
# non-zero code
def _execute_and_check_ret_code(repo_ctx, cmd_and_args):
@@ -121,8 +94,6 @@ def _patched_http_archive_impl(repo_ctx):
patched_http_archive = repository_rule(
attrs = {
"patch_file": attr.label(),
- "build_file": attr.label(),
- "repository": attr.string(),
"urls": attr.string_list(default = []),
"sha256": attr.string(default = ""),
"strip_prefix": attr.string(default = ""),
@@ -157,7 +128,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
sha256 = "57ba56c4c243f403ff78f417ff854ef50b9eddf4a610a917b7c95e7fa8553a4b",
strip_prefix = "mklml_lnx_2018.0.20170720",
build_file = str(Label("//third_party/mkl:mkl.BUILD")),
- repository = tf_repo_name,
)
if path_prefix:
@@ -292,7 +262,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
build_file = str(Label("//third_party:nasm.BUILD")),
)
- temp_workaround_http_archive(
+ native.new_http_archive(
name = "jpeg",
urls = [
"https://mirror.bazel.build/github.com/libjpeg-turbo/libjpeg-turbo/archive/1.5.1.tar.gz",
@@ -301,7 +271,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
sha256 = "c15a9607892113946379ccea3ca8b85018301b200754f209453ab21674268e77",
strip_prefix = "libjpeg-turbo-1.5.1",
build_file = str(Label("//third_party/jpeg:jpeg.BUILD")),
- repository = tf_repo_name,
)
native.new_http_archive(
@@ -447,11 +416,11 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
native.http_archive(
name = "nsync",
urls = [
- "https://mirror.bazel.build/github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz",
- "https://github.com/google/nsync/archive/93815892dddafe9146a5f7e7042281d59d0f4323.tar.gz",
+ "https://mirror.bazel.build/github.com/google/nsync/archive/8502189abfa44c249c01c2cad64e6ed660a9a668.tar.gz",
+ "https://github.com/google/nsync/archive/8502189abfa44c249c01c2cad64e6ed660a9a668.tar.gz",
],
- sha256 = "e3bd4555415ace511338fc27e595351738eea4e9006f1612b76c82914770716b",
- strip_prefix = "nsync-93815892dddafe9146a5f7e7042281d59d0f4323",
+ sha256 = "51f81ff4202bbb820cdbedc061bd2eb6765f2b5c06489e7a8694bedac329e8f8",
+ strip_prefix = "nsync-8502189abfa44c249c01c2cad64e6ed660a9a668",
)
native.http_archive(
@@ -502,7 +471,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
build_file = str(Label("//third_party:swig.BUILD")),
)
- temp_workaround_http_archive(
+ native.new_http_archive(
name = "curl",
sha256 = "ff3e80c1ca6a068428726cd7dd19037a47cc538ce58ef61c59587191039b2ca6",
urls = [
@@ -511,7 +480,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
],
strip_prefix = "curl-7.49.1",
build_file = str(Label("//third_party:curl.BUILD")),
- repository = tf_repo_name
)
# grpc expects //external:protobuf_clib and //external:protobuf_compiler
@@ -575,16 +543,15 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
# TODO(phawkins): currently, this rule uses an unofficial LLVM mirror.
# Switch to an official source of snapshots if/when possible.
- temp_workaround_http_archive(
+ native.new_http_archive(
name = "llvm",
urls = [
- "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/8d26b8bee4d8e7230870a600bc968c7ee8cf6f67.tar.gz",
- "https://github.com/llvm-mirror/llvm/archive/8d26b8bee4d8e7230870a600bc968c7ee8cf6f67.tar.gz",
+ "https://mirror.bazel.build/github.com/llvm-mirror/llvm/archive/9ab4c272cb604a7f947865428c4ef2169fee2100.tar.gz",
+ "https://github.com/llvm-mirror/llvm/archive/9ab4c272cb604a7f947865428c4ef2169fee2100.tar.gz",
],
- sha256 = "ff5ddbe5af5e264426c8d489e7fddfc5ad7e0975f19cefe9db8c0a5d0faeb23e",
- strip_prefix = "llvm-8d26b8bee4d8e7230870a600bc968c7ee8cf6f67",
+ sha256 = "1b1b7d3800a94ca2302e3dd670dbe84238749583027883784b55297059d83da8",
+ strip_prefix = "llvm-9ab4c272cb604a7f947865428c4ef2169fee2100",
build_file = str(Label("//third_party/llvm:llvm.BUILD")),
- repository = tf_repo_name,
)
native.new_http_archive(
@@ -650,7 +617,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
build_file = str(Label("//third_party/fft2d:fft2d.BUILD")),
)
- temp_workaround_http_archive(
+ native.new_http_archive(
name = "snappy",
urls = [
"https://mirror.bazel.build/github.com/google/snappy/archive/1.1.4.tar.gz",
@@ -659,10 +626,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
sha256 = "2f7504c73d85bac842e893340333be8cb8561710642fc9562fccdd9d2c3fcc94",
strip_prefix = "snappy-1.1.4",
build_file = str(Label("//third_party:snappy.BUILD")),
- repository = tf_repo_name,
)
- temp_workaround_http_archive(
+ native.new_http_archive(
name = "nccl_archive",
urls = [
"https://mirror.bazel.build/github.com/nvidia/nccl/archive/03d856977ecbaac87e598c0c4bafca96761b9ac7.tar.gz",
@@ -671,10 +637,9 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
sha256 = "2ca86fb6179ecbff789cc67c836139c1bbc0324ed8c04643405a30bf26325176",
strip_prefix = "nccl-03d856977ecbaac87e598c0c4bafca96761b9ac7",
build_file = str(Label("//third_party:nccl.BUILD")),
- repository = tf_repo_name,
)
- temp_workaround_http_archive(
+ native.new_http_archive(
name = "aws",
urls = [
"https://mirror.bazel.build/github.com/aws/aws-sdk-cpp/archive/1.0.90.tar.gz",
@@ -683,7 +648,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
sha256 = "f599b57aec4f03ad696044dd430b2d201864113937353adc346f53ad47991319",
strip_prefix = "aws-sdk-cpp-1.0.90",
build_file = str(Label("//third_party:aws.BUILD")),
- repository = tf_repo_name
)
java_import_external(
@@ -711,7 +675,7 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
testonly_ = True,
)
- temp_workaround_http_archive(
+ native.new_http_archive(
name = "jemalloc",
urls = [
"https://mirror.bazel.build/github.com/jemalloc/jemalloc/archive/4.4.0.tar.gz",
@@ -720,7 +684,6 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
sha256 = "3c8f25c02e806c3ce0ab5fb7da1817f89fc9732709024e2a81b6b82f7cc792a8",
strip_prefix = "jemalloc-4.4.0",
build_file = str(Label("//third_party:jemalloc.BUILD")),
- repository = tf_repo_name,
)
java_import_external(
diff --git a/third_party/aws.BUILD b/third_party/aws.BUILD
index bc9e37ffb3..bf5310aa16 100644
--- a/third_party/aws.BUILD
+++ b/third_party/aws.BUILD
@@ -7,21 +7,21 @@ licenses(["notice"]) # Apache 2.0
exports_files(["LICENSE"])
-load("@%ws%//third_party:common.bzl", "template_rule")
+load("@org_tensorflow//third_party:common.bzl", "template_rule")
cc_library(
name = "aws",
srcs = select({
- "@%ws%//tensorflow:linux_x86_64": glob([
+ "@org_tensorflow//tensorflow:linux_x86_64": glob([
"aws-cpp-sdk-core/source/platform/linux-shared/*.cpp",
]),
- "@%ws%//tensorflow:darwin": glob([
+ "@org_tensorflow//tensorflow:darwin": glob([
"aws-cpp-sdk-core/source/platform/linux-shared/*.cpp",
]),
- "@%ws%//tensorflow:linux_ppc64le": glob([
+ "@org_tensorflow//tensorflow:linux_ppc64le": glob([
"aws-cpp-sdk-core/source/platform/linux-shared/*.cpp",
]),
- "@%ws%//tensorflow:raspberry_pi_armeabi": glob([
+ "@org_tensorflow//tensorflow:raspberry_pi_armeabi": glob([
"aws-cpp-sdk-core/source/platform/linux-shared/*.cpp",
]),
"//conditions:default": [],
@@ -53,17 +53,17 @@ cc_library(
"aws-cpp-sdk-core/include/aws/core/SDKConfig.h",
],
defines = select({
- "@%ws%//tensorflow:linux_x86_64": [
+ "@org_tensorflow//tensorflow:linux_x86_64": [
"PLATFORM_LINUX",
"ENABLE_CURL_CLIENT",
"ENABLE_NO_ENCRYPTION",
],
- "@%ws%//tensorflow:darwin": [
+ "@org_tensorflow//tensorflow:darwin": [
"PLATFORM_APPLE",
"ENABLE_CURL_CLIENT",
"ENABLE_NO_ENCRYPTION",
],
- "@%ws%//tensorflow:linux_ppc64le": [
+ "@org_tensorflow//tensorflow:linux_ppc64le": [
"PLATFORM_LINUX",
"ENABLE_CURL_CLIENT",
"ENABLE_NO_ENCRYPTION",
diff --git a/third_party/curl.BUILD b/third_party/curl.BUILD
index 805a30d262..e311c7e758 100644
--- a/third_party/curl.BUILD
+++ b/third_party/curl.BUILD
@@ -6,7 +6,7 @@ licenses(["notice"]) # MIT/X derivative license
exports_files(["COPYING"])
CURL_WIN_COPTS = [
- "/I%prefix%/curl/lib",
+ "/Iexternal/curl/lib",
"/DHAVE_CONFIG_H",
"/DCURL_DISABLE_FTP",
"/DCURL_DISABLE_NTLM",
@@ -224,14 +224,14 @@ cc_library(
"lib/wildcard.h",
"lib/x509asn1.h",
] + select({
- "@%ws%//tensorflow:darwin": [
+ "@org_tensorflow//tensorflow:darwin": [
"lib/vtls/darwinssl.c",
],
- "@%ws%//tensorflow:ios": [
+ "@org_tensorflow//tensorflow:ios": [
"lib/vtls/darwinssl.c",
],
- "@%ws%//tensorflow:windows": CURL_WIN_SRCS,
- "@%ws%//tensorflow:windows_msvc": CURL_WIN_SRCS,
+ "@org_tensorflow//tensorflow:windows": CURL_WIN_SRCS,
+ "@org_tensorflow//tensorflow:windows_msvc": CURL_WIN_SRCS,
"//conditions:default": [
"lib/vtls/openssl.c",
],
@@ -248,10 +248,10 @@ cc_library(
"include/curl/typecheck-gcc.h",
],
copts = select({
- "@%ws%//tensorflow:windows": CURL_WIN_COPTS,
- "@%ws%//tensorflow:windows_msvc": CURL_WIN_COPTS,
+ "@org_tensorflow//tensorflow:windows": CURL_WIN_COPTS,
+ "@org_tensorflow//tensorflow:windows_msvc": CURL_WIN_COPTS,
"//conditions:default": [
- "-I%prefix%/curl/lib",
+ "-Iexternal/curl/lib",
"-D_GNU_SOURCE",
"-DHAVE_CONFIG_H",
"-DCURL_DISABLE_FTP",
@@ -261,14 +261,14 @@ cc_library(
"-Wno-string-plus-int",
],
}) + select({
- "@%ws%//tensorflow:darwin": [
+ "@org_tensorflow//tensorflow:darwin": [
"-fno-constant-cfstrings",
],
- "@%ws%//tensorflow:windows": [
+ "@org_tensorflow//tensorflow:windows": [
# See curl.h for discussion of write size and Windows
"/DCURL_MAX_WRITE_SIZE=16384",
],
- "@%ws%//tensorflow:windows_msvc": [
+ "@org_tensorflow//tensorflow:windows_msvc": [
# See curl.h for discussion of write size and Windows
"/DCURL_MAX_WRITE_SIZE=16384",
],
@@ -278,20 +278,20 @@ cc_library(
}),
includes = ["include"],
linkopts = select({
- "@%ws%//tensorflow:android": [
+ "@org_tensorflow//tensorflow:android": [
"-pie",
],
- "@%ws%//tensorflow:darwin": [
+ "@org_tensorflow//tensorflow:darwin": [
"-Wl,-framework",
"-Wl,CoreFoundation",
"-Wl,-framework",
"-Wl,Security",
],
- "@%ws%//tensorflow:ios": [],
- "@%ws%//tensorflow:windows": [
+ "@org_tensorflow//tensorflow:ios": [],
+ "@org_tensorflow//tensorflow:windows": [
"-Wl,ws2_32.lib",
],
- "@%ws%//tensorflow:windows_msvc": [
+ "@org_tensorflow//tensorflow:windows_msvc": [
"-Wl,ws2_32.lib",
],
"//conditions:default": [
@@ -302,9 +302,9 @@ cc_library(
deps = [
"@zlib_archive//:zlib",
] + select({
- "@%ws%//tensorflow:ios": [],
- "@%ws%//tensorflow:windows": [],
- "@%ws%//tensorflow:windows_msvc": [],
+ "@org_tensorflow//tensorflow:ios": [],
+ "@org_tensorflow//tensorflow:windows": [],
+ "@org_tensorflow//tensorflow:windows_msvc": [],
"//conditions:default": [
"@boringssl//:ssl",
],
@@ -312,7 +312,7 @@ cc_library(
)
CURL_BIN_WIN_COPTS = [
- "/I%prefix%/curl/lib",
+ "/Iexternal/curl/lib",
"/DHAVE_CONFIG_H",
"/DCURL_DISABLE_LIBCURL_OPTION",
]
@@ -406,10 +406,10 @@ cc_binary(
"src/tool_xattr.h",
],
copts = select({
- "@%ws%//tensorflow:windows": CURL_BIN_WIN_COPTS,
- "@%ws%//tensorflow:windows_msvc": CURL_BIN_WIN_COPTS,
+ "@org_tensorflow//tensorflow:windows": CURL_BIN_WIN_COPTS,
+ "@org_tensorflow//tensorflow:windows_msvc": CURL_BIN_WIN_COPTS,
"//conditions:default": [
- "-I%prefix%/curl/lib",
+ "-Iexternal/curl/lib",
"-D_GNU_SOURCE",
"-DHAVE_CONFIG_H",
"-DCURL_DISABLE_LIBCURL_OPTION",
diff --git a/third_party/gif.BUILD b/third_party/gif.BUILD
index 27808a9d64..78fbd6c0e0 100644
--- a/third_party/gif.BUILD
+++ b/third_party/gif.BUILD
@@ -21,7 +21,7 @@ cc_library(
],
hdrs = ["lib/gif_lib.h"],
defines = select({
- #"@%ws%//tensorflow:android": [
+ #"@org_tensorflow//tensorflow:android": [
":android": [
"S_IREAD=S_IRUSR",
"S_IWRITE=S_IWUSR",
diff --git a/third_party/jemalloc.BUILD b/third_party/jemalloc.BUILD
index a2addf2c66..1b0829b8fe 100644
--- a/third_party/jemalloc.BUILD
+++ b/third_party/jemalloc.BUILD
@@ -5,7 +5,7 @@ licenses(["notice"]) # BSD
exports_files(["COPYING"])
-load("@%ws%//third_party:common.bzl", "template_rule")
+load("@org_tensorflow//third_party:common.bzl", "template_rule")
cc_library(
name = "jemalloc_headers",
@@ -97,10 +97,10 @@ cc_library(
includes = ["include"],
# pthread_atfork() is called for PPC.
linkopts = select({
- "@%ws%//tensorflow:linux_ppc64le": [
+ "@org_tensorflow//tensorflow:linux_ppc64le": [
"-lpthread",
],
- "@%ws%//tensorflow:linux_x86_64": [
+ "@org_tensorflow//tensorflow:linux_x86_64": [
"-lpthread",
],
"//conditions:default": [
@@ -208,8 +208,8 @@ genrule(
name = "size_classes_h",
outs = ["include/jemalloc/internal/size_classes.h"],
cmd = select({
- "@%ws%//tensorflow:linux_ppc64le": "$(location :size_classes_sh) \"3 4\" 3 16 2 >$@",
- "@%ws%//tensorflow:linux_x86_64": "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@",
+ "@org_tensorflow//tensorflow:linux_ppc64le": "$(location :size_classes_sh) \"3 4\" 3 16 2 >$@",
+ "@org_tensorflow//tensorflow:linux_x86_64": "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@",
"//conditions:default": "$(location :size_classes_sh) \"3 4\" 3 12 2 >$@",
}),
tools = [":size_classes_sh"],
diff --git a/third_party/jpeg/jpeg.BUILD b/third_party/jpeg/jpeg.BUILD
index f6078052ec..e431f19382 100644
--- a/third_party/jpeg/jpeg.BUILD
+++ b/third_party/jpeg/jpeg.BUILD
@@ -5,7 +5,7 @@ licenses(["notice"]) # custom notice-style license, see LICENSE.md
exports_files(["LICENSE.md"])
-load("@%ws%//third_party:common.bzl", "template_rule")
+load("@org_tensorflow//third_party:common.bzl", "template_rule")
libjpegturbo_nocopts = "-[W]error"
diff --git a/third_party/mkl/build_defs.bzl b/third_party/mkl/build_defs.bzl
index 6574f25092..8b73ddabdd 100644
--- a/third_party/mkl/build_defs.bzl
+++ b/third_party/mkl/build_defs.bzl
@@ -60,7 +60,6 @@ mkl_repository = repository_rule(
],
attrs = {
"build_file": attr.label(),
- "repository": attr.string(),
"urls": attr.string_list(default = []),
"sha256": attr.string(default = ""),
"strip_prefix": attr.string(default = ""),
diff --git a/third_party/nccl.BUILD b/third_party/nccl.BUILD
index 8c7b9bdbe9..b2b8e18824 100644
--- a/third_party/nccl.BUILD
+++ b/third_party/nccl.BUILD
@@ -44,17 +44,17 @@ cc_library(
"-O3",
] + cuda_default_copts(),
linkopts = select({
- "@%ws%//tensorflow:android": [
+ "@org_tensorflow//tensorflow:android": [
"-pie",
],
- "@%ws%//tensorflow:darwin": [
+ "@org_tensorflow//tensorflow:darwin": [
"-Wl,-framework",
"-Wl,CoreFoundation",
"-Wl,-framework",
"-Wl,Security",
],
- "@%ws%//tensorflow:ios": [],
- "@%ws%//tensorflow:windows": [
+ "@org_tensorflow//tensorflow:ios": [],
+ "@org_tensorflow//tensorflow:windows": [
"-DEFAULTLIB:ws2_32.lib",
],
"//conditions:default": [
diff --git a/third_party/snappy.BUILD b/third_party/snappy.BUILD
index 9c00b7068a..fd48ed8941 100644
--- a/third_party/snappy.BUILD
+++ b/third_party/snappy.BUILD
@@ -50,8 +50,8 @@ genrule(
"-e 's/@ac_cv_have_stddef_h@/1/g' " +
"-e 's/@ac_cv_have_stdint_h@/1/g' " +
select({
- "@%ws%//tensorflow:windows": "-e 's/@ac_cv_have_sys_uio_h@/0/g' ",
- "@%ws%//tensorflow:windows_msvc": "-e 's/@ac_cv_have_sys_uio_h@/0/g' ",
+ "@org_tensorflow//tensorflow:windows": "-e 's/@ac_cv_have_sys_uio_h@/0/g' ",
+ "@org_tensorflow//tensorflow:windows_msvc": "-e 's/@ac_cv_have_sys_uio_h@/0/g' ",
"//conditions:default": "-e 's/@ac_cv_have_sys_uio_h@/1/g' ",
}) +
"-e 's/@SNAPPY_MAJOR@/1/g' " +